Skip to main content

adk_model/gemini/
client.rs

1use crate::attachment;
2use crate::retry::{RetryConfig, execute_with_retry, is_retryable_model_error};
3use adk_core::{
4    CacheCapable, CitationMetadata, CitationSource, Content, FinishReason, Llm, LlmRequest,
5    LlmResponse, LlmResponseStream, Part, Result, UsageMetadata,
6};
7use adk_gemini::Gemini;
8use async_trait::async_trait;
9use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
10use futures::TryStreamExt;
11
12pub struct GeminiModel {
13    client: Gemini,
14    model_name: String,
15    retry_config: RetryConfig,
16}
17
18impl GeminiModel {
19    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
20        let model_name = model.into();
21        let client = Gemini::with_model(api_key.into(), model_name.clone())
22            .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
23
24        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
25    }
26
27    /// Create a Gemini model via Vertex AI with API key auth.
28    ///
29    /// Requires `gemini-vertex` feature.
30    #[cfg(feature = "gemini-vertex")]
31    pub fn new_google_cloud(
32        api_key: impl Into<String>,
33        project_id: impl AsRef<str>,
34        location: impl AsRef<str>,
35        model: impl Into<String>,
36    ) -> Result<Self> {
37        let model_name = model.into();
38        let client = Gemini::with_google_cloud_model(
39            api_key.into(),
40            project_id,
41            location,
42            model_name.clone(),
43        )
44        .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
45
46        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
47    }
48
49    /// Create a Gemini model via Vertex AI with service account JSON.
50    ///
51    /// Requires `gemini-vertex` feature.
52    #[cfg(feature = "gemini-vertex")]
53    pub fn new_google_cloud_service_account(
54        service_account_json: &str,
55        project_id: impl AsRef<str>,
56        location: impl AsRef<str>,
57        model: impl Into<String>,
58    ) -> Result<Self> {
59        let model_name = model.into();
60        let client = Gemini::with_google_cloud_service_account_json(
61            service_account_json,
62            project_id.as_ref(),
63            location.as_ref(),
64            model_name.clone(),
65        )
66        .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
67
68        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
69    }
70
71    /// Create a Gemini model via Vertex AI with Application Default Credentials.
72    ///
73    /// Requires `gemini-vertex` feature.
74    #[cfg(feature = "gemini-vertex")]
75    pub fn new_google_cloud_adc(
76        project_id: impl AsRef<str>,
77        location: impl AsRef<str>,
78        model: impl Into<String>,
79    ) -> Result<Self> {
80        let model_name = model.into();
81        let client = Gemini::with_google_cloud_adc_model(
82            project_id.as_ref(),
83            location.as_ref(),
84            model_name.clone(),
85        )
86        .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
87
88        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
89    }
90
91    /// Create a Gemini model via Vertex AI with Workload Identity Federation.
92    ///
93    /// Requires `gemini-vertex` feature.
94    #[cfg(feature = "gemini-vertex")]
95    pub fn new_google_cloud_wif(
96        wif_json: &str,
97        project_id: impl AsRef<str>,
98        location: impl AsRef<str>,
99        model: impl Into<String>,
100    ) -> Result<Self> {
101        let model_name = model.into();
102        let client = Gemini::with_google_cloud_wif_json(
103            wif_json,
104            project_id.as_ref(),
105            location.as_ref(),
106            model_name.clone(),
107        )
108        .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
109
110        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
111    }
112
113    #[must_use]
114    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
115        self.retry_config = retry_config;
116        self
117    }
118
119    pub fn set_retry_config(&mut self, retry_config: RetryConfig) {
120        self.retry_config = retry_config;
121    }
122
123    pub fn retry_config(&self) -> &RetryConfig {
124        &self.retry_config
125    }
126
127    fn convert_response(resp: &adk_gemini::GenerationResponse) -> Result<LlmResponse> {
128        let mut converted_parts: Vec<Part> = Vec::new();
129
130        // Convert content parts
131        if let Some(parts) = resp.candidates.first().and_then(|c| c.content.parts.as_ref()) {
132            for p in parts {
133                match p {
134                    adk_gemini::Part::Text { text, thought, thought_signature } => {
135                        if thought == &Some(true) {
136                            converted_parts.push(Part::Thinking {
137                                thinking: text.clone(),
138                                signature: thought_signature.clone(),
139                            });
140                        } else {
141                            converted_parts.push(Part::Text { text: text.clone() });
142                        }
143                    }
144                    adk_gemini::Part::InlineData { inline_data } => {
145                        let decoded =
146                            BASE64_STANDARD.decode(&inline_data.data).map_err(|error| {
147                                adk_core::AdkError::Model(format!(
148                                    "failed to decode inline data from gemini response: {error}"
149                                ))
150                            })?;
151                        converted_parts.push(Part::InlineData {
152                            mime_type: inline_data.mime_type.clone(),
153                            data: decoded,
154                        });
155                    }
156                    adk_gemini::Part::FunctionCall { function_call, thought_signature } => {
157                        converted_parts.push(Part::FunctionCall {
158                            name: function_call.name.clone(),
159                            args: function_call.args.clone(),
160                            id: None,
161                            thought_signature: thought_signature.clone(),
162                        });
163                    }
164                    adk_gemini::Part::FunctionResponse { function_response } => {
165                        converted_parts.push(Part::FunctionResponse {
166                            function_response: adk_core::FunctionResponseData {
167                                name: function_response.name.clone(),
168                                response: function_response
169                                    .response
170                                    .clone()
171                                    .unwrap_or(serde_json::Value::Null),
172                            },
173                            id: None,
174                        });
175                    }
176                }
177            }
178        }
179
180        // Add grounding metadata as text if present (required for Google Search grounding compliance)
181        if let Some(grounding) = resp.candidates.first().and_then(|c| c.grounding_metadata.as_ref())
182        {
183            if let Some(queries) = &grounding.web_search_queries {
184                if !queries.is_empty() {
185                    let search_info = format!("\n\nšŸ” **Searched:** {}", queries.join(", "));
186                    converted_parts.push(Part::Text { text: search_info });
187                }
188            }
189            if let Some(chunks) = &grounding.grounding_chunks {
190                let sources: Vec<String> = chunks
191                    .iter()
192                    .filter_map(|c| {
193                        c.web.as_ref().and_then(|w| match (&w.title, &w.uri) {
194                            (Some(title), Some(uri)) => Some(format!("[{}]({})", title, uri)),
195                            (Some(title), None) => Some(title.clone()),
196                            (None, Some(uri)) => Some(uri.to_string()),
197                            (None, None) => None,
198                        })
199                    })
200                    .collect();
201                if !sources.is_empty() {
202                    let sources_info = format!("\nšŸ“š **Sources:** {}", sources.join(" | "));
203                    converted_parts.push(Part::Text { text: sources_info });
204                }
205            }
206        }
207
208        let content = if converted_parts.is_empty() {
209            None
210        } else {
211            Some(Content { role: "model".to_string(), parts: converted_parts })
212        };
213
214        let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
215            prompt_token_count: u.prompt_token_count.unwrap_or(0),
216            candidates_token_count: u.candidates_token_count.unwrap_or(0),
217            total_token_count: u.total_token_count.unwrap_or(0),
218            thinking_token_count: u.thoughts_token_count,
219            cache_read_input_token_count: u.cached_content_token_count,
220            ..Default::default()
221        });
222
223        let finish_reason =
224            resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
225                adk_gemini::FinishReason::Stop => FinishReason::Stop,
226                adk_gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
227                adk_gemini::FinishReason::Safety => FinishReason::Safety,
228                adk_gemini::FinishReason::Recitation => FinishReason::Recitation,
229                _ => FinishReason::Other,
230            });
231
232        let citation_metadata =
233            resp.candidates.first().and_then(|c| c.citation_metadata.as_ref()).map(|meta| {
234                CitationMetadata {
235                    citation_sources: meta
236                        .citation_sources
237                        .iter()
238                        .map(|source| CitationSource {
239                            uri: source.uri.clone(),
240                            title: source.title.clone(),
241                            start_index: source.start_index,
242                            end_index: source.end_index,
243                            license: source.license.clone(),
244                            publication_date: source.publication_date.map(|d| d.to_string()),
245                        })
246                        .collect(),
247                }
248            });
249
250        Ok(LlmResponse {
251            content,
252            usage_metadata,
253            finish_reason,
254            citation_metadata,
255            partial: false,
256            turn_complete: true,
257            interrupted: false,
258            error_code: None,
259            error_message: None,
260        })
261    }
262
263    fn gemini_function_response_payload(response: serde_json::Value) -> serde_json::Value {
264        match response {
265            // Gemini functionResponse.response must be a JSON object.
266            serde_json::Value::Object(_) => response,
267            other => serde_json::json!({ "result": other }),
268        }
269    }
270
271    fn stream_chunks_from_response(
272        mut response: LlmResponse,
273        saw_partial_chunk: bool,
274    ) -> (Vec<LlmResponse>, bool) {
275        let is_final = response.finish_reason.is_some();
276
277        if !is_final {
278            response.partial = true;
279            response.turn_complete = false;
280            return (vec![response], true);
281        }
282
283        response.partial = false;
284        response.turn_complete = true;
285
286        if saw_partial_chunk {
287            return (vec![response], true);
288        }
289
290        let synthetic_partial = LlmResponse {
291            content: None,
292            usage_metadata: None,
293            finish_reason: None,
294            citation_metadata: None,
295            partial: true,
296            turn_complete: false,
297            interrupted: false,
298            error_code: None,
299            error_message: None,
300        };
301
302        (vec![synthetic_partial, response], true)
303    }
304
305    async fn generate_content_internal(
306        &self,
307        req: LlmRequest,
308        stream: bool,
309    ) -> Result<LlmResponseStream> {
310        // Helper to format the full error chain (Display + all source errors)
311        fn format_error_chain(e: &dyn std::error::Error) -> String {
312            let mut msg = e.to_string();
313            let mut source = e.source();
314            while let Some(s) = source {
315                msg.push_str(": ");
316                msg.push_str(&s.to_string());
317                source = s.source();
318            }
319            msg
320        }
321
322        let mut builder = self.client.generate_content();
323
324        // Add contents using proper builder methods
325        for content in &req.contents {
326            match content.role.as_str() {
327                "user" => {
328                    // For user messages, build gemini Content with potentially multiple parts
329                    let mut gemini_parts = Vec::new();
330                    for part in &content.parts {
331                        match part {
332                            Part::Text { text } => {
333                                gemini_parts.push(adk_gemini::Part::Text {
334                                    text: text.clone(),
335                                    thought: None,
336                                    thought_signature: None,
337                                });
338                            }
339                            Part::Thinking { thinking, signature } => {
340                                gemini_parts.push(adk_gemini::Part::Text {
341                                    text: thinking.clone(),
342                                    thought: Some(true),
343                                    thought_signature: signature.clone(),
344                                });
345                            }
346                            Part::InlineData { data, mime_type } => {
347                                let encoded = attachment::encode_base64(data);
348                                gemini_parts.push(adk_gemini::Part::InlineData {
349                                    inline_data: adk_gemini::Blob {
350                                        mime_type: mime_type.clone(),
351                                        data: encoded,
352                                    },
353                                });
354                            }
355                            Part::FileData { mime_type, file_uri } => {
356                                gemini_parts.push(adk_gemini::Part::Text {
357                                    text: attachment::file_attachment_to_text(mime_type, file_uri),
358                                    thought: None,
359                                    thought_signature: None,
360                                });
361                            }
362                            _ => {}
363                        }
364                    }
365                    if !gemini_parts.is_empty() {
366                        let user_content = adk_gemini::Content {
367                            role: Some(adk_gemini::Role::User),
368                            parts: Some(gemini_parts),
369                        };
370                        builder = builder.with_message(adk_gemini::Message {
371                            content: user_content,
372                            role: adk_gemini::Role::User,
373                        });
374                    }
375                }
376                "model" => {
377                    // For model messages, build gemini Content
378                    let mut gemini_parts = Vec::new();
379                    for part in &content.parts {
380                        match part {
381                            Part::Text { text } => {
382                                gemini_parts.push(adk_gemini::Part::Text {
383                                    text: text.clone(),
384                                    thought: None,
385                                    thought_signature: None,
386                                });
387                            }
388                            Part::Thinking { thinking, signature } => {
389                                gemini_parts.push(adk_gemini::Part::Text {
390                                    text: thinking.clone(),
391                                    thought: Some(true),
392                                    thought_signature: signature.clone(),
393                                });
394                            }
395                            Part::FunctionCall { name, args, thought_signature, .. } => {
396                                gemini_parts.push(adk_gemini::Part::FunctionCall {
397                                    function_call: adk_gemini::FunctionCall {
398                                        name: name.clone(),
399                                        args: args.clone(),
400                                        thought_signature: None,
401                                    },
402                                    thought_signature: thought_signature.clone(),
403                                });
404                            }
405                            _ => {}
406                        }
407                    }
408                    if !gemini_parts.is_empty() {
409                        let model_content = adk_gemini::Content {
410                            role: Some(adk_gemini::Role::Model),
411                            parts: Some(gemini_parts),
412                        };
413                        builder = builder.with_message(adk_gemini::Message {
414                            content: model_content,
415                            role: adk_gemini::Role::Model,
416                        });
417                    }
418                }
419                "function" => {
420                    // For function responses
421                    for part in &content.parts {
422                        if let Part::FunctionResponse { function_response, .. } = part {
423                            builder = builder
424                                .with_function_response(
425                                    &function_response.name,
426                                    Self::gemini_function_response_payload(
427                                        function_response.response.clone(),
428                                    ),
429                                )
430                                .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
431                        }
432                    }
433                }
434                _ => {}
435            }
436        }
437
438        // Add generation config
439        if let Some(config) = req.config {
440            let has_schema = config.response_schema.is_some();
441            let gen_config = adk_gemini::GenerationConfig {
442                temperature: config.temperature,
443                top_p: config.top_p,
444                top_k: config.top_k,
445                max_output_tokens: config.max_output_tokens,
446                response_schema: config.response_schema,
447                response_mime_type: if has_schema {
448                    Some("application/json".to_string())
449                } else {
450                    None
451                },
452                ..Default::default()
453            };
454            builder = builder.with_generation_config(gen_config);
455
456            // Attach cached content reference if provided
457            if let Some(ref name) = config.cached_content {
458                let handle = self.client.get_cached_content(name);
459                builder = builder.with_cached_content(&handle);
460            }
461        }
462
463        // Add tools
464        if !req.tools.is_empty() {
465            let mut function_declarations = Vec::new();
466            let mut has_google_search = false;
467
468            for (name, tool_decl) in &req.tools {
469                if name == "google_search" {
470                    has_google_search = true;
471                    continue;
472                }
473
474                // Deserialize our tool declaration into adk_gemini::FunctionDeclaration
475                if let Ok(func_decl) =
476                    serde_json::from_value::<adk_gemini::FunctionDeclaration>(tool_decl.clone())
477                {
478                    function_declarations.push(func_decl);
479                }
480            }
481
482            if !function_declarations.is_empty() {
483                let tool = adk_gemini::Tool::with_functions(function_declarations);
484                builder = builder.with_tool(tool);
485            }
486
487            if has_google_search {
488                // Enable built-in Google Search
489                let tool = adk_gemini::Tool::google_search();
490                builder = builder.with_tool(tool);
491            }
492        }
493
494        if stream {
495            adk_telemetry::debug!("Executing streaming request");
496            let response_stream = builder.execute_stream().await.map_err(|e| {
497                adk_telemetry::error!(error = %e, "Model request failed");
498                adk_core::AdkError::Model(format_error_chain(&e))
499            })?;
500
501            let mapped_stream = async_stream::stream! {
502                let mut stream = response_stream;
503                let mut saw_partial_chunk = false;
504                while let Some(result) = stream.try_next().await.transpose() {
505                    match result {
506                        Ok(resp) => {
507                            match Self::convert_response(&resp) {
508                                Ok(llm_resp) => {
509                                    let (chunks, next_saw_partial) =
510                                        Self::stream_chunks_from_response(llm_resp, saw_partial_chunk);
511                                    saw_partial_chunk = next_saw_partial;
512                                    for chunk in chunks {
513                                        yield Ok(chunk);
514                                    }
515                                }
516                                Err(e) => {
517                                    adk_telemetry::error!(error = %e, "Failed to convert response");
518                                    yield Err(e);
519                                }
520                            }
521                        }
522                        Err(e) => {
523                            adk_telemetry::error!(error = %e, "Stream error");
524                            yield Err(adk_core::AdkError::Model(format_error_chain(&e)));
525                        }
526                    }
527                }
528            };
529
530            Ok(Box::pin(mapped_stream))
531        } else {
532            adk_telemetry::debug!("Executing blocking request");
533            let response = builder.execute().await.map_err(|e| {
534                adk_telemetry::error!(error = %e, "Model request failed");
535                adk_core::AdkError::Model(format_error_chain(&e))
536            })?;
537
538            let llm_response = Self::convert_response(&response)?;
539
540            let stream = async_stream::stream! {
541                yield Ok(llm_response);
542            };
543
544            Ok(Box::pin(stream))
545        }
546    }
547
548    /// Create a cached content resource with the given system instruction, tools, and TTL.
549    ///
550    /// Returns the cache name (e.g., "cachedContents/abc123") on success.
551    /// The cache is created using the model configured on this `GeminiModel` instance.
552    pub async fn create_cached_content(
553        &self,
554        system_instruction: &str,
555        tools: &std::collections::HashMap<String, serde_json::Value>,
556        ttl_seconds: u32,
557    ) -> Result<String> {
558        let mut cache_builder = self
559            .client
560            .create_cache()
561            .with_system_instruction(system_instruction)
562            .with_ttl(std::time::Duration::from_secs(u64::from(ttl_seconds)));
563
564        // Convert ADK tool definitions to Gemini FunctionDeclarations
565        let mut function_declarations = Vec::new();
566        for (name, tool_decl) in tools {
567            if name == "google_search" {
568                continue;
569            }
570            if let Ok(func_decl) =
571                serde_json::from_value::<adk_gemini::FunctionDeclaration>(tool_decl.clone())
572            {
573                function_declarations.push(func_decl);
574            }
575        }
576        if !function_declarations.is_empty() {
577            cache_builder = cache_builder
578                .with_tools(vec![adk_gemini::Tool::with_functions(function_declarations)]);
579        }
580
581        let handle = cache_builder
582            .execute()
583            .await
584            .map_err(|e| adk_core::AdkError::Model(format!("cache creation failed: {e}")))?;
585
586        Ok(handle.name().to_string())
587    }
588
589    /// Delete a cached content resource by name.
590    pub async fn delete_cached_content(&self, name: &str) -> Result<()> {
591        let handle = self.client.get_cached_content(name);
592        handle
593            .delete()
594            .await
595            .map_err(|(_, e)| adk_core::AdkError::Model(format!("cache deletion failed: {e}")))?;
596        Ok(())
597    }
598}
599
600#[async_trait]
601impl Llm for GeminiModel {
602    fn name(&self) -> &str {
603        &self.model_name
604    }
605
606    #[adk_telemetry::instrument(
607        name = "call_llm",
608        skip(self, req),
609        fields(
610            model.name = %self.model_name,
611            stream = %stream,
612            request.contents_count = %req.contents.len(),
613            request.tools_count = %req.tools.len()
614        )
615    )]
616    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
617        adk_telemetry::info!("Generating content");
618        // Retries only cover request setup/execution. Stream failures after the stream starts
619        // are yielded to the caller and are not replayed automatically.
620        execute_with_retry(&self.retry_config, is_retryable_model_error, || {
621            self.generate_content_internal(req.clone(), stream)
622        })
623        .await
624    }
625}
626
627#[async_trait]
628impl CacheCapable for GeminiModel {
629    async fn create_cache(
630        &self,
631        system_instruction: &str,
632        tools: &std::collections::HashMap<String, serde_json::Value>,
633        ttl_seconds: u32,
634    ) -> Result<String> {
635        self.create_cached_content(system_instruction, tools, ttl_seconds).await
636    }
637
638    async fn delete_cache(&self, name: &str) -> Result<()> {
639        self.delete_cached_content(name).await
640    }
641}
642
643#[cfg(test)]
644mod tests {
645    use super::*;
646    use adk_core::AdkError;
647    use std::{
648        sync::{
649            Arc,
650            atomic::{AtomicU32, Ordering},
651        },
652        time::Duration,
653    };
654
655    #[test]
656    fn constructor_is_backward_compatible_and_sync() {
657        fn accepts_sync_constructor<F>(_f: F)
658        where
659            F: Fn(&str, &str) -> Result<GeminiModel>,
660        {
661        }
662
663        accepts_sync_constructor(|api_key, model| GeminiModel::new(api_key, model));
664    }
665
666    #[test]
667    fn stream_chunks_from_response_injects_partial_before_lone_final_chunk() {
668        let response = LlmResponse {
669            content: Some(Content::new("model").with_text("hello")),
670            usage_metadata: None,
671            finish_reason: Some(FinishReason::Stop),
672            citation_metadata: None,
673            partial: false,
674            turn_complete: true,
675            interrupted: false,
676            error_code: None,
677            error_message: None,
678        };
679
680        let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, false);
681        assert!(saw_partial);
682        assert_eq!(chunks.len(), 2);
683        assert!(chunks[0].partial);
684        assert!(!chunks[0].turn_complete);
685        assert!(chunks[0].content.is_none());
686        assert!(!chunks[1].partial);
687        assert!(chunks[1].turn_complete);
688    }
689
690    #[test]
691    fn stream_chunks_from_response_keeps_final_only_when_partial_already_seen() {
692        let response = LlmResponse {
693            content: Some(Content::new("model").with_text("done")),
694            usage_metadata: None,
695            finish_reason: Some(FinishReason::Stop),
696            citation_metadata: None,
697            partial: false,
698            turn_complete: true,
699            interrupted: false,
700            error_code: None,
701            error_message: None,
702        };
703
704        let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, true);
705        assert!(saw_partial);
706        assert_eq!(chunks.len(), 1);
707        assert!(!chunks[0].partial);
708        assert!(chunks[0].turn_complete);
709    }
710
711    #[tokio::test]
712    async fn execute_with_retry_retries_retryable_errors() {
713        let retry_config = RetryConfig::default()
714            .with_max_retries(2)
715            .with_initial_delay(Duration::from_millis(0))
716            .with_max_delay(Duration::from_millis(0));
717        let attempts = Arc::new(AtomicU32::new(0));
718
719        let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
720            let attempts = Arc::clone(&attempts);
721            async move {
722                let attempt = attempts.fetch_add(1, Ordering::SeqCst);
723                if attempt < 2 {
724                    return Err(AdkError::Model("code 429 RESOURCE_EXHAUSTED".to_string()));
725                }
726                Ok("ok")
727            }
728        })
729        .await
730        .expect("retry should eventually succeed");
731
732        assert_eq!(result, "ok");
733        assert_eq!(attempts.load(Ordering::SeqCst), 3);
734    }
735
736    #[tokio::test]
737    async fn execute_with_retry_does_not_retry_non_retryable_errors() {
738        let retry_config = RetryConfig::default()
739            .with_max_retries(3)
740            .with_initial_delay(Duration::from_millis(0))
741            .with_max_delay(Duration::from_millis(0));
742        let attempts = Arc::new(AtomicU32::new(0));
743
744        let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
745            let attempts = Arc::clone(&attempts);
746            async move {
747                attempts.fetch_add(1, Ordering::SeqCst);
748                Err::<(), _>(AdkError::Model("code 400 invalid request".to_string()))
749            }
750        })
751        .await
752        .expect_err("non-retryable error should return immediately");
753
754        assert!(matches!(error, AdkError::Model(_)));
755        assert_eq!(attempts.load(Ordering::SeqCst), 1);
756    }
757
758    #[tokio::test]
759    async fn execute_with_retry_respects_disabled_config() {
760        let retry_config = RetryConfig::disabled().with_max_retries(10);
761        let attempts = Arc::new(AtomicU32::new(0));
762
763        let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
764            let attempts = Arc::clone(&attempts);
765            async move {
766                attempts.fetch_add(1, Ordering::SeqCst);
767                Err::<(), _>(AdkError::Model("code 429 RESOURCE_EXHAUSTED".to_string()))
768            }
769        })
770        .await
771        .expect_err("disabled retries should return first error");
772
773        assert!(matches!(error, AdkError::Model(_)));
774        assert_eq!(attempts.load(Ordering::SeqCst), 1);
775    }
776
777    #[test]
778    fn convert_response_preserves_citation_metadata() {
779        let response = adk_gemini::GenerationResponse {
780            candidates: vec![adk_gemini::Candidate {
781                content: adk_gemini::Content {
782                    role: Some(adk_gemini::Role::Model),
783                    parts: Some(vec![adk_gemini::Part::Text {
784                        text: "hello world".to_string(),
785                        thought: None,
786                        thought_signature: None,
787                    }]),
788                },
789                safety_ratings: None,
790                citation_metadata: Some(adk_gemini::CitationMetadata {
791                    citation_sources: vec![adk_gemini::CitationSource {
792                        uri: Some("https://example.com".to_string()),
793                        title: Some("Example".to_string()),
794                        start_index: Some(0),
795                        end_index: Some(5),
796                        license: Some("CC-BY".to_string()),
797                        publication_date: None,
798                    }],
799                }),
800                grounding_metadata: None,
801                finish_reason: Some(adk_gemini::FinishReason::Stop),
802                index: Some(0),
803            }],
804            prompt_feedback: None,
805            usage_metadata: None,
806            model_version: None,
807            response_id: None,
808        };
809
810        let converted =
811            GeminiModel::convert_response(&response).expect("conversion should succeed");
812        let metadata = converted.citation_metadata.expect("citation metadata should be mapped");
813        assert_eq!(metadata.citation_sources.len(), 1);
814        assert_eq!(metadata.citation_sources[0].uri.as_deref(), Some("https://example.com"));
815        assert_eq!(metadata.citation_sources[0].start_index, Some(0));
816        assert_eq!(metadata.citation_sources[0].end_index, Some(5));
817    }
818
819    #[test]
820    fn convert_response_handles_inline_data_from_model() {
821        let image_bytes = vec![0x89, 0x50, 0x4E, 0x47];
822        let encoded = crate::attachment::encode_base64(&image_bytes);
823
824        let response = adk_gemini::GenerationResponse {
825            candidates: vec![adk_gemini::Candidate {
826                content: adk_gemini::Content {
827                    role: Some(adk_gemini::Role::Model),
828                    parts: Some(vec![
829                        adk_gemini::Part::Text {
830                            text: "Here is the image".to_string(),
831                            thought: None,
832                            thought_signature: None,
833                        },
834                        adk_gemini::Part::InlineData {
835                            inline_data: adk_gemini::Blob {
836                                mime_type: "image/png".to_string(),
837                                data: encoded,
838                            },
839                        },
840                    ]),
841                },
842                safety_ratings: None,
843                citation_metadata: None,
844                grounding_metadata: None,
845                finish_reason: Some(adk_gemini::FinishReason::Stop),
846                index: Some(0),
847            }],
848            prompt_feedback: None,
849            usage_metadata: None,
850            model_version: None,
851            response_id: None,
852        };
853
854        let converted =
855            GeminiModel::convert_response(&response).expect("conversion should succeed");
856        let content = converted.content.expect("should have content");
857        assert!(
858            content
859                .parts
860                .iter()
861                .any(|part| matches!(part, Part::Text { text } if text == "Here is the image"))
862        );
863        assert!(content.parts.iter().any(|part| {
864            matches!(
865                part,
866                Part::InlineData { mime_type, data }
867                    if mime_type == "image/png" && data.as_slice() == image_bytes.as_slice()
868            )
869        }));
870    }
871
872    #[test]
873    fn gemini_function_response_payload_preserves_objects() {
874        let value = serde_json::json!({
875            "documents": [
876                { "id": "pricing", "score": 0.91 }
877            ]
878        });
879
880        let payload = GeminiModel::gemini_function_response_payload(value.clone());
881
882        assert_eq!(payload, value);
883    }
884
885    #[test]
886    fn gemini_function_response_payload_wraps_arrays() {
887        let payload =
888            GeminiModel::gemini_function_response_payload(serde_json::json!([{ "id": "pricing" }]));
889
890        assert_eq!(payload, serde_json::json!({ "result": [{ "id": "pricing" }] }));
891    }
892}