1use super::util;
7use super::{
8 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9 Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::Deserialize;
15use serde_json::{Value, json};
16
17const GOOGLE_OPENAI_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/openai";
18
19pub struct GoogleProvider {
20 client: Client,
21 api_key: String,
22}
23
24impl std::fmt::Debug for GoogleProvider {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("GoogleProvider")
27 .field("api_key", &"<REDACTED>")
28 .field("api_key_len", &self.api_key.len())
29 .finish()
30 }
31}
32
33impl GoogleProvider {
34 pub fn new(api_key: String) -> Result<Self> {
35 tracing::debug!(
36 provider = "google",
37 api_key_len = api_key.len(),
38 "Creating Google Gemini provider"
39 );
40 Ok(Self {
41 client: Client::new(),
42 api_key,
43 })
44 }
45
46 fn validate_api_key(&self) -> Result<()> {
47 if self.api_key.is_empty() {
48 anyhow::bail!("Google API key is empty");
49 }
50 Ok(())
51 }
52
53 fn convert_messages(messages: &[Message]) -> Vec<Value> {
54 messages
55 .iter()
56 .map(|msg| {
57 let role = match msg.role {
58 Role::System => "system",
59 Role::User => "user",
60 Role::Assistant => "assistant",
61 Role::Tool => "tool",
62 };
63
64 if msg.role == Role::Tool {
66 let mut content_parts: Vec<Value> = Vec::new();
67 let mut tool_call_id = None;
68 for part in &msg.content {
69 match part {
70 ContentPart::ToolResult {
71 tool_call_id: id,
72 content,
73 } => {
74 tool_call_id = Some(id.clone());
75 content_parts.push(json!(content));
76 }
77 ContentPart::Text { text } => {
78 content_parts.push(json!(text));
79 }
80 _ => {}
81 }
82 }
83 let content_str = content_parts
84 .iter()
85 .filter_map(|v| v.as_str())
86 .collect::<Vec<_>>()
87 .join("\n");
88 let mut m = json!({
89 "role": "tool",
90 "content": content_str,
91 });
92 if let Some(id) = tool_call_id {
93 m["tool_call_id"] = json!(id);
94 }
95 return m;
96 }
97
98 if msg.role == Role::Assistant {
100 let mut text_parts = Vec::new();
101 let mut tool_calls = Vec::new();
102 for part in &msg.content {
103 match part {
104 ContentPart::Text { text } => {
105 if !text.is_empty() {
106 text_parts.push(text.clone());
107 }
108 }
109 ContentPart::ToolCall {
110 id,
111 name,
112 arguments,
113 thought_signature,
114 } => {
115 let mut tc = json!({
116 "id": id,
117 "type": "function",
118 "function": {
119 "name": name,
120 "arguments": arguments
121 }
122 });
123 if let Some(sig) = thought_signature {
125 tc["extra_content"] = json!({
126 "google": {
127 "thought_signature": sig
128 }
129 });
130 }
131 tool_calls.push(tc);
132 }
133 _ => {}
134 }
135 }
136 let content = text_parts.join("\n");
137 let mut m = json!({"role": "assistant"});
138 if !content.is_empty() || tool_calls.is_empty() {
139 m["content"] = json!(content);
140 }
141 if !tool_calls.is_empty() {
142 m["tool_calls"] = json!(tool_calls);
143 }
144 return m;
145 }
146
147 let text: String = msg
148 .content
149 .iter()
150 .filter_map(|p| match p {
151 ContentPart::Text { text } => Some(text.clone()),
152 _ => None,
153 })
154 .collect::<Vec<_>>()
155 .join("\n");
156
157 json!({
158 "role": role,
159 "content": text
160 })
161 })
162 .collect()
163 }
164
165 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
166 tools
167 .iter()
168 .map(|t| {
169 json!({
170 "type": "function",
171 "function": {
172 "name": t.name,
173 "description": t.description,
174 "parameters": t.parameters
175 }
176 })
177 })
178 .collect()
179 }
180}
181
182#[derive(Debug, Deserialize)]
185struct ChatCompletion {
186 #[allow(dead_code)]
187 id: Option<String>,
188 choices: Vec<Choice>,
189 #[serde(default)]
190 usage: Option<ApiUsage>,
191}
192
193#[derive(Debug, Deserialize)]
194struct Choice {
195 message: ChoiceMessage,
196 #[serde(default)]
197 finish_reason: Option<String>,
198}
199
200#[derive(Debug, Deserialize)]
201struct ChoiceMessage {
202 #[allow(dead_code)]
203 role: Option<String>,
204 #[serde(default)]
205 content: Option<String>,
206 #[serde(default)]
207 tool_calls: Option<Vec<ToolCall>>,
208}
209
210#[derive(Debug, Deserialize)]
211struct ToolCall {
212 id: String,
213 function: FunctionCall,
214 #[serde(default)]
216 extra_content: Option<ExtraContent>,
217}
218
219#[derive(Debug, Deserialize)]
220struct ExtraContent {
221 google: Option<GoogleExtra>,
222}
223
224#[derive(Debug, Deserialize)]
225struct GoogleExtra {
226 thought_signature: Option<String>,
227}
228
229#[derive(Debug, Deserialize)]
230struct FunctionCall {
231 name: String,
232 arguments: String,
233}
234
235#[derive(Debug, Deserialize)]
236struct ApiUsage {
237 #[serde(default)]
238 prompt_tokens: usize,
239 #[serde(default)]
240 completion_tokens: usize,
241 #[serde(default)]
242 total_tokens: usize,
243}
244
245#[derive(Debug, Deserialize)]
246struct ApiError {
247 error: ApiErrorDetail,
248}
249
250#[derive(Debug, Deserialize)]
251struct ApiErrorDetail {
252 message: String,
253}
254
255#[async_trait]
256impl Provider for GoogleProvider {
257 fn name(&self) -> &str {
258 "google"
259 }
260
261 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
262 self.validate_api_key()?;
263
264 Ok(vec![
265 ModelInfo {
267 id: "gemini-3.1-pro-preview".to_string(),
268 name: "Gemini 3.1 Pro Preview".to_string(),
269 provider: "google".to_string(),
270 context_window: 1_048_576,
271 max_output_tokens: Some(65_536),
272 supports_vision: true,
273 supports_tools: true,
274 supports_streaming: true,
275 input_cost_per_million: Some(2.0),
276 output_cost_per_million: Some(12.0),
277 },
278 ModelInfo {
279 id: "gemini-3.1-pro-preview-customtools".to_string(),
280 name: "Gemini 3.1 Pro Preview (Custom Tools)".to_string(),
281 provider: "google".to_string(),
282 context_window: 1_048_576,
283 max_output_tokens: Some(65_536),
284 supports_vision: true,
285 supports_tools: true,
286 supports_streaming: true,
287 input_cost_per_million: Some(2.0),
288 output_cost_per_million: Some(12.0),
289 },
290 ModelInfo {
291 id: "gemini-3-pro-preview".to_string(),
292 name: "Gemini 3 Pro Preview".to_string(),
293 provider: "google".to_string(),
294 context_window: 1_048_576,
295 max_output_tokens: Some(65_536),
296 supports_vision: true,
297 supports_tools: true,
298 supports_streaming: true,
299 input_cost_per_million: Some(2.0),
300 output_cost_per_million: Some(12.0),
301 },
302 ModelInfo {
303 id: "gemini-3-flash-preview".to_string(),
304 name: "Gemini 3 Flash Preview".to_string(),
305 provider: "google".to_string(),
306 context_window: 1_048_576,
307 max_output_tokens: Some(65_536),
308 supports_vision: true,
309 supports_tools: true,
310 supports_streaming: true,
311 input_cost_per_million: Some(0.50),
312 output_cost_per_million: Some(3.0),
313 },
314 ModelInfo {
315 id: "gemini-3-pro-image-preview".to_string(),
316 name: "Gemini 3 Pro Image Preview".to_string(),
317 provider: "google".to_string(),
318 context_window: 65_536,
319 max_output_tokens: Some(32_768),
320 supports_vision: true,
321 supports_tools: false,
322 supports_streaming: false,
323 input_cost_per_million: Some(2.0),
324 output_cost_per_million: Some(134.0),
325 },
326 ModelInfo {
328 id: "gemini-2.5-pro".to_string(),
329 name: "Gemini 2.5 Pro".to_string(),
330 provider: "google".to_string(),
331 context_window: 1_048_576,
332 max_output_tokens: Some(65_536),
333 supports_vision: true,
334 supports_tools: true,
335 supports_streaming: true,
336 input_cost_per_million: Some(1.25),
337 output_cost_per_million: Some(10.0),
338 },
339 ModelInfo {
340 id: "gemini-2.5-flash".to_string(),
341 name: "Gemini 2.5 Flash".to_string(),
342 provider: "google".to_string(),
343 context_window: 1_048_576,
344 max_output_tokens: Some(65_536),
345 supports_vision: true,
346 supports_tools: true,
347 supports_streaming: true,
348 input_cost_per_million: Some(0.15),
349 output_cost_per_million: Some(0.60),
350 },
351 ModelInfo {
352 id: "gemini-2.0-flash".to_string(),
353 name: "Gemini 2.0 Flash".to_string(),
354 provider: "google".to_string(),
355 context_window: 1_048_576,
356 max_output_tokens: Some(8_192),
357 supports_vision: true,
358 supports_tools: true,
359 supports_streaming: true,
360 input_cost_per_million: Some(0.10),
361 output_cost_per_million: Some(0.40),
362 },
363 ])
364 }
365
366 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
367 tracing::debug!(
368 provider = "google",
369 model = %request.model,
370 message_count = request.messages.len(),
371 tool_count = request.tools.len(),
372 "Starting Google Gemini completion request"
373 );
374
375 self.validate_api_key()?;
376
377 let messages = Self::convert_messages(&request.messages);
378 let tools = Self::convert_tools(&request.tools);
379
380 let mut body = json!({
381 "model": request.model,
382 "messages": messages,
383 });
384
385 if let Some(max_tokens) = request.max_tokens {
386 body["max_tokens"] = json!(max_tokens);
387 }
388 if !tools.is_empty() {
389 body["tools"] = json!(tools);
390 }
391 if let Some(temp) = request.temperature {
392 body["temperature"] = json!(temp);
393 }
394 if let Some(top_p) = request.top_p {
395 body["top_p"] = json!(top_p);
396 }
397
398 tracing::debug!("Google Gemini request to model {}", request.model);
399
400 let url = format!("{}/chat/completions", GOOGLE_OPENAI_BASE);
402 let response = self
403 .client
404 .post(&url)
405 .header("content-type", "application/json")
406 .header("Authorization", format!("Bearer {}", self.api_key))
407 .json(&body)
408 .send()
409 .await
410 .context("Failed to send request to Google Gemini")?;
411
412 let status = response.status();
413 let text = response
414 .text()
415 .await
416 .context("Failed to read Google Gemini response")?;
417
418 if !status.is_success() {
419 if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
420 anyhow::bail!("Google Gemini API error: {}", err.error.message);
421 }
422 anyhow::bail!("Google Gemini API error: {} {}", status, text);
423 }
424
425 let completion: ChatCompletion = serde_json::from_str(&text).context(format!(
426 "Failed to parse Google Gemini response: {}",
427 util::truncate_bytes_safe(&text, 200)
428 ))?;
429
430 let choice = completion
431 .choices
432 .into_iter()
433 .next()
434 .context("No choices in Google Gemini response")?;
435
436 let mut content_parts = Vec::new();
437 let mut has_tool_calls = false;
438
439 if let Some(text) = choice.message.content
440 && !text.is_empty()
441 {
442 content_parts.push(ContentPart::Text { text });
443 }
444
445 if let Some(tool_calls) = choice.message.tool_calls {
446 has_tool_calls = !tool_calls.is_empty();
447 for tc in tool_calls {
448 let thought_signature = tc
450 .extra_content
451 .as_ref()
452 .and_then(|ec| ec.google.as_ref())
453 .and_then(|g| g.thought_signature.clone());
454
455 content_parts.push(ContentPart::ToolCall {
456 id: tc.id,
457 name: tc.function.name,
458 arguments: tc.function.arguments,
459 thought_signature,
460 });
461 }
462 }
463
464 let finish_reason = if has_tool_calls {
465 FinishReason::ToolCalls
466 } else {
467 match choice.finish_reason.as_deref() {
468 Some("stop") => FinishReason::Stop,
469 Some("length") => FinishReason::Length,
470 Some("tool_calls") => FinishReason::ToolCalls,
471 Some("content_filter") => FinishReason::ContentFilter,
472 _ => FinishReason::Stop,
473 }
474 };
475
476 let usage = completion.usage.as_ref();
477
478 Ok(CompletionResponse {
479 message: Message {
480 role: Role::Assistant,
481 content: content_parts,
482 },
483 usage: Usage {
484 prompt_tokens: usage.map(|u| u.prompt_tokens).unwrap_or(0),
485 completion_tokens: usage.map(|u| u.completion_tokens).unwrap_or(0),
486 total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
487 cache_read_tokens: None,
488 cache_write_tokens: None,
489 },
490 finish_reason,
491 })
492 }
493
494 async fn complete_stream(
495 &self,
496 request: CompletionRequest,
497 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
498 let response = self.complete(request).await?;
500 let text = response
501 .message
502 .content
503 .iter()
504 .filter_map(|p| match p {
505 ContentPart::Text { text } => Some(text.clone()),
506 _ => None,
507 })
508 .collect::<Vec<_>>()
509 .join("");
510
511 Ok(Box::pin(futures::stream::once(async move {
512 StreamChunk::Text(text)
513 })))
514 }
515}