use chrono::Utc;
use serde::{Deserialize, Serialize};
use tracing::warn;
use tt_shared::{
messages::{ContentPart, Message, MessageContent, ToolCall, ToolCallFunction, ToolChoice},
usage::Usage,
ChatCompletionResponse, Choice, ProviderError,
};
use uuid::Uuid;
use crate::pricing::BRACKET_THRESHOLD_TOKENS;
#[derive(Debug, Serialize)]
pub struct GeminiRequest {
pub contents: Vec<GeminiContent>,
#[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
pub system_instruction: Option<GeminiSystemInstruction>,
#[serde(rename = "generationConfig", skip_serializing_if = "Option::is_none")]
pub generation_config: Option<GeminiGenerationConfig>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<GeminiToolBlock>,
#[serde(rename = "toolConfig", skip_serializing_if = "Option::is_none")]
pub tool_config: Option<GeminiToolConfig>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GeminiContent {
pub role: String,
pub parts: Vec<GeminiPart>,
}
#[derive(Debug, Clone)]
pub enum GeminiPart {
Text(String),
InlineData(GeminiInlineData),
FileData(GeminiFileData),
FunctionCall(GeminiFunctionCall),
FunctionResponse(GeminiFunctionResponse),
}
impl Serialize for GeminiPart {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeMap;
match self {
GeminiPart::Text(t) => {
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_entry("text", t)?;
map.end()
}
GeminiPart::InlineData(d) => {
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_entry("inlineData", d)?;
map.end()
}
GeminiPart::FileData(d) => {
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_entry("fileData", d)?;
map.end()
}
GeminiPart::FunctionCall(fc) => {
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_entry("functionCall", fc)?;
map.end()
}
GeminiPart::FunctionResponse(fr) => {
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_entry("functionResponse", fr)?;
map.end()
}
}
}
}
impl<'de> Deserialize<'de> for GeminiPart {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let map: serde_json::Value = serde_json::Value::deserialize(deserializer)?;
if let Some(text) = map.get("text").and_then(|v| v.as_str()) {
return Ok(GeminiPart::Text(text.to_string()));
}
if let Some(fc) = map.get("functionCall") {
let fc: GeminiFunctionCall =
serde_json::from_value(fc.clone()).map_err(serde::de::Error::custom)?;
return Ok(GeminiPart::FunctionCall(fc));
}
if let Some(fr) = map.get("functionResponse") {
let fr: GeminiFunctionResponse =
serde_json::from_value(fr.clone()).map_err(serde::de::Error::custom)?;
return Ok(GeminiPart::FunctionResponse(fr));
}
if let Some(id) = map.get("inlineData") {
let d: GeminiInlineData =
serde_json::from_value(id.clone()).map_err(serde::de::Error::custom)?;
return Ok(GeminiPart::InlineData(d));
}
if let Some(fd) = map.get("fileData") {
let d: GeminiFileData =
serde_json::from_value(fd.clone()).map_err(serde::de::Error::custom)?;
return Ok(GeminiPart::FileData(d));
}
Err(serde::de::Error::custom("unknown GeminiPart variant"))
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GeminiInlineData {
#[serde(rename = "mimeType")]
pub mime_type: String,
pub data: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GeminiFileData {
#[serde(rename = "mimeType")]
pub mime_type: String,
#[serde(rename = "fileUri")]
pub file_uri: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GeminiFunctionCall {
pub name: String,
pub args: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GeminiFunctionResponse {
pub name: String,
pub response: GeminiFunctionResponseContent,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GeminiFunctionResponseContent {
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct GeminiSystemInstruction {
pub parts: Vec<GeminiTextPart>,
}
#[derive(Debug, Serialize)]
pub struct GeminiTextPart {
pub text: String,
}
#[derive(Debug, Serialize)]
pub struct GeminiGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(rename = "topP", skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(rename = "maxOutputTokens", skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
#[serde(rename = "stopSequences", skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
#[serde(rename = "responseMimeType", skip_serializing_if = "Option::is_none")]
pub response_mime_type: Option<String>,
#[serde(rename = "responseSchema", skip_serializing_if = "Option::is_none")]
pub response_schema: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct GeminiToolBlock {
#[serde(rename = "functionDeclarations")]
pub function_declarations: Vec<GeminiFunctionDeclaration>,
}
#[derive(Debug, Serialize)]
pub struct GeminiFunctionDeclaration {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Clone)]
pub struct GeminiToolConfig {
#[serde(rename = "functionCallingConfig")]
pub function_calling_config: GeminiFunctionCallingConfig,
}
#[derive(Debug, Serialize, Clone)]
pub struct GeminiFunctionCallingConfig {
pub mode: String,
#[serde(rename = "allowedFunctionNames", skip_serializing_if = "Vec::is_empty")]
pub allowed_function_names: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct GeminiResponse {
#[serde(default)]
pub candidates: Vec<GeminiCandidate>,
#[serde(rename = "usageMetadata", default)]
pub usage_metadata: Option<GeminiUsageMetadata>,
#[serde(rename = "modelVersion")]
pub model_version: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct GeminiCandidate {
pub content: Option<GeminiContent>,
#[serde(rename = "finishReason")]
pub finish_reason: Option<String>,
#[serde(default)]
pub index: u32,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct GeminiUsageMetadata {
#[serde(rename = "promptTokenCount", default)]
pub prompt_token_count: u64,
#[serde(rename = "candidatesTokenCount", default)]
pub candidates_token_count: u64,
#[serde(rename = "totalTokenCount", default)]
pub total_token_count: u64,
#[serde(rename = "cachedContentTokenCount", default)]
pub cached_content_token_count: u64,
}
pub fn validate_model_id(model: &str) -> Result<(), ProviderError> {
let ok = !model.is_empty()
&& model
.bytes()
.all(|b| b.is_ascii_alphanumeric() || matches!(b, b'.' | b'_' | b'-'));
if ok {
Ok(())
} else {
Err(ProviderError::InvalidRequest(format!(
"invalid Gemini model id {model:?}: only [A-Za-z0-9._-] is allowed"
)))
}
}
pub fn translate_request(
req: tt_shared::ChatCompletionRequest,
) -> Result<GeminiRequest, ProviderError> {
let mut system_parts: Vec<GeminiTextPart> = Vec::new();
let mut contents: Vec<GeminiContent> = Vec::new();
let tool_call_id_to_name = build_tool_call_id_map(&req.messages);
for msg in req.messages {
match msg {
Message::System { content } => {
let text = extract_text_from_content(content)?;
system_parts.push(GeminiTextPart { text });
}
Message::User { content, .. } => {
let parts = translate_user_content(content)?;
contents.push(GeminiContent {
role: "user".to_string(),
parts,
});
}
Message::Assistant {
content,
tool_calls,
..
} => {
let mut parts: Vec<GeminiPart> = Vec::new();
if let Some(c) = content {
parts.extend(translate_user_content(c)?);
}
for tc in tool_calls {
let args: serde_json::Value = serde_json::from_str(&tc.function.arguments)
.map_err(|e| {
ProviderError::Deserialize(format!(
"tool_call arguments not valid JSON: {e}"
))
})?;
parts.push(GeminiPart::FunctionCall(GeminiFunctionCall {
name: tc.function.name,
args,
}));
}
contents.push(GeminiContent {
role: "model".to_string(),
parts,
});
}
Message::Tool {
content,
tool_call_id,
} => {
let text = extract_text_from_content(content)?;
let fn_name = tool_call_id_to_name
.get(&tool_call_id)
.cloned()
.unwrap_or_else(|| {
warn!(
tool_call_id = %tool_call_id,
"could not find function name for tool_call_id; using id as fallback"
);
tool_call_id.clone()
});
contents.push(GeminiContent {
role: "function".to_string(),
parts: vec![GeminiPart::FunctionResponse(GeminiFunctionResponse {
name: fn_name,
response: GeminiFunctionResponseContent { content: text },
})],
});
}
}
}
let system_instruction = if system_parts.is_empty() {
None
} else {
Some(GeminiSystemInstruction {
parts: system_parts,
})
};
let (response_mime_type, response_schema) = translate_response_format(req.response_format);
let generation_config = build_generation_config(
req.temperature,
req.top_p,
req.max_tokens,
req.stop,
response_mime_type,
response_schema,
);
if let Some(cfg) = &generation_config {
if let Some(max_out) = cfg.max_output_tokens {
let _ = max_out;
}
}
let tools = if req.tools.is_empty() {
vec![]
} else {
let decls: Vec<GeminiFunctionDeclaration> = req
.tools
.into_iter()
.map(|t| GeminiFunctionDeclaration {
name: t.function.name,
description: t.function.description,
parameters: t.function.parameters,
})
.collect();
vec![GeminiToolBlock {
function_declarations: decls,
}]
};
let tool_config = req.tool_choice.map(translate_tool_choice);
Ok(GeminiRequest {
contents,
system_instruction,
generation_config,
tools,
tool_config,
})
}
fn build_tool_call_id_map(messages: &[Message]) -> std::collections::HashMap<String, String> {
let mut map = std::collections::HashMap::new();
for msg in messages {
if let Message::Assistant { tool_calls, .. } = msg {
for tc in tool_calls {
map.insert(tc.id.clone(), tc.function.name.clone());
}
}
}
map
}
fn extract_text_from_content(content: MessageContent) -> Result<String, ProviderError> {
match content {
MessageContent::Text(t) => Ok(t),
MessageContent::Parts(parts) => {
let text = parts
.into_iter()
.filter_map(|p| match p {
ContentPart::Text { text } => Some(text),
_ => None,
})
.collect::<Vec<_>>()
.join("");
Ok(text)
}
}
}
fn translate_user_content(content: MessageContent) -> Result<Vec<GeminiPart>, ProviderError> {
match content {
MessageContent::Text(t) => Ok(vec![GeminiPart::Text(t)]),
MessageContent::Parts(parts) => {
let mut gemini_parts = Vec::new();
for part in parts {
match part {
ContentPart::Text { text } => {
gemini_parts.push(GeminiPart::Text(text));
}
ContentPart::ImageUrl { image_url } => {
match tt_shared::messages::parse_data_url(&image_url.url) {
Some((mime_type, data)) => {
gemini_parts.push(GeminiPart::InlineData(GeminiInlineData {
mime_type,
data,
}));
}
None => {
let mime_type = guess_mime_from_url(&image_url.url);
gemini_parts.push(GeminiPart::FileData(GeminiFileData {
mime_type,
file_uri: image_url.url,
}));
}
}
}
ContentPart::InputAudio { .. } => {
return Err(ProviderError::Unsupported(
"audio input is not supported by the Gemini adapter".to_string(),
));
}
}
}
Ok(gemini_parts)
}
}
}
fn guess_mime_from_url(url: &str) -> String {
let lower = url.to_lowercase();
if lower.ends_with(".png") {
"image/png".to_string()
} else if lower.ends_with(".gif") {
"image/gif".to_string()
} else if lower.ends_with(".webp") {
"image/webp".to_string()
} else {
"image/jpeg".to_string()
}
}
fn build_generation_config(
temperature: Option<f32>,
top_p: Option<f32>,
max_tokens: Option<u32>,
stop: Vec<String>,
response_mime_type: Option<String>,
response_schema: Option<serde_json::Value>,
) -> Option<GeminiGenerationConfig> {
let has_anything = temperature.is_some()
|| top_p.is_some()
|| max_tokens.is_some()
|| !stop.is_empty()
|| response_mime_type.is_some()
|| response_schema.is_some();
if !has_anything {
return None;
}
Some(GeminiGenerationConfig {
temperature,
top_p,
max_output_tokens: max_tokens,
stop_sequences: stop,
response_mime_type,
response_schema,
})
}
fn translate_response_format(
rf: Option<tt_shared::messages::ResponseFormat>,
) -> (Option<String>, Option<serde_json::Value>) {
match rf {
None => (None, None),
Some(fmt) => {
if fmt.r#type == "json_schema" || fmt.r#type == "json_object" {
(Some("application/json".to_string()), fmt.json_schema)
} else {
(None, None)
}
}
}
}
fn translate_tool_choice(choice: ToolChoice) -> GeminiToolConfig {
match choice {
ToolChoice::Auto(s) if s == "none" => GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig {
mode: "NONE".to_string(),
allowed_function_names: vec![],
},
},
ToolChoice::Auto(s) if s == "required" => GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig {
mode: "ANY".to_string(), allowed_function_names: vec![],
},
},
ToolChoice::Auto(_) => GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig {
mode: "AUTO".to_string(),
allowed_function_names: vec![],
},
},
ToolChoice::Specific { function, .. } => GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig {
mode: "ANY".to_string(),
allowed_function_names: vec![function.name],
},
},
}
}
pub fn deserialize_response(
body: &str,
requested_model: &str,
) -> Result<ChatCompletionResponse, ProviderError> {
let resp: GeminiResponse =
serde_json::from_str(body).map_err(|e| ProviderError::Deserialize(e.to_string()))?;
Ok(translate_response(resp, requested_model))
}
pub fn translate_response(resp: GeminiResponse, requested_model: &str) -> ChatCompletionResponse {
let id = format!("chatcmpl-gem-{}", Uuid::new_v4());
let created = Utc::now().timestamp();
let model = resp
.model_version
.unwrap_or_else(|| requested_model.to_string());
let usage = resp.usage_metadata.map(translate_usage).unwrap_or_default();
if usage.prompt_tokens > BRACKET_THRESHOLD_TOKENS {
tracing::debug!(
prompt_tokens = usage.prompt_tokens,
threshold = BRACKET_THRESHOLD_TOKENS,
"prompt token count exceeds 200K bracket threshold; higher pricing tier applies"
);
}
let choice = if let Some(candidate) = resp.candidates.into_iter().next() {
let (message, finish_reason) = translate_candidate(candidate);
Choice {
index: 0,
message,
finish_reason,
}
} else {
Choice {
index: 0,
message: Message::Assistant {
content: None,
tool_calls: vec![],
name: None,
},
finish_reason: Some("stop".to_string()),
}
};
ChatCompletionResponse {
id,
object: "chat.completion".to_string(),
created,
model,
choices: vec![choice],
usage,
}
}
fn translate_candidate(candidate: GeminiCandidate) -> (Message, Option<String>) {
let mut text_parts: Vec<String> = Vec::new();
let mut tool_calls: Vec<ToolCall> = Vec::new();
if let Some(content) = candidate.content {
for part in content.parts {
match part {
GeminiPart::Text(t) => text_parts.push(t),
GeminiPart::FunctionCall(fc) => {
tool_calls.push(ToolCall {
id: format!("call_{}", Uuid::new_v4()),
r#type: "function".to_string(),
function: ToolCallFunction {
name: fc.name,
arguments: fc.args.to_string(),
},
});
}
_ => {} }
}
}
let message_content = if text_parts.is_empty() {
None
} else {
Some(MessageContent::Text(text_parts.join("")))
};
let finish_reason = if !tool_calls.is_empty() {
Some("tool_calls".to_string())
} else {
candidate
.finish_reason
.as_deref()
.map(map_finish_reason)
.map(str::to_string)
};
let message = Message::Assistant {
content: message_content,
tool_calls,
name: None,
};
(message, finish_reason)
}
pub fn map_finish_reason(reason: &str) -> &'static str {
match reason {
"STOP" => "stop",
"MAX_TOKENS" => "length",
"SAFETY" => "content_filter",
"RECITATION" => "content_filter",
"OTHER" => "stop",
_ => "stop",
}
}
pub fn translate_usage(u: GeminiUsageMetadata) -> Usage {
let prompt = u.prompt_token_count;
let mut completion = u.candidates_token_count;
let mut total = u.total_token_count;
if completion == 0 && total > prompt {
completion = total - prompt;
}
if total == 0 {
total = prompt + completion;
}
Usage {
prompt_tokens: prompt,
completion_tokens: completion,
total_tokens: total,
cached_tokens: u.cached_content_token_count,
cache_creation_input_tokens: None,
}
}
#[cfg(test)]
mod tests {
use super::GeminiUsageMetadata;
#[test]
fn translate_usage_reconciles_partial_metadata() {
let u = super::translate_usage(GeminiUsageMetadata {
prompt_token_count: 10,
candidates_token_count: 0,
total_token_count: 25,
cached_content_token_count: 0,
});
assert_eq!(u.completion_tokens, 15);
assert_eq!(u.total_tokens, 25);
let u = super::translate_usage(GeminiUsageMetadata {
prompt_token_count: 10,
candidates_token_count: 5,
total_token_count: 0,
cached_content_token_count: 0,
});
assert_eq!(u.total_tokens, 15);
let u = super::translate_usage(GeminiUsageMetadata {
prompt_token_count: 10,
candidates_token_count: 5,
total_token_count: 15,
cached_content_token_count: 2,
});
assert_eq!(
(
u.prompt_tokens,
u.completion_tokens,
u.total_tokens,
u.cached_tokens
),
(10, 5, 15, 2)
);
let u = super::translate_usage(GeminiUsageMetadata {
prompt_token_count: 10,
candidates_token_count: 5,
total_token_count: 100,
cached_content_token_count: 0,
});
assert_eq!((u.completion_tokens, u.total_tokens), (5, 100));
}
}