1use crate::attachment;
2use crate::retry::{RetryConfig, execute_with_retry, is_retryable_model_error};
3use adk_core::{
4 CacheCapable, CitationMetadata, CitationSource, Content, FinishReason, Llm, LlmRequest,
5 LlmResponse, LlmResponseStream, Part, Result, UsageMetadata,
6};
7use adk_gemini::Gemini;
8use async_trait::async_trait;
9use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
10use futures::TryStreamExt;
11
12pub struct GeminiModel {
13 client: Gemini,
14 model_name: String,
15 retry_config: RetryConfig,
16}
17
18impl GeminiModel {
19 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
20 let model_name = model.into();
21 let client = Gemini::with_model(api_key.into(), model_name.clone())
22 .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
23
24 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
25 }
26
27 #[cfg(feature = "gemini-vertex")]
31 pub fn new_google_cloud(
32 api_key: impl Into<String>,
33 project_id: impl AsRef<str>,
34 location: impl AsRef<str>,
35 model: impl Into<String>,
36 ) -> Result<Self> {
37 let model_name = model.into();
38 let client = Gemini::with_google_cloud_model(
39 api_key.into(),
40 project_id,
41 location,
42 model_name.clone(),
43 )
44 .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
45
46 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
47 }
48
49 #[cfg(feature = "gemini-vertex")]
53 pub fn new_google_cloud_service_account(
54 service_account_json: &str,
55 project_id: impl AsRef<str>,
56 location: impl AsRef<str>,
57 model: impl Into<String>,
58 ) -> Result<Self> {
59 let model_name = model.into();
60 let client = Gemini::with_google_cloud_service_account_json(
61 service_account_json,
62 project_id.as_ref(),
63 location.as_ref(),
64 model_name.clone(),
65 )
66 .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
67
68 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
69 }
70
71 #[cfg(feature = "gemini-vertex")]
75 pub fn new_google_cloud_adc(
76 project_id: impl AsRef<str>,
77 location: impl AsRef<str>,
78 model: impl Into<String>,
79 ) -> Result<Self> {
80 let model_name = model.into();
81 let client = Gemini::with_google_cloud_adc_model(
82 project_id.as_ref(),
83 location.as_ref(),
84 model_name.clone(),
85 )
86 .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
87
88 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
89 }
90
91 #[cfg(feature = "gemini-vertex")]
95 pub fn new_google_cloud_wif(
96 wif_json: &str,
97 project_id: impl AsRef<str>,
98 location: impl AsRef<str>,
99 model: impl Into<String>,
100 ) -> Result<Self> {
101 let model_name = model.into();
102 let client = Gemini::with_google_cloud_wif_json(
103 wif_json,
104 project_id.as_ref(),
105 location.as_ref(),
106 model_name.clone(),
107 )
108 .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
109
110 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
111 }
112
113 #[must_use]
114 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
115 self.retry_config = retry_config;
116 self
117 }
118
119 pub fn set_retry_config(&mut self, retry_config: RetryConfig) {
120 self.retry_config = retry_config;
121 }
122
123 pub fn retry_config(&self) -> &RetryConfig {
124 &self.retry_config
125 }
126
127 fn convert_response(resp: &adk_gemini::GenerationResponse) -> Result<LlmResponse> {
128 let mut converted_parts: Vec<Part> = Vec::new();
129
130 if let Some(parts) = resp.candidates.first().and_then(|c| c.content.parts.as_ref()) {
132 for p in parts {
133 match p {
134 adk_gemini::Part::Text { text, thought, thought_signature } => {
135 if thought == &Some(true) {
136 converted_parts.push(Part::Thinking {
137 thinking: text.clone(),
138 signature: thought_signature.clone(),
139 });
140 } else {
141 converted_parts.push(Part::Text { text: text.clone() });
142 }
143 }
144 adk_gemini::Part::InlineData { inline_data } => {
145 let decoded =
146 BASE64_STANDARD.decode(&inline_data.data).map_err(|error| {
147 adk_core::AdkError::Model(format!(
148 "failed to decode inline data from gemini response: {error}"
149 ))
150 })?;
151 converted_parts.push(Part::InlineData {
152 mime_type: inline_data.mime_type.clone(),
153 data: decoded,
154 });
155 }
156 adk_gemini::Part::FunctionCall { function_call, thought_signature } => {
157 converted_parts.push(Part::FunctionCall {
158 name: function_call.name.clone(),
159 args: function_call.args.clone(),
160 id: None,
161 thought_signature: thought_signature.clone(),
162 });
163 }
164 adk_gemini::Part::FunctionResponse { function_response } => {
165 converted_parts.push(Part::FunctionResponse {
166 function_response: adk_core::FunctionResponseData {
167 name: function_response.name.clone(),
168 response: function_response
169 .response
170 .clone()
171 .unwrap_or(serde_json::Value::Null),
172 },
173 id: None,
174 });
175 }
176 }
177 }
178 }
179
180 if let Some(grounding) = resp.candidates.first().and_then(|c| c.grounding_metadata.as_ref())
182 {
183 if let Some(queries) = &grounding.web_search_queries {
184 if !queries.is_empty() {
185 let search_info = format!("\n\nš **Searched:** {}", queries.join(", "));
186 converted_parts.push(Part::Text { text: search_info });
187 }
188 }
189 if let Some(chunks) = &grounding.grounding_chunks {
190 let sources: Vec<String> = chunks
191 .iter()
192 .filter_map(|c| {
193 c.web.as_ref().and_then(|w| match (&w.title, &w.uri) {
194 (Some(title), Some(uri)) => Some(format!("[{}]({})", title, uri)),
195 (Some(title), None) => Some(title.clone()),
196 (None, Some(uri)) => Some(uri.to_string()),
197 (None, None) => None,
198 })
199 })
200 .collect();
201 if !sources.is_empty() {
202 let sources_info = format!("\nš **Sources:** {}", sources.join(" | "));
203 converted_parts.push(Part::Text { text: sources_info });
204 }
205 }
206 }
207
208 let content = if converted_parts.is_empty() {
209 None
210 } else {
211 Some(Content { role: "model".to_string(), parts: converted_parts })
212 };
213
214 let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
215 prompt_token_count: u.prompt_token_count.unwrap_or(0),
216 candidates_token_count: u.candidates_token_count.unwrap_or(0),
217 total_token_count: u.total_token_count.unwrap_or(0),
218 thinking_token_count: u.thoughts_token_count,
219 cache_read_input_token_count: u.cached_content_token_count,
220 ..Default::default()
221 });
222
223 let finish_reason =
224 resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
225 adk_gemini::FinishReason::Stop => FinishReason::Stop,
226 adk_gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
227 adk_gemini::FinishReason::Safety => FinishReason::Safety,
228 adk_gemini::FinishReason::Recitation => FinishReason::Recitation,
229 _ => FinishReason::Other,
230 });
231
232 let citation_metadata =
233 resp.candidates.first().and_then(|c| c.citation_metadata.as_ref()).map(|meta| {
234 CitationMetadata {
235 citation_sources: meta
236 .citation_sources
237 .iter()
238 .map(|source| CitationSource {
239 uri: source.uri.clone(),
240 title: source.title.clone(),
241 start_index: source.start_index,
242 end_index: source.end_index,
243 license: source.license.clone(),
244 publication_date: source.publication_date.map(|d| d.to_string()),
245 })
246 .collect(),
247 }
248 });
249
250 Ok(LlmResponse {
251 content,
252 usage_metadata,
253 finish_reason,
254 citation_metadata,
255 partial: false,
256 turn_complete: true,
257 interrupted: false,
258 error_code: None,
259 error_message: None,
260 })
261 }
262
263 fn gemini_function_response_payload(response: serde_json::Value) -> serde_json::Value {
264 match response {
265 serde_json::Value::Object(_) => response,
267 other => serde_json::json!({ "result": other }),
268 }
269 }
270
271 fn stream_chunks_from_response(
272 mut response: LlmResponse,
273 saw_partial_chunk: bool,
274 ) -> (Vec<LlmResponse>, bool) {
275 let is_final = response.finish_reason.is_some();
276
277 if !is_final {
278 response.partial = true;
279 response.turn_complete = false;
280 return (vec![response], true);
281 }
282
283 response.partial = false;
284 response.turn_complete = true;
285
286 if saw_partial_chunk {
287 return (vec![response], true);
288 }
289
290 let synthetic_partial = LlmResponse {
291 content: None,
292 usage_metadata: None,
293 finish_reason: None,
294 citation_metadata: None,
295 partial: true,
296 turn_complete: false,
297 interrupted: false,
298 error_code: None,
299 error_message: None,
300 };
301
302 (vec![synthetic_partial, response], true)
303 }
304
305 async fn generate_content_internal(
306 &self,
307 req: LlmRequest,
308 stream: bool,
309 ) -> Result<LlmResponseStream> {
310 fn format_error_chain(e: &dyn std::error::Error) -> String {
312 let mut msg = e.to_string();
313 let mut source = e.source();
314 while let Some(s) = source {
315 msg.push_str(": ");
316 msg.push_str(&s.to_string());
317 source = s.source();
318 }
319 msg
320 }
321
322 let mut builder = self.client.generate_content();
323
324 for content in &req.contents {
326 match content.role.as_str() {
327 "user" => {
328 let mut gemini_parts = Vec::new();
330 for part in &content.parts {
331 match part {
332 Part::Text { text } => {
333 gemini_parts.push(adk_gemini::Part::Text {
334 text: text.clone(),
335 thought: None,
336 thought_signature: None,
337 });
338 }
339 Part::Thinking { thinking, signature } => {
340 gemini_parts.push(adk_gemini::Part::Text {
341 text: thinking.clone(),
342 thought: Some(true),
343 thought_signature: signature.clone(),
344 });
345 }
346 Part::InlineData { data, mime_type } => {
347 let encoded = attachment::encode_base64(data);
348 gemini_parts.push(adk_gemini::Part::InlineData {
349 inline_data: adk_gemini::Blob {
350 mime_type: mime_type.clone(),
351 data: encoded,
352 },
353 });
354 }
355 Part::FileData { mime_type, file_uri } => {
356 gemini_parts.push(adk_gemini::Part::Text {
357 text: attachment::file_attachment_to_text(mime_type, file_uri),
358 thought: None,
359 thought_signature: None,
360 });
361 }
362 _ => {}
363 }
364 }
365 if !gemini_parts.is_empty() {
366 let user_content = adk_gemini::Content {
367 role: Some(adk_gemini::Role::User),
368 parts: Some(gemini_parts),
369 };
370 builder = builder.with_message(adk_gemini::Message {
371 content: user_content,
372 role: adk_gemini::Role::User,
373 });
374 }
375 }
376 "model" => {
377 let mut gemini_parts = Vec::new();
379 for part in &content.parts {
380 match part {
381 Part::Text { text } => {
382 gemini_parts.push(adk_gemini::Part::Text {
383 text: text.clone(),
384 thought: None,
385 thought_signature: None,
386 });
387 }
388 Part::Thinking { thinking, signature } => {
389 gemini_parts.push(adk_gemini::Part::Text {
390 text: thinking.clone(),
391 thought: Some(true),
392 thought_signature: signature.clone(),
393 });
394 }
395 Part::FunctionCall { name, args, thought_signature, .. } => {
396 gemini_parts.push(adk_gemini::Part::FunctionCall {
397 function_call: adk_gemini::FunctionCall {
398 name: name.clone(),
399 args: args.clone(),
400 thought_signature: None,
401 },
402 thought_signature: thought_signature.clone(),
403 });
404 }
405 _ => {}
406 }
407 }
408 if !gemini_parts.is_empty() {
409 let model_content = adk_gemini::Content {
410 role: Some(adk_gemini::Role::Model),
411 parts: Some(gemini_parts),
412 };
413 builder = builder.with_message(adk_gemini::Message {
414 content: model_content,
415 role: adk_gemini::Role::Model,
416 });
417 }
418 }
419 "function" => {
420 for part in &content.parts {
422 if let Part::FunctionResponse { function_response, .. } = part {
423 builder = builder
424 .with_function_response(
425 &function_response.name,
426 Self::gemini_function_response_payload(
427 function_response.response.clone(),
428 ),
429 )
430 .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
431 }
432 }
433 }
434 _ => {}
435 }
436 }
437
438 if let Some(config) = req.config {
440 let has_schema = config.response_schema.is_some();
441 let gen_config = adk_gemini::GenerationConfig {
442 temperature: config.temperature,
443 top_p: config.top_p,
444 top_k: config.top_k,
445 max_output_tokens: config.max_output_tokens,
446 response_schema: config.response_schema,
447 response_mime_type: if has_schema {
448 Some("application/json".to_string())
449 } else {
450 None
451 },
452 ..Default::default()
453 };
454 builder = builder.with_generation_config(gen_config);
455
456 if let Some(ref name) = config.cached_content {
458 let handle = self.client.get_cached_content(name);
459 builder = builder.with_cached_content(&handle);
460 }
461 }
462
463 if !req.tools.is_empty() {
465 let mut function_declarations = Vec::new();
466 let mut has_google_search = false;
467
468 for (name, tool_decl) in &req.tools {
469 if name == "google_search" {
470 has_google_search = true;
471 continue;
472 }
473
474 if let Ok(func_decl) =
476 serde_json::from_value::<adk_gemini::FunctionDeclaration>(tool_decl.clone())
477 {
478 function_declarations.push(func_decl);
479 }
480 }
481
482 if !function_declarations.is_empty() {
483 let tool = adk_gemini::Tool::with_functions(function_declarations);
484 builder = builder.with_tool(tool);
485 }
486
487 if has_google_search {
488 let tool = adk_gemini::Tool::google_search();
490 builder = builder.with_tool(tool);
491 }
492 }
493
494 if stream {
495 adk_telemetry::debug!("Executing streaming request");
496 let response_stream = builder.execute_stream().await.map_err(|e| {
497 adk_telemetry::error!(error = %e, "Model request failed");
498 adk_core::AdkError::Model(format_error_chain(&e))
499 })?;
500
501 let mapped_stream = async_stream::stream! {
502 let mut stream = response_stream;
503 let mut saw_partial_chunk = false;
504 while let Some(result) = stream.try_next().await.transpose() {
505 match result {
506 Ok(resp) => {
507 match Self::convert_response(&resp) {
508 Ok(llm_resp) => {
509 let (chunks, next_saw_partial) =
510 Self::stream_chunks_from_response(llm_resp, saw_partial_chunk);
511 saw_partial_chunk = next_saw_partial;
512 for chunk in chunks {
513 yield Ok(chunk);
514 }
515 }
516 Err(e) => {
517 adk_telemetry::error!(error = %e, "Failed to convert response");
518 yield Err(e);
519 }
520 }
521 }
522 Err(e) => {
523 adk_telemetry::error!(error = %e, "Stream error");
524 yield Err(adk_core::AdkError::Model(format_error_chain(&e)));
525 }
526 }
527 }
528 };
529
530 Ok(Box::pin(mapped_stream))
531 } else {
532 adk_telemetry::debug!("Executing blocking request");
533 let response = builder.execute().await.map_err(|e| {
534 adk_telemetry::error!(error = %e, "Model request failed");
535 adk_core::AdkError::Model(format_error_chain(&e))
536 })?;
537
538 let llm_response = Self::convert_response(&response)?;
539
540 let stream = async_stream::stream! {
541 yield Ok(llm_response);
542 };
543
544 Ok(Box::pin(stream))
545 }
546 }
547
548 pub async fn create_cached_content(
553 &self,
554 system_instruction: &str,
555 tools: &std::collections::HashMap<String, serde_json::Value>,
556 ttl_seconds: u32,
557 ) -> Result<String> {
558 let mut cache_builder = self
559 .client
560 .create_cache()
561 .with_system_instruction(system_instruction)
562 .with_ttl(std::time::Duration::from_secs(u64::from(ttl_seconds)));
563
564 let mut function_declarations = Vec::new();
566 for (name, tool_decl) in tools {
567 if name == "google_search" {
568 continue;
569 }
570 if let Ok(func_decl) =
571 serde_json::from_value::<adk_gemini::FunctionDeclaration>(tool_decl.clone())
572 {
573 function_declarations.push(func_decl);
574 }
575 }
576 if !function_declarations.is_empty() {
577 cache_builder = cache_builder
578 .with_tools(vec![adk_gemini::Tool::with_functions(function_declarations)]);
579 }
580
581 let handle = cache_builder
582 .execute()
583 .await
584 .map_err(|e| adk_core::AdkError::Model(format!("cache creation failed: {e}")))?;
585
586 Ok(handle.name().to_string())
587 }
588
589 pub async fn delete_cached_content(&self, name: &str) -> Result<()> {
591 let handle = self.client.get_cached_content(name);
592 handle
593 .delete()
594 .await
595 .map_err(|(_, e)| adk_core::AdkError::Model(format!("cache deletion failed: {e}")))?;
596 Ok(())
597 }
598}
599
600#[async_trait]
601impl Llm for GeminiModel {
602 fn name(&self) -> &str {
603 &self.model_name
604 }
605
606 #[adk_telemetry::instrument(
607 name = "call_llm",
608 skip(self, req),
609 fields(
610 model.name = %self.model_name,
611 stream = %stream,
612 request.contents_count = %req.contents.len(),
613 request.tools_count = %req.tools.len()
614 )
615 )]
616 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
617 adk_telemetry::info!("Generating content");
618 execute_with_retry(&self.retry_config, is_retryable_model_error, || {
621 self.generate_content_internal(req.clone(), stream)
622 })
623 .await
624 }
625}
626
627#[async_trait]
628impl CacheCapable for GeminiModel {
629 async fn create_cache(
630 &self,
631 system_instruction: &str,
632 tools: &std::collections::HashMap<String, serde_json::Value>,
633 ttl_seconds: u32,
634 ) -> Result<String> {
635 self.create_cached_content(system_instruction, tools, ttl_seconds).await
636 }
637
638 async fn delete_cache(&self, name: &str) -> Result<()> {
639 self.delete_cached_content(name).await
640 }
641}
642
643#[cfg(test)]
644mod tests {
645 use super::*;
646 use adk_core::AdkError;
647 use std::{
648 sync::{
649 Arc,
650 atomic::{AtomicU32, Ordering},
651 },
652 time::Duration,
653 };
654
655 #[test]
656 fn constructor_is_backward_compatible_and_sync() {
657 fn accepts_sync_constructor<F>(_f: F)
658 where
659 F: Fn(&str, &str) -> Result<GeminiModel>,
660 {
661 }
662
663 accepts_sync_constructor(|api_key, model| GeminiModel::new(api_key, model));
664 }
665
666 #[test]
667 fn stream_chunks_from_response_injects_partial_before_lone_final_chunk() {
668 let response = LlmResponse {
669 content: Some(Content::new("model").with_text("hello")),
670 usage_metadata: None,
671 finish_reason: Some(FinishReason::Stop),
672 citation_metadata: None,
673 partial: false,
674 turn_complete: true,
675 interrupted: false,
676 error_code: None,
677 error_message: None,
678 };
679
680 let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, false);
681 assert!(saw_partial);
682 assert_eq!(chunks.len(), 2);
683 assert!(chunks[0].partial);
684 assert!(!chunks[0].turn_complete);
685 assert!(chunks[0].content.is_none());
686 assert!(!chunks[1].partial);
687 assert!(chunks[1].turn_complete);
688 }
689
690 #[test]
691 fn stream_chunks_from_response_keeps_final_only_when_partial_already_seen() {
692 let response = LlmResponse {
693 content: Some(Content::new("model").with_text("done")),
694 usage_metadata: None,
695 finish_reason: Some(FinishReason::Stop),
696 citation_metadata: None,
697 partial: false,
698 turn_complete: true,
699 interrupted: false,
700 error_code: None,
701 error_message: None,
702 };
703
704 let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, true);
705 assert!(saw_partial);
706 assert_eq!(chunks.len(), 1);
707 assert!(!chunks[0].partial);
708 assert!(chunks[0].turn_complete);
709 }
710
711 #[tokio::test]
712 async fn execute_with_retry_retries_retryable_errors() {
713 let retry_config = RetryConfig::default()
714 .with_max_retries(2)
715 .with_initial_delay(Duration::from_millis(0))
716 .with_max_delay(Duration::from_millis(0));
717 let attempts = Arc::new(AtomicU32::new(0));
718
719 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
720 let attempts = Arc::clone(&attempts);
721 async move {
722 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
723 if attempt < 2 {
724 return Err(AdkError::Model("code 429 RESOURCE_EXHAUSTED".to_string()));
725 }
726 Ok("ok")
727 }
728 })
729 .await
730 .expect("retry should eventually succeed");
731
732 assert_eq!(result, "ok");
733 assert_eq!(attempts.load(Ordering::SeqCst), 3);
734 }
735
736 #[tokio::test]
737 async fn execute_with_retry_does_not_retry_non_retryable_errors() {
738 let retry_config = RetryConfig::default()
739 .with_max_retries(3)
740 .with_initial_delay(Duration::from_millis(0))
741 .with_max_delay(Duration::from_millis(0));
742 let attempts = Arc::new(AtomicU32::new(0));
743
744 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
745 let attempts = Arc::clone(&attempts);
746 async move {
747 attempts.fetch_add(1, Ordering::SeqCst);
748 Err::<(), _>(AdkError::Model("code 400 invalid request".to_string()))
749 }
750 })
751 .await
752 .expect_err("non-retryable error should return immediately");
753
754 assert!(matches!(error, AdkError::Model(_)));
755 assert_eq!(attempts.load(Ordering::SeqCst), 1);
756 }
757
758 #[tokio::test]
759 async fn execute_with_retry_respects_disabled_config() {
760 let retry_config = RetryConfig::disabled().with_max_retries(10);
761 let attempts = Arc::new(AtomicU32::new(0));
762
763 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
764 let attempts = Arc::clone(&attempts);
765 async move {
766 attempts.fetch_add(1, Ordering::SeqCst);
767 Err::<(), _>(AdkError::Model("code 429 RESOURCE_EXHAUSTED".to_string()))
768 }
769 })
770 .await
771 .expect_err("disabled retries should return first error");
772
773 assert!(matches!(error, AdkError::Model(_)));
774 assert_eq!(attempts.load(Ordering::SeqCst), 1);
775 }
776
777 #[test]
778 fn convert_response_preserves_citation_metadata() {
779 let response = adk_gemini::GenerationResponse {
780 candidates: vec![adk_gemini::Candidate {
781 content: adk_gemini::Content {
782 role: Some(adk_gemini::Role::Model),
783 parts: Some(vec![adk_gemini::Part::Text {
784 text: "hello world".to_string(),
785 thought: None,
786 thought_signature: None,
787 }]),
788 },
789 safety_ratings: None,
790 citation_metadata: Some(adk_gemini::CitationMetadata {
791 citation_sources: vec![adk_gemini::CitationSource {
792 uri: Some("https://example.com".to_string()),
793 title: Some("Example".to_string()),
794 start_index: Some(0),
795 end_index: Some(5),
796 license: Some("CC-BY".to_string()),
797 publication_date: None,
798 }],
799 }),
800 grounding_metadata: None,
801 finish_reason: Some(adk_gemini::FinishReason::Stop),
802 index: Some(0),
803 }],
804 prompt_feedback: None,
805 usage_metadata: None,
806 model_version: None,
807 response_id: None,
808 };
809
810 let converted =
811 GeminiModel::convert_response(&response).expect("conversion should succeed");
812 let metadata = converted.citation_metadata.expect("citation metadata should be mapped");
813 assert_eq!(metadata.citation_sources.len(), 1);
814 assert_eq!(metadata.citation_sources[0].uri.as_deref(), Some("https://example.com"));
815 assert_eq!(metadata.citation_sources[0].start_index, Some(0));
816 assert_eq!(metadata.citation_sources[0].end_index, Some(5));
817 }
818
819 #[test]
820 fn convert_response_handles_inline_data_from_model() {
821 let image_bytes = vec![0x89, 0x50, 0x4E, 0x47];
822 let encoded = crate::attachment::encode_base64(&image_bytes);
823
824 let response = adk_gemini::GenerationResponse {
825 candidates: vec![adk_gemini::Candidate {
826 content: adk_gemini::Content {
827 role: Some(adk_gemini::Role::Model),
828 parts: Some(vec![
829 adk_gemini::Part::Text {
830 text: "Here is the image".to_string(),
831 thought: None,
832 thought_signature: None,
833 },
834 adk_gemini::Part::InlineData {
835 inline_data: adk_gemini::Blob {
836 mime_type: "image/png".to_string(),
837 data: encoded,
838 },
839 },
840 ]),
841 },
842 safety_ratings: None,
843 citation_metadata: None,
844 grounding_metadata: None,
845 finish_reason: Some(adk_gemini::FinishReason::Stop),
846 index: Some(0),
847 }],
848 prompt_feedback: None,
849 usage_metadata: None,
850 model_version: None,
851 response_id: None,
852 };
853
854 let converted =
855 GeminiModel::convert_response(&response).expect("conversion should succeed");
856 let content = converted.content.expect("should have content");
857 assert!(
858 content
859 .parts
860 .iter()
861 .any(|part| matches!(part, Part::Text { text } if text == "Here is the image"))
862 );
863 assert!(content.parts.iter().any(|part| {
864 matches!(
865 part,
866 Part::InlineData { mime_type, data }
867 if mime_type == "image/png" && data.as_slice() == image_bytes.as_slice()
868 )
869 }));
870 }
871
872 #[test]
873 fn gemini_function_response_payload_preserves_objects() {
874 let value = serde_json::json!({
875 "documents": [
876 { "id": "pricing", "score": 0.91 }
877 ]
878 });
879
880 let payload = GeminiModel::gemini_function_response_payload(value.clone());
881
882 assert_eq!(payload, value);
883 }
884
885 #[test]
886 fn gemini_function_response_payload_wraps_arrays() {
887 let payload =
888 GeminiModel::gemini_function_response_payload(serde_json::json!([{ "id": "pricing" }]));
889
890 assert_eq!(payload, serde_json::json!({ "result": [{ "id": "pricing" }] }));
891 }
892}