1use adk_core::{
2 Content, FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result,
3 UsageMetadata,
4};
5use adk_gemini::Gemini;
6use async_trait::async_trait;
7
8pub struct GeminiModel {
9 client: Gemini,
10 model_name: String,
11}
12
13impl GeminiModel {
14 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
15 let client =
16 Gemini::new(api_key.into()).map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
17
18 Ok(Self { client, model_name: model.into() })
19 }
20
21 fn convert_response(resp: &adk_gemini::GenerationResponse) -> Result<LlmResponse> {
22 let mut converted_parts: Vec<Part> = Vec::new();
23
24 if let Some(parts) = resp.candidates.first().and_then(|c| c.content.parts.as_ref()) {
26 for p in parts {
27 match p {
28 adk_gemini::Part::Text { text, .. } => {
29 converted_parts.push(Part::Text { text: text.clone() });
30 }
31 adk_gemini::Part::FunctionCall { function_call, .. } => {
32 converted_parts.push(Part::FunctionCall {
33 name: function_call.name.clone(),
34 args: function_call.args.clone(),
35 id: None,
36 });
37 }
38 adk_gemini::Part::FunctionResponse { function_response } => {
39 converted_parts.push(Part::FunctionResponse {
40 function_response: adk_core::FunctionResponseData {
41 name: function_response.name.clone(),
42 response: function_response
43 .response
44 .clone()
45 .unwrap_or(serde_json::Value::Null),
46 },
47 id: None,
48 });
49 }
50 _ => {}
51 }
52 }
53 }
54
55 if let Some(grounding) = resp.candidates.first().and_then(|c| c.grounding_metadata.as_ref())
57 {
58 if let Some(queries) = &grounding.web_search_queries {
59 if !queries.is_empty() {
60 let search_info = format!("\n\nš **Searched:** {}", queries.join(", "));
61 converted_parts.push(Part::Text { text: search_info });
62 }
63 }
64 if let Some(chunks) = &grounding.grounding_chunks {
65 let sources: Vec<String> = chunks
66 .iter()
67 .filter_map(|c| {
68 c.web.as_ref().and_then(|w| match (&w.title, &w.uri) {
69 (Some(title), Some(uri)) => Some(format!("[{}]({})", title, uri)),
70 (Some(title), None) => Some(title.clone()),
71 (None, Some(uri)) => Some(uri.to_string()),
72 (None, None) => None,
73 })
74 })
75 .collect();
76 if !sources.is_empty() {
77 let sources_info = format!("\nš **Sources:** {}", sources.join(" | "));
78 converted_parts.push(Part::Text { text: sources_info });
79 }
80 }
81 }
82
83 let content = if converted_parts.is_empty() {
84 None
85 } else {
86 Some(Content { role: "model".to_string(), parts: converted_parts })
87 };
88
89 let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
90 prompt_token_count: u.prompt_token_count.unwrap_or(0),
91 candidates_token_count: u.candidates_token_count.unwrap_or(0),
92 total_token_count: u.total_token_count.unwrap_or(0),
93 });
94
95 let finish_reason =
96 resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
97 adk_gemini::FinishReason::Stop => FinishReason::Stop,
98 adk_gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
99 adk_gemini::FinishReason::Safety => FinishReason::Safety,
100 adk_gemini::FinishReason::Recitation => FinishReason::Recitation,
101 _ => FinishReason::Other,
102 });
103
104 Ok(LlmResponse {
105 content,
106 usage_metadata,
107 finish_reason,
108 partial: false,
109 turn_complete: true,
110 interrupted: false,
111 error_code: None,
112 error_message: None,
113 })
114 }
115}
116
117#[async_trait]
118impl Llm for GeminiModel {
119 fn name(&self) -> &str {
120 &self.model_name
121 }
122
123 #[adk_telemetry::instrument(
124 name = "call_llm",
125 skip(self, req),
126 fields(
127 model.name = %self.model_name,
128 stream = %stream,
129 request.contents_count = %req.contents.len(),
130 request.tools_count = %req.tools.len()
131 )
132 )]
133 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
134 adk_telemetry::info!("Generating content");
135
136 let mut builder = self.client.generate_content();
137
138 for content in &req.contents {
140 match content.role.as_str() {
141 "user" => {
142 let mut gemini_parts = Vec::new();
144 for part in &content.parts {
145 match part {
146 Part::Text { text } => {
147 gemini_parts.push(adk_gemini::Part::Text {
148 text: text.clone(),
149 thought: None,
150 thought_signature: None,
151 });
152 }
153 Part::InlineData { data, mime_type } => {
154 use base64::{Engine as _, engine::general_purpose::STANDARD};
155 let encoded = STANDARD.encode(data);
156 gemini_parts.push(adk_gemini::Part::InlineData {
157 inline_data: adk_gemini::Blob {
158 mime_type: mime_type.clone(),
159 data: encoded,
160 },
161 });
162 }
163 _ => {}
164 }
165 }
166 if !gemini_parts.is_empty() {
167 let user_content = adk_gemini::Content {
168 role: Some(adk_gemini::Role::User),
169 parts: Some(gemini_parts),
170 };
171 builder = builder.with_message(adk_gemini::Message {
172 content: user_content,
173 role: adk_gemini::Role::User,
174 });
175 }
176 }
177 "model" => {
178 let mut gemini_parts = Vec::new();
180 for part in &content.parts {
181 match part {
182 Part::Text { text } => {
183 gemini_parts.push(adk_gemini::Part::Text {
184 text: text.clone(),
185 thought: None,
186 thought_signature: None,
187 });
188 }
189 Part::FunctionCall { name, args, .. } => {
190 gemini_parts.push(adk_gemini::Part::FunctionCall {
191 function_call: adk_gemini::FunctionCall {
192 name: name.clone(),
193 args: args.clone(),
194 thought_signature: None,
195 },
196 thought_signature: None,
197 });
198 }
199 _ => {}
200 }
201 }
202 if !gemini_parts.is_empty() {
203 let model_content = adk_gemini::Content {
204 role: Some(adk_gemini::Role::Model),
205 parts: Some(gemini_parts),
206 };
207 builder = builder.with_message(adk_gemini::Message {
208 content: model_content,
209 role: adk_gemini::Role::Model,
210 });
211 }
212 }
213 "function" => {
214 for part in &content.parts {
216 if let Part::FunctionResponse { function_response, .. } = part {
217 builder = builder
218 .with_function_response(
219 &function_response.name,
220 function_response.response.clone(),
221 )
222 .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
223 }
224 }
225 }
226 _ => {}
227 }
228 }
229
230 if let Some(config) = req.config {
232 let has_schema = config.response_schema.is_some();
233 let gen_config = adk_gemini::GenerationConfig {
234 temperature: config.temperature,
235 top_p: config.top_p,
236 top_k: config.top_k,
237 max_output_tokens: config.max_output_tokens,
238 response_schema: config.response_schema,
239 response_mime_type: if has_schema {
240 Some("application/json".to_string())
241 } else {
242 None
243 },
244 ..Default::default()
245 };
246 builder = builder.with_generation_config(gen_config);
247 }
248
249 if !req.tools.is_empty() {
251 let mut function_declarations = Vec::new();
252 let mut has_google_search = false;
253
254 for (name, tool_decl) in &req.tools {
255 if name == "google_search" {
256 has_google_search = true;
257 continue;
258 }
259
260 if let Ok(func_decl) =
262 serde_json::from_value::<adk_gemini::FunctionDeclaration>(tool_decl.clone())
263 {
264 function_declarations.push(func_decl);
265 }
266 }
267
268 if !function_declarations.is_empty() {
269 let tool = adk_gemini::Tool::with_functions(function_declarations);
270 builder = builder.with_tool(tool);
271 }
272
273 if has_google_search {
274 let tool = adk_gemini::Tool::google_search();
276 builder = builder.with_tool(tool);
277 }
278 }
279
280 if stream {
281 adk_telemetry::debug!("Executing streaming request");
282 let response_stream = builder.execute_stream().await.map_err(|e| {
283 adk_telemetry::error!(error = %e, "Model request failed");
284 adk_core::AdkError::Model(e.to_string())
285 })?;
286
287 let mapped_stream = async_stream::stream! {
288 use futures::TryStreamExt;
289 let mut stream = response_stream;
290 while let Some(result) = stream.try_next().await.transpose() {
291 match result {
292 Ok(resp) => {
293 match Self::convert_response(&resp) {
294 Ok(mut llm_resp) => {
295 let is_final = llm_resp.finish_reason.is_some();
297 llm_resp.partial = !is_final;
298 llm_resp.turn_complete = is_final;
299 yield Ok(llm_resp);
300 }
301 Err(e) => {
302 adk_telemetry::error!(error = %e, "Failed to convert response");
303 yield Err(e);
304 }
305 }
306 }
307 Err(e) => {
308 adk_telemetry::error!(error = %e, "Stream error");
309 yield Err(adk_core::AdkError::Model(e.to_string()));
310 }
311 }
312 }
313 };
314
315 Ok(Box::pin(mapped_stream))
316 } else {
317 adk_telemetry::debug!("Executing blocking request");
318 let response = builder.execute().await.map_err(|e| {
319 adk_telemetry::error!(error = %e, "Model request failed");
320 adk_core::AdkError::Model(e.to_string())
321 })?;
322
323 let llm_response = Self::convert_response(&response)?;
324
325 let stream = async_stream::stream! {
326 yield Ok(llm_response);
327 };
328
329 Ok(Box::pin(stream))
330 }
331 }
332}