use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::Usage;
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CacheMode {
#[default]
Normal,
Bypass,
Refresh,
#[serde(rename = "read-only")]
ReadOnly,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CacheControlConfig {
#[serde(default)]
pub mode: CacheMode,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ttl_secs: Option<u64>,
}
pub fn parse_cache_control(
extras: &HashMap<String, serde_json::Value>,
) -> Option<CacheControlConfig> {
let val = extras.get("cache")?;
match serde_json::from_value::<CacheControlConfig>(val.clone()) {
Ok(cfg) => Some(cfg),
Err(e) => {
tracing::warn!(
error = %e,
"tt_extras.cache deserialization failed — treating as normal"
);
Some(CacheControlConfig::default())
}
}
}
#[cfg(test)]
mod cache_control_tests {
use super::*;
fn extras(json: &str) -> HashMap<String, serde_json::Value> {
serde_json::from_str(json).unwrap()
}
#[test]
fn no_cache_key_returns_none() {
assert!(parse_cache_control(&extras("{}")).is_none());
}
#[test]
fn bypass_mode_parsed() {
let cfg = parse_cache_control(&extras(r#"{"cache":{"mode":"bypass"}}"#)).unwrap();
assert_eq!(cfg.mode, CacheMode::Bypass);
assert!(cfg.ttl_secs.is_none());
}
#[test]
fn refresh_mode_with_ttl() {
let cfg = parse_cache_control(&extras(r#"{"cache":{"mode":"refresh","ttl_secs":3600}}"#))
.unwrap();
assert_eq!(cfg.mode, CacheMode::Refresh);
assert_eq!(cfg.ttl_secs, Some(3600));
}
#[test]
fn read_only_mode() {
let cfg = parse_cache_control(&extras(r#"{"cache":{"mode":"read-only"}}"#)).unwrap();
assert_eq!(cfg.mode, CacheMode::ReadOnly);
}
#[test]
fn absent_mode_defaults_to_normal() {
let cfg = parse_cache_control(&extras(r#"{"cache":{}}"#)).unwrap();
assert_eq!(cfg.mode, CacheMode::Normal);
}
#[test]
fn malformed_value_falls_back_to_default() {
let cfg = parse_cache_control(&extras(r#"{"cache":"not-an-object"}"#)).unwrap();
assert_eq!(cfg.mode, CacheMode::Normal);
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stream_options: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub tt_extras: HashMap<String, serde_json::Value>,
#[serde(flatten, default, skip_serializing_if = "HashMap::is_empty")]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message {
System {
content: MessageContent,
},
User {
content: MessageContent,
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
#[serde(default, skip_serializing_if = "Option::is_none")]
content: Option<MessageContent>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>,
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Tool {
content: MessageContent,
tool_call_id: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
InputAudio { input_audio: InputAudio },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrl {
pub url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InputAudio {
pub data: String,
pub format: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub r#type: String,
pub function: ToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Auto(String),
Specific {
#[serde(rename = "type")]
r#type: String,
function: ToolChoiceFunction,
},
}
impl ToolChoice {
#[must_use]
pub fn auto() -> Self {
ToolChoice::Auto("auto".to_string())
}
#[must_use]
pub fn none() -> Self {
ToolChoice::Auto("none".to_string())
}
#[must_use]
pub fn required() -> Self {
ToolChoice::Auto("required".to_string())
}
#[must_use]
pub fn function(name: impl Into<String>) -> Self {
ToolChoice::Specific {
r#type: "function".to_string(),
function: ToolChoiceFunction { name: name.into() },
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoiceFunction {
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub r#type: String,
pub function: ToolCallFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallFunction {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub r#type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub json_schema: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: Message,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<ChunkChoice>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(flatten, default, skip_serializing_if = "HashMap::is_empty")]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkChoice {
pub index: u32,
pub delta: ChunkDelta,
pub finish_reason: Option<String>,
#[serde(flatten, default, skip_serializing_if = "HashMap::is_empty")]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChunkDelta {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCall>,
#[serde(flatten, default, skip_serializing_if = "HashMap::is_empty")]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingsRequest {
pub model: String,
pub input: EmbeddingInput,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
Single(String),
Batch(Vec<String>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingsResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingData {
pub object: String,
pub index: u32,
pub embedding: Vec<f32>,
}
#[must_use]
pub fn parse_data_url(url: &str) -> Option<(String, String)> {
let rest = url.strip_prefix("data:")?;
let (meta, data) = rest.split_once(',')?;
let media_with_params = meta.strip_suffix(";base64")?;
let media_type = media_with_params.split(';').next().unwrap_or("");
if media_type.is_empty() || data.is_empty() {
return None;
}
Some((media_type.to_string(), data.to_string()))
}
#[cfg(test)]
mod embeddings_default_tests {
use super::*;
#[test]
fn chat_request_default_is_empty() {
let r = ChatCompletionRequest::default();
assert_eq!(r.model, "");
assert!(r.messages.is_empty());
assert!(!r.stream);
assert!(r.tools.is_empty());
assert!(r.max_tokens.is_none());
}
#[test]
fn typed_compat_fields_roundtrip() {
let json = serde_json::json!({
"model": "o3",
"messages": [{ "role": "user", "content": "hi" }],
"max_completion_tokens": 4096,
"stream_options": { "include_usage": true },
"parallel_tool_calls": false,
"reasoning_effort": "high",
});
let req: ChatCompletionRequest = serde_json::from_value(json).unwrap();
assert_eq!(req.max_completion_tokens, Some(4096));
assert_eq!(req.parallel_tool_calls, Some(false));
assert_eq!(req.reasoning_effort.as_deref(), Some("high"));
assert_eq!(
req.stream_options,
Some(serde_json::json!({ "include_usage": true }))
);
assert!(req.extra.is_empty());
let out = serde_json::to_value(&req).unwrap();
assert_eq!(out["max_completion_tokens"], 4096);
assert_eq!(
out["stream_options"],
serde_json::json!({"include_usage": true})
);
assert_eq!(out["parallel_tool_calls"], false);
assert_eq!(out["reasoning_effort"], "high");
}
#[test]
fn unknown_fields_passthrough_via_flatten() {
let json = serde_json::json!({
"model": "gpt-4o",
"messages": [{ "role": "user", "content": "hi" }],
"logprobs": true,
"top_logprobs": 5,
"service_tier": "auto",
});
let req: ChatCompletionRequest = serde_json::from_value(json.clone()).unwrap();
assert_eq!(req.extra.get("logprobs"), Some(&serde_json::json!(true)));
assert_eq!(req.extra.get("top_logprobs"), Some(&serde_json::json!(5)));
assert_eq!(
req.extra.get("service_tier"),
Some(&serde_json::json!("auto"))
);
let out = serde_json::to_value(&req).unwrap();
assert_eq!(out["logprobs"], true);
assert_eq!(out["top_logprobs"], 5);
assert_eq!(out["service_tier"], "auto");
}
#[test]
fn streaming_chunk_unknown_fields_passthrough() {
let json = serde_json::json!({
"id": "chatcmpl-1",
"object": "chat.completion.chunk",
"created": 1716598234,
"model": "gpt-4o",
"system_fingerprint": "fp_abc123",
"choices": [{
"index": 0,
"delta": { "content": "hi", "refusal": null },
"finish_reason": null,
"logprobs": { "content": [] }
}]
});
let chunk: ChatCompletionChunk = serde_json::from_value(json).unwrap();
assert_eq!(
chunk.extra.get("system_fingerprint"),
Some(&serde_json::json!("fp_abc123"))
);
assert_eq!(
chunk.choices[0].extra.get("logprobs"),
Some(&serde_json::json!({ "content": [] }))
);
assert_eq!(
chunk.choices[0].delta.extra.get("refusal"),
Some(&serde_json::Value::Null)
);
let out = serde_json::to_value(&chunk).unwrap();
assert_eq!(out["system_fingerprint"], "fp_abc123");
assert_eq!(
out["choices"][0]["logprobs"],
serde_json::json!({ "content": [] })
);
assert_eq!(
out["choices"][0]["delta"]["refusal"],
serde_json::Value::Null
);
}
#[test]
fn parse_data_url_extracts_media_type_and_payload() {
assert_eq!(
parse_data_url("data:image/png;base64,iVBORw0KGgo="),
Some(("image/png".to_string(), "iVBORw0KGgo=".to_string()))
);
assert_eq!(parse_data_url("https://example.com/cat.png"), None);
assert_eq!(parse_data_url("data:image/png,notbase64"), None);
assert_eq!(parse_data_url("data:;base64,abc"), None);
assert_eq!(parse_data_url("data:image/png;base64,"), None);
assert_eq!(
parse_data_url("data:image/png;charset=utf-8;base64,iVBORw0KGgo="),
Some(("image/png".to_string(), "iVBORw0KGgo=".to_string()))
);
}
#[test]
fn tool_choice_constructors_serialize_to_the_wire_form() {
assert_eq!(
serde_json::to_value(ToolChoice::auto()).unwrap(),
serde_json::json!("auto")
);
assert_eq!(
serde_json::to_value(ToolChoice::none()).unwrap(),
serde_json::json!("none")
);
assert_eq!(
serde_json::to_value(ToolChoice::required()).unwrap(),
serde_json::json!("required")
);
assert_eq!(
serde_json::to_value(ToolChoice::function("get_weather")).unwrap(),
serde_json::json!({ "type": "function", "function": { "name": "get_weather" } })
);
}
}