Skip to main content

adk_model/gemini/
client.rs

1use crate::attachment;
2use crate::retry::{RetryConfig, execute_with_retry, is_retryable_model_error};
3use adk_core::{
4    CacheCapable, CitationMetadata, CitationSource, Content, ErrorCategory, ErrorComponent,
5    FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result, UsageMetadata,
6};
7use adk_gemini::Gemini;
8use async_trait::async_trait;
9use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
10use futures::TryStreamExt;
11
12pub struct GeminiModel {
13    client: Gemini,
14    model_name: String,
15    retry_config: RetryConfig,
16}
17
18/// Convert a Gemini client error to a structured `AdkError` with proper category and retry hints.
19fn gemini_error_to_adk(e: &adk_gemini::ClientError) -> adk_core::AdkError {
20    fn format_error_chain(e: &dyn std::error::Error) -> String {
21        let mut msg = e.to_string();
22        let mut source = e.source();
23        while let Some(s) = source {
24            msg.push_str(": ");
25            msg.push_str(&s.to_string());
26            source = s.source();
27        }
28        msg
29    }
30
31    let message = format_error_chain(e);
32
33    // Extract status code from BadResponse variant via Display output
34    // BadResponse format: "bad response from server; code {code}; description: ..."
35    let (category, code, status_code) = if message.contains("code 429")
36        || message.contains("RESOURCE_EXHAUSTED")
37        || message.contains("rate limit")
38    {
39        (ErrorCategory::RateLimited, "model.gemini.rate_limited", Some(429u16))
40    } else if message.contains("code 503") || message.contains("UNAVAILABLE") {
41        (ErrorCategory::Unavailable, "model.gemini.unavailable", Some(503))
42    } else if message.contains("code 529") || message.contains("OVERLOADED") {
43        (ErrorCategory::Unavailable, "model.gemini.overloaded", Some(529))
44    } else if message.contains("code 408")
45        || message.contains("DEADLINE_EXCEEDED")
46        || message.contains("TIMEOUT")
47    {
48        (ErrorCategory::Timeout, "model.gemini.timeout", Some(408))
49    } else if message.contains("code 401") || message.contains("Invalid API key") {
50        (ErrorCategory::Unauthorized, "model.gemini.unauthorized", Some(401))
51    } else if message.contains("code 400") {
52        (ErrorCategory::InvalidInput, "model.gemini.bad_request", Some(400))
53    } else if message.contains("code 404") {
54        (ErrorCategory::NotFound, "model.gemini.not_found", Some(404))
55    } else if message.contains("invalid generation config") {
56        (ErrorCategory::InvalidInput, "model.gemini.invalid_config", None)
57    } else {
58        (ErrorCategory::Internal, "model.gemini.internal", None)
59    };
60
61    let mut err = adk_core::AdkError::new(ErrorComponent::Model, category, code, message)
62        .with_provider("gemini");
63    if let Some(sc) = status_code {
64        err = err.with_upstream_status(sc);
65    }
66    err
67}
68
69impl GeminiModel {
70    fn gemini_part_thought_signature(value: &serde_json::Value) -> Option<String> {
71        value.get("thoughtSignature").and_then(serde_json::Value::as_str).map(str::to_string)
72    }
73
74    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
75        let model_name = model.into();
76        let client = Gemini::with_model(api_key.into(), model_name.clone())
77            .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
78
79        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
80    }
81
82    /// Create a Gemini model via Vertex AI with API key auth.
83    ///
84    /// Requires `gemini-vertex` feature.
85    #[cfg(feature = "gemini-vertex")]
86    pub fn new_google_cloud(
87        api_key: impl Into<String>,
88        project_id: impl AsRef<str>,
89        location: impl AsRef<str>,
90        model: impl Into<String>,
91    ) -> Result<Self> {
92        let model_name = model.into();
93        let client = Gemini::with_google_cloud_model(
94            api_key.into(),
95            project_id,
96            location,
97            model_name.clone(),
98        )
99        .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
100
101        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
102    }
103
104    /// Create a Gemini model via Vertex AI with service account JSON.
105    ///
106    /// Requires `gemini-vertex` feature.
107    #[cfg(feature = "gemini-vertex")]
108    pub fn new_google_cloud_service_account(
109        service_account_json: &str,
110        project_id: impl AsRef<str>,
111        location: impl AsRef<str>,
112        model: impl Into<String>,
113    ) -> Result<Self> {
114        let model_name = model.into();
115        let client = Gemini::with_google_cloud_service_account_json(
116            service_account_json,
117            project_id.as_ref(),
118            location.as_ref(),
119            model_name.clone(),
120        )
121        .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
122
123        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
124    }
125
126    /// Create a Gemini model via Vertex AI with Application Default Credentials.
127    ///
128    /// Requires `gemini-vertex` feature.
129    #[cfg(feature = "gemini-vertex")]
130    pub fn new_google_cloud_adc(
131        project_id: impl AsRef<str>,
132        location: impl AsRef<str>,
133        model: impl Into<String>,
134    ) -> Result<Self> {
135        let model_name = model.into();
136        let client = Gemini::with_google_cloud_adc_model(
137            project_id.as_ref(),
138            location.as_ref(),
139            model_name.clone(),
140        )
141        .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
142
143        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
144    }
145
146    /// Create a Gemini model via Vertex AI with Workload Identity Federation.
147    ///
148    /// Requires `gemini-vertex` feature.
149    #[cfg(feature = "gemini-vertex")]
150    pub fn new_google_cloud_wif(
151        wif_json: &str,
152        project_id: impl AsRef<str>,
153        location: impl AsRef<str>,
154        model: impl Into<String>,
155    ) -> Result<Self> {
156        let model_name = model.into();
157        let client = Gemini::with_google_cloud_wif_json(
158            wif_json,
159            project_id.as_ref(),
160            location.as_ref(),
161            model_name.clone(),
162        )
163        .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
164
165        Ok(Self { client, model_name, retry_config: RetryConfig::default() })
166    }
167
168    #[must_use]
169    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
170        self.retry_config = retry_config;
171        self
172    }
173
174    pub fn set_retry_config(&mut self, retry_config: RetryConfig) {
175        self.retry_config = retry_config;
176    }
177
178    pub fn retry_config(&self) -> &RetryConfig {
179        &self.retry_config
180    }
181
182    fn convert_response(resp: &adk_gemini::GenerationResponse) -> Result<LlmResponse> {
183        let mut converted_parts: Vec<Part> = Vec::new();
184
185        // Convert content parts
186        if let Some(parts) = resp.candidates.first().and_then(|c| c.content.parts.as_ref()) {
187            for p in parts {
188                match p {
189                    adk_gemini::Part::Text { text, thought, thought_signature } => {
190                        if thought == &Some(true) {
191                            converted_parts.push(Part::Thinking {
192                                thinking: text.clone(),
193                                signature: thought_signature.clone(),
194                            });
195                        } else {
196                            converted_parts.push(Part::Text { text: text.clone() });
197                        }
198                    }
199                    adk_gemini::Part::InlineData { inline_data } => {
200                        let decoded =
201                            BASE64_STANDARD.decode(&inline_data.data).map_err(|error| {
202                                adk_core::AdkError::model(format!(
203                                    "failed to decode inline data from gemini response: {error}"
204                                ))
205                            })?;
206                        converted_parts.push(Part::InlineData {
207                            mime_type: inline_data.mime_type.clone(),
208                            data: decoded,
209                        });
210                    }
211                    adk_gemini::Part::FunctionCall { function_call, thought_signature } => {
212                        converted_parts.push(Part::FunctionCall {
213                            name: function_call.name.clone(),
214                            args: function_call.args.clone(),
215                            id: function_call.id.clone(),
216                            thought_signature: thought_signature.clone(),
217                        });
218                    }
219                    adk_gemini::Part::FunctionResponse { function_response, .. } => {
220                        converted_parts.push(Part::FunctionResponse {
221                            function_response: adk_core::FunctionResponseData::new(
222                                function_response.name.clone(),
223                                function_response
224                                    .response
225                                    .clone()
226                                    .unwrap_or(serde_json::Value::Null),
227                            ),
228                            id: None,
229                        });
230                    }
231                    adk_gemini::Part::ToolCall { .. } | adk_gemini::Part::ExecutableCode { .. } => {
232                        if let Ok(value) = serde_json::to_value(p) {
233                            converted_parts.push(Part::ServerToolCall { server_tool_call: value });
234                        }
235                    }
236                    adk_gemini::Part::ToolResponse { .. }
237                    | adk_gemini::Part::CodeExecutionResult { .. } => {
238                        let value = serde_json::to_value(p).unwrap_or(serde_json::Value::Null);
239                        converted_parts
240                            .push(Part::ServerToolResponse { server_tool_response: value });
241                    }
242                    adk_gemini::Part::FileData { file_data } => {
243                        converted_parts.push(Part::FileData {
244                            mime_type: file_data.mime_type.clone(),
245                            file_uri: file_data.file_uri.clone(),
246                        });
247                    }
248                }
249            }
250        }
251
252        // Add grounding metadata as text if present (required for Google Search grounding compliance)
253        if let Some(grounding) = resp.candidates.first().and_then(|c| c.grounding_metadata.as_ref())
254        {
255            if let Some(queries) = &grounding.web_search_queries {
256                if !queries.is_empty() {
257                    let search_info = format!("\n\nšŸ” **Searched:** {}", queries.join(", "));
258                    converted_parts.push(Part::Text { text: search_info });
259                }
260            }
261            if let Some(chunks) = &grounding.grounding_chunks {
262                let sources: Vec<String> = chunks
263                    .iter()
264                    .filter_map(|c| {
265                        c.web.as_ref().and_then(|w| match (&w.title, &w.uri) {
266                            (Some(title), Some(uri)) => Some(format!("[{}]({})", title, uri)),
267                            (Some(title), None) => Some(title.clone()),
268                            (None, Some(uri)) => Some(uri.to_string()),
269                            (None, None) => None,
270                        })
271                    })
272                    .collect();
273                if !sources.is_empty() {
274                    let sources_info = format!("\nšŸ“š **Sources:** {}", sources.join(" | "));
275                    converted_parts.push(Part::Text { text: sources_info });
276                }
277            }
278        }
279
280        let content = if converted_parts.is_empty() {
281            None
282        } else {
283            Some(Content { role: "model".to_string(), parts: converted_parts })
284        };
285
286        let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
287            prompt_token_count: u.prompt_token_count.unwrap_or(0),
288            candidates_token_count: u.candidates_token_count.unwrap_or(0),
289            total_token_count: u.total_token_count.unwrap_or(0),
290            thinking_token_count: u.thoughts_token_count,
291            cache_read_input_token_count: u.cached_content_token_count,
292            ..Default::default()
293        });
294
295        let finish_reason =
296            resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
297                adk_gemini::FinishReason::Stop => FinishReason::Stop,
298                adk_gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
299                adk_gemini::FinishReason::Safety => FinishReason::Safety,
300                adk_gemini::FinishReason::Recitation => FinishReason::Recitation,
301                _ => FinishReason::Other,
302            });
303
304        let citation_metadata =
305            resp.candidates.first().and_then(|c| c.citation_metadata.as_ref()).map(|meta| {
306                CitationMetadata {
307                    citation_sources: meta
308                        .citation_sources
309                        .iter()
310                        .map(|source| CitationSource {
311                            uri: source.uri.clone(),
312                            title: source.title.clone(),
313                            start_index: source.start_index,
314                            end_index: source.end_index,
315                            license: source.license.clone(),
316                            publication_date: source.publication_date.map(|d| d.to_string()),
317                        })
318                        .collect(),
319                }
320            });
321
322        // Serialize grounding metadata into provider_metadata so consumers
323        // can access structured grounding data (search queries, sources, supports).
324        let provider_metadata = resp
325            .candidates
326            .first()
327            .and_then(|c| c.grounding_metadata.as_ref())
328            .and_then(|g| serde_json::to_value(g).ok());
329
330        Ok(LlmResponse {
331            content,
332            usage_metadata,
333            finish_reason,
334            citation_metadata,
335            partial: false,
336            turn_complete: true,
337            interrupted: false,
338            error_code: None,
339            error_message: None,
340            provider_metadata,
341        })
342    }
343
344    fn gemini_function_response_payload(response: serde_json::Value) -> serde_json::Value {
345        match response {
346            // Gemini functionResponse.response must be a JSON object.
347            serde_json::Value::Object(_) => response,
348            other => serde_json::json!({ "result": other }),
349        }
350    }
351
352    fn merge_object_value(
353        target: &mut serde_json::Map<String, serde_json::Value>,
354        value: serde_json::Value,
355    ) {
356        if let serde_json::Value::Object(object) = value {
357            for (key, value) in object {
358                target.insert(key, value);
359            }
360        }
361    }
362
363    fn build_gemini_tools(
364        tools: &std::collections::HashMap<String, serde_json::Value>,
365    ) -> Result<(Vec<adk_gemini::Tool>, adk_gemini::ToolConfig)> {
366        let mut gemini_tools = Vec::new();
367        let mut function_declarations = Vec::new();
368        let mut has_provider_native_tools = false;
369        let mut tool_config_json = serde_json::Map::new();
370
371        for (name, tool_decl) in tools {
372            if let Some(provider_tool) = tool_decl.get("x-adk-gemini-tool") {
373                let tool = serde_json::from_value::<adk_gemini::Tool>(provider_tool.clone())
374                    .map_err(|error| {
375                        adk_core::AdkError::model(format!(
376                            "failed to deserialize Gemini native tool '{name}': {error}"
377                        ))
378                    })?;
379                has_provider_native_tools = true;
380                gemini_tools.push(tool);
381            } else if let Ok(func_decl) =
382                serde_json::from_value::<adk_gemini::FunctionDeclaration>(tool_decl.clone())
383            {
384                function_declarations.push(func_decl);
385            } else {
386                return Err(adk_core::AdkError::model(format!(
387                    "failed to deserialize Gemini tool '{name}' as a function declaration"
388                )));
389            }
390
391            if let Some(tool_config) = tool_decl.get("x-adk-gemini-tool-config") {
392                Self::merge_object_value(&mut tool_config_json, tool_config.clone());
393            }
394        }
395
396        let has_function_declarations = !function_declarations.is_empty();
397        if has_function_declarations {
398            gemini_tools.push(adk_gemini::Tool::with_functions(function_declarations));
399        }
400
401        if has_provider_native_tools {
402            tool_config_json.insert(
403                "includeServerSideToolInvocations".to_string(),
404                serde_json::Value::Bool(true),
405            );
406        }
407
408        let tool_config = if tool_config_json.is_empty() {
409            adk_gemini::ToolConfig::default()
410        } else {
411            serde_json::from_value::<adk_gemini::ToolConfig>(serde_json::Value::Object(
412                tool_config_json,
413            ))
414            .map_err(|error| {
415                adk_core::AdkError::model(format!(
416                    "failed to deserialize Gemini tool configuration: {error}"
417                ))
418            })?
419        };
420
421        Ok((gemini_tools, tool_config))
422    }
423
424    fn stream_chunks_from_response(
425        mut response: LlmResponse,
426        saw_partial_chunk: bool,
427    ) -> (Vec<LlmResponse>, bool) {
428        let is_final = response.finish_reason.is_some();
429
430        if !is_final {
431            response.partial = true;
432            response.turn_complete = false;
433            return (vec![response], true);
434        }
435
436        response.partial = false;
437        response.turn_complete = true;
438
439        if saw_partial_chunk {
440            return (vec![response], true);
441        }
442
443        let synthetic_partial = LlmResponse {
444            content: None,
445            usage_metadata: None,
446            finish_reason: None,
447            citation_metadata: None,
448            partial: true,
449            turn_complete: false,
450            interrupted: false,
451            error_code: None,
452            error_message: None,
453            provider_metadata: None,
454        };
455
456        (vec![synthetic_partial, response], true)
457    }
458
459    async fn generate_content_internal(
460        &self,
461        req: LlmRequest,
462        stream: bool,
463    ) -> Result<LlmResponseStream> {
464        let mut builder = self.client.generate_content();
465
466        // Build a map of function_name → thought_signature from FunctionCall parts
467        // in model content. Gemini 3.x requires thought_signature on FunctionResponse
468        // parts when thinking is active, but adk_core::Part::FunctionResponse doesn't
469        // carry it (it's Gemini-specific). We recover it here at the provider boundary.
470        let mut fn_call_signatures: std::collections::HashMap<String, String> =
471            std::collections::HashMap::new();
472        for content in &req.contents {
473            if content.role == "model" {
474                for part in &content.parts {
475                    if let Part::FunctionCall { name, thought_signature: Some(sig), .. } = part {
476                        fn_call_signatures.insert(name.clone(), sig.clone());
477                    }
478                }
479            }
480        }
481
482        // Add contents using proper builder methods
483        for content in &req.contents {
484            match content.role.as_str() {
485                "user" => {
486                    // For user messages, build gemini Content with potentially multiple parts
487                    let mut gemini_parts = Vec::new();
488                    for part in &content.parts {
489                        match part {
490                            Part::Text { text } => {
491                                gemini_parts.push(adk_gemini::Part::Text {
492                                    text: text.clone(),
493                                    thought: None,
494                                    thought_signature: None,
495                                });
496                            }
497                            Part::Thinking { thinking, signature } => {
498                                gemini_parts.push(adk_gemini::Part::Text {
499                                    text: thinking.clone(),
500                                    thought: Some(true),
501                                    thought_signature: signature.clone(),
502                                });
503                            }
504                            Part::InlineData { data, mime_type } => {
505                                let encoded = attachment::encode_base64(data);
506                                gemini_parts.push(adk_gemini::Part::InlineData {
507                                    inline_data: adk_gemini::Blob {
508                                        mime_type: mime_type.clone(),
509                                        data: encoded,
510                                    },
511                                });
512                            }
513                            Part::FileData { mime_type, file_uri } => {
514                                gemini_parts.push(adk_gemini::Part::Text {
515                                    text: attachment::file_attachment_to_text(mime_type, file_uri),
516                                    thought: None,
517                                    thought_signature: None,
518                                });
519                            }
520                            _ => {}
521                        }
522                    }
523                    if !gemini_parts.is_empty() {
524                        let user_content = adk_gemini::Content {
525                            role: Some(adk_gemini::Role::User),
526                            parts: Some(gemini_parts),
527                        };
528                        builder = builder.with_message(adk_gemini::Message {
529                            content: user_content,
530                            role: adk_gemini::Role::User,
531                        });
532                    }
533                }
534                "model" => {
535                    // For model messages, build gemini Content
536                    let mut gemini_parts = Vec::new();
537                    for part in &content.parts {
538                        match part {
539                            Part::Text { text } => {
540                                gemini_parts.push(adk_gemini::Part::Text {
541                                    text: text.clone(),
542                                    thought: None,
543                                    thought_signature: None,
544                                });
545                            }
546                            Part::Thinking { thinking, signature } => {
547                                gemini_parts.push(adk_gemini::Part::Text {
548                                    text: thinking.clone(),
549                                    thought: Some(true),
550                                    thought_signature: signature.clone(),
551                                });
552                            }
553                            Part::FunctionCall { name, args, thought_signature, id } => {
554                                gemini_parts.push(adk_gemini::Part::FunctionCall {
555                                    function_call: adk_gemini::FunctionCall {
556                                        name: name.clone(),
557                                        args: args.clone(),
558                                        id: id.clone(),
559                                        thought_signature: None,
560                                    },
561                                    thought_signature: thought_signature.clone(),
562                                });
563                            }
564                            Part::ServerToolCall { server_tool_call } => {
565                                if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
566                                    server_tool_call.clone(),
567                                ) {
568                                    match native_part {
569                                        adk_gemini::Part::ToolCall { .. }
570                                        | adk_gemini::Part::ExecutableCode { .. } => {
571                                            gemini_parts.push(native_part);
572                                            continue;
573                                        }
574                                        _ => {}
575                                    }
576                                }
577
578                                gemini_parts.push(adk_gemini::Part::ToolCall {
579                                    tool_call: server_tool_call.clone(),
580                                    thought_signature: Self::gemini_part_thought_signature(
581                                        server_tool_call,
582                                    ),
583                                });
584                            }
585                            Part::ServerToolResponse { server_tool_response } => {
586                                if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
587                                    server_tool_response.clone(),
588                                ) {
589                                    match native_part {
590                                        adk_gemini::Part::ToolResponse { .. }
591                                        | adk_gemini::Part::CodeExecutionResult { .. } => {
592                                            gemini_parts.push(native_part);
593                                            continue;
594                                        }
595                                        _ => {}
596                                    }
597                                }
598
599                                gemini_parts.push(adk_gemini::Part::ToolResponse {
600                                    tool_response: server_tool_response.clone(),
601                                    thought_signature: Self::gemini_part_thought_signature(
602                                        server_tool_response,
603                                    ),
604                                });
605                            }
606                            _ => {}
607                        }
608                    }
609                    if !gemini_parts.is_empty() {
610                        let model_content = adk_gemini::Content {
611                            role: Some(adk_gemini::Role::Model),
612                            parts: Some(gemini_parts),
613                        };
614                        builder = builder.with_message(adk_gemini::Message {
615                            content: model_content,
616                            role: adk_gemini::Role::Model,
617                        });
618                    }
619                }
620                "function" => {
621                    // For function responses, build content directly to attach thought_signature
622                    // recovered from the preceding FunctionCall (Gemini 3.x requirement)
623                    let mut gemini_parts = Vec::new();
624                    for part in &content.parts {
625                        if let Part::FunctionResponse { function_response, .. } = part {
626                            let sig = fn_call_signatures.get(&function_response.name).cloned();
627
628                            // Build nested FunctionResponsePart entries for multimodal data
629                            let mut fr_parts = Vec::new();
630                            for inline in &function_response.inline_data {
631                                let encoded = attachment::encode_base64(&inline.data);
632                                fr_parts.push(adk_gemini::FunctionResponsePart::InlineData {
633                                    inline_data: adk_gemini::Blob {
634                                        mime_type: inline.mime_type.clone(),
635                                        data: encoded,
636                                    },
637                                });
638                            }
639                            for file in &function_response.file_data {
640                                fr_parts.push(adk_gemini::FunctionResponsePart::FileData {
641                                    file_data: adk_gemini::FileDataRef {
642                                        mime_type: file.mime_type.clone(),
643                                        file_uri: file.file_uri.clone(),
644                                    },
645                                });
646                            }
647
648                            let mut gemini_fr = adk_gemini::tools::FunctionResponse::new(
649                                &function_response.name,
650                                Self::gemini_function_response_payload(
651                                    function_response.response.clone(),
652                                ),
653                            );
654                            gemini_fr.parts = fr_parts;
655
656                            gemini_parts.push(adk_gemini::Part::FunctionResponse {
657                                function_response: gemini_fr,
658                                thought_signature: sig,
659                            });
660                        }
661                    }
662                    if !gemini_parts.is_empty() {
663                        let fn_content = adk_gemini::Content {
664                            role: Some(adk_gemini::Role::User),
665                            parts: Some(gemini_parts),
666                        };
667                        builder = builder.with_message(adk_gemini::Message {
668                            content: fn_content,
669                            role: adk_gemini::Role::User,
670                        });
671                    }
672                }
673                _ => {}
674            }
675        }
676
677        // Add generation config
678        if let Some(config) = req.config {
679            let has_schema = config.response_schema.is_some();
680            let gen_config = adk_gemini::GenerationConfig {
681                temperature: config.temperature,
682                top_p: config.top_p,
683                top_k: config.top_k,
684                max_output_tokens: config.max_output_tokens,
685                response_schema: config.response_schema,
686                response_mime_type: if has_schema {
687                    Some("application/json".to_string())
688                } else {
689                    None
690                },
691                ..Default::default()
692            };
693            builder = builder.with_generation_config(gen_config);
694
695            // Attach cached content reference if provided
696            if let Some(ref name) = config.cached_content {
697                let handle = self.client.get_cached_content(name);
698                builder = builder.with_cached_content(&handle);
699            }
700        }
701
702        // Add tools
703        if !req.tools.is_empty() {
704            let (gemini_tools, tool_config) = Self::build_gemini_tools(&req.tools)?;
705            for tool in gemini_tools {
706                builder = builder.with_tool(tool);
707            }
708            if tool_config != adk_gemini::ToolConfig::default() {
709                builder = builder.with_tool_config(tool_config);
710            }
711        }
712
713        if stream {
714            adk_telemetry::debug!("Executing streaming request");
715            let response_stream = builder.execute_stream().await.map_err(|e| {
716                adk_telemetry::error!(error = %e, "Model request failed");
717                gemini_error_to_adk(&e)
718            })?;
719
720            let mapped_stream = async_stream::stream! {
721                let mut stream = response_stream;
722                let mut saw_partial_chunk = false;
723                while let Some(result) = stream.try_next().await.transpose() {
724                    match result {
725                        Ok(resp) => {
726                            match Self::convert_response(&resp) {
727                                Ok(llm_resp) => {
728                                    let (chunks, next_saw_partial) =
729                                        Self::stream_chunks_from_response(llm_resp, saw_partial_chunk);
730                                    saw_partial_chunk = next_saw_partial;
731                                    for chunk in chunks {
732                                        yield Ok(chunk);
733                                    }
734                                }
735                                Err(e) => {
736                                    adk_telemetry::error!(error = %e, "Failed to convert response");
737                                    yield Err(e);
738                                }
739                            }
740                        }
741                        Err(e) => {
742                            adk_telemetry::error!(error = %e, "Stream error");
743                            yield Err(gemini_error_to_adk(&e));
744                        }
745                    }
746                }
747            };
748
749            Ok(Box::pin(mapped_stream))
750        } else {
751            adk_telemetry::debug!("Executing blocking request");
752            let response = builder.execute().await.map_err(|e| {
753                adk_telemetry::error!(error = %e, "Model request failed");
754                gemini_error_to_adk(&e)
755            })?;
756
757            let llm_response = Self::convert_response(&response)?;
758
759            let stream = async_stream::stream! {
760                yield Ok(llm_response);
761            };
762
763            Ok(Box::pin(stream))
764        }
765    }
766
767    /// Create a cached content resource with the given system instruction, tools, and TTL.
768    ///
769    /// Returns the cache name (e.g., "cachedContents/abc123") on success.
770    /// The cache is created using the model configured on this `GeminiModel` instance.
771    pub async fn create_cached_content(
772        &self,
773        system_instruction: &str,
774        tools: &std::collections::HashMap<String, serde_json::Value>,
775        ttl_seconds: u32,
776    ) -> Result<String> {
777        let mut cache_builder = self
778            .client
779            .create_cache()
780            .with_system_instruction(system_instruction)
781            .with_ttl(std::time::Duration::from_secs(u64::from(ttl_seconds)));
782
783        let (gemini_tools, tool_config) = Self::build_gemini_tools(tools)?;
784        if !gemini_tools.is_empty() {
785            cache_builder = cache_builder.with_tools(gemini_tools);
786        }
787        if tool_config != adk_gemini::ToolConfig::default() {
788            cache_builder = cache_builder.with_tool_config(tool_config);
789        }
790
791        let handle = cache_builder
792            .execute()
793            .await
794            .map_err(|e| adk_core::AdkError::model(format!("cache creation failed: {e}")))?;
795
796        Ok(handle.name().to_string())
797    }
798
799    /// Delete a cached content resource by name.
800    pub async fn delete_cached_content(&self, name: &str) -> Result<()> {
801        let handle = self.client.get_cached_content(name);
802        handle
803            .delete()
804            .await
805            .map_err(|(_, e)| adk_core::AdkError::model(format!("cache deletion failed: {e}")))?;
806        Ok(())
807    }
808}
809
810#[async_trait]
811impl Llm for GeminiModel {
812    fn name(&self) -> &str {
813        &self.model_name
814    }
815
816    #[adk_telemetry::instrument(
817        name = "call_llm",
818        skip(self, req),
819        fields(
820            model.name = %self.model_name,
821            stream = %stream,
822            request.contents_count = %req.contents.len(),
823            request.tools_count = %req.tools.len()
824        )
825    )]
826    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
827        adk_telemetry::info!("Generating content");
828        let usage_span = adk_telemetry::llm_generate_span("gemini", &self.model_name, stream);
829        // Retries only cover request setup/execution. Stream failures after the stream starts
830        // are yielded to the caller and are not replayed automatically.
831        let result = execute_with_retry(&self.retry_config, is_retryable_model_error, || {
832            self.generate_content_internal(req.clone(), stream)
833        })
834        .await?;
835        Ok(crate::usage_tracking::with_usage_tracking(result, usage_span))
836    }
837}
838
839#[cfg(test)]
840mod native_tool_tests {
841    use super::*;
842
843    #[test]
844    fn test_build_gemini_tools_supports_native_tool_metadata() {
845        let mut tools = std::collections::HashMap::new();
846        tools.insert(
847            "google_search".to_string(),
848            serde_json::json!({
849                "x-adk-gemini-tool": {
850                    "google_search": {}
851                }
852            }),
853        );
854        tools.insert(
855            "lookup_weather".to_string(),
856            serde_json::json!({
857                "name": "lookup_weather",
858                "description": "lookup weather",
859                "parameters": {
860                    "type": "object",
861                    "properties": {
862                        "city": { "type": "string" }
863                    }
864                }
865            }),
866        );
867
868        let (gemini_tools, tool_config) =
869            GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
870
871        assert_eq!(gemini_tools.len(), 2);
872        assert_eq!(tool_config.include_server_side_tool_invocations, Some(true));
873    }
874
875    #[test]
876    fn test_build_gemini_tools_sets_flag_for_builtin_only() {
877        let mut tools = std::collections::HashMap::new();
878        tools.insert(
879            "google_search".to_string(),
880            serde_json::json!({
881                "x-adk-gemini-tool": {
882                    "google_search": {}
883                }
884            }),
885        );
886
887        let (_gemini_tools, tool_config) =
888            GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
889
890        assert_eq!(
891            tool_config.include_server_side_tool_invocations,
892            Some(true),
893            "includeServerSideToolInvocations should be set even with only built-in tools"
894        );
895    }
896
897    #[test]
898    fn test_build_gemini_tools_no_flag_for_function_only() {
899        let mut tools = std::collections::HashMap::new();
900        tools.insert(
901            "lookup_weather".to_string(),
902            serde_json::json!({
903                "name": "lookup_weather",
904                "description": "lookup weather",
905                "parameters": {
906                    "type": "object",
907                    "properties": {
908                        "city": { "type": "string" }
909                    }
910                }
911            }),
912        );
913
914        let (_gemini_tools, tool_config) =
915            GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
916
917        assert_eq!(
918            tool_config.include_server_side_tool_invocations, None,
919            "includeServerSideToolInvocations should NOT be set for function-only tools"
920        );
921    }
922
923    #[test]
924    fn test_build_gemini_tools_merges_native_tool_config() {
925        let mut tools = std::collections::HashMap::new();
926        tools.insert(
927            "google_maps".to_string(),
928            serde_json::json!({
929                "x-adk-gemini-tool": {
930                    "google_maps": {
931                        "enable_widget": true
932                    }
933                },
934                "x-adk-gemini-tool-config": {
935                    "retrievalConfig": {
936                        "latLng": {
937                            "latitude": 1.23,
938                            "longitude": 4.56
939                        }
940                    }
941                }
942            }),
943        );
944
945        let (_gemini_tools, tool_config) =
946            GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
947
948        assert_eq!(
949            tool_config.retrieval_config,
950            Some(serde_json::json!({
951                "latLng": {
952                    "latitude": 1.23,
953                    "longitude": 4.56
954                }
955            }))
956        );
957    }
958}
959
960#[async_trait]
961impl CacheCapable for GeminiModel {
962    async fn create_cache(
963        &self,
964        system_instruction: &str,
965        tools: &std::collections::HashMap<String, serde_json::Value>,
966        ttl_seconds: u32,
967    ) -> Result<String> {
968        self.create_cached_content(system_instruction, tools, ttl_seconds).await
969    }
970
971    async fn delete_cache(&self, name: &str) -> Result<()> {
972        self.delete_cached_content(name).await
973    }
974}
975
976#[cfg(test)]
977mod tests {
978    use super::*;
979    use adk_core::AdkError;
980    use std::{
981        sync::{
982            Arc,
983            atomic::{AtomicU32, Ordering},
984        },
985        time::Duration,
986    };
987
988    #[test]
989    fn constructor_is_backward_compatible_and_sync() {
990        fn accepts_sync_constructor<F>(_f: F)
991        where
992            F: Fn(&str, &str) -> Result<GeminiModel>,
993        {
994        }
995
996        accepts_sync_constructor(|api_key, model| GeminiModel::new(api_key, model));
997    }
998
999    #[test]
1000    fn stream_chunks_from_response_injects_partial_before_lone_final_chunk() {
1001        let response = LlmResponse {
1002            content: Some(Content::new("model").with_text("hello")),
1003            usage_metadata: None,
1004            finish_reason: Some(FinishReason::Stop),
1005            citation_metadata: None,
1006            partial: false,
1007            turn_complete: true,
1008            interrupted: false,
1009            error_code: None,
1010            error_message: None,
1011            provider_metadata: None,
1012        };
1013
1014        let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, false);
1015        assert!(saw_partial);
1016        assert_eq!(chunks.len(), 2);
1017        assert!(chunks[0].partial);
1018        assert!(!chunks[0].turn_complete);
1019        assert!(chunks[0].content.is_none());
1020        assert!(!chunks[1].partial);
1021        assert!(chunks[1].turn_complete);
1022    }
1023
1024    #[test]
1025    fn stream_chunks_from_response_keeps_final_only_when_partial_already_seen() {
1026        let response = LlmResponse {
1027            content: Some(Content::new("model").with_text("done")),
1028            usage_metadata: None,
1029            finish_reason: Some(FinishReason::Stop),
1030            citation_metadata: None,
1031            partial: false,
1032            turn_complete: true,
1033            interrupted: false,
1034            error_code: None,
1035            error_message: None,
1036            provider_metadata: None,
1037        };
1038
1039        let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, true);
1040        assert!(saw_partial);
1041        assert_eq!(chunks.len(), 1);
1042        assert!(!chunks[0].partial);
1043        assert!(chunks[0].turn_complete);
1044    }
1045
1046    #[tokio::test]
1047    async fn execute_with_retry_retries_retryable_errors() {
1048        let retry_config = RetryConfig::default()
1049            .with_max_retries(2)
1050            .with_initial_delay(Duration::from_millis(0))
1051            .with_max_delay(Duration::from_millis(0));
1052        let attempts = Arc::new(AtomicU32::new(0));
1053
1054        let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
1055            let attempts = Arc::clone(&attempts);
1056            async move {
1057                let attempt = attempts.fetch_add(1, Ordering::SeqCst);
1058                if attempt < 2 {
1059                    return Err(AdkError::model("code 429 RESOURCE_EXHAUSTED"));
1060                }
1061                Ok("ok")
1062            }
1063        })
1064        .await
1065        .expect("retry should eventually succeed");
1066
1067        assert_eq!(result, "ok");
1068        assert_eq!(attempts.load(Ordering::SeqCst), 3);
1069    }
1070
1071    #[tokio::test]
1072    async fn execute_with_retry_does_not_retry_non_retryable_errors() {
1073        let retry_config = RetryConfig::default()
1074            .with_max_retries(3)
1075            .with_initial_delay(Duration::from_millis(0))
1076            .with_max_delay(Duration::from_millis(0));
1077        let attempts = Arc::new(AtomicU32::new(0));
1078
1079        let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1080            let attempts = Arc::clone(&attempts);
1081            async move {
1082                attempts.fetch_add(1, Ordering::SeqCst);
1083                Err::<(), _>(AdkError::model("code 400 invalid request"))
1084            }
1085        })
1086        .await
1087        .expect_err("non-retryable error should return immediately");
1088
1089        assert!(error.is_model());
1090        assert_eq!(attempts.load(Ordering::SeqCst), 1);
1091    }
1092
1093    #[tokio::test]
1094    async fn execute_with_retry_respects_disabled_config() {
1095        let retry_config = RetryConfig::disabled().with_max_retries(10);
1096        let attempts = Arc::new(AtomicU32::new(0));
1097
1098        let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1099            let attempts = Arc::clone(&attempts);
1100            async move {
1101                attempts.fetch_add(1, Ordering::SeqCst);
1102                Err::<(), _>(AdkError::model("code 429 RESOURCE_EXHAUSTED"))
1103            }
1104        })
1105        .await
1106        .expect_err("disabled retries should return first error");
1107
1108        assert!(error.is_model());
1109        assert_eq!(attempts.load(Ordering::SeqCst), 1);
1110    }
1111
1112    #[test]
1113    fn convert_response_preserves_citation_metadata() {
1114        let response = adk_gemini::GenerationResponse {
1115            candidates: vec![adk_gemini::Candidate {
1116                content: adk_gemini::Content {
1117                    role: Some(adk_gemini::Role::Model),
1118                    parts: Some(vec![adk_gemini::Part::Text {
1119                        text: "hello world".to_string(),
1120                        thought: None,
1121                        thought_signature: None,
1122                    }]),
1123                },
1124                safety_ratings: None,
1125                citation_metadata: Some(adk_gemini::CitationMetadata {
1126                    citation_sources: vec![adk_gemini::CitationSource {
1127                        uri: Some("https://example.com".to_string()),
1128                        title: Some("Example".to_string()),
1129                        start_index: Some(0),
1130                        end_index: Some(5),
1131                        license: Some("CC-BY".to_string()),
1132                        publication_date: None,
1133                    }],
1134                }),
1135                grounding_metadata: None,
1136                finish_reason: Some(adk_gemini::FinishReason::Stop),
1137                index: Some(0),
1138            }],
1139            prompt_feedback: None,
1140            usage_metadata: None,
1141            model_version: None,
1142            response_id: None,
1143        };
1144
1145        let converted =
1146            GeminiModel::convert_response(&response).expect("conversion should succeed");
1147        let metadata = converted.citation_metadata.expect("citation metadata should be mapped");
1148        assert_eq!(metadata.citation_sources.len(), 1);
1149        assert_eq!(metadata.citation_sources[0].uri.as_deref(), Some("https://example.com"));
1150        assert_eq!(metadata.citation_sources[0].start_index, Some(0));
1151        assert_eq!(metadata.citation_sources[0].end_index, Some(5));
1152    }
1153
1154    #[test]
1155    fn convert_response_handles_inline_data_from_model() {
1156        let image_bytes = vec![0x89, 0x50, 0x4E, 0x47];
1157        let encoded = crate::attachment::encode_base64(&image_bytes);
1158
1159        let response = adk_gemini::GenerationResponse {
1160            candidates: vec![adk_gemini::Candidate {
1161                content: adk_gemini::Content {
1162                    role: Some(adk_gemini::Role::Model),
1163                    parts: Some(vec![
1164                        adk_gemini::Part::Text {
1165                            text: "Here is the image".to_string(),
1166                            thought: None,
1167                            thought_signature: None,
1168                        },
1169                        adk_gemini::Part::InlineData {
1170                            inline_data: adk_gemini::Blob {
1171                                mime_type: "image/png".to_string(),
1172                                data: encoded,
1173                            },
1174                        },
1175                    ]),
1176                },
1177                safety_ratings: None,
1178                citation_metadata: None,
1179                grounding_metadata: None,
1180                finish_reason: Some(adk_gemini::FinishReason::Stop),
1181                index: Some(0),
1182            }],
1183            prompt_feedback: None,
1184            usage_metadata: None,
1185            model_version: None,
1186            response_id: None,
1187        };
1188
1189        let converted =
1190            GeminiModel::convert_response(&response).expect("conversion should succeed");
1191        let content = converted.content.expect("should have content");
1192        assert!(
1193            content
1194                .parts
1195                .iter()
1196                .any(|part| matches!(part, Part::Text { text } if text == "Here is the image"))
1197        );
1198        assert!(content.parts.iter().any(|part| {
1199            matches!(
1200                part,
1201                Part::InlineData { mime_type, data }
1202                    if mime_type == "image/png" && data.as_slice() == image_bytes.as_slice()
1203            )
1204        }));
1205    }
1206
1207    #[test]
1208    fn gemini_function_response_payload_preserves_objects() {
1209        let value = serde_json::json!({
1210            "documents": [
1211                { "id": "pricing", "score": 0.91 }
1212            ]
1213        });
1214
1215        let payload = GeminiModel::gemini_function_response_payload(value.clone());
1216
1217        assert_eq!(payload, value);
1218    }
1219
1220    #[test]
1221    fn gemini_function_response_payload_wraps_arrays() {
1222        let payload =
1223            GeminiModel::gemini_function_response_payload(serde_json::json!([{ "id": "pricing" }]));
1224
1225        assert_eq!(payload, serde_json::json!({ "result": [{ "id": "pricing" }] }));
1226    }
1227
1228    // ===== Multimodal function response conversion tests =====
1229
1230    /// Helper to build a FunctionResponse with nested multimodal parts
1231    /// simulating the conversion logic from generate_content_internal.
1232    fn convert_function_response_to_gemini_fr(
1233        frd: &adk_core::FunctionResponseData,
1234    ) -> adk_gemini::tools::FunctionResponse {
1235        let mut fr_parts = Vec::new();
1236
1237        for inline in &frd.inline_data {
1238            let encoded = crate::attachment::encode_base64(&inline.data);
1239            fr_parts.push(adk_gemini::FunctionResponsePart::InlineData {
1240                inline_data: adk_gemini::Blob {
1241                    mime_type: inline.mime_type.clone(),
1242                    data: encoded,
1243                },
1244            });
1245        }
1246
1247        for file in &frd.file_data {
1248            fr_parts.push(adk_gemini::FunctionResponsePart::FileData {
1249                file_data: adk_gemini::FileDataRef {
1250                    mime_type: file.mime_type.clone(),
1251                    file_uri: file.file_uri.clone(),
1252                },
1253            });
1254        }
1255
1256        let mut gemini_fr = adk_gemini::tools::FunctionResponse::new(
1257            &frd.name,
1258            GeminiModel::gemini_function_response_payload(frd.response.clone()),
1259        );
1260        gemini_fr.parts = fr_parts;
1261        gemini_fr
1262    }
1263
1264    #[test]
1265    fn json_only_function_response_has_no_nested_parts() {
1266        let frd = adk_core::FunctionResponseData::new("tool", serde_json::json!({"ok": true}));
1267        let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1268        assert!(gemini_fr.parts.is_empty());
1269        // Serialized JSON should have name and response but no parts key
1270        let json = serde_json::to_string(&gemini_fr).unwrap();
1271        assert!(!json.contains("\"parts\""));
1272    }
1273
1274    #[test]
1275    fn function_response_with_inline_data_has_nested_parts() {
1276        let frd = adk_core::FunctionResponseData::with_inline_data(
1277            "chart",
1278            serde_json::json!({"status": "ok"}),
1279            vec![adk_core::InlineDataPart {
1280                mime_type: "image/png".to_string(),
1281                data: vec![0x89, 0x50, 0x4E, 0x47],
1282            }],
1283        );
1284        let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1285        assert_eq!(gemini_fr.parts.len(), 1);
1286        match &gemini_fr.parts[0] {
1287            adk_gemini::FunctionResponsePart::InlineData { inline_data } => {
1288                assert_eq!(inline_data.mime_type, "image/png");
1289                let decoded = BASE64_STANDARD.decode(&inline_data.data).unwrap();
1290                assert_eq!(decoded, vec![0x89, 0x50, 0x4E, 0x47]);
1291            }
1292            other => panic!("expected InlineData, got {other:?}"),
1293        }
1294    }
1295
1296    #[test]
1297    fn function_response_with_file_data_has_nested_parts() {
1298        let frd = adk_core::FunctionResponseData::with_file_data(
1299            "doc",
1300            serde_json::json!({"ok": true}),
1301            vec![adk_core::FileDataPart {
1302                mime_type: "application/pdf".to_string(),
1303                file_uri: "gs://bucket/report.pdf".to_string(),
1304            }],
1305        );
1306        let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1307        assert_eq!(gemini_fr.parts.len(), 1);
1308        match &gemini_fr.parts[0] {
1309            adk_gemini::FunctionResponsePart::FileData { file_data } => {
1310                assert_eq!(file_data.mime_type, "application/pdf");
1311                assert_eq!(file_data.file_uri, "gs://bucket/report.pdf");
1312            }
1313            other => panic!("expected FileData, got {other:?}"),
1314        }
1315    }
1316
1317    #[test]
1318    fn function_response_with_both_inline_and_file_data_ordering() {
1319        let frd = adk_core::FunctionResponseData::with_multimodal(
1320            "multi",
1321            serde_json::json!({}),
1322            vec![
1323                adk_core::InlineDataPart { mime_type: "image/png".to_string(), data: vec![1, 2] },
1324                adk_core::InlineDataPart { mime_type: "image/jpeg".to_string(), data: vec![3, 4] },
1325            ],
1326            vec![adk_core::FileDataPart {
1327                mime_type: "application/pdf".to_string(),
1328                file_uri: "gs://b/f.pdf".to_string(),
1329            }],
1330        );
1331        let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1332        // 2 inline + 1 file = 3 nested parts
1333        assert_eq!(gemini_fr.parts.len(), 3);
1334        assert!(matches!(&gemini_fr.parts[0], adk_gemini::FunctionResponsePart::InlineData { .. }));
1335        assert!(matches!(&gemini_fr.parts[1], adk_gemini::FunctionResponsePart::InlineData { .. }));
1336        assert!(matches!(&gemini_fr.parts[2], adk_gemini::FunctionResponsePart::FileData { .. }));
1337    }
1338}