use serde::{Deserialize, Serialize};
use serde_json::Value;
use tt_shared::{
messages::{Message, ResponseFormat, Tool, ToolChoice},
usage::Usage,
ProviderError,
};
pub fn is_reasoning_model(model: &str) -> bool {
matches!(model, "o3" | "o4-mini")
}
pub fn dropped_params(req: &tt_shared::ChatCompletionRequest) -> Vec<String> {
if is_reasoning_model(&req.model) && req.temperature.is_some() {
vec!["temperature".to_string()]
} else {
Vec::new()
}
}
#[derive(Debug, Serialize)]
pub struct OpenAiRequestBody {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<Tool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub stop: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
pub fn translate_request(
req: tt_shared::ChatCompletionRequest,
) -> Result<OpenAiRequestBody, ProviderError> {
let reasoning = is_reasoning_model(&req.model);
let (max_tokens, max_completion_tokens, temperature) = if reasoning {
if req.temperature.is_some() {
tracing::warn!(
model = %req.model,
"reasoning models do not support temperature; dropping the field"
);
}
(None, req.max_tokens, None)
} else {
(req.max_tokens, None, req.temperature)
};
Ok(OpenAiRequestBody {
model: req.model,
messages: req.messages,
temperature,
top_p: req.top_p,
max_tokens,
max_completion_tokens,
stream: req.stream,
tools: req.tools,
tool_choice: req.tool_choice,
response_format: req.response_format,
stop: req.stop,
presence_penalty: req.presence_penalty,
frequency_penalty: req.frequency_penalty,
n: req.n,
seed: req.seed,
user: req.user,
})
}
#[derive(Debug, Deserialize)]
pub struct OpenAiUsage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
#[serde(default)]
pub prompt_tokens_details: Option<PromptTokensDetails>,
}
#[derive(Debug, Deserialize)]
pub struct PromptTokensDetails {
#[serde(default)]
pub cached_tokens: u64,
}
impl From<OpenAiUsage> for Usage {
fn from(u: OpenAiUsage) -> Self {
let cached_tokens = u
.prompt_tokens_details
.map(|d| d.cached_tokens)
.unwrap_or(0);
Usage {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
cached_tokens,
cache_creation_input_tokens: None,
}
}
}
pub fn extract_usage(raw: &Value) -> Result<Usage, ProviderError> {
let usage_val = raw
.get("usage")
.ok_or_else(|| ProviderError::Deserialize("missing 'usage' field".to_string()))?;
let openai_usage: OpenAiUsage = serde_json::from_value(usage_val.clone())
.map_err(|e| ProviderError::Deserialize(e.to_string()))?;
Ok(openai_usage.into())
}
pub fn deserialize_response(
body: &str,
) -> Result<tt_shared::ChatCompletionResponse, ProviderError> {
let raw: Value =
serde_json::from_str(body).map_err(|e| ProviderError::Deserialize(e.to_string()))?;
let canonical_usage = extract_usage(&raw)?;
let mut resp: tt_shared::ChatCompletionResponse =
serde_json::from_value(raw).map_err(|e| ProviderError::Deserialize(e.to_string()))?;
resp.usage = canonical_usage;
Ok(resp)
}
pub fn translate_embeddings_request(
req: tt_shared::EmbeddingsRequest,
) -> Result<tt_shared::EmbeddingsRequest, ProviderError> {
Ok(req)
}
pub fn deserialize_embeddings_response(
body: &str,
) -> Result<tt_shared::EmbeddingsResponse, ProviderError> {
serde_json::from_str(body).map_err(|e| ProviderError::Deserialize(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use tt_shared::{messages::MessageContent, ChatCompletionRequest};
fn base_request(model: &str) -> ChatCompletionRequest {
ChatCompletionRequest {
model: model.to_string(),
messages: vec![Message::User {
content: MessageContent::Text("Hello".to_string()),
name: None,
}],
temperature: Some(0.7),
top_p: None,
max_tokens: Some(512),
stream: false,
tools: vec![],
tool_choice: None,
response_format: None,
stop: vec![],
presence_penalty: None,
frequency_penalty: None,
n: None,
seed: None,
user: None,
tt_extras: std::collections::HashMap::new(),
}
}
#[test]
fn dropped_params_temperature_only_for_reasoning_models() {
let req = base_request("o3");
assert_eq!(dropped_params(&req), vec!["temperature".to_string()]);
let req2 = base_request("gpt-4o");
assert!(dropped_params(&req2).is_empty());
let mut req3 = base_request("o4-mini");
req3.temperature = None;
assert!(dropped_params(&req3).is_empty());
}
#[test]
fn non_reasoning_passes_through() {
let req = base_request("gpt-4o");
let body = translate_request(req).expect("translate ok");
assert_eq!(body.temperature, Some(0.7));
assert_eq!(body.max_tokens, Some(512));
assert!(body.max_completion_tokens.is_none());
}
#[test]
fn reasoning_model_renames_max_tokens() {
let req = base_request("o3");
let body = translate_request(req).expect("translate ok");
assert!(body.max_tokens.is_none());
assert_eq!(body.max_completion_tokens, Some(512));
assert!(body.temperature.is_none());
}
#[test]
fn tt_extras_not_serialized() {
let mut req = base_request("gpt-4o");
req.tt_extras
.insert("route_hint".to_string(), serde_json::json!("us-east-1"));
let body = translate_request(req).expect("translate ok");
let serialized = serde_json::to_string(&body).expect("serialize ok");
assert!(!serialized.contains("tt_extras"));
assert!(!serialized.contains("route_hint"));
}
#[test]
fn usage_cached_tokens_populated() {
let raw = serde_json::json!({
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_tokens_details": { "cached_tokens": 80 }
}
});
let usage = extract_usage(&raw).expect("extract ok");
assert_eq!(usage.cached_tokens, 80);
assert_eq!(usage.prompt_tokens, 100);
}
#[test]
fn usage_cached_tokens_absent_defaults_zero() {
let raw = serde_json::json!({
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150
}
});
let usage = extract_usage(&raw).expect("extract ok");
assert_eq!(usage.cached_tokens, 0);
}
}