adk_model/gemini/
client.rs

1use adk_core::{
2    Content, FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result,
3    UsageMetadata,
4};
5use async_trait::async_trait;
6use gemini::Gemini;
7
8pub struct GeminiModel {
9    client: Gemini,
10    model_name: String,
11}
12
13impl GeminiModel {
14    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
15        let client =
16            Gemini::new(api_key.into()).map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
17
18        Ok(Self { client, model_name: model.into() })
19    }
20
21    fn convert_response(resp: &gemini::GenerationResponse) -> Result<LlmResponse> {
22        let content = resp.candidates.first().and_then(|c| c.content.parts.as_ref()).map(|parts| {
23            let converted_parts: Vec<Part> = parts
24                .iter()
25                .filter_map(|p| match p {
26                    gemini::Part::Text { text, .. } => Some(Part::Text { text: text.clone() }),
27                    gemini::Part::FunctionCall { function_call, .. } => Some(Part::FunctionCall {
28                        name: function_call.name.clone(),
29                        args: function_call.args.clone(),
30                        id: None, // Gemini doesn't use tool call IDs
31                    }),
32                    gemini::Part::FunctionResponse { function_response } => {
33                        Some(Part::FunctionResponse {
34                            name: function_response.name.clone(),
35                            response: function_response
36                                .response
37                                .clone()
38                                .unwrap_or(serde_json::Value::Null),
39                            id: None, // Gemini doesn't use tool call IDs
40                        })
41                    }
42                    _ => None,
43                })
44                .collect();
45
46            Content { role: "model".to_string(), parts: converted_parts }
47        });
48
49        let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
50            prompt_token_count: u.prompt_token_count.unwrap_or(0),
51            candidates_token_count: u.candidates_token_count.unwrap_or(0),
52            total_token_count: u.total_token_count.unwrap_or(0),
53        });
54
55        let finish_reason =
56            resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
57                gemini::FinishReason::Stop => FinishReason::Stop,
58                gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
59                gemini::FinishReason::Safety => FinishReason::Safety,
60                gemini::FinishReason::Recitation => FinishReason::Recitation,
61                _ => FinishReason::Other,
62            });
63
64        Ok(LlmResponse {
65            content,
66            usage_metadata,
67            finish_reason,
68            partial: false,
69            turn_complete: true,
70            interrupted: false,
71            error_code: None,
72            error_message: None,
73        })
74    }
75}
76
77#[async_trait]
78impl Llm for GeminiModel {
79    fn name(&self) -> &str {
80        &self.model_name
81    }
82
83    #[adk_telemetry::instrument(
84        skip(self, req),
85        fields(
86            model.name = %self.model_name,
87            stream = %stream,
88            request.contents_count = %req.contents.len(),
89            request.tools_count = %req.tools.len()
90        )
91    )]
92    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
93        adk_telemetry::info!("Generating content");
94
95        let mut builder = self.client.generate_content();
96
97        // Add contents using proper builder methods
98        for content in &req.contents {
99            match content.role.as_str() {
100                "user" => {
101                    // For user messages, build gemini Content with potentially multiple parts
102                    let mut gemini_parts = Vec::new();
103                    for part in &content.parts {
104                        match part {
105                            Part::Text { text } => {
106                                gemini_parts.push(gemini::Part::Text {
107                                    text: text.clone(),
108                                    thought: None,
109                                    thought_signature: None,
110                                });
111                            }
112                            Part::InlineData { data, mime_type } => {
113                                use base64::{engine::general_purpose::STANDARD, Engine as _};
114                                let encoded = STANDARD.encode(data);
115                                gemini_parts.push(gemini::Part::InlineData {
116                                    inline_data: gemini::Blob {
117                                        mime_type: mime_type.clone(),
118                                        data: encoded,
119                                    },
120                                });
121                            }
122                            _ => {}
123                        }
124                    }
125                    if !gemini_parts.is_empty() {
126                        let user_content = gemini::Content {
127                            role: Some(gemini::Role::User),
128                            parts: Some(gemini_parts),
129                        };
130                        builder = builder.with_message(gemini::Message {
131                            content: user_content,
132                            role: gemini::Role::User,
133                        });
134                    }
135                }
136                "model" => {
137                    // For model messages, build gemini Content
138                    let mut gemini_parts = Vec::new();
139                    for part in &content.parts {
140                        match part {
141                            Part::Text { text } => {
142                                gemini_parts.push(gemini::Part::Text {
143                                    text: text.clone(),
144                                    thought: None,
145                                    thought_signature: None,
146                                });
147                            }
148                            Part::FunctionCall { name, args, .. } => {
149                                gemini_parts.push(gemini::Part::FunctionCall {
150                                    function_call: gemini::FunctionCall {
151                                        name: name.clone(),
152                                        args: args.clone(),
153                                        thought_signature: None,
154                                    },
155                                    thought_signature: None,
156                                });
157                            }
158                            _ => {}
159                        }
160                    }
161                    if !gemini_parts.is_empty() {
162                        let model_content = gemini::Content {
163                            role: Some(gemini::Role::Model),
164                            parts: Some(gemini_parts),
165                        };
166                        builder = builder.with_message(gemini::Message {
167                            content: model_content,
168                            role: gemini::Role::Model,
169                        });
170                    }
171                }
172                "function" => {
173                    // For function responses
174                    for part in &content.parts {
175                        if let Part::FunctionResponse { name, response, .. } = part {
176                            builder = builder
177                                .with_function_response(name, response.clone())
178                                .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
179                        }
180                    }
181                }
182                _ => {}
183            }
184        }
185
186        // Add generation config
187        if let Some(config) = req.config {
188            let gen_config = gemini::GenerationConfig {
189                temperature: config.temperature,
190                top_p: config.top_p,
191                top_k: config.top_k,
192                max_output_tokens: config.max_output_tokens,
193                ..Default::default()
194            };
195            builder = builder.with_generation_config(gen_config);
196        }
197
198        // Add tools
199        if !req.tools.is_empty() {
200            let mut function_declarations = Vec::new();
201            let mut has_google_search = false;
202
203            for (name, tool_decl) in &req.tools {
204                if name == "google_search" {
205                    has_google_search = true;
206                    continue;
207                }
208
209                // Deserialize our tool declaration into gemini::FunctionDeclaration
210                if let Ok(func_decl) =
211                    serde_json::from_value::<gemini::FunctionDeclaration>(tool_decl.clone())
212                {
213                    function_declarations.push(func_decl);
214                }
215            }
216
217            if !function_declarations.is_empty() {
218                let tool = gemini::Tool::with_functions(function_declarations);
219                builder = builder.with_tool(tool);
220            }
221
222            if has_google_search {
223                // Enable built-in Google Search
224                let tool = gemini::Tool::google_search();
225                builder = builder.with_tool(tool);
226            }
227        }
228
229        if stream {
230            adk_telemetry::debug!("Executing streaming request");
231            let response_stream = builder.execute_stream().await.map_err(|e| {
232                adk_telemetry::error!(error = %e, "Model request failed");
233                adk_core::AdkError::Model(e.to_string())
234            })?;
235
236            let mapped_stream = async_stream::stream! {
237                use futures::TryStreamExt;
238                let mut stream = response_stream;
239                while let Some(result) = stream.try_next().await.transpose() {
240                    match result {
241                        Ok(resp) => {
242                            match Self::convert_response(&resp) {
243                                Ok(mut llm_resp) => {
244                                    llm_resp.partial = true;
245                                    llm_resp.turn_complete = false;
246                                    yield Ok(llm_resp);
247                                }
248                                Err(e) => {
249                                    adk_telemetry::error!(error = %e, "Failed to convert response");
250                                    yield Err(e);
251                                }
252                            }
253                        }
254                        Err(e) => {
255                            adk_telemetry::error!(error = %e, "Stream error");
256                            yield Err(adk_core::AdkError::Model(e.to_string()));
257                        }
258                    }
259                }
260            };
261
262            Ok(Box::pin(mapped_stream))
263        } else {
264            adk_telemetry::debug!("Executing blocking request");
265            let response = builder.execute().await.map_err(|e| {
266                adk_telemetry::error!(error = %e, "Model request failed");
267                adk_core::AdkError::Model(e.to_string())
268            })?;
269
270            let llm_response = Self::convert_response(&response)?;
271
272            let stream = async_stream::stream! {
273                yield Ok(llm_response);
274            };
275
276            Ok(Box::pin(stream))
277        }
278    }
279}