Skip to main content

bamboo_agent/agent/llm/protocol/
gemini.rs

1//! Google Gemini protocol conversion implementation.
2//!
3//! Gemini API has a unique format:
4//! - Messages are called "contents"
5//! - Role is "user" or "model" (not "assistant")
6//! - Content is an array of "parts"
7//! - System instructions are separate from messages
8//!
9//! # Example Gemini Request
10//! ```json
11//! {
12//!   "contents": [
13//!     {
14//!       "role": "user",
15//!       "parts": [{"text": "Hello"}]
16//!     }
17//!   ],
18//!   "systemInstruction": {
19//!     "parts": [{"text": "You are helpful"}]
20//!   },
21//!   "tools": [...]
22//! }
23//! ```
24
25use crate::agent::core::tools::{FunctionCall, FunctionSchema, ToolCall, ToolSchema};
26use crate::agent::core::{Message, Role};
27use crate::agent::llm::protocol::{FromProvider, ProtocolError, ProtocolResult, ToProvider};
28use serde::{Deserialize, Serialize};
29use serde_json::Value;
30
31/// Gemini protocol converter.
32pub struct GeminiProtocol;
33
34// ============================================================================
35// Gemini API Types
36// ============================================================================
37
38/// Gemini request format
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct GeminiRequest {
41    /// Conversation history
42    pub contents: Vec<GeminiContent>,
43    /// System instructions (extracted from system messages)
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub system_instruction: Option<GeminiContent>,
46    /// Available tools
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub tools: Option<Vec<GeminiTool>>,
49    /// Generation config (temperature, max_tokens, etc.)
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub generation_config: Option<Value>,
52}
53
54/// Gemini message/content format
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct GeminiContent {
57    /// "user" or "model" (not "assistant")
58    pub role: String,
59    /// Array of content parts
60    pub parts: Vec<GeminiPart>,
61}
62
63/// Gemini content part
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct GeminiPart {
66    /// Text content
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub text: Option<String>,
69    /// Function call (for model responses)
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub function_call: Option<GeminiFunctionCall>,
72    /// Function response (for user/tool messages)
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub function_response: Option<GeminiFunctionResponse>,
75}
76
77/// Gemini function call
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct GeminiFunctionCall {
80    pub name: String,
81    pub args: Value,
82}
83
84/// Gemini function response
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct GeminiFunctionResponse {
87    pub name: String,
88    pub response: Value,
89}
90
91/// Gemini tool definition
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct GeminiTool {
94    pub function_declarations: Vec<GeminiFunctionDeclaration>,
95}
96
97/// Gemini function declaration (tool schema)
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct GeminiFunctionDeclaration {
100    pub name: String,
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub description: Option<String>,
103    pub parameters: Value,
104}
105
106/// Gemini response format
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct GeminiResponse {
109    pub candidates: Vec<GeminiCandidate>,
110}
111
112/// Gemini response candidate
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct GeminiCandidate {
115    pub content: GeminiContent,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub finish_reason: Option<String>,
118}
119
120// ============================================================================
121// Gemini → Internal (FromProvider)
122// ============================================================================
123
124impl FromProvider<GeminiContent> for Message {
125    fn from_provider(content: GeminiContent) -> ProtocolResult<Self> {
126        let role = match content.role.as_str() {
127            "user" => Role::User,
128            "model" => Role::Assistant,
129            "system" => Role::System,
130            _ => return Err(ProtocolError::InvalidRole(content.role)),
131        };
132
133        // Extract text and tool calls from parts
134        let mut text_parts = Vec::new();
135        let mut tool_calls = Vec::new();
136
137        for part in content.parts {
138            if let Some(text) = part.text {
139                text_parts.push(text);
140            }
141
142            if let Some(func_call) = part.function_call {
143                tool_calls.push(ToolCall {
144                    id: format!("gemini_{}", uuid::Uuid::new_v4()), // Gemini doesn't have IDs
145                    tool_type: "function".to_string(),
146                    function: FunctionCall {
147                        name: func_call.name,
148                        arguments: serde_json::to_string(&func_call.args).unwrap_or_default(),
149                    },
150                });
151            }
152
153            if let Some(func_response) = part.function_response {
154                // Tool response becomes a tool message
155                return Ok(Message::tool_result(
156                    format!("gemini_tool_{}", func_response.name),
157                    serde_json::to_string(&func_response.response).unwrap_or_default(),
158                ));
159            }
160        }
161
162        let content_text = text_parts.join("");
163
164        Ok(Message {
165            id: String::new(),
166            role,
167            content: content_text,
168            tool_calls: if tool_calls.is_empty() {
169                None
170            } else {
171                Some(tool_calls)
172            },
173            tool_call_id: None,
174            created_at: chrono::Utc::now(),
175        })
176    }
177}
178
179impl FromProvider<GeminiTool> for ToolSchema {
180    fn from_provider(tool: GeminiTool) -> ProtocolResult<Self> {
181        // Gemini tools can have multiple function declarations
182        // We'll convert the first one
183        let func = tool
184            .function_declarations
185            .into_iter()
186            .next()
187            .ok_or_else(|| ProtocolError::InvalidToolCall("Empty tool declarations".to_string()))?;
188
189        Ok(ToolSchema {
190            schema_type: "function".to_string(),
191            function: FunctionSchema {
192                name: func.name,
193                description: func.description.unwrap_or_default(),
194                parameters: func.parameters,
195            },
196        })
197    }
198}
199
200// ============================================================================
201// Internal → Gemini (ToProvider)
202// ============================================================================
203
204/// Convert internal messages to Gemini request format.
205///
206/// Note: Gemini extracts system messages to `system_instruction` field.
207pub struct GeminiRequestBuilder;
208
209impl ToProvider<GeminiRequest> for Vec<Message> {
210    fn to_provider(&self) -> ProtocolResult<GeminiRequest> {
211        let mut system_instruction = None;
212        let mut contents = Vec::new();
213
214        for msg in self {
215            match msg.role {
216                Role::System => {
217                    // System messages become system_instruction
218                    system_instruction = Some(GeminiContent {
219                        role: "system".to_string(),
220                        parts: vec![GeminiPart {
221                            text: Some(msg.content.clone()),
222                            function_call: None,
223                            function_response: None,
224                        }],
225                    });
226                }
227                _ => {
228                    contents.push(msg.to_provider()?);
229                }
230            }
231        }
232
233        Ok(GeminiRequest {
234            contents,
235            system_instruction,
236            tools: None,
237            generation_config: None,
238        })
239    }
240}
241
242impl ToProvider<GeminiContent> for Message {
243    fn to_provider(&self) -> ProtocolResult<GeminiContent> {
244        // Handle tool messages specially
245        if self.role == Role::Tool {
246            let tool_name = self
247                .tool_call_id
248                .clone()
249                .ok_or_else(|| ProtocolError::MissingField("tool_call_id".to_string()))?;
250
251            return Ok(GeminiContent {
252                role: "user".to_string(),
253                parts: vec![GeminiPart {
254                    text: None,
255                    function_call: None,
256                    function_response: Some(GeminiFunctionResponse {
257                        name: tool_name,
258                        response: serde_json::from_str(&self.content)
259                            .unwrap_or_else(|_| Value::String(self.content.clone())),
260                    }),
261                }],
262            });
263        }
264
265        let role = match self.role {
266            Role::User => "user",
267            Role::Assistant => "model",
268            Role::System => "system",
269            Role::Tool => "user", // Already handled above, but kept for completeness
270        };
271
272        let mut parts = Vec::new();
273
274        // Add text content
275        if !self.content.is_empty() {
276            parts.push(GeminiPart {
277                text: Some(self.content.clone()),
278                function_call: None,
279                function_response: None,
280            });
281        }
282
283        // Add tool calls as function_call parts
284        if let Some(tool_calls) = &self.tool_calls {
285            for tc in tool_calls {
286                let args: Value = serde_json::from_str(&tc.function.arguments)
287                    .unwrap_or_else(|_| Value::Object(serde_json::Map::new()));
288
289                parts.push(GeminiPart {
290                    text: None,
291                    function_call: Some(GeminiFunctionCall {
292                        name: tc.function.name.clone(),
293                        args,
294                    }),
295                    function_response: None,
296                });
297            }
298        }
299
300        // Ensure at least one part
301        if parts.is_empty() {
302            parts.push(GeminiPart {
303                text: Some(String::new()),
304                function_call: None,
305                function_response: None,
306            });
307        }
308
309        Ok(GeminiContent {
310            role: role.to_string(),
311            parts,
312        })
313    }
314}
315
316impl ToProvider<GeminiTool> for ToolSchema {
317    fn to_provider(&self) -> ProtocolResult<GeminiTool> {
318        Ok(GeminiTool {
319            function_declarations: vec![GeminiFunctionDeclaration {
320                name: self.function.name.clone(),
321                description: Some(self.function.description.clone()),
322                parameters: self.function.parameters.clone(),
323            }],
324        })
325    }
326}
327
328// ============================================================================
329// Batch conversion for tools
330// ============================================================================
331
332impl ToProvider<Vec<GeminiTool>> for Vec<ToolSchema> {
333    fn to_provider(&self) -> ProtocolResult<Vec<GeminiTool>> {
334        // Gemini groups all function declarations into a single tool
335        let declarations: Vec<GeminiFunctionDeclaration> = self
336            .iter()
337            .map(|schema| GeminiFunctionDeclaration {
338                name: schema.function.name.clone(),
339                description: Some(schema.function.description.clone()),
340                parameters: schema.function.parameters.clone(),
341            })
342            .collect();
343
344        if declarations.is_empty() {
345            Ok(vec![])
346        } else {
347            Ok(vec![GeminiTool {
348                function_declarations: declarations,
349            }])
350        }
351    }
352}
353
354// ============================================================================
355// Extension trait for ergonomic conversion
356// ============================================================================
357
358/// Extension trait for Gemini conversion
359pub trait GeminiExt: Sized {
360    fn into_internal(self) -> ProtocolResult<Message>;
361    fn to_gemini(&self) -> ProtocolResult<GeminiContent>;
362}
363
364impl GeminiExt for GeminiContent {
365    fn into_internal(self) -> ProtocolResult<Message> {
366        Message::from_provider(self)
367    }
368
369    fn to_gemini(&self) -> ProtocolResult<GeminiContent> {
370        Ok(self.clone())
371    }
372}
373
374impl GeminiExt for Message {
375    fn into_internal(self) -> ProtocolResult<Message> {
376        Ok(self)
377    }
378
379    fn to_gemini(&self) -> ProtocolResult<GeminiContent> {
380        self.to_provider()
381    }
382}
383
384// ============================================================================
385// Tests
386// ============================================================================
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_gemini_to_internal_user_message() {
394        let gemini = GeminiContent {
395            role: "user".to_string(),
396            parts: vec![GeminiPart {
397                text: Some("Hello".to_string()),
398                function_call: None,
399                function_response: None,
400            }],
401        };
402
403        let internal: Message = Message::from_provider(gemini).unwrap();
404
405        assert_eq!(internal.role, Role::User);
406        assert_eq!(internal.content, "Hello");
407        assert!(internal.tool_calls.is_none());
408    }
409
410    #[test]
411    fn test_internal_to_gemini_user_message() {
412        let internal = Message::user("Hello");
413
414        let gemini: GeminiContent = internal.to_provider().unwrap();
415
416        assert_eq!(gemini.role, "user");
417        assert_eq!(gemini.parts.len(), 1);
418        assert_eq!(gemini.parts[0].text, Some("Hello".to_string()));
419    }
420
421    #[test]
422    fn test_gemini_to_internal_model_message() {
423        let gemini = GeminiContent {
424            role: "model".to_string(),
425            parts: vec![GeminiPart {
426                text: Some("Hello there!".to_string()),
427                function_call: None,
428                function_response: None,
429            }],
430        };
431
432        let internal: Message = Message::from_provider(gemini).unwrap();
433
434        assert_eq!(internal.role, Role::Assistant);
435        assert_eq!(internal.content, "Hello there!");
436    }
437
438    #[test]
439    fn test_internal_to_gemini_with_tool_call() {
440        let tool_call = ToolCall {
441            id: "call_1".to_string(),
442            tool_type: "function".to_string(),
443            function: FunctionCall {
444                name: "search".to_string(),
445                arguments: r#"{"q":"test"}"#.to_string(),
446            },
447        };
448
449        let internal = Message::assistant("Let me search", Some(vec![tool_call]));
450
451        let gemini: GeminiContent = internal.to_provider().unwrap();
452
453        assert_eq!(gemini.role, "model");
454        assert_eq!(gemini.parts.len(), 2);
455        assert_eq!(gemini.parts[0].text, Some("Let me search".to_string()));
456        assert!(gemini.parts[1].function_call.is_some());
457
458        let func_call = gemini.parts[1].function_call.as_ref().unwrap();
459        assert_eq!(func_call.name, "search");
460        assert_eq!(func_call.args, serde_json::json!({"q": "test"}));
461    }
462
463    #[test]
464    fn test_gemini_to_internal_with_tool_call() {
465        let gemini = GeminiContent {
466            role: "model".to_string(),
467            parts: vec![GeminiPart {
468                text: None,
469                function_call: Some(GeminiFunctionCall {
470                    name: "search".to_string(),
471                    args: serde_json::json!({"q": "test"}),
472                }),
473                function_response: None,
474            }],
475        };
476
477        let internal: Message = Message::from_provider(gemini).unwrap();
478
479        assert_eq!(internal.role, Role::Assistant);
480        assert!(internal.tool_calls.is_some());
481
482        let tool_calls = internal.tool_calls.unwrap();
483        assert_eq!(tool_calls.len(), 1);
484        assert_eq!(tool_calls[0].function.name, "search");
485    }
486
487    #[test]
488    fn test_system_message_extraction() {
489        let messages = vec![Message::system("You are helpful"), Message::user("Hello")];
490
491        let request: GeminiRequest = messages.to_provider().unwrap();
492
493        assert!(request.system_instruction.is_some());
494        let sys = request.system_instruction.unwrap();
495        assert_eq!(sys.role, "system");
496        assert_eq!(sys.parts[0].text, Some("You are helpful".to_string()));
497
498        assert_eq!(request.contents.len(), 1);
499        assert_eq!(request.contents[0].role, "user");
500    }
501
502    #[test]
503    fn test_tool_response_conversion() {
504        let internal = Message::tool_result("search_tool", r#"{"result": "ok"}"#);
505
506        let gemini: GeminiContent = internal.to_provider().unwrap();
507
508        assert_eq!(gemini.role, "user");
509        assert!(gemini.parts[0].function_response.is_some());
510
511        let func_resp = gemini.parts[0].function_response.as_ref().unwrap();
512        assert_eq!(func_resp.name, "search_tool");
513    }
514
515    #[test]
516    fn test_tool_schema_conversion() {
517        let gemini_tool = GeminiTool {
518            function_declarations: vec![GeminiFunctionDeclaration {
519                name: "search".to_string(),
520                description: Some("Search the web".to_string()),
521                parameters: serde_json::json!({
522                    "type": "object",
523                    "properties": {
524                        "q": { "type": "string" }
525                    }
526                }),
527            }],
528        };
529
530        // Gemini → Internal
531        let internal_schema: ToolSchema = ToolSchema::from_provider(gemini_tool.clone()).unwrap();
532        assert_eq!(internal_schema.function.name, "search");
533
534        // Internal → Gemini
535        let roundtrip: GeminiTool = internal_schema.to_provider().unwrap();
536        assert_eq!(roundtrip.function_declarations.len(), 1);
537        assert_eq!(roundtrip.function_declarations[0].name, "search");
538    }
539
540    #[test]
541    fn test_multiple_tools_grouped() {
542        let tools = vec![
543            ToolSchema {
544                schema_type: "function".to_string(),
545                function: FunctionSchema {
546                    name: "search".to_string(),
547                    description: "Search".to_string(),
548                    parameters: serde_json::json!({"type": "object"}),
549                },
550            },
551            ToolSchema {
552                schema_type: "function".to_string(),
553                function: FunctionSchema {
554                    name: "read".to_string(),
555                    description: "Read file".to_string(),
556                    parameters: serde_json::json!({"type": "object"}),
557                },
558            },
559        ];
560
561        let gemini_tools: Vec<GeminiTool> = tools.to_provider().unwrap();
562
563        // Gemini groups all tools into one
564        assert_eq!(gemini_tools.len(), 1);
565        assert_eq!(gemini_tools[0].function_declarations.len(), 2);
566        assert_eq!(gemini_tools[0].function_declarations[0].name, "search");
567        assert_eq!(gemini_tools[0].function_declarations[1].name, "read");
568    }
569
570    #[test]
571    fn test_roundtrip_conversion() {
572        let original = Message::user("Hello, world!");
573
574        // Internal → Gemini
575        let gemini: GeminiContent = original.to_provider().unwrap();
576
577        // Gemini → Internal
578        let roundtrip: Message = Message::from_provider(gemini).unwrap();
579
580        assert_eq!(roundtrip.role, original.role);
581        assert_eq!(roundtrip.content, original.content);
582    }
583
584    #[test]
585    fn test_invalid_role_error() {
586        let gemini = GeminiContent {
587            role: "invalid_role".to_string(),
588            parts: vec![GeminiPart {
589                text: Some("test".to_string()),
590                function_call: None,
591                function_response: None,
592            }],
593        };
594
595        let result: ProtocolResult<Message> = Message::from_provider(gemini);
596        assert!(matches!(result, Err(ProtocolError::InvalidRole(_))));
597    }
598
599    #[test]
600    fn test_empty_parts_has_default() {
601        let internal = Message::assistant("", None);
602
603        let gemini: GeminiContent = internal.to_provider().unwrap();
604
605        // Should have at least one part with empty text
606        assert_eq!(gemini.parts.len(), 1);
607        assert_eq!(gemini.parts[0].text, Some(String::new()));
608    }
609}