1use crate::agent::core::tools::{FunctionCall, FunctionSchema, ToolCall, ToolSchema};
26use crate::agent::core::{Message, Role};
27use crate::agent::llm::protocol::{FromProvider, ProtocolError, ProtocolResult, ToProvider};
28use serde::{Deserialize, Serialize};
29use serde_json::Value;
30
31pub struct GeminiProtocol;
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct GeminiRequest {
41 pub contents: Vec<GeminiContent>,
43 #[serde(skip_serializing_if = "Option::is_none")]
45 pub system_instruction: Option<GeminiContent>,
46 #[serde(skip_serializing_if = "Option::is_none")]
48 pub tools: Option<Vec<GeminiTool>>,
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub generation_config: Option<Value>,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct GeminiContent {
57 pub role: String,
59 pub parts: Vec<GeminiPart>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct GeminiPart {
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub text: Option<String>,
69 #[serde(skip_serializing_if = "Option::is_none")]
71 pub function_call: Option<GeminiFunctionCall>,
72 #[serde(skip_serializing_if = "Option::is_none")]
74 pub function_response: Option<GeminiFunctionResponse>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct GeminiFunctionCall {
80 pub name: String,
81 pub args: Value,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct GeminiFunctionResponse {
87 pub name: String,
88 pub response: Value,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct GeminiTool {
94 pub function_declarations: Vec<GeminiFunctionDeclaration>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct GeminiFunctionDeclaration {
100 pub name: String,
101 #[serde(skip_serializing_if = "Option::is_none")]
102 pub description: Option<String>,
103 pub parameters: Value,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct GeminiResponse {
109 pub candidates: Vec<GeminiCandidate>,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct GeminiCandidate {
115 pub content: GeminiContent,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 pub finish_reason: Option<String>,
118}
119
120impl FromProvider<GeminiContent> for Message {
125 fn from_provider(content: GeminiContent) -> ProtocolResult<Self> {
126 let role = match content.role.as_str() {
127 "user" => Role::User,
128 "model" => Role::Assistant,
129 "system" => Role::System,
130 _ => return Err(ProtocolError::InvalidRole(content.role)),
131 };
132
133 let mut text_parts = Vec::new();
135 let mut tool_calls = Vec::new();
136
137 for part in content.parts {
138 if let Some(text) = part.text {
139 text_parts.push(text);
140 }
141
142 if let Some(func_call) = part.function_call {
143 tool_calls.push(ToolCall {
144 id: format!("gemini_{}", uuid::Uuid::new_v4()), tool_type: "function".to_string(),
146 function: FunctionCall {
147 name: func_call.name,
148 arguments: serde_json::to_string(&func_call.args).unwrap_or_default(),
149 },
150 });
151 }
152
153 if let Some(func_response) = part.function_response {
154 return Ok(Message::tool_result(
156 format!("gemini_tool_{}", func_response.name),
157 serde_json::to_string(&func_response.response).unwrap_or_default(),
158 ));
159 }
160 }
161
162 let content_text = text_parts.join("");
163
164 Ok(Message {
165 id: String::new(),
166 role,
167 content: content_text,
168 tool_calls: if tool_calls.is_empty() {
169 None
170 } else {
171 Some(tool_calls)
172 },
173 tool_call_id: None,
174 created_at: chrono::Utc::now(),
175 })
176 }
177}
178
179impl FromProvider<GeminiTool> for ToolSchema {
180 fn from_provider(tool: GeminiTool) -> ProtocolResult<Self> {
181 let func = tool
184 .function_declarations
185 .into_iter()
186 .next()
187 .ok_or_else(|| ProtocolError::InvalidToolCall("Empty tool declarations".to_string()))?;
188
189 Ok(ToolSchema {
190 schema_type: "function".to_string(),
191 function: FunctionSchema {
192 name: func.name,
193 description: func.description.unwrap_or_default(),
194 parameters: func.parameters,
195 },
196 })
197 }
198}
199
200pub struct GeminiRequestBuilder;
208
209impl ToProvider<GeminiRequest> for Vec<Message> {
210 fn to_provider(&self) -> ProtocolResult<GeminiRequest> {
211 let mut system_instruction = None;
212 let mut contents = Vec::new();
213
214 for msg in self {
215 match msg.role {
216 Role::System => {
217 system_instruction = Some(GeminiContent {
219 role: "system".to_string(),
220 parts: vec![GeminiPart {
221 text: Some(msg.content.clone()),
222 function_call: None,
223 function_response: None,
224 }],
225 });
226 }
227 _ => {
228 contents.push(msg.to_provider()?);
229 }
230 }
231 }
232
233 Ok(GeminiRequest {
234 contents,
235 system_instruction,
236 tools: None,
237 generation_config: None,
238 })
239 }
240}
241
242impl ToProvider<GeminiContent> for Message {
243 fn to_provider(&self) -> ProtocolResult<GeminiContent> {
244 if self.role == Role::Tool {
246 let tool_name = self
247 .tool_call_id
248 .clone()
249 .ok_or_else(|| ProtocolError::MissingField("tool_call_id".to_string()))?;
250
251 return Ok(GeminiContent {
252 role: "user".to_string(),
253 parts: vec![GeminiPart {
254 text: None,
255 function_call: None,
256 function_response: Some(GeminiFunctionResponse {
257 name: tool_name,
258 response: serde_json::from_str(&self.content)
259 .unwrap_or_else(|_| Value::String(self.content.clone())),
260 }),
261 }],
262 });
263 }
264
265 let role = match self.role {
266 Role::User => "user",
267 Role::Assistant => "model",
268 Role::System => "system",
269 Role::Tool => "user", };
271
272 let mut parts = Vec::new();
273
274 if !self.content.is_empty() {
276 parts.push(GeminiPart {
277 text: Some(self.content.clone()),
278 function_call: None,
279 function_response: None,
280 });
281 }
282
283 if let Some(tool_calls) = &self.tool_calls {
285 for tc in tool_calls {
286 let args: Value = serde_json::from_str(&tc.function.arguments)
287 .unwrap_or_else(|_| Value::Object(serde_json::Map::new()));
288
289 parts.push(GeminiPart {
290 text: None,
291 function_call: Some(GeminiFunctionCall {
292 name: tc.function.name.clone(),
293 args,
294 }),
295 function_response: None,
296 });
297 }
298 }
299
300 if parts.is_empty() {
302 parts.push(GeminiPart {
303 text: Some(String::new()),
304 function_call: None,
305 function_response: None,
306 });
307 }
308
309 Ok(GeminiContent {
310 role: role.to_string(),
311 parts,
312 })
313 }
314}
315
316impl ToProvider<GeminiTool> for ToolSchema {
317 fn to_provider(&self) -> ProtocolResult<GeminiTool> {
318 Ok(GeminiTool {
319 function_declarations: vec![GeminiFunctionDeclaration {
320 name: self.function.name.clone(),
321 description: Some(self.function.description.clone()),
322 parameters: self.function.parameters.clone(),
323 }],
324 })
325 }
326}
327
328impl ToProvider<Vec<GeminiTool>> for Vec<ToolSchema> {
333 fn to_provider(&self) -> ProtocolResult<Vec<GeminiTool>> {
334 let declarations: Vec<GeminiFunctionDeclaration> = self
336 .iter()
337 .map(|schema| GeminiFunctionDeclaration {
338 name: schema.function.name.clone(),
339 description: Some(schema.function.description.clone()),
340 parameters: schema.function.parameters.clone(),
341 })
342 .collect();
343
344 if declarations.is_empty() {
345 Ok(vec![])
346 } else {
347 Ok(vec![GeminiTool {
348 function_declarations: declarations,
349 }])
350 }
351 }
352}
353
354pub trait GeminiExt: Sized {
360 fn into_internal(self) -> ProtocolResult<Message>;
361 fn to_gemini(&self) -> ProtocolResult<GeminiContent>;
362}
363
364impl GeminiExt for GeminiContent {
365 fn into_internal(self) -> ProtocolResult<Message> {
366 Message::from_provider(self)
367 }
368
369 fn to_gemini(&self) -> ProtocolResult<GeminiContent> {
370 Ok(self.clone())
371 }
372}
373
374impl GeminiExt for Message {
375 fn into_internal(self) -> ProtocolResult<Message> {
376 Ok(self)
377 }
378
379 fn to_gemini(&self) -> ProtocolResult<GeminiContent> {
380 self.to_provider()
381 }
382}
383
384#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_gemini_to_internal_user_message() {
394 let gemini = GeminiContent {
395 role: "user".to_string(),
396 parts: vec![GeminiPart {
397 text: Some("Hello".to_string()),
398 function_call: None,
399 function_response: None,
400 }],
401 };
402
403 let internal: Message = Message::from_provider(gemini).unwrap();
404
405 assert_eq!(internal.role, Role::User);
406 assert_eq!(internal.content, "Hello");
407 assert!(internal.tool_calls.is_none());
408 }
409
410 #[test]
411 fn test_internal_to_gemini_user_message() {
412 let internal = Message::user("Hello");
413
414 let gemini: GeminiContent = internal.to_provider().unwrap();
415
416 assert_eq!(gemini.role, "user");
417 assert_eq!(gemini.parts.len(), 1);
418 assert_eq!(gemini.parts[0].text, Some("Hello".to_string()));
419 }
420
421 #[test]
422 fn test_gemini_to_internal_model_message() {
423 let gemini = GeminiContent {
424 role: "model".to_string(),
425 parts: vec![GeminiPart {
426 text: Some("Hello there!".to_string()),
427 function_call: None,
428 function_response: None,
429 }],
430 };
431
432 let internal: Message = Message::from_provider(gemini).unwrap();
433
434 assert_eq!(internal.role, Role::Assistant);
435 assert_eq!(internal.content, "Hello there!");
436 }
437
438 #[test]
439 fn test_internal_to_gemini_with_tool_call() {
440 let tool_call = ToolCall {
441 id: "call_1".to_string(),
442 tool_type: "function".to_string(),
443 function: FunctionCall {
444 name: "search".to_string(),
445 arguments: r#"{"q":"test"}"#.to_string(),
446 },
447 };
448
449 let internal = Message::assistant("Let me search", Some(vec![tool_call]));
450
451 let gemini: GeminiContent = internal.to_provider().unwrap();
452
453 assert_eq!(gemini.role, "model");
454 assert_eq!(gemini.parts.len(), 2);
455 assert_eq!(gemini.parts[0].text, Some("Let me search".to_string()));
456 assert!(gemini.parts[1].function_call.is_some());
457
458 let func_call = gemini.parts[1].function_call.as_ref().unwrap();
459 assert_eq!(func_call.name, "search");
460 assert_eq!(func_call.args, serde_json::json!({"q": "test"}));
461 }
462
463 #[test]
464 fn test_gemini_to_internal_with_tool_call() {
465 let gemini = GeminiContent {
466 role: "model".to_string(),
467 parts: vec![GeminiPart {
468 text: None,
469 function_call: Some(GeminiFunctionCall {
470 name: "search".to_string(),
471 args: serde_json::json!({"q": "test"}),
472 }),
473 function_response: None,
474 }],
475 };
476
477 let internal: Message = Message::from_provider(gemini).unwrap();
478
479 assert_eq!(internal.role, Role::Assistant);
480 assert!(internal.tool_calls.is_some());
481
482 let tool_calls = internal.tool_calls.unwrap();
483 assert_eq!(tool_calls.len(), 1);
484 assert_eq!(tool_calls[0].function.name, "search");
485 }
486
487 #[test]
488 fn test_system_message_extraction() {
489 let messages = vec![Message::system("You are helpful"), Message::user("Hello")];
490
491 let request: GeminiRequest = messages.to_provider().unwrap();
492
493 assert!(request.system_instruction.is_some());
494 let sys = request.system_instruction.unwrap();
495 assert_eq!(sys.role, "system");
496 assert_eq!(sys.parts[0].text, Some("You are helpful".to_string()));
497
498 assert_eq!(request.contents.len(), 1);
499 assert_eq!(request.contents[0].role, "user");
500 }
501
502 #[test]
503 fn test_tool_response_conversion() {
504 let internal = Message::tool_result("search_tool", r#"{"result": "ok"}"#);
505
506 let gemini: GeminiContent = internal.to_provider().unwrap();
507
508 assert_eq!(gemini.role, "user");
509 assert!(gemini.parts[0].function_response.is_some());
510
511 let func_resp = gemini.parts[0].function_response.as_ref().unwrap();
512 assert_eq!(func_resp.name, "search_tool");
513 }
514
515 #[test]
516 fn test_tool_schema_conversion() {
517 let gemini_tool = GeminiTool {
518 function_declarations: vec![GeminiFunctionDeclaration {
519 name: "search".to_string(),
520 description: Some("Search the web".to_string()),
521 parameters: serde_json::json!({
522 "type": "object",
523 "properties": {
524 "q": { "type": "string" }
525 }
526 }),
527 }],
528 };
529
530 let internal_schema: ToolSchema = ToolSchema::from_provider(gemini_tool.clone()).unwrap();
532 assert_eq!(internal_schema.function.name, "search");
533
534 let roundtrip: GeminiTool = internal_schema.to_provider().unwrap();
536 assert_eq!(roundtrip.function_declarations.len(), 1);
537 assert_eq!(roundtrip.function_declarations[0].name, "search");
538 }
539
540 #[test]
541 fn test_multiple_tools_grouped() {
542 let tools = vec![
543 ToolSchema {
544 schema_type: "function".to_string(),
545 function: FunctionSchema {
546 name: "search".to_string(),
547 description: "Search".to_string(),
548 parameters: serde_json::json!({"type": "object"}),
549 },
550 },
551 ToolSchema {
552 schema_type: "function".to_string(),
553 function: FunctionSchema {
554 name: "read".to_string(),
555 description: "Read file".to_string(),
556 parameters: serde_json::json!({"type": "object"}),
557 },
558 },
559 ];
560
561 let gemini_tools: Vec<GeminiTool> = tools.to_provider().unwrap();
562
563 assert_eq!(gemini_tools.len(), 1);
565 assert_eq!(gemini_tools[0].function_declarations.len(), 2);
566 assert_eq!(gemini_tools[0].function_declarations[0].name, "search");
567 assert_eq!(gemini_tools[0].function_declarations[1].name, "read");
568 }
569
570 #[test]
571 fn test_roundtrip_conversion() {
572 let original = Message::user("Hello, world!");
573
574 let gemini: GeminiContent = original.to_provider().unwrap();
576
577 let roundtrip: Message = Message::from_provider(gemini).unwrap();
579
580 assert_eq!(roundtrip.role, original.role);
581 assert_eq!(roundtrip.content, original.content);
582 }
583
584 #[test]
585 fn test_invalid_role_error() {
586 let gemini = GeminiContent {
587 role: "invalid_role".to_string(),
588 parts: vec![GeminiPart {
589 text: Some("test".to_string()),
590 function_call: None,
591 function_response: None,
592 }],
593 };
594
595 let result: ProtocolResult<Message> = Message::from_provider(gemini);
596 assert!(matches!(result, Err(ProtocolError::InvalidRole(_))));
597 }
598
599 #[test]
600 fn test_empty_parts_has_default() {
601 let internal = Message::assistant("", None);
602
603 let gemini: GeminiContent = internal.to_provider().unwrap();
604
605 assert_eq!(gemini.parts.len(), 1);
607 assert_eq!(gemini.parts[0].text, Some(String::new()));
608 }
609}