1use super::{
7 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
8 Role, StreamChunk, ToolDefinition, Usage,
9};
10use anyhow::{Context, Result};
11use async_trait::async_trait;
12use reqwest::Client;
13use serde::Deserialize;
14use serde_json::{Value, json};
15
16const GOOGLE_OPENAI_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/openai";
17
18pub struct GoogleProvider {
19 client: Client,
20 api_key: String,
21}
22
23impl std::fmt::Debug for GoogleProvider {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 f.debug_struct("GoogleProvider")
26 .field("api_key", &"<REDACTED>")
27 .field("api_key_len", &self.api_key.len())
28 .finish()
29 }
30}
31
32impl GoogleProvider {
33 pub fn new(api_key: String) -> Result<Self> {
34 tracing::debug!(
35 provider = "google",
36 api_key_len = api_key.len(),
37 "Creating Google Gemini provider"
38 );
39 Ok(Self {
40 client: Client::new(),
41 api_key,
42 })
43 }
44
45 fn validate_api_key(&self) -> Result<()> {
46 if self.api_key.is_empty() {
47 anyhow::bail!("Google API key is empty");
48 }
49 Ok(())
50 }
51
52 fn convert_messages(messages: &[Message]) -> Vec<Value> {
53 messages
54 .iter()
55 .map(|msg| {
56 let role = match msg.role {
57 Role::System => "system",
58 Role::User => "user",
59 Role::Assistant => "assistant",
60 Role::Tool => "tool",
61 };
62
63 if msg.role == Role::Tool {
65 let mut content_parts: Vec<Value> = Vec::new();
66 let mut tool_call_id = None;
67 for part in &msg.content {
68 match part {
69 ContentPart::ToolResult {
70 tool_call_id: id,
71 content,
72 } => {
73 tool_call_id = Some(id.clone());
74 content_parts.push(json!(content));
75 }
76 ContentPart::Text { text } => {
77 content_parts.push(json!(text));
78 }
79 _ => {}
80 }
81 }
82 let content_str = content_parts
83 .iter()
84 .filter_map(|v| v.as_str())
85 .collect::<Vec<_>>()
86 .join("\n");
87 let mut m = json!({
88 "role": "tool",
89 "content": content_str,
90 });
91 if let Some(id) = tool_call_id {
92 m["tool_call_id"] = json!(id);
93 }
94 return m;
95 }
96
97 if msg.role == Role::Assistant {
99 let mut text_parts = Vec::new();
100 let mut tool_calls = Vec::new();
101 for part in &msg.content {
102 match part {
103 ContentPart::Text { text } => {
104 if !text.is_empty() {
105 text_parts.push(text.clone());
106 }
107 }
108 ContentPart::ToolCall {
109 id,
110 name,
111 arguments,
112 thought_signature,
113 } => {
114 let mut tc = json!({
115 "id": id,
116 "type": "function",
117 "function": {
118 "name": name,
119 "arguments": arguments
120 }
121 });
122 if let Some(sig) = thought_signature {
124 tc["extra_content"] = json!({
125 "google": {
126 "thought_signature": sig
127 }
128 });
129 }
130 tool_calls.push(tc);
131 }
132 _ => {}
133 }
134 }
135 let content = text_parts.join("\n");
136 let mut m = json!({"role": "assistant"});
137 if !content.is_empty() || tool_calls.is_empty() {
138 m["content"] = json!(content);
139 }
140 if !tool_calls.is_empty() {
141 m["tool_calls"] = json!(tool_calls);
142 }
143 return m;
144 }
145
146 let text: String = msg
147 .content
148 .iter()
149 .filter_map(|p| match p {
150 ContentPart::Text { text } => Some(text.clone()),
151 _ => None,
152 })
153 .collect::<Vec<_>>()
154 .join("\n");
155
156 json!({
157 "role": role,
158 "content": text
159 })
160 })
161 .collect()
162 }
163
164 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
165 tools
166 .iter()
167 .map(|t| {
168 json!({
169 "type": "function",
170 "function": {
171 "name": t.name,
172 "description": t.description,
173 "parameters": t.parameters
174 }
175 })
176 })
177 .collect()
178 }
179}
180
181#[derive(Debug, Deserialize)]
184struct ChatCompletion {
185 #[allow(dead_code)]
186 id: Option<String>,
187 choices: Vec<Choice>,
188 #[serde(default)]
189 usage: Option<ApiUsage>,
190}
191
192#[derive(Debug, Deserialize)]
193struct Choice {
194 message: ChoiceMessage,
195 #[serde(default)]
196 finish_reason: Option<String>,
197}
198
199#[derive(Debug, Deserialize)]
200struct ChoiceMessage {
201 #[allow(dead_code)]
202 role: Option<String>,
203 #[serde(default)]
204 content: Option<String>,
205 #[serde(default)]
206 tool_calls: Option<Vec<ToolCall>>,
207}
208
209#[derive(Debug, Deserialize)]
210struct ToolCall {
211 id: String,
212 function: FunctionCall,
213 #[serde(default)]
215 extra_content: Option<ExtraContent>,
216}
217
218#[derive(Debug, Deserialize)]
219struct ExtraContent {
220 google: Option<GoogleExtra>,
221}
222
223#[derive(Debug, Deserialize)]
224struct GoogleExtra {
225 thought_signature: Option<String>,
226}
227
228#[derive(Debug, Deserialize)]
229struct FunctionCall {
230 name: String,
231 arguments: String,
232}
233
234#[derive(Debug, Deserialize)]
235struct ApiUsage {
236 #[serde(default)]
237 prompt_tokens: usize,
238 #[serde(default)]
239 completion_tokens: usize,
240 #[serde(default)]
241 total_tokens: usize,
242}
243
244#[derive(Debug, Deserialize)]
245struct ApiError {
246 error: ApiErrorDetail,
247}
248
249#[derive(Debug, Deserialize)]
250struct ApiErrorDetail {
251 message: String,
252}
253
254#[async_trait]
255impl Provider for GoogleProvider {
256 fn name(&self) -> &str {
257 "google"
258 }
259
260 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
261 self.validate_api_key()?;
262
263 Ok(vec![
264 ModelInfo {
266 id: "gemini-3.1-pro-preview".to_string(),
267 name: "Gemini 3.1 Pro Preview".to_string(),
268 provider: "google".to_string(),
269 context_window: 1_048_576,
270 max_output_tokens: Some(65_536),
271 supports_vision: true,
272 supports_tools: true,
273 supports_streaming: true,
274 input_cost_per_million: Some(2.0),
275 output_cost_per_million: Some(12.0),
276 },
277 ModelInfo {
278 id: "gemini-3.1-pro-preview-customtools".to_string(),
279 name: "Gemini 3.1 Pro Preview (Custom Tools)".to_string(),
280 provider: "google".to_string(),
281 context_window: 1_048_576,
282 max_output_tokens: Some(65_536),
283 supports_vision: true,
284 supports_tools: true,
285 supports_streaming: true,
286 input_cost_per_million: Some(2.0),
287 output_cost_per_million: Some(12.0),
288 },
289 ModelInfo {
290 id: "gemini-3-pro-preview".to_string(),
291 name: "Gemini 3 Pro Preview".to_string(),
292 provider: "google".to_string(),
293 context_window: 1_048_576,
294 max_output_tokens: Some(65_536),
295 supports_vision: true,
296 supports_tools: true,
297 supports_streaming: true,
298 input_cost_per_million: Some(2.0),
299 output_cost_per_million: Some(12.0),
300 },
301 ModelInfo {
302 id: "gemini-3-flash-preview".to_string(),
303 name: "Gemini 3 Flash Preview".to_string(),
304 provider: "google".to_string(),
305 context_window: 1_048_576,
306 max_output_tokens: Some(65_536),
307 supports_vision: true,
308 supports_tools: true,
309 supports_streaming: true,
310 input_cost_per_million: Some(0.50),
311 output_cost_per_million: Some(3.0),
312 },
313 ModelInfo {
314 id: "gemini-3-pro-image-preview".to_string(),
315 name: "Gemini 3 Pro Image Preview".to_string(),
316 provider: "google".to_string(),
317 context_window: 65_536,
318 max_output_tokens: Some(32_768),
319 supports_vision: true,
320 supports_tools: false,
321 supports_streaming: false,
322 input_cost_per_million: Some(2.0),
323 output_cost_per_million: Some(134.0),
324 },
325 ModelInfo {
327 id: "gemini-2.5-pro".to_string(),
328 name: "Gemini 2.5 Pro".to_string(),
329 provider: "google".to_string(),
330 context_window: 1_048_576,
331 max_output_tokens: Some(65_536),
332 supports_vision: true,
333 supports_tools: true,
334 supports_streaming: true,
335 input_cost_per_million: Some(1.25),
336 output_cost_per_million: Some(10.0),
337 },
338 ModelInfo {
339 id: "gemini-2.5-flash".to_string(),
340 name: "Gemini 2.5 Flash".to_string(),
341 provider: "google".to_string(),
342 context_window: 1_048_576,
343 max_output_tokens: Some(65_536),
344 supports_vision: true,
345 supports_tools: true,
346 supports_streaming: true,
347 input_cost_per_million: Some(0.15),
348 output_cost_per_million: Some(0.60),
349 },
350 ModelInfo {
351 id: "gemini-2.0-flash".to_string(),
352 name: "Gemini 2.0 Flash".to_string(),
353 provider: "google".to_string(),
354 context_window: 1_048_576,
355 max_output_tokens: Some(8_192),
356 supports_vision: true,
357 supports_tools: true,
358 supports_streaming: true,
359 input_cost_per_million: Some(0.10),
360 output_cost_per_million: Some(0.40),
361 },
362 ])
363 }
364
365 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
366 tracing::debug!(
367 provider = "google",
368 model = %request.model,
369 message_count = request.messages.len(),
370 tool_count = request.tools.len(),
371 "Starting Google Gemini completion request"
372 );
373
374 self.validate_api_key()?;
375
376 let messages = Self::convert_messages(&request.messages);
377 let tools = Self::convert_tools(&request.tools);
378
379 let mut body = json!({
380 "model": request.model,
381 "messages": messages,
382 });
383
384 if let Some(max_tokens) = request.max_tokens {
385 body["max_tokens"] = json!(max_tokens);
386 }
387 if !tools.is_empty() {
388 body["tools"] = json!(tools);
389 }
390 if let Some(temp) = request.temperature {
391 body["temperature"] = json!(temp);
392 }
393 if let Some(top_p) = request.top_p {
394 body["top_p"] = json!(top_p);
395 }
396
397 tracing::debug!("Google Gemini request to model {}", request.model);
398
399 let url = format!("{}/chat/completions", GOOGLE_OPENAI_BASE);
401 let response = self
402 .client
403 .post(&url)
404 .header("content-type", "application/json")
405 .header("Authorization", format!("Bearer {}", self.api_key))
406 .json(&body)
407 .send()
408 .await
409 .context("Failed to send request to Google Gemini")?;
410
411 let status = response.status();
412 let text = response
413 .text()
414 .await
415 .context("Failed to read Google Gemini response")?;
416
417 if !status.is_success() {
418 if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
419 anyhow::bail!("Google Gemini API error: {}", err.error.message);
420 }
421 anyhow::bail!("Google Gemini API error: {} {}", status, text);
422 }
423
424 let completion: ChatCompletion = serde_json::from_str(&text).context(format!(
425 "Failed to parse Google Gemini response: {}",
426 &text[..text.len().min(200)]
427 ))?;
428
429 let choice = completion
430 .choices
431 .into_iter()
432 .next()
433 .context("No choices in Google Gemini response")?;
434
435 let mut content_parts = Vec::new();
436 let mut has_tool_calls = false;
437
438 if let Some(text) = choice.message.content {
439 if !text.is_empty() {
440 content_parts.push(ContentPart::Text { text });
441 }
442 }
443
444 if let Some(tool_calls) = choice.message.tool_calls {
445 has_tool_calls = !tool_calls.is_empty();
446 for tc in tool_calls {
447 let thought_signature = tc
449 .extra_content
450 .as_ref()
451 .and_then(|ec| ec.google.as_ref())
452 .and_then(|g| g.thought_signature.clone());
453
454 content_parts.push(ContentPart::ToolCall {
455 id: tc.id,
456 name: tc.function.name,
457 arguments: tc.function.arguments,
458 thought_signature,
459 });
460 }
461 }
462
463 let finish_reason = if has_tool_calls {
464 FinishReason::ToolCalls
465 } else {
466 match choice.finish_reason.as_deref() {
467 Some("stop") => FinishReason::Stop,
468 Some("length") => FinishReason::Length,
469 Some("tool_calls") => FinishReason::ToolCalls,
470 Some("content_filter") => FinishReason::ContentFilter,
471 _ => FinishReason::Stop,
472 }
473 };
474
475 let usage = completion.usage.as_ref();
476
477 Ok(CompletionResponse {
478 message: Message {
479 role: Role::Assistant,
480 content: content_parts,
481 },
482 usage: Usage {
483 prompt_tokens: usage.map(|u| u.prompt_tokens).unwrap_or(0),
484 completion_tokens: usage.map(|u| u.completion_tokens).unwrap_or(0),
485 total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
486 cache_read_tokens: None,
487 cache_write_tokens: None,
488 },
489 finish_reason,
490 })
491 }
492
493 async fn complete_stream(
494 &self,
495 request: CompletionRequest,
496 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
497 let response = self.complete(request).await?;
499 let text = response
500 .message
501 .content
502 .iter()
503 .filter_map(|p| match p {
504 ContentPart::Text { text } => Some(text.clone()),
505 _ => None,
506 })
507 .collect::<Vec<_>>()
508 .join("");
509
510 Ok(Box::pin(futures::stream::once(async move {
511 StreamChunk::Text(text)
512 })))
513 }
514}