Skip to main content

adk_model/gemini/
client.rs

1use crate::attachment;
2use crate::retry::{RetryConfig, execute_with_retry, is_retryable_model_error};
3use adk_core::{
4    CacheCapable, CitationMetadata, CitationSource, Content, ErrorCategory, ErrorComponent,
5    FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result, UsageMetadata,
6};
7use adk_gemini::Gemini;
8use async_trait::async_trait;
9use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
10use futures::TryStreamExt;
11
12pub struct GeminiModel {
13    client: Gemini,
14    model_name: String,
15    retry_config: RetryConfig,
16}
17
18/// Convert a Gemini client error to a structured `AdkError` with proper category and retry hints.
19fn gemini_error_to_adk(e: &adk_gemini::ClientError) -> adk_core::AdkError {
20    fn format_error_chain(e: &dyn std::error::Error) -> String {
21        let mut msg = e.to_string();
22        let mut source = e.source();
23        while let Some(s) = source {
24            msg.push_str(": ");
25            msg.push_str(&s.to_string());
26            source = s.source();
27        }
28        msg
29    }
30
31    let message = format_error_chain(e);
32
33    // Extract status code from BadResponse variant via Display output
34    // BadResponse format: "bad response from server; code {code}; description: ..."
35    let (category, code, status_code) = if message.contains("code 429")
36        || message.contains("RESOURCE_EXHAUSTED")
37        || message.contains("rate limit")
38    {
39        (ErrorCategory::RateLimited, "model.gemini.rate_limited", Some(429u16))
40    } else if message.contains("code 503") || message.contains("UNAVAILABLE") {
41        (ErrorCategory::Unavailable, "model.gemini.unavailable", Some(503))
42    } else if message.contains("code 529") || message.contains("OVERLOADED") {
43        (ErrorCategory::Unavailable, "model.gemini.overloaded", Some(529))
44    } else if message.contains("code 408")
45        || message.contains("DEADLINE_EXCEEDED")
46        || message.contains("TIMEOUT")
47    {
48        (ErrorCategory::Timeout, "model.gemini.timeout", Some(408))
49    } else if message.contains("code 401") || message.contains("Invalid API key") {
50        (ErrorCategory::Unauthorized, "model.gemini.unauthorized", Some(401))
51    } else if message.contains("code 400") {
52        (ErrorCategory::InvalidInput, "model.gemini.bad_request", Some(400))
53    } else if message.contains("code 404") {
54        (ErrorCategory::NotFound, "model.gemini.not_found", Some(404))
55    } else if message.contains("invalid generation config") {
56        (ErrorCategory::InvalidInput, "model.gemini.invalid_config", None)
57    } else {
58        (ErrorCategory::Internal, "model.gemini.internal", None)
59    };
60
61    let mut err = adk_core::AdkError::new(ErrorComponent::Model, category, code, message)
62        .with_provider("gemini");
63    if let Some(sc) = status_code {
64        err = err.with_upstream_status(sc);
65    }
66    err
67}
68
69impl GeminiModel {
70    fn gemini_part_thought_signature(value: &serde_json::Value) -> Option<String> {
71        value.get("thoughtSignature").and_then(serde_json::Value::as_str).map(str::to_string)
72    }
73
74    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
75        let model_name = model.into();
76        let client = Gemini::with_model(api_key.into(), model_name.clone())
77            .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
78
79        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
80    }
81
82    /// Create a Gemini model via Vertex AI with API key auth.
83    ///
84    /// Requires `gemini-vertex` feature.
85    #[cfg(feature = "gemini-vertex")]
86    pub fn new_google_cloud(
87        api_key: impl Into<String>,
88        project_id: impl AsRef<str>,
89        location: impl AsRef<str>,
90        model: impl Into<String>,
91    ) -> Result<Self> {
92        let model_name = model.into();
93        let client = Gemini::with_google_cloud_model(
94            api_key.into(),
95            project_id,
96            location,
97            model_name.clone(),
98        )
99        .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
100
101        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
102    }
103
104    /// Create a Gemini model via Vertex AI with service account JSON.
105    ///
106    /// Requires `gemini-vertex` feature.
107    #[cfg(feature = "gemini-vertex")]
108    pub fn new_google_cloud_service_account(
109        service_account_json: &str,
110        project_id: impl AsRef<str>,
111        location: impl AsRef<str>,
112        model: impl Into<String>,
113    ) -> Result<Self> {
114        let model_name = model.into();
115        let client = Gemini::with_google_cloud_service_account_json(
116            service_account_json,
117            project_id.as_ref(),
118            location.as_ref(),
119            model_name.clone(),
120        )
121        .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
122
123        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
124    }
125
126    /// Create a Gemini model via Vertex AI with Application Default Credentials.
127    ///
128    /// Requires `gemini-vertex` feature.
129    #[cfg(feature = "gemini-vertex")]
130    pub fn new_google_cloud_adc(
131        project_id: impl AsRef<str>,
132        location: impl AsRef<str>,
133        model: impl Into<String>,
134    ) -> Result<Self> {
135        let model_name = model.into();
136        let client = Gemini::with_google_cloud_adc_model(
137            project_id.as_ref(),
138            location.as_ref(),
139            model_name.clone(),
140        )
141        .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
142
143        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
144    }
145
146    /// Create a Gemini model via Vertex AI with Workload Identity Federation.
147    ///
148    /// Requires `gemini-vertex` feature.
149    #[cfg(feature = "gemini-vertex")]
150    pub fn new_google_cloud_wif(
151        wif_json: &str,
152        project_id: impl AsRef<str>,
153        location: impl AsRef<str>,
154        model: impl Into<String>,
155    ) -> Result<Self> {
156        let model_name = model.into();
157        let client = Gemini::with_google_cloud_wif_json(
158            wif_json,
159            project_id.as_ref(),
160            location.as_ref(),
161            model_name.clone(),
162        )
163        .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
164
165        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
166    }
167
168    #[must_use]
169    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
170        self.retry_config = retry_config;
171        self
172    }
173
174    pub fn set_retry_config(&mut self, retry_config: RetryConfig) {
175        self.retry_config = retry_config;
176    }
177
178    pub fn retry_config(&self) -> &RetryConfig {
179        &self.retry_config
180    }
181
182    fn convert_response(resp: &adk_gemini::GenerationResponse) -> Result<LlmResponse> {
183        let mut converted_parts: Vec<Part> = Vec::new();
184
185        // Convert content parts
186        if let Some(parts) = resp.candidates.first().and_then(|c| c.content.parts.as_ref()) {
187            for p in parts {
188                match p {
189                    adk_gemini::Part::Text { text, thought, thought_signature } => {
190                        if thought == &Some(true) {
191                            converted_parts.push(Part::Thinking {
192                                thinking: text.clone(),
193                                signature: thought_signature.clone(),
194                            });
195                        } else {
196                            converted_parts.push(Part::Text { text: text.clone() });
197                        }
198                    }
199                    adk_gemini::Part::InlineData { inline_data } => {
200                        let decoded =
201                            BASE64_STANDARD.decode(&inline_data.data).map_err(|error| {
202                                adk_core::AdkError::model(format!(
203                                    "failed to decode inline data from gemini response: {error}"
204                                ))
205                            })?;
206                        converted_parts.push(Part::InlineData {
207                            mime_type: inline_data.mime_type.clone(),
208                            data: decoded,
209                        });
210                    }
211                    adk_gemini::Part::FunctionCall { function_call, thought_signature } => {
212                        converted_parts.push(Part::FunctionCall {
213                            name: function_call.name.clone(),
214                            args: function_call.args.clone(),
215                            id: None,
216                            thought_signature: thought_signature.clone(),
217                        });
218                    }
219                    adk_gemini::Part::FunctionResponse { function_response, .. } => {
220                        converted_parts.push(Part::FunctionResponse {
221                            function_response: adk_core::FunctionResponseData {
222                                name: function_response.name.clone(),
223                                response: function_response
224                                    .response
225                                    .clone()
226                                    .unwrap_or(serde_json::Value::Null),
227                            },
228                            id: None,
229                        });
230                    }
231                    adk_gemini::Part::ToolCall { .. } | adk_gemini::Part::ExecutableCode { .. } => {
232                        if let Ok(value) = serde_json::to_value(p) {
233                            converted_parts.push(Part::ServerToolCall { server_tool_call: value });
234                        }
235                    }
236                    adk_gemini::Part::ToolResponse { .. }
237                    | adk_gemini::Part::CodeExecutionResult { .. } => {
238                        let value = serde_json::to_value(p).unwrap_or(serde_json::Value::Null);
239                        converted_parts
240                            .push(Part::ServerToolResponse { server_tool_response: value });
241                    }
242                }
243            }
244        }
245
246        // Add grounding metadata as text if present (required for Google Search grounding compliance)
247        if let Some(grounding) = resp.candidates.first().and_then(|c| c.grounding_metadata.as_ref())
248        {
249            if let Some(queries) = &grounding.web_search_queries {
250                if !queries.is_empty() {
251                    let search_info = format!("\n\nšŸ” **Searched:** {}", queries.join(", "));
252                    converted_parts.push(Part::Text { text: search_info });
253                }
254            }
255            if let Some(chunks) = &grounding.grounding_chunks {
256                let sources: Vec<String> = chunks
257                    .iter()
258                    .filter_map(|c| {
259                        c.web.as_ref().and_then(|w| match (&w.title, &w.uri) {
260                            (Some(title), Some(uri)) => Some(format!("[{}]({})", title, uri)),
261                            (Some(title), None) => Some(title.clone()),
262                            (None, Some(uri)) => Some(uri.to_string()),
263                            (None, None) => None,
264                        })
265                    })
266                    .collect();
267                if !sources.is_empty() {
268                    let sources_info = format!("\nšŸ“š **Sources:** {}", sources.join(" | "));
269                    converted_parts.push(Part::Text { text: sources_info });
270                }
271            }
272        }
273
274        let content = if converted_parts.is_empty() {
275            None
276        } else {
277            Some(Content { role: "model".to_string(), parts: converted_parts })
278        };
279
280        let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
281            prompt_token_count: u.prompt_token_count.unwrap_or(0),
282            candidates_token_count: u.candidates_token_count.unwrap_or(0),
283            total_token_count: u.total_token_count.unwrap_or(0),
284            thinking_token_count: u.thoughts_token_count,
285            cache_read_input_token_count: u.cached_content_token_count,
286            ..Default::default()
287        });
288
289        let finish_reason =
290            resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
291                adk_gemini::FinishReason::Stop => FinishReason::Stop,
292                adk_gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
293                adk_gemini::FinishReason::Safety => FinishReason::Safety,
294                adk_gemini::FinishReason::Recitation => FinishReason::Recitation,
295                _ => FinishReason::Other,
296            });
297
298        let citation_metadata =
299            resp.candidates.first().and_then(|c| c.citation_metadata.as_ref()).map(|meta| {
300                CitationMetadata {
301                    citation_sources: meta
302                        .citation_sources
303                        .iter()
304                        .map(|source| CitationSource {
305                            uri: source.uri.clone(),
306                            title: source.title.clone(),
307                            start_index: source.start_index,
308                            end_index: source.end_index,
309                            license: source.license.clone(),
310                            publication_date: source.publication_date.map(|d| d.to_string()),
311                        })
312                        .collect(),
313                }
314            });
315
316        // Serialize grounding metadata into provider_metadata so consumers
317        // can access structured grounding data (search queries, sources, supports).
318        let provider_metadata = resp
319            .candidates
320            .first()
321            .and_then(|c| c.grounding_metadata.as_ref())
322            .and_then(|g| serde_json::to_value(g).ok());
323
324        Ok(LlmResponse {
325            content,
326            usage_metadata,
327            finish_reason,
328            citation_metadata,
329            partial: false,
330            turn_complete: true,
331            interrupted: false,
332            error_code: None,
333            error_message: None,
334            provider_metadata,
335        })
336    }
337
338    fn gemini_function_response_payload(response: serde_json::Value) -> serde_json::Value {
339        match response {
340            // Gemini functionResponse.response must be a JSON object.
341            serde_json::Value::Object(_) => response,
342            other => serde_json::json!({ "result": other }),
343        }
344    }
345
346    fn merge_object_value(
347        target: &mut serde_json::Map<String, serde_json::Value>,
348        value: serde_json::Value,
349    ) {
350        if let serde_json::Value::Object(object) = value {
351            for (key, value) in object {
352                target.insert(key, value);
353            }
354        }
355    }
356
357    fn build_gemini_tools(
358        tools: &std::collections::HashMap<String, serde_json::Value>,
359    ) -> Result<(Vec<adk_gemini::Tool>, adk_gemini::ToolConfig)> {
360        let mut gemini_tools = Vec::new();
361        let mut function_declarations = Vec::new();
362        let mut has_provider_native_tools = false;
363        let mut tool_config_json = serde_json::Map::new();
364
365        for (name, tool_decl) in tools {
366            if let Some(provider_tool) = tool_decl.get("x-adk-gemini-tool") {
367                let tool = serde_json::from_value::<adk_gemini::Tool>(provider_tool.clone())
368                    .map_err(|error| {
369                        adk_core::AdkError::model(format!(
370                            "failed to deserialize Gemini native tool '{name}': {error}"
371                        ))
372                    })?;
373                has_provider_native_tools = true;
374                gemini_tools.push(tool);
375            } else if let Ok(func_decl) =
376                serde_json::from_value::<adk_gemini::FunctionDeclaration>(tool_decl.clone())
377            {
378                function_declarations.push(func_decl);
379            } else {
380                return Err(adk_core::AdkError::model(format!(
381                    "failed to deserialize Gemini tool '{name}' as a function declaration"
382                )));
383            }
384
385            if let Some(tool_config) = tool_decl.get("x-adk-gemini-tool-config") {
386                Self::merge_object_value(&mut tool_config_json, tool_config.clone());
387            }
388        }
389
390        let has_function_declarations = !function_declarations.is_empty();
391        if has_function_declarations {
392            gemini_tools.push(adk_gemini::Tool::with_functions(function_declarations));
393        }
394
395        if has_provider_native_tools && has_function_declarations {
396            tool_config_json.insert(
397                "includeServerSideToolInvocations".to_string(),
398                serde_json::Value::Bool(true),
399            );
400        }
401
402        let tool_config = if tool_config_json.is_empty() {
403            adk_gemini::ToolConfig::default()
404        } else {
405            serde_json::from_value::<adk_gemini::ToolConfig>(serde_json::Value::Object(
406                tool_config_json,
407            ))
408            .map_err(|error| {
409                adk_core::AdkError::model(format!(
410                    "failed to deserialize Gemini tool configuration: {error}"
411                ))
412            })?
413        };
414
415        Ok((gemini_tools, tool_config))
416    }
417
418    fn stream_chunks_from_response(
419        mut response: LlmResponse,
420        saw_partial_chunk: bool,
421    ) -> (Vec<LlmResponse>, bool) {
422        let is_final = response.finish_reason.is_some();
423
424        if !is_final {
425            response.partial = true;
426            response.turn_complete = false;
427            return (vec![response], true);
428        }
429
430        response.partial = false;
431        response.turn_complete = true;
432
433        if saw_partial_chunk {
434            return (vec![response], true);
435        }
436
437        let synthetic_partial = LlmResponse {
438            content: None,
439            usage_metadata: None,
440            finish_reason: None,
441            citation_metadata: None,
442            partial: true,
443            turn_complete: false,
444            interrupted: false,
445            error_code: None,
446            error_message: None,
447            provider_metadata: None,
448        };
449
450        (vec![synthetic_partial, response], true)
451    }
452
453    async fn generate_content_internal(
454        &self,
455        req: LlmRequest,
456        stream: bool,
457    ) -> Result<LlmResponseStream> {
458        let mut builder = self.client.generate_content();
459
460        // Build a map of function_name → thought_signature from FunctionCall parts
461        // in model content. Gemini 3.x requires thought_signature on FunctionResponse
462        // parts when thinking is active, but adk_core::Part::FunctionResponse doesn't
463        // carry it (it's Gemini-specific). We recover it here at the provider boundary.
464        let mut fn_call_signatures: std::collections::HashMap<String, String> =
465            std::collections::HashMap::new();
466        for content in &req.contents {
467            if content.role == "model" {
468                for part in &content.parts {
469                    if let Part::FunctionCall { name, thought_signature: Some(sig), .. } = part {
470                        fn_call_signatures.insert(name.clone(), sig.clone());
471                    }
472                }
473            }
474        }
475
476        // Add contents using proper builder methods
477        for content in &req.contents {
478            match content.role.as_str() {
479                "user" => {
480                    // For user messages, build gemini Content with potentially multiple parts
481                    let mut gemini_parts = Vec::new();
482                    for part in &content.parts {
483                        match part {
484                            Part::Text { text } => {
485                                gemini_parts.push(adk_gemini::Part::Text {
486                                    text: text.clone(),
487                                    thought: None,
488                                    thought_signature: None,
489                                });
490                            }
491                            Part::Thinking { thinking, signature } => {
492                                gemini_parts.push(adk_gemini::Part::Text {
493                                    text: thinking.clone(),
494                                    thought: Some(true),
495                                    thought_signature: signature.clone(),
496                                });
497                            }
498                            Part::InlineData { data, mime_type } => {
499                                let encoded = attachment::encode_base64(data);
500                                gemini_parts.push(adk_gemini::Part::InlineData {
501                                    inline_data: adk_gemini::Blob {
502                                        mime_type: mime_type.clone(),
503                                        data: encoded,
504                                    },
505                                });
506                            }
507                            Part::FileData { mime_type, file_uri } => {
508                                gemini_parts.push(adk_gemini::Part::Text {
509                                    text: attachment::file_attachment_to_text(mime_type, file_uri),
510                                    thought: None,
511                                    thought_signature: None,
512                                });
513                            }
514                            _ => {}
515                        }
516                    }
517                    if !gemini_parts.is_empty() {
518                        let user_content = adk_gemini::Content {
519                            role: Some(adk_gemini::Role::User),
520                            parts: Some(gemini_parts),
521                        };
522                        builder = builder.with_message(adk_gemini::Message {
523                            content: user_content,
524                            role: adk_gemini::Role::User,
525                        });
526                    }
527                }
528                "model" => {
529                    // For model messages, build gemini Content
530                    let mut gemini_parts = Vec::new();
531                    for part in &content.parts {
532                        match part {
533                            Part::Text { text } => {
534                                gemini_parts.push(adk_gemini::Part::Text {
535                                    text: text.clone(),
536                                    thought: None,
537                                    thought_signature: None,
538                                });
539                            }
540                            Part::Thinking { thinking, signature } => {
541                                gemini_parts.push(adk_gemini::Part::Text {
542                                    text: thinking.clone(),
543                                    thought: Some(true),
544                                    thought_signature: signature.clone(),
545                                });
546                            }
547                            Part::FunctionCall { name, args, thought_signature, .. } => {
548                                gemini_parts.push(adk_gemini::Part::FunctionCall {
549                                    function_call: adk_gemini::FunctionCall {
550                                        name: name.clone(),
551                                        args: args.clone(),
552                                        thought_signature: None,
553                                    },
554                                    thought_signature: thought_signature.clone(),
555                                });
556                            }
557                            Part::ServerToolCall { server_tool_call } => {
558                                if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
559                                    server_tool_call.clone(),
560                                ) {
561                                    match native_part {
562                                        adk_gemini::Part::ToolCall { .. }
563                                        | adk_gemini::Part::ExecutableCode { .. } => {
564                                            gemini_parts.push(native_part);
565                                            continue;
566                                        }
567                                        _ => {}
568                                    }
569                                }
570
571                                gemini_parts.push(adk_gemini::Part::ToolCall {
572                                    tool_call: server_tool_call.clone(),
573                                    thought_signature: Self::gemini_part_thought_signature(
574                                        server_tool_call,
575                                    ),
576                                });
577                            }
578                            Part::ServerToolResponse { server_tool_response } => {
579                                if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
580                                    server_tool_response.clone(),
581                                ) {
582                                    match native_part {
583                                        adk_gemini::Part::ToolResponse { .. }
584                                        | adk_gemini::Part::CodeExecutionResult { .. } => {
585                                            gemini_parts.push(native_part);
586                                            continue;
587                                        }
588                                        _ => {}
589                                    }
590                                }
591
592                                gemini_parts.push(adk_gemini::Part::ToolResponse {
593                                    tool_response: server_tool_response.clone(),
594                                    thought_signature: Self::gemini_part_thought_signature(
595                                        server_tool_response,
596                                    ),
597                                });
598                            }
599                            _ => {}
600                        }
601                    }
602                    if !gemini_parts.is_empty() {
603                        let model_content = adk_gemini::Content {
604                            role: Some(adk_gemini::Role::Model),
605                            parts: Some(gemini_parts),
606                        };
607                        builder = builder.with_message(adk_gemini::Message {
608                            content: model_content,
609                            role: adk_gemini::Role::Model,
610                        });
611                    }
612                }
613                "function" => {
614                    // For function responses, build content directly to attach thought_signature
615                    // recovered from the preceding FunctionCall (Gemini 3.x requirement)
616                    let mut gemini_parts = Vec::new();
617                    for part in &content.parts {
618                        if let Part::FunctionResponse { function_response, .. } = part {
619                            let sig = fn_call_signatures.get(&function_response.name).cloned();
620                            gemini_parts.push(adk_gemini::Part::FunctionResponse {
621                                function_response: adk_gemini::tools::FunctionResponse::new(
622                                    &function_response.name,
623                                    Self::gemini_function_response_payload(
624                                        function_response.response.clone(),
625                                    ),
626                                ),
627                                thought_signature: sig,
628                            });
629                        }
630                    }
631                    if !gemini_parts.is_empty() {
632                        let fn_content = adk_gemini::Content {
633                            role: Some(adk_gemini::Role::User),
634                            parts: Some(gemini_parts),
635                        };
636                        builder = builder.with_message(adk_gemini::Message {
637                            content: fn_content,
638                            role: adk_gemini::Role::User,
639                        });
640                    }
641                }
642                _ => {}
643            }
644        }
645
646        // Add generation config
647        if let Some(config) = req.config {
648            let has_schema = config.response_schema.is_some();
649            let gen_config = adk_gemini::GenerationConfig {
650                temperature: config.temperature,
651                top_p: config.top_p,
652                top_k: config.top_k,
653                max_output_tokens: config.max_output_tokens,
654                response_schema: config.response_schema,
655                response_mime_type: if has_schema {
656                    Some("application/json".to_string())
657                } else {
658                    None
659                },
660                ..Default::default()
661            };
662            builder = builder.with_generation_config(gen_config);
663
664            // Attach cached content reference if provided
665            if let Some(ref name) = config.cached_content {
666                let handle = self.client.get_cached_content(name);
667                builder = builder.with_cached_content(&handle);
668            }
669        }
670
671        // Add tools
672        if !req.tools.is_empty() {
673            let (gemini_tools, tool_config) = Self::build_gemini_tools(&req.tools)?;
674            for tool in gemini_tools {
675                builder = builder.with_tool(tool);
676            }
677            if tool_config != adk_gemini::ToolConfig::default() {
678                builder = builder.with_tool_config(tool_config);
679            }
680        }
681
682        if stream {
683            adk_telemetry::debug!("Executing streaming request");
684            let response_stream = builder.execute_stream().await.map_err(|e| {
685                adk_telemetry::error!(error = %e, "Model request failed");
686                gemini_error_to_adk(&e)
687            })?;
688
689            let mapped_stream = async_stream::stream! {
690                let mut stream = response_stream;
691                let mut saw_partial_chunk = false;
692                while let Some(result) = stream.try_next().await.transpose() {
693                    match result {
694                        Ok(resp) => {
695                            match Self::convert_response(&resp) {
696                                Ok(llm_resp) => {
697                                    let (chunks, next_saw_partial) =
698                                        Self::stream_chunks_from_response(llm_resp, saw_partial_chunk);
699                                    saw_partial_chunk = next_saw_partial;
700                                    for chunk in chunks {
701                                        yield Ok(chunk);
702                                    }
703                                }
704                                Err(e) => {
705                                    adk_telemetry::error!(error = %e, "Failed to convert response");
706                                    yield Err(e);
707                                }
708                            }
709                        }
710                        Err(e) => {
711                            adk_telemetry::error!(error = %e, "Stream error");
712                            yield Err(gemini_error_to_adk(&e));
713                        }
714                    }
715                }
716            };
717
718            Ok(Box::pin(mapped_stream))
719        } else {
720            adk_telemetry::debug!("Executing blocking request");
721            let response = builder.execute().await.map_err(|e| {
722                adk_telemetry::error!(error = %e, "Model request failed");
723                gemini_error_to_adk(&e)
724            })?;
725
726            let llm_response = Self::convert_response(&response)?;
727
728            let stream = async_stream::stream! {
729                yield Ok(llm_response);
730            };
731
732            Ok(Box::pin(stream))
733        }
734    }
735
736    /// Create a cached content resource with the given system instruction, tools, and TTL.
737    ///
738    /// Returns the cache name (e.g., "cachedContents/abc123") on success.
739    /// The cache is created using the model configured on this `GeminiModel` instance.
740    pub async fn create_cached_content(
741        &self,
742        system_instruction: &str,
743        tools: &std::collections::HashMap<String, serde_json::Value>,
744        ttl_seconds: u32,
745    ) -> Result<String> {
746        let mut cache_builder = self
747            .client
748            .create_cache()
749            .with_system_instruction(system_instruction)
750            .with_ttl(std::time::Duration::from_secs(u64::from(ttl_seconds)));
751
752        let (gemini_tools, tool_config) = Self::build_gemini_tools(tools)?;
753        if !gemini_tools.is_empty() {
754            cache_builder = cache_builder.with_tools(gemini_tools);
755        }
756        if tool_config != adk_gemini::ToolConfig::default() {
757            cache_builder = cache_builder.with_tool_config(tool_config);
758        }
759
760        let handle = cache_builder
761            .execute()
762            .await
763            .map_err(|e| adk_core::AdkError::model(format!("cache creation failed: {e}")))?;
764
765        Ok(handle.name().to_string())
766    }
767
768    /// Delete a cached content resource by name.
769    pub async fn delete_cached_content(&self, name: &str) -> Result<()> {
770        let handle = self.client.get_cached_content(name);
771        handle
772            .delete()
773            .await
774            .map_err(|(_, e)| adk_core::AdkError::model(format!("cache deletion failed: {e}")))?;
775        Ok(())
776    }
777}
778
779#[async_trait]
780impl Llm for GeminiModel {
781    fn name(&self) -> &str {
782        &self.model_name
783    }
784
785    #[adk_telemetry::instrument(
786        name = "call_llm",
787        skip(self, req),
788        fields(
789            model.name = %self.model_name,
790            stream = %stream,
791            request.contents_count = %req.contents.len(),
792            request.tools_count = %req.tools.len()
793        )
794    )]
795    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
796        adk_telemetry::info!("Generating content");
797        let usage_span = adk_telemetry::llm_generate_span("gemini", &self.model_name, stream);
798        // Retries only cover request setup/execution. Stream failures after the stream starts
799        // are yielded to the caller and are not replayed automatically.
800        let result = execute_with_retry(&self.retry_config, is_retryable_model_error, || {
801            self.generate_content_internal(req.clone(), stream)
802        })
803        .await?;
804        Ok(crate::usage_tracking::with_usage_tracking(result, usage_span))
805    }
806}
807
808#[cfg(test)]
809mod native_tool_tests {
810    use super::*;
811
812    #[test]
813    fn test_build_gemini_tools_supports_native_tool_metadata() {
814        let mut tools = std::collections::HashMap::new();
815        tools.insert(
816            "google_search".to_string(),
817            serde_json::json!({
818                "x-adk-gemini-tool": {
819                    "google_search": {}
820                }
821            }),
822        );
823        tools.insert(
824            "lookup_weather".to_string(),
825            serde_json::json!({
826                "name": "lookup_weather",
827                "description": "lookup weather",
828                "parameters": {
829                    "type": "object",
830                    "properties": {
831                        "city": { "type": "string" }
832                    }
833                }
834            }),
835        );
836
837        let (gemini_tools, tool_config) =
838            GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
839
840        assert_eq!(gemini_tools.len(), 2);
841        assert_eq!(tool_config.include_server_side_tool_invocations, Some(true));
842    }
843
844    #[test]
845    fn test_build_gemini_tools_merges_native_tool_config() {
846        let mut tools = std::collections::HashMap::new();
847        tools.insert(
848            "google_maps".to_string(),
849            serde_json::json!({
850                "x-adk-gemini-tool": {
851                    "google_maps": {
852                        "enable_widget": true
853                    }
854                },
855                "x-adk-gemini-tool-config": {
856                    "retrievalConfig": {
857                        "latLng": {
858                            "latitude": 1.23,
859                            "longitude": 4.56
860                        }
861                    }
862                }
863            }),
864        );
865
866        let (_gemini_tools, tool_config) =
867            GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
868
869        assert_eq!(
870            tool_config.retrieval_config,
871            Some(serde_json::json!({
872                "latLng": {
873                    "latitude": 1.23,
874                    "longitude": 4.56
875                }
876            }))
877        );
878    }
879}
880
881#[async_trait]
882impl CacheCapable for GeminiModel {
883    async fn create_cache(
884        &self,
885        system_instruction: &str,
886        tools: &std::collections::HashMap<String, serde_json::Value>,
887        ttl_seconds: u32,
888    ) -> Result<String> {
889        self.create_cached_content(system_instruction, tools, ttl_seconds).await
890    }
891
892    async fn delete_cache(&self, name: &str) -> Result<()> {
893        self.delete_cached_content(name).await
894    }
895}
896
897#[cfg(test)]
898mod tests {
899    use super::*;
900    use adk_core::AdkError;
901    use std::{
902        sync::{
903            Arc,
904            atomic::{AtomicU32, Ordering},
905        },
906        time::Duration,
907    };
908
909    #[test]
910    fn constructor_is_backward_compatible_and_sync() {
911        fn accepts_sync_constructor<F>(_f: F)
912        where
913            F: Fn(&str, &str) -> Result<GeminiModel>,
914        {
915        }
916
917        accepts_sync_constructor(|api_key, model| GeminiModel::new(api_key, model));
918    }
919
920    #[test]
921    fn stream_chunks_from_response_injects_partial_before_lone_final_chunk() {
922        let response = LlmResponse {
923            content: Some(Content::new("model").with_text("hello")),
924            usage_metadata: None,
925            finish_reason: Some(FinishReason::Stop),
926            citation_metadata: None,
927            partial: false,
928            turn_complete: true,
929            interrupted: false,
930            error_code: None,
931            error_message: None,
932            provider_metadata: None,
933        };
934
935        let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, false);
936        assert!(saw_partial);
937        assert_eq!(chunks.len(), 2);
938        assert!(chunks[0].partial);
939        assert!(!chunks[0].turn_complete);
940        assert!(chunks[0].content.is_none());
941        assert!(!chunks[1].partial);
942        assert!(chunks[1].turn_complete);
943    }
944
945    #[test]
946    fn stream_chunks_from_response_keeps_final_only_when_partial_already_seen() {
947        let response = LlmResponse {
948            content: Some(Content::new("model").with_text("done")),
949            usage_metadata: None,
950            finish_reason: Some(FinishReason::Stop),
951            citation_metadata: None,
952            partial: false,
953            turn_complete: true,
954            interrupted: false,
955            error_code: None,
956            error_message: None,
957            provider_metadata: None,
958        };
959
960        let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, true);
961        assert!(saw_partial);
962        assert_eq!(chunks.len(), 1);
963        assert!(!chunks[0].partial);
964        assert!(chunks[0].turn_complete);
965    }
966
967    #[tokio::test]
968    async fn execute_with_retry_retries_retryable_errors() {
969        let retry_config = RetryConfig::default()
970            .with_max_retries(2)
971            .with_initial_delay(Duration::from_millis(0))
972            .with_max_delay(Duration::from_millis(0));
973        let attempts = Arc::new(AtomicU32::new(0));
974
975        let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
976            let attempts = Arc::clone(&attempts);
977            async move {
978                let attempt = attempts.fetch_add(1, Ordering::SeqCst);
979                if attempt < 2 {
980                    return Err(AdkError::model("code 429 RESOURCE_EXHAUSTED"));
981                }
982                Ok("ok")
983            }
984        })
985        .await
986        .expect("retry should eventually succeed");
987
988        assert_eq!(result, "ok");
989        assert_eq!(attempts.load(Ordering::SeqCst), 3);
990    }
991
992    #[tokio::test]
993    async fn execute_with_retry_does_not_retry_non_retryable_errors() {
994        let retry_config = RetryConfig::default()
995            .with_max_retries(3)
996            .with_initial_delay(Duration::from_millis(0))
997            .with_max_delay(Duration::from_millis(0));
998        let attempts = Arc::new(AtomicU32::new(0));
999
1000        let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1001            let attempts = Arc::clone(&attempts);
1002            async move {
1003                attempts.fetch_add(1, Ordering::SeqCst);
1004                Err::<(), _>(AdkError::model("code 400 invalid request"))
1005            }
1006        })
1007        .await
1008        .expect_err("non-retryable error should return immediately");
1009
1010        assert!(error.is_model());
1011        assert_eq!(attempts.load(Ordering::SeqCst), 1);
1012    }
1013
1014    #[tokio::test]
1015    async fn execute_with_retry_respects_disabled_config() {
1016        let retry_config = RetryConfig::disabled().with_max_retries(10);
1017        let attempts = Arc::new(AtomicU32::new(0));
1018
1019        let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1020            let attempts = Arc::clone(&attempts);
1021            async move {
1022                attempts.fetch_add(1, Ordering::SeqCst);
1023                Err::<(), _>(AdkError::model("code 429 RESOURCE_EXHAUSTED"))
1024            }
1025        })
1026        .await
1027        .expect_err("disabled retries should return first error");
1028
1029        assert!(error.is_model());
1030        assert_eq!(attempts.load(Ordering::SeqCst), 1);
1031    }
1032
1033    #[test]
1034    fn convert_response_preserves_citation_metadata() {
1035        let response = adk_gemini::GenerationResponse {
1036            candidates: vec![adk_gemini::Candidate {
1037                content: adk_gemini::Content {
1038                    role: Some(adk_gemini::Role::Model),
1039                    parts: Some(vec![adk_gemini::Part::Text {
1040                        text: "hello world".to_string(),
1041                        thought: None,
1042                        thought_signature: None,
1043                    }]),
1044                },
1045                safety_ratings: None,
1046                citation_metadata: Some(adk_gemini::CitationMetadata {
1047                    citation_sources: vec![adk_gemini::CitationSource {
1048                        uri: Some("https://example.com".to_string()),
1049                        title: Some("Example".to_string()),
1050                        start_index: Some(0),
1051                        end_index: Some(5),
1052                        license: Some("CC-BY".to_string()),
1053                        publication_date: None,
1054                    }],
1055                }),
1056                grounding_metadata: None,
1057                finish_reason: Some(adk_gemini::FinishReason::Stop),
1058                index: Some(0),
1059            }],
1060            prompt_feedback: None,
1061            usage_metadata: None,
1062            model_version: None,
1063            response_id: None,
1064        };
1065
1066        let converted =
1067            GeminiModel::convert_response(&response).expect("conversion should succeed");
1068        let metadata = converted.citation_metadata.expect("citation metadata should be mapped");
1069        assert_eq!(metadata.citation_sources.len(), 1);
1070        assert_eq!(metadata.citation_sources[0].uri.as_deref(), Some("https://example.com"));
1071        assert_eq!(metadata.citation_sources[0].start_index, Some(0));
1072        assert_eq!(metadata.citation_sources[0].end_index, Some(5));
1073    }
1074
1075    #[test]
1076    fn convert_response_handles_inline_data_from_model() {
1077        let image_bytes = vec![0x89, 0x50, 0x4E, 0x47];
1078        let encoded = crate::attachment::encode_base64(&image_bytes);
1079
1080        let response = adk_gemini::GenerationResponse {
1081            candidates: vec![adk_gemini::Candidate {
1082                content: adk_gemini::Content {
1083                    role: Some(adk_gemini::Role::Model),
1084                    parts: Some(vec![
1085                        adk_gemini::Part::Text {
1086                            text: "Here is the image".to_string(),
1087                            thought: None,
1088                            thought_signature: None,
1089                        },
1090                        adk_gemini::Part::InlineData {
1091                            inline_data: adk_gemini::Blob {
1092                                mime_type: "image/png".to_string(),
1093                                data: encoded,
1094                            },
1095                        },
1096                    ]),
1097                },
1098                safety_ratings: None,
1099                citation_metadata: None,
1100                grounding_metadata: None,
1101                finish_reason: Some(adk_gemini::FinishReason::Stop),
1102                index: Some(0),
1103            }],
1104            prompt_feedback: None,
1105            usage_metadata: None,
1106            model_version: None,
1107            response_id: None,
1108        };
1109
1110        let converted =
1111            GeminiModel::convert_response(&response).expect("conversion should succeed");
1112        let content = converted.content.expect("should have content");
1113        assert!(
1114            content
1115                .parts
1116                .iter()
1117                .any(|part| matches!(part, Part::Text { text } if text == "Here is the image"))
1118        );
1119        assert!(content.parts.iter().any(|part| {
1120            matches!(
1121                part,
1122                Part::InlineData { mime_type, data }
1123                    if mime_type == "image/png" && data.as_slice() == image_bytes.as_slice()
1124            )
1125        }));
1126    }
1127
1128    #[test]
1129    fn gemini_function_response_payload_preserves_objects() {
1130        let value = serde_json::json!({
1131            "documents": [
1132                { "id": "pricing", "score": 0.91 }
1133            ]
1134        });
1135
1136        let payload = GeminiModel::gemini_function_response_payload(value.clone());
1137
1138        assert_eq!(payload, value);
1139    }
1140
1141    #[test]
1142    fn gemini_function_response_payload_wraps_arrays() {
1143        let payload =
1144            GeminiModel::gemini_function_response_payload(serde_json::json!([{ "id": "pricing" }]));
1145
1146        assert_eq!(payload, serde_json::json!({ "result": [{ "id": "pricing" }] }));
1147    }
1148}