use serde::{Deserialize, Serialize};
use super::chat_completions::{StopSequence, Usage};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CompletionPrompt {
Single(String),
Multiple(Vec<String>),
Tokens(Vec<u32>),
TokenArrays(Vec<Vec<u32>>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionRequest {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<CompletionPrompt>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub echo: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StopSequence>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChunkChoice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionChunkChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deserialize_string_prompt() {
let json = r#"{"model": "gpt-3.5-turbo-instruct", "prompt": "Say hello"}"#;
let req: CompletionRequest = serde_json::from_str(json).unwrap();
assert!(matches!(req.prompt, Some(CompletionPrompt::Single(_))));
}
#[test]
fn test_deserialize_array_of_strings_prompt() {
let json = r#"{"model": "gpt-3.5-turbo-instruct", "prompt": ["Hello", "World"]}"#;
let req: CompletionRequest = serde_json::from_str(json).unwrap();
assert!(matches!(req.prompt, Some(CompletionPrompt::Multiple(_))));
}
#[test]
fn test_deserialize_token_array_prompt() {
let json = r#"{"model": "gpt-3.5-turbo-instruct", "prompt": [1, 2, 3]}"#;
let req: CompletionRequest = serde_json::from_str(json).unwrap();
assert!(matches!(req.prompt, Some(CompletionPrompt::Tokens(_))));
}
#[test]
fn test_deserialize_token_array_of_arrays_prompt() {
let json = r#"{"model": "gpt-3.5-turbo-instruct", "prompt": [[1, 2], [3, 4]]}"#;
let req: CompletionRequest = serde_json::from_str(json).unwrap();
assert!(matches!(req.prompt, Some(CompletionPrompt::TokenArrays(_))));
}
#[test]
fn test_reject_mixed_token_array_prompt() {
let json = r#"{"model": "gpt-3.5-turbo-instruct", "prompt": [1, "hello"]}"#;
assert!(serde_json::from_str::<CompletionRequest>(json).is_err());
}
#[test]
fn test_reject_float_token_array_prompt() {
let json = r#"{"model": "gpt-3.5-turbo-instruct", "prompt": [1.5, 2.5]}"#;
assert!(serde_json::from_str::<CompletionRequest>(json).is_err());
}
#[test]
fn test_deserialize_with_all_fields() {
let json = r#"{
"model": "gpt-3.5-turbo-instruct",
"prompt": "Complete this",
"suffix": "end",
"max_tokens": 100,
"temperature": 0.7,
"top_p": 0.9,
"n": 1,
"stream": false,
"logprobs": 3,
"echo": true,
"stop": "\n",
"presence_penalty": 0.1,
"frequency_penalty": 0.2,
"best_of": 3,
"user": "user-123",
"seed": 42
}"#;
let req: CompletionRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.model, "gpt-3.5-turbo-instruct");
assert_eq!(req.max_tokens, Some(100));
assert_eq!(req.temperature, Some(0.7));
assert_eq!(req.logprobs, Some(3));
assert_eq!(req.echo, Some(true));
assert_eq!(req.best_of, Some(3));
assert_eq!(req.seed, Some(42));
}
}