adk_model/gemini/
client.rs

1use adk_core::{
2    Content, FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result,
3    UsageMetadata,
4};
5use adk_gemini::Gemini;
6use async_trait::async_trait;
7
8pub struct GeminiModel {
9    client: Gemini,
10    model_name: String,
11}
12
13impl GeminiModel {
14    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
15        let client =
16            Gemini::new(api_key.into()).map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
17
18        Ok(Self { client, model_name: model.into() })
19    }
20
21    fn convert_response(resp: &adk_gemini::GenerationResponse) -> Result<LlmResponse> {
22        let mut converted_parts: Vec<Part> = Vec::new();
23
24        // Convert content parts
25        if let Some(parts) = resp.candidates.first().and_then(|c| c.content.parts.as_ref()) {
26            for p in parts {
27                match p {
28                    adk_gemini::Part::Text { text, .. } => {
29                        converted_parts.push(Part::Text { text: text.clone() });
30                    }
31                    adk_gemini::Part::FunctionCall { function_call, .. } => {
32                        converted_parts.push(Part::FunctionCall {
33                            name: function_call.name.clone(),
34                            args: function_call.args.clone(),
35                            id: None,
36                        });
37                    }
38                    adk_gemini::Part::FunctionResponse { function_response } => {
39                        converted_parts.push(Part::FunctionResponse {
40                            function_response: adk_core::FunctionResponseData {
41                                name: function_response.name.clone(),
42                                response: function_response
43                                    .response
44                                    .clone()
45                                    .unwrap_or(serde_json::Value::Null),
46                            },
47                            id: None,
48                        });
49                    }
50                    _ => {}
51                }
52            }
53        }
54
55        // Add grounding metadata as text if present (required for Google Search grounding compliance)
56        if let Some(grounding) = resp.candidates.first().and_then(|c| c.grounding_metadata.as_ref())
57        {
58            if let Some(queries) = &grounding.web_search_queries {
59                if !queries.is_empty() {
60                    let search_info = format!("\n\nšŸ” **Searched:** {}", queries.join(", "));
61                    converted_parts.push(Part::Text { text: search_info });
62                }
63            }
64            if let Some(chunks) = &grounding.grounding_chunks {
65                let sources: Vec<String> = chunks
66                    .iter()
67                    .filter_map(|c| {
68                        c.web.as_ref().and_then(|w| match (&w.title, &w.uri) {
69                            (Some(title), Some(uri)) => Some(format!("[{}]({})", title, uri)),
70                            (Some(title), None) => Some(title.clone()),
71                            (None, Some(uri)) => Some(uri.to_string()),
72                            (None, None) => None,
73                        })
74                    })
75                    .collect();
76                if !sources.is_empty() {
77                    let sources_info = format!("\nšŸ“š **Sources:** {}", sources.join(" | "));
78                    converted_parts.push(Part::Text { text: sources_info });
79                }
80            }
81        }
82
83        let content = if converted_parts.is_empty() {
84            None
85        } else {
86            Some(Content { role: "model".to_string(), parts: converted_parts })
87        };
88
89        let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
90            prompt_token_count: u.prompt_token_count.unwrap_or(0),
91            candidates_token_count: u.candidates_token_count.unwrap_or(0),
92            total_token_count: u.total_token_count.unwrap_or(0),
93        });
94
95        let finish_reason =
96            resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
97                adk_gemini::FinishReason::Stop => FinishReason::Stop,
98                adk_gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
99                adk_gemini::FinishReason::Safety => FinishReason::Safety,
100                adk_gemini::FinishReason::Recitation => FinishReason::Recitation,
101                _ => FinishReason::Other,
102            });
103
104        Ok(LlmResponse {
105            content,
106            usage_metadata,
107            finish_reason,
108            partial: false,
109            turn_complete: true,
110            interrupted: false,
111            error_code: None,
112            error_message: None,
113        })
114    }
115}
116
117#[async_trait]
118impl Llm for GeminiModel {
119    fn name(&self) -> &str {
120        &self.model_name
121    }
122
123    #[adk_telemetry::instrument(
124        name = "call_llm",
125        skip(self, req),
126        fields(
127            model.name = %self.model_name,
128            stream = %stream,
129            request.contents_count = %req.contents.len(),
130            request.tools_count = %req.tools.len()
131        )
132    )]
133    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
134        adk_telemetry::info!("Generating content");
135
136        let mut builder = self.client.generate_content();
137
138        // Add contents using proper builder methods
139        for content in &req.contents {
140            match content.role.as_str() {
141                "user" => {
142                    // For user messages, build gemini Content with potentially multiple parts
143                    let mut gemini_parts = Vec::new();
144                    for part in &content.parts {
145                        match part {
146                            Part::Text { text } => {
147                                gemini_parts.push(adk_gemini::Part::Text {
148                                    text: text.clone(),
149                                    thought: None,
150                                    thought_signature: None,
151                                });
152                            }
153                            Part::InlineData { data, mime_type } => {
154                                use base64::{Engine as _, engine::general_purpose::STANDARD};
155                                let encoded = STANDARD.encode(data);
156                                gemini_parts.push(adk_gemini::Part::InlineData {
157                                    inline_data: adk_gemini::Blob {
158                                        mime_type: mime_type.clone(),
159                                        data: encoded,
160                                    },
161                                });
162                            }
163                            _ => {}
164                        }
165                    }
166                    if !gemini_parts.is_empty() {
167                        let user_content = adk_gemini::Content {
168                            role: Some(adk_gemini::Role::User),
169                            parts: Some(gemini_parts),
170                        };
171                        builder = builder.with_message(adk_gemini::Message {
172                            content: user_content,
173                            role: adk_gemini::Role::User,
174                        });
175                    }
176                }
177                "model" => {
178                    // For model messages, build gemini Content
179                    let mut gemini_parts = Vec::new();
180                    for part in &content.parts {
181                        match part {
182                            Part::Text { text } => {
183                                gemini_parts.push(adk_gemini::Part::Text {
184                                    text: text.clone(),
185                                    thought: None,
186                                    thought_signature: None,
187                                });
188                            }
189                            Part::FunctionCall { name, args, .. } => {
190                                gemini_parts.push(adk_gemini::Part::FunctionCall {
191                                    function_call: adk_gemini::FunctionCall {
192                                        name: name.clone(),
193                                        args: args.clone(),
194                                        thought_signature: None,
195                                    },
196                                    thought_signature: None,
197                                });
198                            }
199                            _ => {}
200                        }
201                    }
202                    if !gemini_parts.is_empty() {
203                        let model_content = adk_gemini::Content {
204                            role: Some(adk_gemini::Role::Model),
205                            parts: Some(gemini_parts),
206                        };
207                        builder = builder.with_message(adk_gemini::Message {
208                            content: model_content,
209                            role: adk_gemini::Role::Model,
210                        });
211                    }
212                }
213                "function" => {
214                    // For function responses
215                    for part in &content.parts {
216                        if let Part::FunctionResponse { function_response, .. } = part {
217                            builder = builder
218                                .with_function_response(
219                                    &function_response.name,
220                                    function_response.response.clone(),
221                                )
222                                .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
223                        }
224                    }
225                }
226                _ => {}
227            }
228        }
229
230        // Add generation config
231        if let Some(config) = req.config {
232            let gen_config = adk_gemini::GenerationConfig {
233                temperature: config.temperature,
234                top_p: config.top_p,
235                top_k: config.top_k,
236                max_output_tokens: config.max_output_tokens,
237                ..Default::default()
238            };
239            builder = builder.with_generation_config(gen_config);
240        }
241
242        // Add tools
243        if !req.tools.is_empty() {
244            let mut function_declarations = Vec::new();
245            let mut has_google_search = false;
246
247            for (name, tool_decl) in &req.tools {
248                if name == "google_search" {
249                    has_google_search = true;
250                    continue;
251                }
252
253                // Deserialize our tool declaration into adk_gemini::FunctionDeclaration
254                if let Ok(func_decl) =
255                    serde_json::from_value::<adk_gemini::FunctionDeclaration>(tool_decl.clone())
256                {
257                    function_declarations.push(func_decl);
258                }
259            }
260
261            if !function_declarations.is_empty() {
262                let tool = adk_gemini::Tool::with_functions(function_declarations);
263                builder = builder.with_tool(tool);
264            }
265
266            if has_google_search {
267                // Enable built-in Google Search
268                let tool = adk_gemini::Tool::google_search();
269                builder = builder.with_tool(tool);
270            }
271        }
272
273        if stream {
274            adk_telemetry::debug!("Executing streaming request");
275            let response_stream = builder.execute_stream().await.map_err(|e| {
276                adk_telemetry::error!(error = %e, "Model request failed");
277                adk_core::AdkError::Model(e.to_string())
278            })?;
279
280            let mapped_stream = async_stream::stream! {
281                use futures::TryStreamExt;
282                let mut stream = response_stream;
283                while let Some(result) = stream.try_next().await.transpose() {
284                    match result {
285                        Ok(resp) => {
286                            match Self::convert_response(&resp) {
287                                Ok(mut llm_resp) => {
288                                    // Check if this is the final chunk (has finish_reason)
289                                    let is_final = llm_resp.finish_reason.is_some();
290                                    llm_resp.partial = !is_final;
291                                    llm_resp.turn_complete = is_final;
292                                    yield Ok(llm_resp);
293                                }
294                                Err(e) => {
295                                    adk_telemetry::error!(error = %e, "Failed to convert response");
296                                    yield Err(e);
297                                }
298                            }
299                        }
300                        Err(e) => {
301                            adk_telemetry::error!(error = %e, "Stream error");
302                            yield Err(adk_core::AdkError::Model(e.to_string()));
303                        }
304                    }
305                }
306            };
307
308            Ok(Box::pin(mapped_stream))
309        } else {
310            adk_telemetry::debug!("Executing blocking request");
311            let response = builder.execute().await.map_err(|e| {
312                adk_telemetry::error!(error = %e, "Model request failed");
313                adk_core::AdkError::Model(e.to_string())
314            })?;
315
316            let llm_response = Self::convert_response(&response)?;
317
318            let stream = async_stream::stream! {
319                yield Ok(llm_response);
320            };
321
322            Ok(Box::pin(stream))
323        }
324    }
325}