adk_model/gemini/
client.rs1use adk_core::{
2 Content, FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result,
3 UsageMetadata,
4};
5use async_trait::async_trait;
6use gemini::Gemini;
7
8pub struct GeminiModel {
9 client: Gemini,
10 model_name: String,
11}
12
13impl GeminiModel {
14 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
15 let client =
16 Gemini::new(api_key.into()).map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
17
18 Ok(Self { client, model_name: model.into() })
19 }
20
21 fn convert_response(resp: &gemini::GenerationResponse) -> Result<LlmResponse> {
22 let content = resp.candidates.first().and_then(|c| c.content.parts.as_ref()).map(|parts| {
23 let converted_parts: Vec<Part> = parts
24 .iter()
25 .filter_map(|p| match p {
26 gemini::Part::Text { text, .. } => Some(Part::Text { text: text.clone() }),
27 gemini::Part::FunctionCall { function_call, .. } => Some(Part::FunctionCall {
28 name: function_call.name.clone(),
29 args: function_call.args.clone(),
30 }),
31 gemini::Part::FunctionResponse { function_response } => {
32 Some(Part::FunctionResponse {
33 name: function_response.name.clone(),
34 response: function_response
35 .response
36 .clone()
37 .unwrap_or(serde_json::Value::Null),
38 })
39 }
40 _ => None,
41 })
42 .collect();
43
44 Content { role: "model".to_string(), parts: converted_parts }
45 });
46
47 let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
48 prompt_token_count: u.prompt_token_count.unwrap_or(0),
49 candidates_token_count: u.candidates_token_count.unwrap_or(0),
50 total_token_count: u.total_token_count.unwrap_or(0),
51 });
52
53 let finish_reason =
54 resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
55 gemini::FinishReason::Stop => FinishReason::Stop,
56 gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
57 gemini::FinishReason::Safety => FinishReason::Safety,
58 gemini::FinishReason::Recitation => FinishReason::Recitation,
59 _ => FinishReason::Other,
60 });
61
62 Ok(LlmResponse {
63 content,
64 usage_metadata,
65 finish_reason,
66 partial: false,
67 turn_complete: true,
68 interrupted: false,
69 error_code: None,
70 error_message: None,
71 })
72 }
73}
74
75#[async_trait]
76impl Llm for GeminiModel {
77 fn name(&self) -> &str {
78 &self.model_name
79 }
80
81 #[adk_telemetry::instrument(
82 skip(self, req),
83 fields(
84 model.name = %self.model_name,
85 stream = %stream,
86 request.contents_count = %req.contents.len(),
87 request.tools_count = %req.tools.len()
88 )
89 )]
90 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
91 adk_telemetry::info!("Generating content");
92
93 let mut builder = self.client.generate_content();
94
95 for content in &req.contents {
97 match content.role.as_str() {
98 "user" => {
99 let mut gemini_parts = Vec::new();
101 for part in &content.parts {
102 match part {
103 Part::Text { text } => {
104 gemini_parts.push(gemini::Part::Text {
105 text: text.clone(),
106 thought: None,
107 thought_signature: None,
108 });
109 }
110 Part::InlineData { data, mime_type } => {
111 use base64::{engine::general_purpose::STANDARD, Engine as _};
112 let encoded = STANDARD.encode(data);
113 gemini_parts.push(gemini::Part::InlineData {
114 inline_data: gemini::Blob {
115 mime_type: mime_type.clone(),
116 data: encoded,
117 },
118 });
119 }
120 _ => {}
121 }
122 }
123 if !gemini_parts.is_empty() {
124 let user_content = gemini::Content {
125 role: Some(gemini::Role::User),
126 parts: Some(gemini_parts),
127 };
128 builder = builder.with_message(gemini::Message {
129 content: user_content,
130 role: gemini::Role::User,
131 });
132 }
133 }
134 "model" => {
135 let mut gemini_parts = Vec::new();
137 for part in &content.parts {
138 match part {
139 Part::Text { text } => {
140 gemini_parts.push(gemini::Part::Text {
141 text: text.clone(),
142 thought: None,
143 thought_signature: None,
144 });
145 }
146 Part::FunctionCall { name, args } => {
147 gemini_parts.push(gemini::Part::FunctionCall {
148 function_call: gemini::FunctionCall {
149 name: name.clone(),
150 args: args.clone(),
151 thought_signature: None,
152 },
153 thought_signature: None,
154 });
155 }
156 _ => {}
157 }
158 }
159 if !gemini_parts.is_empty() {
160 let model_content = gemini::Content {
161 role: Some(gemini::Role::Model),
162 parts: Some(gemini_parts),
163 };
164 builder = builder.with_message(gemini::Message {
165 content: model_content,
166 role: gemini::Role::Model,
167 });
168 }
169 }
170 "function" => {
171 for part in &content.parts {
173 if let Part::FunctionResponse { name, response } = part {
174 builder = builder
175 .with_function_response(name, response.clone())
176 .map_err(|e| adk_core::AdkError::Model(e.to_string()))?;
177 }
178 }
179 }
180 _ => {}
181 }
182 }
183
184 if let Some(config) = req.config {
186 let gen_config = gemini::GenerationConfig {
187 temperature: config.temperature,
188 top_p: config.top_p,
189 top_k: config.top_k,
190 max_output_tokens: config.max_output_tokens,
191 ..Default::default()
192 };
193 builder = builder.with_generation_config(gen_config);
194 }
195
196 if !req.tools.is_empty() {
198 let mut function_declarations = Vec::new();
199 let mut has_google_search = false;
200
201 for (name, tool_decl) in &req.tools {
202 if name == "google_search" {
203 has_google_search = true;
204 continue;
205 }
206
207 if let Ok(func_decl) =
209 serde_json::from_value::<gemini::FunctionDeclaration>(tool_decl.clone())
210 {
211 function_declarations.push(func_decl);
212 }
213 }
214
215 if !function_declarations.is_empty() {
216 let tool = gemini::Tool::with_functions(function_declarations);
217 builder = builder.with_tool(tool);
218 }
219
220 if has_google_search {
221 let tool = gemini::Tool::google_search();
223 builder = builder.with_tool(tool);
224 }
225 }
226
227 if stream {
228 adk_telemetry::debug!("Executing streaming request");
229 let response_stream = builder.execute_stream().await.map_err(|e| {
230 adk_telemetry::error!(error = %e, "Model request failed");
231 adk_core::AdkError::Model(e.to_string())
232 })?;
233
234 let mapped_stream = async_stream::stream! {
235 use futures::TryStreamExt;
236 let mut stream = response_stream;
237 while let Some(result) = stream.try_next().await.transpose() {
238 match result {
239 Ok(resp) => {
240 match Self::convert_response(&resp) {
241 Ok(mut llm_resp) => {
242 llm_resp.partial = true;
243 llm_resp.turn_complete = false;
244 yield Ok(llm_resp);
245 }
246 Err(e) => {
247 adk_telemetry::error!(error = %e, "Failed to convert response");
248 yield Err(e);
249 }
250 }
251 }
252 Err(e) => {
253 adk_telemetry::error!(error = %e, "Stream error");
254 yield Err(adk_core::AdkError::Model(e.to_string()));
255 }
256 }
257 }
258 };
259
260 Ok(Box::pin(mapped_stream))
261 } else {
262 adk_telemetry::debug!("Executing blocking request");
263 let response = builder.execute().await.map_err(|e| {
264 adk_telemetry::error!(error = %e, "Model request failed");
265 adk_core::AdkError::Model(e.to_string())
266 })?;
267
268 let llm_response = Self::convert_response(&response)?;
269
270 let stream = async_stream::stream! {
271 yield Ok(llm_response);
272 };
273
274 Ok(Box::pin(stream))
275 }
276 }
277}