1use super::util;
7use super::{
8 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9 Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::Deserialize;
15use serde_json::{Value, json};
16
17const GOOGLE_OPENAI_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/openai";
18
19pub struct GoogleProvider {
20 client: Client,
21 api_key: String,
22}
23
24impl std::fmt::Debug for GoogleProvider {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("GoogleProvider")
27 .field("api_key", &"<REDACTED>")
28 .field("api_key_len", &self.api_key.len())
29 .finish()
30 }
31}
32
33impl GoogleProvider {
34 pub fn new(api_key: String) -> Result<Self> {
35 tracing::debug!(
36 provider = "google",
37 api_key_len = api_key.len(),
38 "Creating Google Gemini provider"
39 );
40 Ok(Self {
41 client: crate::provider::shared_http::shared_client().clone(),
42 api_key,
43 })
44 }
45
46 fn validate_api_key(&self) -> Result<()> {
47 if self.api_key.is_empty() {
48 anyhow::bail!("Google API key is empty");
49 }
50 Ok(())
51 }
52
53 fn convert_messages(messages: &[Message]) -> Vec<Value> {
54 messages
55 .iter()
56 .map(|msg| {
57 let role = match msg.role {
58 Role::System => "system",
59 Role::User => "user",
60 Role::Assistant => "assistant",
61 Role::Tool => "tool",
62 };
63
64 if msg.role == Role::Tool {
66 let mut content_parts: Vec<Value> = Vec::new();
67 let mut tool_call_id = None;
68 for part in &msg.content {
69 match part {
70 ContentPart::ToolResult {
71 tool_call_id: id,
72 content,
73 } => {
74 tool_call_id = Some(id.clone());
75 content_parts.push(json!(content));
76 }
77 ContentPart::Text { text } => {
78 content_parts.push(json!(text));
79 }
80 _ => {}
81 }
82 }
83 let content_str = content_parts
84 .iter()
85 .filter_map(|v| v.as_str())
86 .collect::<Vec<_>>()
87 .join("\n");
88 let mut m = json!({
89 "role": "tool",
90 "content": content_str,
91 });
92 if let Some(id) = tool_call_id {
93 m["tool_call_id"] = json!(id);
94 }
95 return m;
96 }
97
98 if msg.role == Role::Assistant {
100 let mut text_parts = Vec::new();
101 let mut tool_calls = Vec::new();
102 for part in &msg.content {
103 match part {
104 ContentPart::Text { text } => {
105 if !text.is_empty() {
106 text_parts.push(text.clone());
107 }
108 }
109 ContentPart::ToolCall {
110 id,
111 name,
112 arguments,
113 thought_signature,
114 } => {
115 let mut tc = json!({
116 "id": id,
117 "type": "function",
118 "function": {
119 "name": name,
120 "arguments": arguments
121 }
122 });
123 if let Some(sig) = thought_signature {
125 tc["extra_content"] = json!({
126 "google": {
127 "thought_signature": sig
128 }
129 });
130 }
131 tool_calls.push(tc);
132 }
133 _ => {}
134 }
135 }
136 let content = text_parts.join("\n");
137 let mut m = json!({"role": "assistant"});
138 if !content.is_empty() || tool_calls.is_empty() {
139 m["content"] = json!(content);
140 }
141 if !tool_calls.is_empty() {
142 m["tool_calls"] = json!(tool_calls);
143 }
144 return m;
145 }
146
147 let text: String = msg
148 .content
149 .iter()
150 .filter_map(|p| match p {
151 ContentPart::Text { text } => Some(text.clone()),
152 _ => None,
153 })
154 .collect::<Vec<_>>()
155 .join("\n");
156
157 json!({
158 "role": role,
159 "content": text
160 })
161 })
162 .collect()
163 }
164
165 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
166 tools
167 .iter()
168 .map(|t| {
169 json!({
170 "type": "function",
171 "function": {
172 "name": t.name,
173 "description": t.description,
174 "parameters": t.parameters
175 }
176 })
177 })
178 .collect()
179 }
180}
181
182#[derive(Debug, Deserialize)]
185struct ChatCompletion {
186 #[allow(dead_code)]
187 id: Option<String>,
188 choices: Vec<Choice>,
189 #[serde(default)]
190 usage: Option<ApiUsage>,
191}
192
193#[derive(Debug, Deserialize)]
194struct Choice {
195 message: ChoiceMessage,
196 #[serde(default)]
197 finish_reason: Option<String>,
198}
199
200#[derive(Debug, Deserialize)]
201struct ChoiceMessage {
202 #[allow(dead_code)]
203 role: Option<String>,
204 #[serde(default)]
205 content: Option<String>,
206 #[serde(default)]
207 tool_calls: Option<Vec<ToolCall>>,
208}
209
210#[derive(Debug, Deserialize)]
211struct ToolCall {
212 id: String,
213 function: FunctionCall,
214 #[serde(default)]
216 extra_content: Option<ExtraContent>,
217}
218
219#[derive(Debug, Deserialize)]
220struct ExtraContent {
221 google: Option<GoogleExtra>,
222}
223
224#[derive(Debug, Deserialize)]
225struct GoogleExtra {
226 thought_signature: Option<String>,
227}
228
229#[derive(Debug, Deserialize)]
230struct FunctionCall {
231 name: String,
232 arguments: String,
233}
234
235#[derive(Debug, Deserialize)]
236struct ApiUsage {
237 #[serde(default)]
238 prompt_tokens: usize,
239 #[serde(default)]
240 completion_tokens: usize,
241 #[serde(default)]
242 total_tokens: usize,
243 #[serde(default, rename = "cached_tokens")]
248 cached_tokens: Option<usize>,
249 #[serde(default, rename = "prompt_tokens_details")]
250 prompt_tokens_details: Option<PromptTokenDetails>,
251}
252
253#[derive(Debug, Deserialize, Default)]
254struct PromptTokenDetails {
255 #[serde(default)]
256 cached_tokens: usize,
257}
258
259impl ApiUsage {
260 fn cached_input_tokens(&self) -> usize {
261 self.cached_tokens.unwrap_or_else(|| {
262 self.prompt_tokens_details
263 .as_ref()
264 .map(|d| d.cached_tokens)
265 .unwrap_or(0)
266 })
267 }
268}
269
270#[derive(Debug, Deserialize)]
271struct ApiError {
272 error: ApiErrorDetail,
273}
274
275#[derive(Debug, Deserialize)]
276struct ApiErrorDetail {
277 message: String,
278}
279
280#[async_trait]
281impl Provider for GoogleProvider {
282 fn name(&self) -> &str {
283 "google"
284 }
285
286 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
287 self.validate_api_key()?;
288
289 Ok(vec![
290 ModelInfo {
292 id: "gemini-3.1-pro-preview".to_string(),
293 name: "Gemini 3.1 Pro Preview".to_string(),
294 provider: "google".to_string(),
295 context_window: 1_048_576,
296 max_output_tokens: Some(65_536),
297 supports_vision: true,
298 supports_tools: true,
299 supports_streaming: true,
300 input_cost_per_million: Some(2.0),
301 output_cost_per_million: Some(12.0),
302 },
303 ModelInfo {
304 id: "gemini-3.1-pro-preview-customtools".to_string(),
305 name: "Gemini 3.1 Pro Preview (Custom Tools)".to_string(),
306 provider: "google".to_string(),
307 context_window: 1_048_576,
308 max_output_tokens: Some(65_536),
309 supports_vision: true,
310 supports_tools: true,
311 supports_streaming: true,
312 input_cost_per_million: Some(2.0),
313 output_cost_per_million: Some(12.0),
314 },
315 ModelInfo {
316 id: "gemini-3-pro-preview".to_string(),
317 name: "Gemini 3 Pro Preview".to_string(),
318 provider: "google".to_string(),
319 context_window: 1_048_576,
320 max_output_tokens: Some(65_536),
321 supports_vision: true,
322 supports_tools: true,
323 supports_streaming: true,
324 input_cost_per_million: Some(2.0),
325 output_cost_per_million: Some(12.0),
326 },
327 ModelInfo {
328 id: "gemini-3-flash-preview".to_string(),
329 name: "Gemini 3 Flash Preview".to_string(),
330 provider: "google".to_string(),
331 context_window: 1_048_576,
332 max_output_tokens: Some(65_536),
333 supports_vision: true,
334 supports_tools: true,
335 supports_streaming: true,
336 input_cost_per_million: Some(0.50),
337 output_cost_per_million: Some(3.0),
338 },
339 ModelInfo {
340 id: "gemini-3-pro-image-preview".to_string(),
341 name: "Gemini 3 Pro Image Preview".to_string(),
342 provider: "google".to_string(),
343 context_window: 65_536,
344 max_output_tokens: Some(32_768),
345 supports_vision: true,
346 supports_tools: false,
347 supports_streaming: false,
348 input_cost_per_million: Some(2.0),
349 output_cost_per_million: Some(134.0),
350 },
351 ModelInfo {
353 id: "gemini-2.5-pro".to_string(),
354 name: "Gemini 2.5 Pro".to_string(),
355 provider: "google".to_string(),
356 context_window: 1_048_576,
357 max_output_tokens: Some(65_536),
358 supports_vision: true,
359 supports_tools: true,
360 supports_streaming: true,
361 input_cost_per_million: Some(1.25),
362 output_cost_per_million: Some(10.0),
363 },
364 ModelInfo {
365 id: "gemini-2.5-flash".to_string(),
366 name: "Gemini 2.5 Flash".to_string(),
367 provider: "google".to_string(),
368 context_window: 1_048_576,
369 max_output_tokens: Some(65_536),
370 supports_vision: true,
371 supports_tools: true,
372 supports_streaming: true,
373 input_cost_per_million: Some(0.15),
374 output_cost_per_million: Some(0.60),
375 },
376 ModelInfo {
377 id: "gemini-2.0-flash".to_string(),
378 name: "Gemini 2.0 Flash".to_string(),
379 provider: "google".to_string(),
380 context_window: 1_048_576,
381 max_output_tokens: Some(8_192),
382 supports_vision: true,
383 supports_tools: true,
384 supports_streaming: true,
385 input_cost_per_million: Some(0.10),
386 output_cost_per_million: Some(0.40),
387 },
388 ])
389 }
390
391 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
392 tracing::debug!(
393 provider = "google",
394 model = %request.model,
395 message_count = request.messages.len(),
396 tool_count = request.tools.len(),
397 "Starting Google Gemini completion request"
398 );
399
400 self.validate_api_key()?;
401
402 let messages = Self::convert_messages(&request.messages);
403 let tools = Self::convert_tools(&request.tools);
404
405 let mut body = json!({
406 "model": request.model,
407 "messages": messages,
408 });
409
410 if let Some(max_tokens) = request.max_tokens {
411 body["max_tokens"] = json!(max_tokens);
412 }
413 if !tools.is_empty() {
414 body["tools"] = json!(tools);
415 }
416 if let Some(temp) = request.temperature {
417 body["temperature"] = json!(temp);
418 }
419 if let Some(top_p) = request.top_p {
420 body["top_p"] = json!(top_p);
421 }
422
423 tracing::debug!("Google Gemini request to model {}", request.model);
424
425 let url = format!("{}/chat/completions", GOOGLE_OPENAI_BASE);
427 let response = self
428 .client
429 .post(&url)
430 .header("content-type", "application/json")
431 .header("Authorization", format!("Bearer {}", self.api_key))
432 .json(&body)
433 .send()
434 .await
435 .context("Failed to send request to Google Gemini")?;
436
437 let status = response.status();
438 let text = response
439 .text()
440 .await
441 .context("Failed to read Google Gemini response")?;
442
443 if !status.is_success() {
444 if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
445 anyhow::bail!("Google Gemini API error: {}", err.error.message);
446 }
447 anyhow::bail!("Google Gemini API error: {} {}", status, text);
448 }
449
450 let completion: ChatCompletion = serde_json::from_str(&text).context(format!(
451 "Failed to parse Google Gemini response: {}",
452 util::truncate_bytes_safe(&text, 200)
453 ))?;
454
455 let choice = completion
456 .choices
457 .into_iter()
458 .next()
459 .context("No choices in Google Gemini response")?;
460
461 let mut content_parts = Vec::new();
462 let mut has_tool_calls = false;
463
464 if let Some(text) = choice.message.content
465 && !text.is_empty()
466 {
467 content_parts.push(ContentPart::Text { text });
468 }
469
470 if let Some(tool_calls) = choice.message.tool_calls {
471 has_tool_calls = !tool_calls.is_empty();
472 for tc in tool_calls {
473 let thought_signature = tc
475 .extra_content
476 .as_ref()
477 .and_then(|ec| ec.google.as_ref())
478 .and_then(|g| g.thought_signature.clone());
479
480 content_parts.push(ContentPart::ToolCall {
481 id: tc.id,
482 name: tc.function.name,
483 arguments: tc.function.arguments,
484 thought_signature,
485 });
486 }
487 }
488
489 let finish_reason = if has_tool_calls {
490 FinishReason::ToolCalls
491 } else {
492 match choice.finish_reason.as_deref() {
493 Some("stop") => FinishReason::Stop,
494 Some("length") => FinishReason::Length,
495 Some("tool_calls") => FinishReason::ToolCalls,
496 Some("content_filter") => FinishReason::ContentFilter,
497 _ => FinishReason::Stop,
498 }
499 };
500
501 let usage = completion.usage.as_ref();
502
503 Ok(CompletionResponse {
504 message: Message {
505 role: Role::Assistant,
506 content: content_parts,
507 },
508 usage: Usage {
509 prompt_tokens: usage
510 .map(|u| u.prompt_tokens.saturating_sub(u.cached_input_tokens()))
511 .unwrap_or(0),
512 completion_tokens: usage.map(|u| u.completion_tokens).unwrap_or(0),
513 total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
514 cache_read_tokens: usage.map(ApiUsage::cached_input_tokens).filter(|&n| n > 0),
515 cache_write_tokens: None,
516 },
517 finish_reason,
518 })
519 }
520
521 async fn complete_stream(
522 &self,
523 request: CompletionRequest,
524 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
525 let response = self.complete(request).await?;
527 let text = response
528 .message
529 .content
530 .iter()
531 .filter_map(|p| match p {
532 ContentPart::Text { text } => Some(text.clone()),
533 _ => None,
534 })
535 .collect::<Vec<_>>()
536 .join("");
537
538 Ok(Box::pin(futures::stream::once(async move {
539 StreamChunk::Text(text)
540 })))
541 }
542}