adk_model/gemini/
client.rs

1use adk_core::{
2    Content, FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result,
3    UsageMetadata,
4};
5use adk_gemini::Gemini;
6use async_trait::async_trait;
7
8pub struct GeminiModel {
9    client: Gemini,
10    model_name: String,
11}
12
13impl GeminiModel {
14    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
15        let client =
16            Gemini::new(api_key.into()).map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
17
18        Ok(Self { client, model_name: model.into() })
19    }
20
21    fn convert_response(resp: &adk_gemini::GenerationResponse) -> Result<LlmResponse> {
22        let mut converted_parts: Vec<Part> = Vec::new();
23
24        // Convert content parts
25        if let Some(parts) = resp.candidates.first().and_then(|c| c.content.parts.as_ref()) {
26            for p in parts {
27                match p {
28                    adk_gemini::Part::Text { text, .. } => {
29                        converted_parts.push(Part::Text { text: text.clone() });
30                    }
31                    adk_gemini::Part::FunctionCall { function_call, .. } => {
32                        converted_parts.push(Part::FunctionCall {
33                            name: function_call.name.clone(),
34                            args: function_call.args.clone(),
35                            id: None,
36                        });
37                    }
38                    adk_gemini::Part::FunctionResponse { function_response } => {
39                        converted_parts.push(Part::FunctionResponse {
40                            function_response: adk_core::FunctionResponseData {
41                                name: function_response.name.clone(),
42                                response: function_response
43                                    .response
44                                    .clone()
45                                    .unwrap_or(serde_json::Value::Null),
46                            },
47                            id: None,
48                        });
49                    }
50                    _ => {}
51                }
52            }
53        }
54
55        // Add grounding metadata as text if present (required for Google Search grounding compliance)
56        if let Some(grounding) = resp.candidates.first().and_then(|c| c.grounding_metadata.as_ref())
57        {
58            if let Some(queries) = &grounding.web_search_queries {
59                if !queries.is_empty() {
60                    let search_info = format!("\n\nšŸ” **Searched:** {}", queries.join(", "));
61                    converted_parts.push(Part::Text { text: search_info });
62                }
63            }
64            if let Some(chunks) = &grounding.grounding_chunks {
65                let sources: Vec<String> = chunks
66                    .iter()
67                    .filter_map(|c| {
68                        c.web.as_ref().and_then(|w| match (&w.title, &w.uri) {
69                            (Some(title), Some(uri)) => Some(format!("[{}]({})", title, uri)),
70                            (Some(title), None) => Some(title.clone()),
71                            (None, Some(uri)) => Some(uri.to_string()),
72                            (None, None) => None,
73                        })
74                    })
75                    .collect();
76                if !sources.is_empty() {
77                    let sources_info = format!("\nšŸ“š **Sources:** {}", sources.join(" | "));
78                    converted_parts.push(Part::Text { text: sources_info });
79                }
80            }
81        }
82
83        let content = if converted_parts.is_empty() {
84            None
85        } else {
86            Some(Content { role: "model".to_string(), parts: converted_parts })
87        };
88
89        let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
90            prompt_token_count: u.prompt_token_count.unwrap_or(0),
91            candidates_token_count: u.candidates_token_count.unwrap_or(0),
92            total_token_count: u.total_token_count.unwrap_or(0),
93        });
94
95        let finish_reason =
96            resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
97                adk_gemini::FinishReason::Stop => FinishReason::Stop,
98                adk_gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
99                adk_gemini::FinishReason::Safety => FinishReason::Safety,
100                adk_gemini::FinishReason::Recitation => FinishReason::Recitation,
101                _ => FinishReason::Other,
102            });
103
104        Ok(LlmResponse {
105            content,
106            usage_metadata,
107            finish_reason,
108            partial: false,
109            turn_complete: true,
110            interrupted: false,
111            error_code: None,
112            error_message: None,
113        })
114    }
115}
116
117#[async_trait]
118impl Llm for GeminiModel {
119    fn name(&self) -> &str {
120        &self.model_name
121    }
122
123    #[adk_telemetry::instrument(
124        name = "call_llm",
125        skip(self, req),
126        fields(
127            model.name = %self.model_name,
128            stream = %stream,
129            request.contents_count = %req.contents.len(),
130            request.tools_count = %req.tools.len()
131        )
132    )]
133    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
134        adk_telemetry::info!("Generating content");
135
136        let mut builder = self.client.generate_content();
137
138        // Add contents using proper builder methods
139        for content in &req.contents {
140            match content.role.as_str() {
141                "user" => {
142                    // For user messages, build gemini Content with potentially multiple parts
143                    let mut gemini_parts = Vec::new();
144                    for part in &content.parts {
145                        match part {
146                            Part::Text { text } => {
147                                gemini_parts.push(adk_gemini::Part::Text {
148                                    text: text.clone(),
149                                    thought: None,
150                                    thought_signature: None,
151                                });
152                            }
153                            Part::InlineData { data, mime_type } => {
154                                use base64::{Engine as _, engine::general_purpose::STANDARD};
155                                let encoded = STANDARD.encode(data);
156                                gemini_parts.push(adk_gemini::Part::InlineData {
157                                    inline_data: adk_gemini::Blob {
158                                        mime_type: mime_type.clone(),
159                                        data: encoded,
160                                    },
161                                });
162                            }
163                            _ => {}
164                        }
165                    }
166                    if !gemini_parts.is_empty() {
167                        let user_content = adk_gemini::Content {
168                            role: Some(adk_gemini::Role::User),
169                            parts: Some(gemini_parts),
170                        };
171                        builder = builder.with_message(adk_gemini::Message {
172                            content: user_content,
173                            role: adk_gemini::Role::User,
174                        });
175                    }
176                }
177                "model" => {
178                    // For model messages, build gemini Content
179                    let mut gemini_parts = Vec::new();
180                    for part in &content.parts {
181                        match part {
182                            Part::Text { text } => {
183                                gemini_parts.push(adk_gemini::Part::Text {
184                                    text: text.clone(),
185                                    thought: None,
186                                    thought_signature: None,
187                                });
188                            }
189                            Part::FunctionCall { name, args, .. } => {
190                                gemini_parts.push(adk_gemini::Part::FunctionCall {
191                                    function_call: adk_gemini::FunctionCall {
192                                        name: name.clone(),
193                                        args: args.clone(),
194                                        thought_signature: None,
195                                    },
196                                    thought_signature: None,
197                                });
198                            }
199                            _ => {}
200                        }
201                    }
202                    if !gemini_parts.is_empty() {
203                        let model_content = adk_gemini::Content {
204                            role: Some(adk_gemini::Role::Model),
205                            parts: Some(gemini_parts),
206                        };
207                        builder = builder.with_message(adk_gemini::Message {
208                            content: model_content,
209                            role: adk_gemini::Role::Model,
210                        });
211                    }
212                }
213                "function" => {
214                    // For function responses
215                    for part in &content.parts {
216                        if let Part::FunctionResponse { function_response, .. } = part {
217                            builder = builder
218                                .with_function_response(
219                                    &function_response.name,
220                                    function_response.response.clone(),
221                                )
222                                .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
223                        }
224                    }
225                }
226                _ => {}
227            }
228        }
229
230        // Add generation config
231        if let Some(config) = req.config {
232            let has_schema = config.response_schema.is_some();
233            let gen_config = adk_gemini::GenerationConfig {
234                temperature: config.temperature,
235                top_p: config.top_p,
236                top_k: config.top_k,
237                max_output_tokens: config.max_output_tokens,
238                response_schema: config.response_schema,
239                response_mime_type: if has_schema {
240                    Some("application/json".to_string())
241                } else {
242                    None
243                },
244                ..Default::default()
245            };
246            builder = builder.with_generation_config(gen_config);
247        }
248
249        // Add tools
250        if !req.tools.is_empty() {
251            let mut function_declarations = Vec::new();
252            let mut has_google_search = false;
253
254            for (name, tool_decl) in &req.tools {
255                if name == "google_search" {
256                    has_google_search = true;
257                    continue;
258                }
259
260                // Deserialize our tool declaration into adk_gemini::FunctionDeclaration
261                if let Ok(func_decl) =
262                    serde_json::from_value::<adk_gemini::FunctionDeclaration>(tool_decl.clone())
263                {
264                    function_declarations.push(func_decl);
265                }
266            }
267
268            if !function_declarations.is_empty() {
269                let tool = adk_gemini::Tool::with_functions(function_declarations);
270                builder = builder.with_tool(tool);
271            }
272
273            if has_google_search {
274                // Enable built-in Google Search
275                let tool = adk_gemini::Tool::google_search();
276                builder = builder.with_tool(tool);
277            }
278        }
279
280        if stream {
281            adk_telemetry::debug!("Executing streaming request");
282            let response_stream = builder.execute_stream().await.map_err(|e| {
283                adk_telemetry::error!(error = %e, "Model request failed");
284                adk_core::AdkError::Model(e.to_string())
285            })?;
286
287            let mapped_stream = async_stream::stream! {
288                use futures::TryStreamExt;
289                let mut stream = response_stream;
290                while let Some(result) = stream.try_next().await.transpose() {
291                    match result {
292                        Ok(resp) => {
293                            match Self::convert_response(&resp) {
294                                Ok(mut llm_resp) => {
295                                    // Check if this is the final chunk (has finish_reason)
296                                    let is_final = llm_resp.finish_reason.is_some();
297                                    llm_resp.partial = !is_final;
298                                    llm_resp.turn_complete = is_final;
299                                    yield Ok(llm_resp);
300                                }
301                                Err(e) => {
302                                    adk_telemetry::error!(error = %e, "Failed to convert response");
303                                    yield Err(e);
304                                }
305                            }
306                        }
307                        Err(e) => {
308                            adk_telemetry::error!(error = %e, "Stream error");
309                            yield Err(adk_core::AdkError::Model(e.to_string()));
310                        }
311                    }
312                }
313            };
314
315            Ok(Box::pin(mapped_stream))
316        } else {
317            adk_telemetry::debug!("Executing blocking request");
318            let response = builder.execute().await.map_err(|e| {
319                adk_telemetry::error!(error = %e, "Model request failed");
320                adk_core::AdkError::Model(e.to_string())
321            })?;
322
323            let llm_response = Self::convert_response(&response)?;
324
325            let stream = async_stream::stream! {
326                yield Ok(llm_response);
327            };
328
329            Ok(Box::pin(stream))
330        }
331    }
332}