use crate::error::{Result, ValidationError};
use crate::message::Message;
use crate::tool::{ToolChoice, ToolDefinition};
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseFormat {
Text,
JsonObject,
#[serde(rename = "json_schema")]
JsonSchema {
json_schema: JsonSchemaFormat,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct JsonSchemaFormat {
pub name: String,
pub schema: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CompletionRequest {
pub messages: Vec<Message>,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub store: Option<bool>,
}
impl CompletionRequest {
pub fn new(model: impl Into<String>) -> Self {
Self {
messages: Vec::new(),
model: model.into(),
max_tokens: None,
temperature: None,
top_p: None,
stream: None,
n: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
user: None,
response_format: None,
tools: None,
tool_choice: None,
instructions: None,
previous_response_id: None,
store: None,
}
}
pub fn messages(mut self, messages: Vec<Message>) -> Self {
self.messages = messages;
self
}
pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
self.instructions = Some(instructions.into());
self
}
pub fn previous_response_id(mut self, id: impl Into<String>) -> Self {
self.previous_response_id = Some(id.into());
self
}
pub fn store(mut self, store: bool) -> Self {
self.store = Some(store);
self
}
pub fn builder() -> CompletionRequestBuilder {
CompletionRequestBuilder::default()
}
pub fn validate(&self) -> Result<()> {
if self.messages.is_empty() {
return Err(ValidationError::Empty {
field: "messages".to_string(),
}
.into());
}
if self.messages.len() > 1000 {
return Err(ValidationError::TooLong {
field: "messages".to_string(),
max: 1000,
}
.into());
}
const MAX_MESSAGE_SIZE: usize = 1024 * 1024;
for (i, msg) in self.messages.iter().enumerate() {
if msg.content.text_len() > MAX_MESSAGE_SIZE {
return Err(ValidationError::TooLong {
field: format!("messages[{}].content", i),
max: MAX_MESSAGE_SIZE,
}
.into());
}
if msg.content.contains_null() {
return Err(ValidationError::InvalidFormat {
field: format!("messages[{}].content", i),
reason: "contains null bytes".to_string(),
}
.into());
}
}
const MAX_TOTAL_REQUEST_SIZE: usize = 10 * 1024 * 1024;
let total_size: usize = self.messages.iter().map(|m| m.content.text_len()).sum();
if total_size > MAX_TOTAL_REQUEST_SIZE {
return Err(ValidationError::TooLong {
field: "total_request_size".to_string(),
max: MAX_TOTAL_REQUEST_SIZE,
}
.into());
}
if self.model.is_empty() {
return Err(ValidationError::Empty {
field: "model".to_string(),
}
.into());
}
if !self
.model
.chars()
.all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.' || c == '/')
{
return Err(ValidationError::InvalidFormat {
field: "model".to_string(),
reason: "must be alphanumeric with -_./ only".to_string(),
}
.into());
}
if let Some(temp) = self.temperature {
if !(0.0..=2.0).contains(&temp) {
return Err(ValidationError::OutOfRange {
field: "temperature".to_string(),
min: 0.0,
max: 2.0,
}
.into());
}
}
if let Some(top_p) = self.top_p {
if !(0.0..=1.0).contains(&top_p) {
return Err(ValidationError::OutOfRange {
field: "top_p".to_string(),
min: 0.0,
max: 1.0,
}
.into());
}
}
if let Some(penalty) = self.presence_penalty {
if !(-2.0..=2.0).contains(&penalty) {
return Err(ValidationError::OutOfRange {
field: "presence_penalty".to_string(),
min: -2.0,
max: 2.0,
}
.into());
}
}
if let Some(penalty) = self.frequency_penalty {
if !(-2.0..=2.0).contains(&penalty) {
return Err(ValidationError::OutOfRange {
field: "frequency_penalty".to_string(),
min: -2.0,
max: 2.0,
}
.into());
}
}
Ok(())
}
}
#[derive(Debug, Default, Clone)]
pub struct CompletionRequestBuilder {
messages: Vec<Message>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
top_p: Option<f32>,
stream: Option<bool>,
n: Option<u32>,
stop: Option<Vec<String>>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
user: Option<String>,
response_format: Option<ResponseFormat>,
tools: Option<Vec<ToolDefinition>>,
tool_choice: Option<ToolChoice>,
instructions: Option<String>,
previous_response_id: Option<String>,
store: Option<bool>,
}
impl CompletionRequestBuilder {
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn message(mut self, message: Message) -> Self {
self.messages.push(message);
self
}
pub fn messages(mut self, messages: Vec<Message>) -> Self {
self.messages = messages;
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn stream(mut self, stream: bool) -> Self {
self.stream = Some(stream);
self
}
pub fn n(mut self, n: u32) -> Self {
self.n = Some(n);
self
}
pub fn stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
pub fn presence_penalty(mut self, penalty: f32) -> Self {
self.presence_penalty = Some(penalty);
self
}
pub fn frequency_penalty(mut self, penalty: f32) -> Self {
self.frequency_penalty = Some(penalty);
self
}
pub fn user(mut self, user: impl Into<String>) -> Self {
self.user = Some(user.into());
self
}
pub fn response_format(mut self, format: ResponseFormat) -> Self {
self.response_format = Some(format);
self
}
pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
self.tools = Some(tools);
self
}
pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.tool_choice = Some(tool_choice);
self
}
pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
self.instructions = Some(instructions.into());
self
}
pub fn previous_response_id(mut self, id: impl Into<String>) -> Self {
self.previous_response_id = Some(id.into());
self
}
pub fn store(mut self, store: bool) -> Self {
self.store = Some(store);
self
}
pub fn json_mode(mut self) -> Self {
self.response_format = Some(ResponseFormat::JsonObject);
self
}
pub fn json_schema(mut self, name: impl Into<String>, schema: Value) -> Self {
self.response_format = Some(ResponseFormat::JsonSchema {
json_schema: JsonSchemaFormat {
name: name.into(),
schema,
strict: Some(true),
},
});
self
}
pub fn build(self) -> Result<CompletionRequest> {
let model = self.model.ok_or_else(|| ValidationError::Empty {
field: "model".to_string(),
})?;
let request = CompletionRequest {
messages: self.messages,
model,
max_tokens: self.max_tokens,
temperature: self.temperature,
top_p: self.top_p,
stream: self.stream,
n: self.n,
stop: self.stop,
presence_penalty: self.presence_penalty,
frequency_penalty: self.frequency_penalty,
user: self.user,
response_format: self.response_format,
tools: self.tools,
tool_choice: self.tool_choice,
instructions: self.instructions,
previous_response_id: self.previous_response_id,
store: self.store,
};
request.validate()?;
Ok(request)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::MessageContent;
#[test]
fn test_builder_basic() {
let request = CompletionRequest::builder()
.model("gpt-4")
.message(Message::user("Hello"))
.build()
.unwrap();
assert_eq!(request.model, "gpt-4");
assert_eq!(request.messages.len(), 1);
assert_eq!(
request.messages[0].content,
MessageContent::Text("Hello".to_string())
);
}
#[test]
fn test_builder_all_fields() {
let request = CompletionRequest::builder()
.model("gpt-4")
.message(Message::user("Hello"))
.max_tokens(100)
.temperature(0.7)
.top_p(0.9)
.stream(true)
.n(1)
.stop(vec!["END".to_string()])
.presence_penalty(0.5)
.frequency_penalty(0.5)
.user("test-user")
.build()
.unwrap();
assert_eq!(request.max_tokens, Some(100));
assert_eq!(request.temperature, Some(0.7));
assert_eq!(request.top_p, Some(0.9));
assert_eq!(request.stream, Some(true));
assert_eq!(request.n, Some(1));
assert_eq!(request.stop, Some(vec!["END".to_string()]));
assert_eq!(request.presence_penalty, Some(0.5));
assert_eq!(request.frequency_penalty, Some(0.5));
assert_eq!(request.user, Some("test-user".to_string()));
}
#[test]
fn test_builder_missing_model() {
let result = CompletionRequest::builder()
.message(Message::user("Hello"))
.build();
assert!(result.is_err());
}
#[test]
fn test_validation_empty_messages() {
let result = CompletionRequest::builder().model("gpt-4").build();
assert!(result.is_err());
}
#[test]
fn test_validation_invalid_temperature() {
let result = CompletionRequest::builder()
.model("gpt-4")
.message(Message::user("Hello"))
.temperature(3.0)
.build();
assert!(result.is_err());
}
#[test]
fn test_validation_invalid_top_p() {
let result = CompletionRequest::builder()
.model("gpt-4")
.message(Message::user("Hello"))
.top_p(1.5)
.build();
assert!(result.is_err());
}
#[test]
fn test_validation_invalid_model_chars() {
let result = CompletionRequest::builder()
.model("gpt-4!")
.message(Message::user("Hello"))
.build();
assert!(result.is_err());
}
#[test]
fn test_serialization() {
let request = CompletionRequest::builder()
.model("gpt-4")
.message(Message::user("Hello"))
.temperature(0.7)
.build()
.unwrap();
let json = serde_json::to_string(&request).unwrap();
let parsed: CompletionRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, parsed);
}
#[test]
fn test_optional_fields_not_serialized() {
let request = CompletionRequest::builder()
.model("gpt-4")
.message(Message::user("Hello"))
.build()
.unwrap();
let json = serde_json::to_value(&request).unwrap();
assert!(json.get("max_tokens").is_none());
assert!(json.get("temperature").is_none());
}
#[test]
fn test_validation_total_request_size_limit() {
let large_content = "x".repeat(2 * 1024 * 1024); let result = CompletionRequest::builder()
.model("gpt-4")
.message(Message::user(large_content.clone()))
.message(Message::user(large_content.clone()))
.message(Message::user(large_content.clone()))
.message(Message::user(large_content.clone()))
.message(Message::user(large_content.clone()))
.message(Message::user(large_content.clone())) .build();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::error::SimpleAgentsError::Validation(ValidationError::TooLong { .. })
));
}
#[test]
fn test_responses_api_fields() {
let req = CompletionRequest::new("gpt-4o")
.messages(vec![Message::user("hello")])
.instructions("You are helpful")
.store(true)
.previous_response_id("resp_abc");
assert_eq!(req.instructions.as_deref(), Some("You are helpful"));
assert_eq!(req.store, Some(true));
assert_eq!(req.previous_response_id.as_deref(), Some("resp_abc"));
}
#[test]
fn test_validation_total_request_size_within_limit() {
let content = "x".repeat(1024 * 1024); let result = CompletionRequest::builder()
.model("gpt-4")
.message(Message::user(content.clone()))
.message(Message::user(content.clone()))
.message(Message::user(content.clone())) .build();
assert!(result.is_ok());
}
}