adk_model/gemini/
client.rs

1use adk_core::{
2    Content, FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result,
3    UsageMetadata,
4};
5use async_trait::async_trait;
6use gemini::Gemini;
7
8pub struct GeminiModel {
9    client: Gemini,
10    model_name: String,
11}
12
13impl GeminiModel {
14    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
15        let client =
16            Gemini::new(api_key.into()).map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
17
18        Ok(Self { client, model_name: model.into() })
19    }
20
21    fn convert_response(resp: &gemini::GenerationResponse) -> Result<LlmResponse> {
22        let content = resp.candidates.first().and_then(|c| c.content.parts.as_ref()).map(|parts| {
23            let converted_parts: Vec<Part> = parts
24                .iter()
25                .filter_map(|p| match p {
26                    gemini::Part::Text { text, .. } => Some(Part::Text { text: text.clone() }),
27                    gemini::Part::FunctionCall { function_call, .. } => Some(Part::FunctionCall {
28                        name: function_call.name.clone(),
29                        args: function_call.args.clone(),
30                    }),
31                    gemini::Part::FunctionResponse { function_response } => {
32                        Some(Part::FunctionResponse {
33                            name: function_response.name.clone(),
34                            response: function_response
35                                .response
36                                .clone()
37                                .unwrap_or(serde_json::Value::Null),
38                        })
39                    }
40                    _ => None,
41                })
42                .collect();
43
44            Content { role: "model".to_string(), parts: converted_parts }
45        });
46
47        let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
48            prompt_token_count: u.prompt_token_count.unwrap_or(0),
49            candidates_token_count: u.candidates_token_count.unwrap_or(0),
50            total_token_count: u.total_token_count.unwrap_or(0),
51        });
52
53        let finish_reason =
54            resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
55                gemini::FinishReason::Stop => FinishReason::Stop,
56                gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
57                gemini::FinishReason::Safety => FinishReason::Safety,
58                gemini::FinishReason::Recitation => FinishReason::Recitation,
59                _ => FinishReason::Other,
60            });
61
62        Ok(LlmResponse {
63            content,
64            usage_metadata,
65            finish_reason,
66            partial: false,
67            turn_complete: true,
68            interrupted: false,
69            error_code: None,
70            error_message: None,
71        })
72    }
73}
74
75#[async_trait]
76impl Llm for GeminiModel {
77    fn name(&self) -> &str {
78        &self.model_name
79    }
80
81    #[adk_telemetry::instrument(
82        skip(self, req),
83        fields(
84            model.name = %self.model_name,
85            stream = %stream,
86            request.contents_count = %req.contents.len(),
87            request.tools_count = %req.tools.len()
88        )
89    )]
90    async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
91        adk_telemetry::info!("Generating content");
92
93        let mut builder = self.client.generate_content();
94
95        // Add contents using proper builder methods
96        for content in &req.contents {
97            match content.role.as_str() {
98                "user" => {
99                    // For user messages, build gemini Content with potentially multiple parts
100                    let mut gemini_parts = Vec::new();
101                    for part in &content.parts {
102                        match part {
103                            Part::Text { text } => {
104                                gemini_parts.push(gemini::Part::Text {
105                                    text: text.clone(),
106                                    thought: None,
107                                    thought_signature: None,
108                                });
109                            }
110                            Part::InlineData { data, mime_type } => {
111                                use base64::{engine::general_purpose::STANDARD, Engine as _};
112                                let encoded = STANDARD.encode(data);
113                                gemini_parts.push(gemini::Part::InlineData {
114                                    inline_data: gemini::Blob {
115                                        mime_type: mime_type.clone(),
116                                        data: encoded,
117                                    },
118                                });
119                            }
120                            _ => {}
121                        }
122                    }
123                    if !gemini_parts.is_empty() {
124                        let user_content = gemini::Content {
125                            role: Some(gemini::Role::User),
126                            parts: Some(gemini_parts),
127                        };
128                        builder = builder.with_message(gemini::Message {
129                            content: user_content,
130                            role: gemini::Role::User,
131                        });
132                    }
133                }
134                "model" => {
135                    // For model messages, build gemini Content
136                    let mut gemini_parts = Vec::new();
137                    for part in &content.parts {
138                        match part {
139                            Part::Text { text } => {
140                                gemini_parts.push(gemini::Part::Text {
141                                    text: text.clone(),
142                                    thought: None,
143                                    thought_signature: None,
144                                });
145                            }
146                            Part::FunctionCall { name, args } => {
147                                gemini_parts.push(gemini::Part::FunctionCall {
148                                    function_call: gemini::FunctionCall {
149                                        name: name.clone(),
150                                        args: args.clone(),
151                                        thought_signature: None,
152                                    },
153                                    thought_signature: None,
154                                });
155                            }
156                            _ => {}
157                        }
158                    }
159                    if !gemini_parts.is_empty() {
160                        let model_content = gemini::Content {
161                            role: Some(gemini::Role::Model),
162                            parts: Some(gemini_parts),
163                        };
164                        builder = builder.with_message(gemini::Message {
165                            content: model_content,
166                            role: gemini::Role::Model,
167                        });
168                    }
169                }
170                "function" => {
171                    // For function responses
172                    for part in &content.parts {
173                        if let Part::FunctionResponse { name, response } = part {
174                            builder = builder
175                                .with_function_response(name, response.clone())
176                                .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
177                        }
178                    }
179                }
180                _ => {}
181            }
182        }
183
184        // Add generation config
185        if let Some(config) = req.config {
186            let gen_config = gemini::GenerationConfig {
187                temperature: config.temperature,
188                top_p: config.top_p,
189                top_k: config.top_k,
190                max_output_tokens: config.max_output_tokens,
191                ..Default::default()
192            };
193            builder = builder.with_generation_config(gen_config);
194        }
195
196        // Add tools
197        if !req.tools.is_empty() {
198            let mut function_declarations = Vec::new();
199            let mut has_google_search = false;
200
201            for (name, tool_decl) in &req.tools {
202                if name == "google_search" {
203                    has_google_search = true;
204                    continue;
205                }
206
207                // Deserialize our tool declaration into gemini::FunctionDeclaration
208                if let Ok(func_decl) =
209                    serde_json::from_value::<gemini::FunctionDeclaration>(tool_decl.clone())
210                {
211                    function_declarations.push(func_decl);
212                }
213            }
214
215            if !function_declarations.is_empty() {
216                let tool = gemini::Tool::with_functions(function_declarations);
217                builder = builder.with_tool(tool);
218            }
219
220            if has_google_search {
221                // Enable built-in Google Search
222                let tool = gemini::Tool::google_search();
223                builder = builder.with_tool(tool);
224            }
225        }
226
227        if stream {
228            adk_telemetry::debug!("Executing streaming request");
229            let response_stream = builder.execute_stream().await.map_err(|e| {
230                adk_telemetry::error!(error = %e, "Model request failed");
231                adk_core::AdkError::Model(e.to_string())
232            })?;
233
234            let mapped_stream = async_stream::stream! {
235                use futures::TryStreamExt;
236                let mut stream = response_stream;
237                while let Some(result) = stream.try_next().await.transpose() {
238                    match result {
239                        Ok(resp) => {
240                            match Self::convert_response(&resp) {
241                                Ok(mut llm_resp) => {
242                                    llm_resp.partial = true;
243                                    llm_resp.turn_complete = false;
244                                    yield Ok(llm_resp);
245                                }
246                                Err(e) => {
247                                    adk_telemetry::error!(error = %e, "Failed to convert response");
248                                    yield Err(e);
249                                }
250                            }
251                        }
252                        Err(e) => {
253                            adk_telemetry::error!(error = %e, "Stream error");
254                            yield Err(adk_core::AdkError::Model(e.to_string()));
255                        }
256                    }
257                }
258            };
259
260            Ok(Box::pin(mapped_stream))
261        } else {
262            adk_telemetry::debug!("Executing blocking request");
263            let response = builder.execute().await.map_err(|e| {
264                adk_telemetry::error!(error = %e, "Model request failed");
265                adk_core::AdkError::Model(e.to_string())
266            })?;
267
268            let llm_response = Self::convert_response(&response)?;
269
270            let stream = async_stream::stream! {
271                yield Ok(llm_response);
272            };
273
274            Ok(Box::pin(stream))
275        }
276    }
277}