1use adk_core::{
2 Content, FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result,
3 UsageMetadata,
4};
5use async_trait::async_trait;
6use gemini::Gemini;
7
8pub struct GeminiModel {
9 client: Gemini,
10 model_name: String,
11}
12
13impl GeminiModel {
14 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
15 let client =
16 Gemini::new(api_key.into()).map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
17
18 Ok(Self { client, model_name: model.into() })
19 }
20
21 fn convert_response(resp: &gemini::GenerationResponse) -> Result<LlmResponse> {
22 let content = resp.candidates.first().and_then(|c| c.content.parts.as_ref()).map(|parts| {
23 let converted_parts: Vec<Part> = parts
24 .iter()
25 .filter_map(|p| match p {
26 gemini::Part::Text { text, .. } => Some(Part::Text { text: text.clone() }),
27 gemini::Part::FunctionCall { function_call, .. } => Some(Part::FunctionCall {
28 name: function_call.name.clone(),
29 args: function_call.args.clone(),
30 id: None, }),
32 gemini::Part::FunctionResponse { function_response } => {
33 Some(Part::FunctionResponse {
34 name: function_response.name.clone(),
35 response: function_response
36 .response
37 .clone()
38 .unwrap_or(serde_json::Value::Null),
39 id: None, })
41 }
42 _ => None,
43 })
44 .collect();
45
46 Content { role: "model".to_string(), parts: converted_parts }
47 });
48
49 let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
50 prompt_token_count: u.prompt_token_count.unwrap_or(0),
51 candidates_token_count: u.candidates_token_count.unwrap_or(0),
52 total_token_count: u.total_token_count.unwrap_or(0),
53 });
54
55 let finish_reason =
56 resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
57 gemini::FinishReason::Stop => FinishReason::Stop,
58 gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
59 gemini::FinishReason::Safety => FinishReason::Safety,
60 gemini::FinishReason::Recitation => FinishReason::Recitation,
61 _ => FinishReason::Other,
62 });
63
64 Ok(LlmResponse {
65 content,
66 usage_metadata,
67 finish_reason,
68 partial: false,
69 turn_complete: true,
70 interrupted: false,
71 error_code: None,
72 error_message: None,
73 })
74 }
75}
76
77#[async_trait]
78impl Llm for GeminiModel {
79 fn name(&self) -> &str {
80 &self.model_name
81 }
82
83 #[adk_telemetry::instrument(
84 skip(self, req),
85 fields(
86 model.name = %self.model_name,
87 stream = %stream,
88 request.contents_count = %req.contents.len(),
89 request.tools_count = %req.tools.len()
90 )
91 )]
92 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
93 adk_telemetry::info!("Generating content");
94
95 let mut builder = self.client.generate_content();
96
97 for content in &req.contents {
99 match content.role.as_str() {
100 "user" => {
101 let mut gemini_parts = Vec::new();
103 for part in &content.parts {
104 match part {
105 Part::Text { text } => {
106 gemini_parts.push(gemini::Part::Text {
107 text: text.clone(),
108 thought: None,
109 thought_signature: None,
110 });
111 }
112 Part::InlineData { data, mime_type } => {
113 use base64::{engine::general_purpose::STANDARD, Engine as _};
114 let encoded = STANDARD.encode(data);
115 gemini_parts.push(gemini::Part::InlineData {
116 inline_data: gemini::Blob {
117 mime_type: mime_type.clone(),
118 data: encoded,
119 },
120 });
121 }
122 _ => {}
123 }
124 }
125 if !gemini_parts.is_empty() {
126 let user_content = gemini::Content {
127 role: Some(gemini::Role::User),
128 parts: Some(gemini_parts),
129 };
130 builder = builder.with_message(gemini::Message {
131 content: user_content,
132 role: gemini::Role::User,
133 });
134 }
135 }
136 "model" => {
137 let mut gemini_parts = Vec::new();
139 for part in &content.parts {
140 match part {
141 Part::Text { text } => {
142 gemini_parts.push(gemini::Part::Text {
143 text: text.clone(),
144 thought: None,
145 thought_signature: None,
146 });
147 }
148 Part::FunctionCall { name, args, .. } => {
149 gemini_parts.push(gemini::Part::FunctionCall {
150 function_call: gemini::FunctionCall {
151 name: name.clone(),
152 args: args.clone(),
153 thought_signature: None,
154 },
155 thought_signature: None,
156 });
157 }
158 _ => {}
159 }
160 }
161 if !gemini_parts.is_empty() {
162 let model_content = gemini::Content {
163 role: Some(gemini::Role::Model),
164 parts: Some(gemini_parts),
165 };
166 builder = builder.with_message(gemini::Message {
167 content: model_content,
168 role: gemini::Role::Model,
169 });
170 }
171 }
172 "function" => {
173 for part in &content.parts {
175 if let Part::FunctionResponse { name, response, .. } = part {
176 builder = builder
177 .with_function_response(name, response.clone())
178 .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
179 }
180 }
181 }
182 _ => {}
183 }
184 }
185
186 if let Some(config) = req.config {
188 let gen_config = gemini::GenerationConfig {
189 temperature: config.temperature,
190 top_p: config.top_p,
191 top_k: config.top_k,
192 max_output_tokens: config.max_output_tokens,
193 ..Default::default()
194 };
195 builder = builder.with_generation_config(gen_config);
196 }
197
198 if !req.tools.is_empty() {
200 let mut function_declarations = Vec::new();
201 let mut has_google_search = false;
202
203 for (name, tool_decl) in &req.tools {
204 if name == "google_search" {
205 has_google_search = true;
206 continue;
207 }
208
209 if let Ok(func_decl) =
211 serde_json::from_value::<gemini::FunctionDeclaration>(tool_decl.clone())
212 {
213 function_declarations.push(func_decl);
214 }
215 }
216
217 if !function_declarations.is_empty() {
218 let tool = gemini::Tool::with_functions(function_declarations);
219 builder = builder.with_tool(tool);
220 }
221
222 if has_google_search {
223 let tool = gemini::Tool::google_search();
225 builder = builder.with_tool(tool);
226 }
227 }
228
229 if stream {
230 adk_telemetry::debug!("Executing streaming request");
231 let response_stream = builder.execute_stream().await.map_err(|e| {
232 adk_telemetry::error!(error = %e, "Model request failed");
233 adk_core::AdkError::Model(e.to_string())
234 })?;
235
236 let mapped_stream = async_stream::stream! {
237 use futures::TryStreamExt;
238 let mut stream = response_stream;
239 while let Some(result) = stream.try_next().await.transpose() {
240 match result {
241 Ok(resp) => {
242 match Self::convert_response(&resp) {
243 Ok(mut llm_resp) => {
244 llm_resp.partial = true;
245 llm_resp.turn_complete = false;
246 yield Ok(llm_resp);
247 }
248 Err(e) => {
249 adk_telemetry::error!(error = %e, "Failed to convert response");
250 yield Err(e);
251 }
252 }
253 }
254 Err(e) => {
255 adk_telemetry::error!(error = %e, "Stream error");
256 yield Err(adk_core::AdkError::Model(e.to_string()));
257 }
258 }
259 }
260 };
261
262 Ok(Box::pin(mapped_stream))
263 } else {
264 adk_telemetry::debug!("Executing blocking request");
265 let response = builder.execute().await.map_err(|e| {
266 adk_telemetry::error!(error = %e, "Model request failed");
267 adk_core::AdkError::Model(e.to_string())
268 })?;
269
270 let llm_response = Self::convert_response(&response)?;
271
272 let stream = async_stream::stream! {
273 yield Ok(llm_response);
274 };
275
276 Ok(Box::pin(stream))
277 }
278 }
279}