use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponsesMessage {
pub role: String,
pub content: serde_json::Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[skip_serializing_none]
pub struct ResponsesToolFunction {
pub name: String,
#[serde(default)]
pub description: Option<String>,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponsesToolDefinition {
Function { function: ResponsesToolFunction },
}
#[derive(Debug, Clone, Deserialize)]
pub struct ResponsesRequest {
pub model: String,
pub messages: Vec<ResponsesMessage>,
#[serde(default)]
pub temperature: Option<f64>,
#[serde(default)]
pub top_p: Option<f64>,
#[serde(default)]
pub max_output_tokens: Option<u32>,
#[serde(default)]
pub stop: Option<serde_json::Value>,
#[serde(default)]
pub presence_penalty: Option<f64>,
#[serde(default)]
pub frequency_penalty: Option<f64>,
#[serde(default)]
pub logit_bias: Option<HashMap<String, f64>>,
#[serde(default)]
pub user: Option<String>,
#[serde(default)]
pub n: Option<u32>,
#[serde(default)]
pub tools: Option<Vec<ResponsesToolDefinition>>,
#[serde(default)]
pub tool_choice: Option<serde_json::Value>,
#[serde(default)]
pub response_format: Option<serde_json::Value>,
#[serde(default)]
pub stream: Option<bool>,
#[serde(default)]
pub conversation: Option<String>,
}
impl Serialize for ResponsesRequest {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde_json::{Map, Number, Value};
let mut root = Map::new();
root.insert("model".to_string(), Value::String(self.model.clone()));
let messages_val =
serde_json::to_value(&self.messages).map_err(serde::ser::Error::custom)?;
root.insert("messages".to_string(), messages_val);
let to_num = |f: f64, label: &str| {
Number::from_f64(f).ok_or_else(|| serde::ser::Error::custom(format!("invalid {label}")))
};
if let Some(v) = self.temperature {
root.insert(
"temperature".into(),
Value::Number(to_num(v, "temperature")?),
);
}
if let Some(v) = self.top_p {
root.insert("top_p".into(), Value::Number(to_num(v, "top_p")?));
}
if let Some(v) = self.max_output_tokens {
root.insert("max_output_tokens".into(), Value::Number(v.into()));
}
if let Some(v) = self.stop.clone() {
root.insert("stop".into(), v);
}
if let Some(v) = self.presence_penalty {
root.insert(
"presence_penalty".into(),
Value::Number(to_num(v, "presence_penalty")?),
);
}
if let Some(v) = self.frequency_penalty {
root.insert(
"frequency_penalty".into(),
Value::Number(to_num(v, "frequency_penalty")?),
);
}
if let Some(map) = self.logit_bias.as_ref() {
let mut obj = Map::new();
for (k, v) in map {
let num = Number::from_f64(*v)
.ok_or_else(|| serde::ser::Error::custom("invalid logit_bias value"))?;
obj.insert(k.clone(), Value::Number(num));
}
root.insert("logit_bias".into(), Value::Object(obj));
}
if let Some(u) = self.user.as_ref() {
root.insert("user".into(), Value::String(u.clone()));
}
if let Some(n) = self.n {
root.insert("n".into(), Value::Number(n.into()));
}
if let Some(tools) = self.tools.as_ref() {
root.insert(
"tools".into(),
serde_json::to_value(tools).map_err(serde::ser::Error::custom)?,
);
}
if let Some(tc) = self.tool_choice.as_ref() {
root.insert("tool_choice".into(), tc.clone());
}
if let Some(rf) = self.response_format.as_ref() {
root.insert("response_format".into(), rf.clone());
}
if let Some(s) = self.stream {
root.insert("stream".into(), Value::Bool(s));
}
if let Some(conv) = self.conversation.as_ref() {
root.insert("conversation".into(), Value::String(conv.clone()));
}
Value::Object(root).serialize(serializer)
}
}