use crate::chat::response::Response;
use crate::common::{
auth::{AuthProvider, OpenAIAuth},
client::create_http_client,
errors::{ErrorResponse, OpenAIToolError, Result},
message::{Content, Message},
models::{ChatModel, ParameterRestriction},
structured_output::Schema,
tool::Tool,
};
use core::str;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub(crate) struct Format {
#[serde(rename = "type")]
type_name: String,
json_schema: Schema,
}
impl Format {
pub fn new<T: AsRef<str>>(type_name: T, json_schema: Schema) -> Self {
Self { type_name: type_name.as_ref().to_string(), json_schema }
}
}
struct ChatContentRef<'a>(&'a Content);
impl<'a> Serialize for ChatContentRef<'a> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
match self.0.type_name.as_str() {
"input_text" => {
let mut state = serializer.serialize_struct("Content", 2)?;
state.serialize_field("type", "text")?;
state.serialize_field("text", &self.0.text)?;
state.end()
}
"input_image" => {
#[derive(Serialize)]
struct ImageUrl<'b> {
url: &'b str,
}
let mut state = serializer.serialize_struct("Content", 2)?;
state.serialize_field("type", "image_url")?;
if let Some(ref url) = self.0.image_url {
state.serialize_field("image_url", &ImageUrl { url })?;
}
state.end()
}
other => {
let mut state = serializer.serialize_struct("Content", 3)?;
state.serialize_field("type", other)?;
if let Some(ref text) = self.0.text {
state.serialize_field("text", text)?;
}
if let Some(ref url) = self.0.image_url {
state.serialize_field("image_url", url)?;
}
state.end()
}
}
}
}
struct ChatMessageRef<'a>(&'a Message);
impl<'a> Serialize for ChatMessageRef<'a> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let msg = self.0;
let mut state = serializer.serialize_struct("Message", 3)?;
state.serialize_field("role", &msg.role)?;
if let Some(ref content) = msg.content {
state.serialize_field("content", &content.text)?;
} else if let Some(ref contents) = msg.content_list {
let chat_contents: Vec<ChatContentRef<'_>> = contents.iter().map(ChatContentRef).collect();
state.serialize_field("content", &chat_contents)?;
}
if let Some(ref tool_call_id) = msg.tool_call_id {
state.serialize_field("tool_call_id", tool_call_id)?;
}
if let Some(ref tool_calls) = msg.tool_calls {
state.serialize_field("tool_calls", tool_calls)?;
}
state.end()
}
}
fn serialize_chat_messages<S>(messages: &Vec<Message>, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeSeq;
let mut seq = serializer.serialize_seq(Some(messages.len()))?;
for msg in messages {
seq.serialize_element(&ChatMessageRef(msg))?;
}
seq.end()
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub(crate) struct Body {
pub(crate) model: ChatModel,
#[serde(serialize_with = "serialize_chat_messages")]
pub(crate) messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) store: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) logit_bias: Option<HashMap<String, i32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) logprobs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) top_logprobs: Option<u8>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) max_completion_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) modalities: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) response_format: Option<Format>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) safety_identifier: Option<String>,
}
const CHAT_COMPLETIONS_PATH: &str = "chat/completions";
#[derive(Debug, Clone)]
pub struct ChatCompletion {
auth: AuthProvider,
pub(crate) request_body: Body,
timeout: Option<Duration>,
}
impl Default for ChatCompletion {
fn default() -> Self {
Self::new()
}
}
impl ChatCompletion {
pub fn new() -> Self {
let auth = AuthProvider::openai_from_env().map_err(|e| OpenAIToolError::Error(format!("Failed to load OpenAI auth: {}", e))).unwrap();
Self { auth, request_body: Body::default(), timeout: None }
}
pub fn with_model(model: ChatModel) -> Self {
let auth = AuthProvider::openai_from_env().map_err(|e| OpenAIToolError::Error(format!("Failed to load OpenAI auth: {}", e))).unwrap();
Self { auth, request_body: Body { model, ..Default::default() }, timeout: None }
}
pub fn with_auth(auth: AuthProvider) -> Self {
Self { auth, request_body: Body::default(), timeout: None }
}
pub fn azure() -> Result<Self> {
let auth = AuthProvider::azure_from_env()?;
Ok(Self { auth, request_body: Body::default(), timeout: None })
}
pub fn detect_provider() -> Result<Self> {
let auth = AuthProvider::from_env()?;
Ok(Self { auth, request_body: Body::default(), timeout: None })
}
pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
let auth = AuthProvider::from_url_with_key(base_url, api_key);
Self { auth, request_body: Body::default(), timeout: None }
}
pub fn from_url<S: Into<String>>(base_url: S) -> Result<Self> {
let auth = AuthProvider::from_url(base_url)?;
Ok(Self { auth, request_body: Body::default(), timeout: None })
}
pub fn auth(&self) -> &AuthProvider {
&self.auth
}
pub fn base_url<T: AsRef<str>>(&mut self, url: T) -> &mut Self {
if let AuthProvider::OpenAI(ref openai_auth) = self.auth {
let new_auth = OpenAIAuth::new(openai_auth.api_key()).with_base_url(url.as_ref());
self.auth = AuthProvider::OpenAI(new_auth);
} else {
tracing::warn!("base_url() is only supported for OpenAI provider. Use azure() or with_auth() for Azure.");
}
self
}
pub fn model(&mut self, model: ChatModel) -> &mut Self {
self.request_body.model = model;
self
}
#[deprecated(since = "0.2.0", note = "Use `model(ChatModel)` instead for type safety")]
pub fn model_id<T: AsRef<str>>(&mut self, model_id: T) -> &mut Self {
self.request_body.model = ChatModel::from(model_id.as_ref());
self
}
pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = Some(timeout);
self
}
pub fn messages(&mut self, messages: Vec<Message>) -> &mut Self {
self.request_body.messages = messages;
self
}
pub fn add_message(&mut self, message: Message) -> &mut Self {
self.request_body.messages.push(message);
self
}
pub fn store(&mut self, store: bool) -> &mut Self {
self.request_body.store = Option::from(store);
self
}
pub fn frequency_penalty(&mut self, frequency_penalty: f32) -> &mut Self {
let support = self.request_body.model.parameter_support();
match support.frequency_penalty {
ParameterRestriction::FixedValue(fixed) => {
if (frequency_penalty as f64 - fixed).abs() > f64::EPSILON {
tracing::warn!(
"Model '{}' only supports frequency_penalty={}. Ignoring frequency_penalty={}.",
self.request_body.model,
fixed,
frequency_penalty
);
return self;
}
}
ParameterRestriction::NotSupported => {
tracing::warn!("Model '{}' does not support frequency_penalty parameter. Ignoring.", self.request_body.model);
return self;
}
ParameterRestriction::Any => {}
}
self.request_body.frequency_penalty = Some(frequency_penalty);
self
}
pub fn logit_bias<T: AsRef<str>>(&mut self, logit_bias: HashMap<T, i32>) -> &mut Self {
let support = self.request_body.model.parameter_support();
if !support.logit_bias {
tracing::warn!("Model '{}' does not support logit_bias parameter. Ignoring.", self.request_body.model);
return self;
}
self.request_body.logit_bias = Some(logit_bias.into_iter().map(|(k, v)| (k.as_ref().to_string(), v)).collect::<HashMap<String, i32>>());
self
}
pub fn logprobs(&mut self, logprobs: bool) -> &mut Self {
let support = self.request_body.model.parameter_support();
if !support.logprobs {
tracing::warn!("Model '{}' does not support logprobs parameter. Ignoring.", self.request_body.model);
return self;
}
self.request_body.logprobs = Some(logprobs);
self
}
pub fn top_logprobs(&mut self, top_logprobs: u8) -> &mut Self {
let support = self.request_body.model.parameter_support();
if !support.top_logprobs {
tracing::warn!("Model '{}' does not support top_logprobs parameter. Ignoring.", self.request_body.model);
return self;
}
self.request_body.top_logprobs = Some(top_logprobs);
self
}
pub fn max_completion_tokens(&mut self, max_completion_tokens: u64) -> &mut Self {
self.request_body.max_completion_tokens = Option::from(max_completion_tokens);
self
}
pub fn n(&mut self, n: u32) -> &mut Self {
let support = self.request_body.model.parameter_support();
if !support.n_multiple && n != 1 {
tracing::warn!("Model '{}' only supports n=1. Ignoring n={}.", self.request_body.model, n);
return self;
}
self.request_body.n = Some(n);
self
}
pub fn modalities<T: AsRef<str>>(&mut self, modalities: Vec<T>) -> &mut Self {
self.request_body.modalities = Option::from(modalities.into_iter().map(|m| m.as_ref().to_string()).collect::<Vec<String>>());
self
}
pub fn presence_penalty(&mut self, presence_penalty: f32) -> &mut Self {
let support = self.request_body.model.parameter_support();
match support.presence_penalty {
ParameterRestriction::FixedValue(fixed) => {
if (presence_penalty as f64 - fixed).abs() > f64::EPSILON {
tracing::warn!(
"Model '{}' only supports presence_penalty={}. Ignoring presence_penalty={}.",
self.request_body.model,
fixed,
presence_penalty
);
return self;
}
}
ParameterRestriction::NotSupported => {
tracing::warn!("Model '{}' does not support presence_penalty parameter. Ignoring.", self.request_body.model);
return self;
}
ParameterRestriction::Any => {}
}
self.request_body.presence_penalty = Some(presence_penalty);
self
}
pub fn temperature(&mut self, temperature: f32) -> &mut Self {
let support = self.request_body.model.parameter_support();
match support.temperature {
ParameterRestriction::FixedValue(fixed) => {
if (temperature as f64 - fixed).abs() > f64::EPSILON {
tracing::warn!("Model '{}' only supports temperature={}. Ignoring temperature={}.", self.request_body.model, fixed, temperature);
return self;
}
}
ParameterRestriction::NotSupported => {
tracing::warn!("Model '{}' does not support temperature parameter. Ignoring.", self.request_body.model);
return self;
}
ParameterRestriction::Any => {}
}
self.request_body.temperature = Some(temperature);
self
}
pub fn json_schema(&mut self, json_schema: Schema) -> &mut Self {
self.request_body.response_format = Option::from(Format::new(String::from("json_schema"), json_schema));
self
}
pub fn tools(&mut self, tools: Vec<Tool>) -> &mut Self {
self.request_body.tools = Option::from(tools);
self
}
pub fn safety_identifier<T: AsRef<str>>(&mut self, safety_id: T) -> &mut Self {
self.request_body.safety_identifier = Some(safety_id.as_ref().to_string());
self
}
pub fn get_message_history(&self) -> Vec<Message> {
self.request_body.messages.clone()
}
fn is_reasoning_model(&self) -> bool {
self.request_body.model.is_reasoning_model()
}
pub async fn chat(&mut self) -> Result<Response> {
if self.request_body.messages.is_empty() {
return Err(OpenAIToolError::Error("Messages are not set.".into()));
}
if self.is_reasoning_model() {
let model = &self.request_body.model;
if let Some(temp) = self.request_body.temperature {
if (temp - 1.0).abs() > f32::EPSILON {
tracing::warn!(
"Reasoning model '{}' does not support custom temperature. \
Ignoring temperature={} and using default (1.0).",
model,
temp
);
self.request_body.temperature = None;
}
}
if let Some(fp) = self.request_body.frequency_penalty {
if fp.abs() > f32::EPSILON {
tracing::warn!(
"Reasoning model '{}' does not support frequency_penalty. \
Ignoring frequency_penalty={} and using default (0).",
model,
fp
);
self.request_body.frequency_penalty = None;
}
}
if let Some(pp) = self.request_body.presence_penalty {
if pp.abs() > f32::EPSILON {
tracing::warn!(
"Reasoning model '{}' does not support presence_penalty. \
Ignoring presence_penalty={} and using default (0).",
model,
pp
);
self.request_body.presence_penalty = None;
}
}
if self.request_body.logprobs.is_some() {
tracing::warn!("Reasoning model '{}' does not support logprobs. Ignoring logprobs parameter.", model);
self.request_body.logprobs = None;
}
if self.request_body.top_logprobs.is_some() {
tracing::warn!("Reasoning model '{}' does not support top_logprobs. Ignoring top_logprobs parameter.", model);
self.request_body.top_logprobs = None;
}
if self.request_body.logit_bias.is_some() {
tracing::warn!("Reasoning model '{}' does not support logit_bias. Ignoring logit_bias parameter.", model);
self.request_body.logit_bias = None;
}
if let Some(n) = self.request_body.n {
if n != 1 {
tracing::warn!(
"Reasoning model '{}' does not support n != 1. \
Ignoring n={} and using default (1).",
model,
n
);
self.request_body.n = None;
}
}
}
let body = serde_json::to_string(&self.request_body)?;
let client = create_http_client(self.timeout)?;
let mut headers = request::header::HeaderMap::new();
headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
self.auth.apply_headers(&mut headers)?;
if cfg!(debug_assertions) {
let body_for_debug = serde_json::to_string_pretty(&self.request_body).unwrap().replace(self.auth.api_key(), "*************");
tracing::info!("Request body: {}", body_for_debug);
}
let endpoint = self.auth.endpoint(CHAT_COMPLETIONS_PATH);
let response = client.post(&endpoint).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(debug_assertions) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<Response>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
#[cfg(test)]
pub(crate) fn test_new_with_model(model: ChatModel) -> Self {
use crate::common::auth::OpenAIAuth;
Self { auth: AuthProvider::OpenAI(OpenAIAuth::new("test-key")), request_body: Body { model, ..Default::default() }, timeout: None }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::models::ChatModel;
use std::collections::HashMap;
#[test]
fn test_standard_model_accepts_all_parameters() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt4oMini);
chat.temperature(0.7);
chat.frequency_penalty(0.5);
chat.presence_penalty(0.5);
chat.logprobs(true);
chat.top_logprobs(5);
chat.n(3);
let logit_bias: HashMap<&str, i32> = [("1234", 10)].iter().cloned().collect();
chat.logit_bias(logit_bias);
assert_eq!(chat.request_body.temperature, Some(0.7));
assert_eq!(chat.request_body.frequency_penalty, Some(0.5));
assert_eq!(chat.request_body.presence_penalty, Some(0.5));
assert_eq!(chat.request_body.logprobs, Some(true));
assert_eq!(chat.request_body.top_logprobs, Some(5));
assert_eq!(chat.request_body.n, Some(3));
assert!(chat.request_body.logit_bias.is_some());
}
#[test]
fn test_gpt4o_accepts_all_parameters() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt4o);
chat.temperature(0.3);
chat.frequency_penalty(-1.0);
chat.presence_penalty(1.5);
assert_eq!(chat.request_body.temperature, Some(0.3));
assert_eq!(chat.request_body.frequency_penalty, Some(-1.0));
assert_eq!(chat.request_body.presence_penalty, Some(1.5));
}
#[test]
fn test_gpt4_1_accepts_all_parameters() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt4_1);
chat.temperature(1.5);
chat.frequency_penalty(0.8);
chat.n(2);
assert_eq!(chat.request_body.temperature, Some(1.5));
assert_eq!(chat.request_body.frequency_penalty, Some(0.8));
assert_eq!(chat.request_body.n, Some(2));
}
#[test]
fn test_o1_ignores_non_default_temperature() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::O1);
chat.temperature(0.5);
assert_eq!(chat.request_body.temperature, None);
chat.temperature(1.0);
assert_eq!(chat.request_body.temperature, Some(1.0));
}
#[test]
fn test_o3_mini_ignores_non_default_temperature() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::O3Mini);
chat.temperature(0.3);
assert_eq!(chat.request_body.temperature, None);
}
#[test]
fn test_o4_mini_ignores_non_default_temperature() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::O4Mini);
chat.temperature(0.7);
assert_eq!(chat.request_body.temperature, None);
}
#[test]
fn test_o1_ignores_frequency_penalty() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::O1);
chat.frequency_penalty(0.5);
assert_eq!(chat.request_body.frequency_penalty, None);
chat.frequency_penalty(0.0);
assert_eq!(chat.request_body.frequency_penalty, Some(0.0));
}
#[test]
fn test_o3_ignores_presence_penalty() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::O3);
chat.presence_penalty(0.5);
assert_eq!(chat.request_body.presence_penalty, None);
chat.presence_penalty(0.0);
assert_eq!(chat.request_body.presence_penalty, Some(0.0));
}
#[test]
fn test_o1_ignores_logprobs() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::O1);
chat.logprobs(true);
assert_eq!(chat.request_body.logprobs, None);
}
#[test]
fn test_o3_mini_ignores_top_logprobs() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::O3Mini);
chat.top_logprobs(5);
assert_eq!(chat.request_body.top_logprobs, None);
}
#[test]
fn test_o1_ignores_logit_bias() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::O1);
let logit_bias: HashMap<&str, i32> = [("1234", 10)].iter().cloned().collect();
chat.logit_bias(logit_bias);
assert_eq!(chat.request_body.logit_bias, None);
}
#[test]
fn test_o1_ignores_n_greater_than_1() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::O1);
chat.n(3);
assert_eq!(chat.request_body.n, None);
chat.n(1);
assert_eq!(chat.request_body.n, Some(1));
}
#[test]
fn test_gpt5_2_ignores_non_default_temperature() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt5_2);
chat.temperature(0.5);
assert_eq!(chat.request_body.temperature, None);
chat.temperature(1.0);
assert_eq!(chat.request_body.temperature, Some(1.0));
}
#[test]
fn test_gpt5_1_ignores_non_default_temperature() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt5_1);
chat.temperature(0.3);
assert_eq!(chat.request_body.temperature, None);
}
#[test]
fn test_gpt5_mini_ignores_frequency_penalty() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt5Mini);
chat.frequency_penalty(0.5);
assert_eq!(chat.request_body.frequency_penalty, None);
}
#[test]
fn test_gpt5_2_pro_ignores_presence_penalty() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt5_2Pro);
chat.presence_penalty(0.8);
assert_eq!(chat.request_body.presence_penalty, None);
}
#[test]
fn test_gpt5_1_codex_max_ignores_logprobs() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt5_1CodexMax);
chat.logprobs(true);
assert_eq!(chat.request_body.logprobs, None);
}
#[test]
fn test_gpt5_2_chat_latest_ignores_n_greater_than_1() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt5_2ChatLatest);
chat.n(5);
assert_eq!(chat.request_body.n, None);
}
#[test]
fn test_o1_ignores_all_restricted_parameters_at_once() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::O1);
chat.temperature(0.5);
chat.frequency_penalty(0.5);
chat.presence_penalty(0.5);
chat.logprobs(true);
chat.top_logprobs(5);
chat.n(3);
let logit_bias: HashMap<&str, i32> = [("1234", 10)].iter().cloned().collect();
chat.logit_bias(logit_bias);
assert_eq!(chat.request_body.temperature, None);
assert_eq!(chat.request_body.frequency_penalty, None);
assert_eq!(chat.request_body.presence_penalty, None);
assert_eq!(chat.request_body.logprobs, None);
assert_eq!(chat.request_body.top_logprobs, None);
assert_eq!(chat.request_body.n, None);
assert_eq!(chat.request_body.logit_bias, None);
}
#[test]
fn test_gpt5_2_ignores_all_restricted_parameters_at_once() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt5_2);
chat.temperature(0.5);
chat.frequency_penalty(0.5);
chat.presence_penalty(0.5);
chat.logprobs(true);
chat.top_logprobs(5);
chat.n(3);
let logit_bias: HashMap<&str, i32> = [("1234", 10)].iter().cloned().collect();
chat.logit_bias(logit_bias);
assert_eq!(chat.request_body.temperature, None);
assert_eq!(chat.request_body.frequency_penalty, None);
assert_eq!(chat.request_body.presence_penalty, None);
assert_eq!(chat.request_body.logprobs, None);
assert_eq!(chat.request_body.top_logprobs, None);
assert_eq!(chat.request_body.n, None);
assert_eq!(chat.request_body.logit_bias, None);
}
#[test]
fn test_custom_gpt5_model_detected_as_reasoning() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::custom("gpt-5.3-preview"));
chat.temperature(0.5);
assert_eq!(chat.request_body.temperature, None);
}
#[test]
fn test_custom_o1_model_detected_as_reasoning() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::custom("o1-pro-2025-01-15"));
chat.temperature(0.5);
assert_eq!(chat.request_body.temperature, None);
}
#[test]
fn test_custom_o3_model_detected_as_reasoning() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::custom("o3-high"));
chat.temperature(0.5);
assert_eq!(chat.request_body.temperature, None);
}
#[test]
fn test_custom_o4_model_detected_as_reasoning() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::custom("o4-mini-preview"));
chat.temperature(0.5);
assert_eq!(chat.request_body.temperature, None);
}
#[test]
fn test_custom_standard_model_accepts_all_parameters() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::custom("ft:gpt-4o-mini:org::123"));
chat.temperature(0.7);
chat.frequency_penalty(0.5);
chat.n(2);
assert_eq!(chat.request_body.temperature, Some(0.7));
assert_eq!(chat.request_body.frequency_penalty, Some(0.5));
assert_eq!(chat.request_body.n, Some(2));
}
#[test]
fn test_temperature_boundary_values() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt4oMini);
chat.temperature(0.0);
assert_eq!(chat.request_body.temperature, Some(0.0));
chat.temperature(2.0);
assert_eq!(chat.request_body.temperature, Some(2.0));
}
#[test]
fn test_frequency_penalty_boundary_values() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt4oMini);
chat.frequency_penalty(-2.0);
assert_eq!(chat.request_body.frequency_penalty, Some(-2.0));
chat.frequency_penalty(2.0);
assert_eq!(chat.request_body.frequency_penalty, Some(2.0));
}
#[test]
fn test_presence_penalty_boundary_values() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt4oMini);
chat.presence_penalty(-2.0);
assert_eq!(chat.request_body.presence_penalty, Some(-2.0));
chat.presence_penalty(2.0);
assert_eq!(chat.request_body.presence_penalty, Some(2.0));
}
#[test]
fn test_max_completion_tokens_accepted_by_all_models() {
let mut chat_standard = ChatCompletion::test_new_with_model(ChatModel::Gpt4oMini);
chat_standard.max_completion_tokens(1000);
assert_eq!(chat_standard.request_body.max_completion_tokens, Some(1000));
let mut chat_reasoning = ChatCompletion::test_new_with_model(ChatModel::O1);
chat_reasoning.max_completion_tokens(2000);
assert_eq!(chat_reasoning.request_body.max_completion_tokens, Some(2000));
let mut chat_gpt5 = ChatCompletion::test_new_with_model(ChatModel::Gpt5_2);
chat_gpt5.max_completion_tokens(3000);
assert_eq!(chat_gpt5.request_body.max_completion_tokens, Some(3000));
}
#[test]
fn test_store_accepted_by_all_models() {
let mut chat_standard = ChatCompletion::test_new_with_model(ChatModel::Gpt4oMini);
chat_standard.store(true);
assert_eq!(chat_standard.request_body.store, Some(true));
let mut chat_reasoning = ChatCompletion::test_new_with_model(ChatModel::O1);
chat_reasoning.store(false);
assert_eq!(chat_reasoning.request_body.store, Some(false));
}
#[test]
fn test_chat_text_content_serialization() {
use crate::common::message::Content;
let content = Content::from_text("Hello, world!");
let wrapper = ChatContentRef(&content);
let json = serde_json::to_value(&wrapper).unwrap();
assert_eq!(json["type"], "text");
assert_eq!(json["text"], "Hello, world!");
assert!(json.get("image_url").is_none());
}
#[test]
fn test_chat_image_content_serialization() {
use crate::common::message::Content;
let content = Content::from_image_url("https://example.com/image.png");
let wrapper = ChatContentRef(&content);
let json = serde_json::to_value(&wrapper).unwrap();
assert_eq!(json["type"], "image_url");
assert_eq!(json["image_url"]["url"], "https://example.com/image.png");
}
#[test]
fn test_chat_multimodal_message_serialization() {
use crate::common::message::{Content, Message};
use crate::common::role::Role;
let contents = vec![Content::from_text("What's in this image?"), Content::from_image_url("https://example.com/image.png")];
let message = Message::from_message_array(Role::User, contents);
let wrapper = ChatMessageRef(&message);
let json = serde_json::to_value(&wrapper).unwrap();
assert_eq!(json["role"], "user");
let content_arr = json["content"].as_array().unwrap();
assert_eq!(content_arr.len(), 2);
assert_eq!(content_arr[0]["type"], "text");
assert_eq!(content_arr[0]["text"], "What's in this image?");
assert_eq!(content_arr[1]["type"], "image_url");
assert_eq!(content_arr[1]["image_url"]["url"], "https://example.com/image.png");
}
#[test]
fn test_chat_single_text_message_serialization() {
use crate::common::message::Message;
use crate::common::role::Role;
let message = Message::from_string(Role::User, "Hello!");
let wrapper = ChatMessageRef(&message);
let json = serde_json::to_value(&wrapper).unwrap();
assert_eq!(json["role"], "user");
assert_eq!(json["content"], "Hello!");
}
#[test]
fn test_chat_body_messages_serialization() {
use crate::common::message::{Content, Message};
use crate::common::role::Role;
let messages = vec![
Message::from_string(Role::System, "You are a helpful assistant."),
Message::from_message_array(
Role::User,
vec![Content::from_text("Describe this image"), Content::from_image_url("https://example.com/photo.jpg")],
),
];
let body = Body { model: ChatModel::Gpt4oMini, messages, ..Default::default() };
let json = serde_json::to_value(&body).unwrap();
let msgs = json["messages"].as_array().unwrap();
assert_eq!(msgs[0]["role"], "system");
assert_eq!(msgs[0]["content"], "You are a helpful assistant.");
assert_eq!(msgs[1]["role"], "user");
let content_arr = msgs[1]["content"].as_array().unwrap();
assert_eq!(content_arr[0]["type"], "text");
assert_eq!(content_arr[1]["type"], "image_url");
assert_eq!(content_arr[1]["image_url"]["url"], "https://example.com/photo.jpg");
}
#[test]
fn test_safety_identifier() {
let mut chat = ChatCompletion::test_new_with_model(ChatModel::Gpt4oMini);
chat.safety_identifier("user_abc123");
assert_eq!(chat.request_body.safety_identifier, Some("user_abc123".to_string()));
let json = serde_json::to_value(&chat.request_body).unwrap();
assert_eq!(json["safety_identifier"], "user_abc123");
}
#[test]
fn test_safety_identifier_not_serialized_when_none() {
let chat = ChatCompletion::test_new_with_model(ChatModel::Gpt4oMini);
let json = serde_json::to_value(&chat.request_body).unwrap();
assert!(json.get("safety_identifier").is_none());
}
}