codetether_agent/provider/
google.rs1use super::{
7 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
8 Role, StreamChunk, ToolDefinition, Usage,
9};
10use anyhow::{Context, Result};
11use async_trait::async_trait;
12use reqwest::Client;
13use serde::Deserialize;
14use serde_json::{Value, json};
15
16const GOOGLE_OPENAI_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/openai";
17
18pub struct GoogleProvider {
19 client: Client,
20 api_key: String,
21}
22
23impl std::fmt::Debug for GoogleProvider {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 f.debug_struct("GoogleProvider")
26 .field("api_key", &"<REDACTED>")
27 .field("api_key_len", &self.api_key.len())
28 .finish()
29 }
30}
31
32impl GoogleProvider {
33 pub fn new(api_key: String) -> Result<Self> {
34 tracing::debug!(
35 provider = "google",
36 api_key_len = api_key.len(),
37 "Creating Google Gemini provider"
38 );
39 Ok(Self {
40 client: Client::new(),
41 api_key,
42 })
43 }
44
45 fn validate_api_key(&self) -> Result<()> {
46 if self.api_key.is_empty() {
47 anyhow::bail!("Google API key is empty");
48 }
49 Ok(())
50 }
51
52 fn convert_messages(messages: &[Message]) -> Vec<Value> {
53 messages
54 .iter()
55 .map(|msg| {
56 let role = match msg.role {
57 Role::System => "system",
58 Role::User => "user",
59 Role::Assistant => "assistant",
60 Role::Tool => "tool",
61 };
62
63 if msg.role == Role::Tool {
65 let mut content_parts: Vec<Value> = Vec::new();
66 let mut tool_call_id = None;
67 for part in &msg.content {
68 match part {
69 ContentPart::ToolResult {
70 tool_call_id: id,
71 content,
72 } => {
73 tool_call_id = Some(id.clone());
74 content_parts.push(json!(content));
75 }
76 ContentPart::Text { text } => {
77 content_parts.push(json!(text));
78 }
79 _ => {}
80 }
81 }
82 let content_str = content_parts
83 .iter()
84 .filter_map(|v| v.as_str())
85 .collect::<Vec<_>>()
86 .join("\n");
87 let mut m = json!({
88 "role": "tool",
89 "content": content_str,
90 });
91 if let Some(id) = tool_call_id {
92 m["tool_call_id"] = json!(id);
93 }
94 return m;
95 }
96
97 if msg.role == Role::Assistant {
99 let mut text_parts = Vec::new();
100 let mut tool_calls = Vec::new();
101 for part in &msg.content {
102 match part {
103 ContentPart::Text { text } => {
104 if !text.is_empty() {
105 text_parts.push(text.clone());
106 }
107 }
108 ContentPart::ToolCall {
109 id,
110 name,
111 arguments,
112 } => {
113 tool_calls.push(json!({
114 "id": id,
115 "type": "function",
116 "function": {
117 "name": name,
118 "arguments": arguments
119 }
120 }));
121 }
122 _ => {}
123 }
124 }
125 let content = text_parts.join("\n");
126 let mut m = json!({"role": "assistant"});
127 if !content.is_empty() || tool_calls.is_empty() {
128 m["content"] = json!(content);
129 }
130 if !tool_calls.is_empty() {
131 m["tool_calls"] = json!(tool_calls);
132 }
133 return m;
134 }
135
136 let text: String = msg
137 .content
138 .iter()
139 .filter_map(|p| match p {
140 ContentPart::Text { text } => Some(text.clone()),
141 _ => None,
142 })
143 .collect::<Vec<_>>()
144 .join("\n");
145
146 json!({
147 "role": role,
148 "content": text
149 })
150 })
151 .collect()
152 }
153
154 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
155 tools
156 .iter()
157 .map(|t| {
158 json!({
159 "type": "function",
160 "function": {
161 "name": t.name,
162 "description": t.description,
163 "parameters": t.parameters
164 }
165 })
166 })
167 .collect()
168 }
169}
170
171#[derive(Debug, Deserialize)]
174struct ChatCompletion {
175 #[allow(dead_code)]
176 id: Option<String>,
177 choices: Vec<Choice>,
178 #[serde(default)]
179 usage: Option<ApiUsage>,
180}
181
182#[derive(Debug, Deserialize)]
183struct Choice {
184 message: ChoiceMessage,
185 #[serde(default)]
186 finish_reason: Option<String>,
187}
188
189#[derive(Debug, Deserialize)]
190struct ChoiceMessage {
191 #[allow(dead_code)]
192 role: Option<String>,
193 #[serde(default)]
194 content: Option<String>,
195 #[serde(default)]
196 tool_calls: Option<Vec<ToolCall>>,
197}
198
199#[derive(Debug, Deserialize)]
200struct ToolCall {
201 id: String,
202 function: FunctionCall,
203}
204
205#[derive(Debug, Deserialize)]
206struct FunctionCall {
207 name: String,
208 arguments: String,
209}
210
211#[derive(Debug, Deserialize)]
212struct ApiUsage {
213 #[serde(default)]
214 prompt_tokens: usize,
215 #[serde(default)]
216 completion_tokens: usize,
217 #[serde(default)]
218 total_tokens: usize,
219}
220
221#[derive(Debug, Deserialize)]
222struct ApiError {
223 error: ApiErrorDetail,
224}
225
226#[derive(Debug, Deserialize)]
227struct ApiErrorDetail {
228 message: String,
229}
230
231#[async_trait]
232impl Provider for GoogleProvider {
233 fn name(&self) -> &str {
234 "google"
235 }
236
237 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
238 self.validate_api_key()?;
239
240 Ok(vec![
241 ModelInfo {
242 id: "gemini-2.5-pro".to_string(),
243 name: "Gemini 2.5 Pro".to_string(),
244 provider: "google".to_string(),
245 context_window: 1_048_576,
246 max_output_tokens: Some(65_536),
247 supports_vision: true,
248 supports_tools: true,
249 supports_streaming: true,
250 input_cost_per_million: Some(1.25),
251 output_cost_per_million: Some(10.0),
252 },
253 ModelInfo {
254 id: "gemini-2.5-flash".to_string(),
255 name: "Gemini 2.5 Flash".to_string(),
256 provider: "google".to_string(),
257 context_window: 1_048_576,
258 max_output_tokens: Some(65_536),
259 supports_vision: true,
260 supports_tools: true,
261 supports_streaming: true,
262 input_cost_per_million: Some(0.15),
263 output_cost_per_million: Some(0.60),
264 },
265 ModelInfo {
266 id: "gemini-2.0-flash".to_string(),
267 name: "Gemini 2.0 Flash".to_string(),
268 provider: "google".to_string(),
269 context_window: 1_048_576,
270 max_output_tokens: Some(8_192),
271 supports_vision: true,
272 supports_tools: true,
273 supports_streaming: true,
274 input_cost_per_million: Some(0.10),
275 output_cost_per_million: Some(0.40),
276 },
277 ])
278 }
279
280 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
281 tracing::debug!(
282 provider = "google",
283 model = %request.model,
284 message_count = request.messages.len(),
285 tool_count = request.tools.len(),
286 "Starting Google Gemini completion request"
287 );
288
289 self.validate_api_key()?;
290
291 let messages = Self::convert_messages(&request.messages);
292 let tools = Self::convert_tools(&request.tools);
293
294 let mut body = json!({
295 "model": request.model,
296 "messages": messages,
297 });
298
299 if let Some(max_tokens) = request.max_tokens {
300 body["max_tokens"] = json!(max_tokens);
301 }
302 if !tools.is_empty() {
303 body["tools"] = json!(tools);
304 }
305 if let Some(temp) = request.temperature {
306 body["temperature"] = json!(temp);
307 }
308 if let Some(top_p) = request.top_p {
309 body["top_p"] = json!(top_p);
310 }
311
312 tracing::debug!("Google Gemini request to model {}", request.model);
313
314 let url = format!(
316 "{}/chat/completions?key={}",
317 GOOGLE_OPENAI_BASE, self.api_key
318 );
319 let response = self
320 .client
321 .post(&url)
322 .header("content-type", "application/json")
323 .json(&body)
324 .send()
325 .await
326 .context("Failed to send request to Google Gemini")?;
327
328 let status = response.status();
329 let text = response
330 .text()
331 .await
332 .context("Failed to read Google Gemini response")?;
333
334 if !status.is_success() {
335 if let Ok(err) = serde_json::from_str::<ApiError>(&text) {
336 anyhow::bail!("Google Gemini API error: {}", err.error.message);
337 }
338 anyhow::bail!("Google Gemini API error: {} {}", status, text);
339 }
340
341 let completion: ChatCompletion = serde_json::from_str(&text).context(format!(
342 "Failed to parse Google Gemini response: {}",
343 &text[..text.len().min(200)]
344 ))?;
345
346 let choice = completion
347 .choices
348 .into_iter()
349 .next()
350 .context("No choices in Google Gemini response")?;
351
352 let mut content_parts = Vec::new();
353 let mut has_tool_calls = false;
354
355 if let Some(text) = choice.message.content {
356 if !text.is_empty() {
357 content_parts.push(ContentPart::Text { text });
358 }
359 }
360
361 if let Some(tool_calls) = choice.message.tool_calls {
362 has_tool_calls = !tool_calls.is_empty();
363 for tc in tool_calls {
364 content_parts.push(ContentPart::ToolCall {
365 id: tc.id,
366 name: tc.function.name,
367 arguments: tc.function.arguments,
368 });
369 }
370 }
371
372 let finish_reason = if has_tool_calls {
373 FinishReason::ToolCalls
374 } else {
375 match choice.finish_reason.as_deref() {
376 Some("stop") => FinishReason::Stop,
377 Some("length") => FinishReason::Length,
378 Some("tool_calls") => FinishReason::ToolCalls,
379 Some("content_filter") => FinishReason::ContentFilter,
380 _ => FinishReason::Stop,
381 }
382 };
383
384 let usage = completion.usage.as_ref();
385
386 Ok(CompletionResponse {
387 message: Message {
388 role: Role::Assistant,
389 content: content_parts,
390 },
391 usage: Usage {
392 prompt_tokens: usage.map(|u| u.prompt_tokens).unwrap_or(0),
393 completion_tokens: usage.map(|u| u.completion_tokens).unwrap_or(0),
394 total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
395 cache_read_tokens: None,
396 cache_write_tokens: None,
397 },
398 finish_reason,
399 })
400 }
401
402 async fn complete_stream(
403 &self,
404 request: CompletionRequest,
405 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
406 let response = self.complete(request).await?;
408 let text = response
409 .message
410 .content
411 .iter()
412 .filter_map(|p| match p {
413 ContentPart::Text { text } => Some(text.clone()),
414 _ => None,
415 })
416 .collect::<Vec<_>>()
417 .join("");
418
419 Ok(Box::pin(futures::stream::once(async move {
420 StreamChunk::Text(text)
421 })))
422 }
423}