1use crate::attachment;
2use crate::retry::{RetryConfig, execute_with_retry, is_retryable_model_error};
3use adk_core::{
4 CacheCapable, CitationMetadata, CitationSource, Content, ErrorCategory, ErrorComponent,
5 FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result, UsageMetadata,
6};
7use adk_gemini::Gemini;
8use async_trait::async_trait;
9use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
10use futures::TryStreamExt;
11
12pub struct GeminiModel {
13 client: Gemini,
14 model_name: String,
15 retry_config: RetryConfig,
16}
17
18fn gemini_error_to_adk(e: &adk_gemini::ClientError) -> adk_core::AdkError {
20 fn format_error_chain(e: &dyn std::error::Error) -> String {
21 let mut msg = e.to_string();
22 let mut source = e.source();
23 while let Some(s) = source {
24 msg.push_str(": ");
25 msg.push_str(&s.to_string());
26 source = s.source();
27 }
28 msg
29 }
30
31 let message = format_error_chain(e);
32
33 let (category, code, status_code) = if message.contains("code 429")
36 || message.contains("RESOURCE_EXHAUSTED")
37 || message.contains("rate limit")
38 {
39 (ErrorCategory::RateLimited, "model.gemini.rate_limited", Some(429u16))
40 } else if message.contains("code 503") || message.contains("UNAVAILABLE") {
41 (ErrorCategory::Unavailable, "model.gemini.unavailable", Some(503))
42 } else if message.contains("code 529") || message.contains("OVERLOADED") {
43 (ErrorCategory::Unavailable, "model.gemini.overloaded", Some(529))
44 } else if message.contains("code 408")
45 || message.contains("DEADLINE_EXCEEDED")
46 || message.contains("TIMEOUT")
47 {
48 (ErrorCategory::Timeout, "model.gemini.timeout", Some(408))
49 } else if message.contains("code 401") || message.contains("Invalid API key") {
50 (ErrorCategory::Unauthorized, "model.gemini.unauthorized", Some(401))
51 } else if message.contains("code 400") {
52 (ErrorCategory::InvalidInput, "model.gemini.bad_request", Some(400))
53 } else if message.contains("code 404") {
54 (ErrorCategory::NotFound, "model.gemini.not_found", Some(404))
55 } else if message.contains("invalid generation config") {
56 (ErrorCategory::InvalidInput, "model.gemini.invalid_config", None)
57 } else {
58 (ErrorCategory::Internal, "model.gemini.internal", None)
59 };
60
61 let mut err = adk_core::AdkError::new(ErrorComponent::Model, category, code, message)
62 .with_provider("gemini");
63 if let Some(sc) = status_code {
64 err = err.with_upstream_status(sc);
65 }
66 err
67}
68
69impl GeminiModel {
70 fn gemini_part_thought_signature(value: &serde_json::Value) -> Option<String> {
71 value.get("thoughtSignature").and_then(serde_json::Value::as_str).map(str::to_string)
72 }
73
74 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
75 let model_name = model.into();
76 let client = Gemini::with_model(api_key.into(), model_name.clone())
77 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
78
79 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
80 }
81
82 #[cfg(feature = "gemini-vertex")]
86 pub fn new_google_cloud(
87 api_key: impl Into<String>,
88 project_id: impl AsRef<str>,
89 location: impl AsRef<str>,
90 model: impl Into<String>,
91 ) -> Result<Self> {
92 let model_name = model.into();
93 let client = Gemini::with_google_cloud_model(
94 api_key.into(),
95 project_id,
96 location,
97 model_name.clone(),
98 )
99 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
100
101 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
102 }
103
104 #[cfg(feature = "gemini-vertex")]
108 pub fn new_google_cloud_service_account(
109 service_account_json: &str,
110 project_id: impl AsRef<str>,
111 location: impl AsRef<str>,
112 model: impl Into<String>,
113 ) -> Result<Self> {
114 let model_name = model.into();
115 let client = Gemini::with_google_cloud_service_account_json(
116 service_account_json,
117 project_id.as_ref(),
118 location.as_ref(),
119 model_name.clone(),
120 )
121 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
122
123 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
124 }
125
126 #[cfg(feature = "gemini-vertex")]
130 pub fn new_google_cloud_adc(
131 project_id: impl AsRef<str>,
132 location: impl AsRef<str>,
133 model: impl Into<String>,
134 ) -> Result<Self> {
135 let model_name = model.into();
136 let client = Gemini::with_google_cloud_adc_model(
137 project_id.as_ref(),
138 location.as_ref(),
139 model_name.clone(),
140 )
141 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
142
143 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
144 }
145
146 #[cfg(feature = "gemini-vertex")]
150 pub fn new_google_cloud_wif(
151 wif_json: &str,
152 project_id: impl AsRef<str>,
153 location: impl AsRef<str>,
154 model: impl Into<String>,
155 ) -> Result<Self> {
156 let model_name = model.into();
157 let client = Gemini::with_google_cloud_wif_json(
158 wif_json,
159 project_id.as_ref(),
160 location.as_ref(),
161 model_name.clone(),
162 )
163 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
164
165 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
166 }
167
168 #[must_use]
169 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
170 self.retry_config = retry_config;
171 self
172 }
173
174 pub fn set_retry_config(&mut self, retry_config: RetryConfig) {
175 self.retry_config = retry_config;
176 }
177
178 pub fn retry_config(&self) -> &RetryConfig {
179 &self.retry_config
180 }
181
182 fn convert_response(resp: &adk_gemini::GenerationResponse) -> Result<LlmResponse> {
183 let mut converted_parts: Vec<Part> = Vec::new();
184
185 if let Some(parts) = resp.candidates.first().and_then(|c| c.content.parts.as_ref()) {
187 for p in parts {
188 match p {
189 adk_gemini::Part::Text { text, thought, thought_signature } => {
190 if thought == &Some(true) {
191 converted_parts.push(Part::Thinking {
192 thinking: text.clone(),
193 signature: thought_signature.clone(),
194 });
195 } else {
196 converted_parts.push(Part::Text { text: text.clone() });
197 }
198 }
199 adk_gemini::Part::InlineData { inline_data } => {
200 let decoded =
201 BASE64_STANDARD.decode(&inline_data.data).map_err(|error| {
202 adk_core::AdkError::model(format!(
203 "failed to decode inline data from gemini response: {error}"
204 ))
205 })?;
206 converted_parts.push(Part::InlineData {
207 mime_type: inline_data.mime_type.clone(),
208 data: decoded,
209 });
210 }
211 adk_gemini::Part::FunctionCall { function_call, thought_signature } => {
212 converted_parts.push(Part::FunctionCall {
213 name: function_call.name.clone(),
214 args: function_call.args.clone(),
215 id: None,
216 thought_signature: thought_signature.clone(),
217 });
218 }
219 adk_gemini::Part::FunctionResponse { function_response, .. } => {
220 converted_parts.push(Part::FunctionResponse {
221 function_response: adk_core::FunctionResponseData {
222 name: function_response.name.clone(),
223 response: function_response
224 .response
225 .clone()
226 .unwrap_or(serde_json::Value::Null),
227 },
228 id: None,
229 });
230 }
231 adk_gemini::Part::ToolCall { .. } | adk_gemini::Part::ExecutableCode { .. } => {
232 if let Ok(value) = serde_json::to_value(p) {
233 converted_parts.push(Part::ServerToolCall { server_tool_call: value });
234 }
235 }
236 adk_gemini::Part::ToolResponse { .. }
237 | adk_gemini::Part::CodeExecutionResult { .. } => {
238 let value = serde_json::to_value(p).unwrap_or(serde_json::Value::Null);
239 converted_parts
240 .push(Part::ServerToolResponse { server_tool_response: value });
241 }
242 }
243 }
244 }
245
246 if let Some(grounding) = resp.candidates.first().and_then(|c| c.grounding_metadata.as_ref())
248 {
249 if let Some(queries) = &grounding.web_search_queries {
250 if !queries.is_empty() {
251 let search_info = format!("\n\nš **Searched:** {}", queries.join(", "));
252 converted_parts.push(Part::Text { text: search_info });
253 }
254 }
255 if let Some(chunks) = &grounding.grounding_chunks {
256 let sources: Vec<String> = chunks
257 .iter()
258 .filter_map(|c| {
259 c.web.as_ref().and_then(|w| match (&w.title, &w.uri) {
260 (Some(title), Some(uri)) => Some(format!("[{}]({})", title, uri)),
261 (Some(title), None) => Some(title.clone()),
262 (None, Some(uri)) => Some(uri.to_string()),
263 (None, None) => None,
264 })
265 })
266 .collect();
267 if !sources.is_empty() {
268 let sources_info = format!("\nš **Sources:** {}", sources.join(" | "));
269 converted_parts.push(Part::Text { text: sources_info });
270 }
271 }
272 }
273
274 let content = if converted_parts.is_empty() {
275 None
276 } else {
277 Some(Content { role: "model".to_string(), parts: converted_parts })
278 };
279
280 let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
281 prompt_token_count: u.prompt_token_count.unwrap_or(0),
282 candidates_token_count: u.candidates_token_count.unwrap_or(0),
283 total_token_count: u.total_token_count.unwrap_or(0),
284 thinking_token_count: u.thoughts_token_count,
285 cache_read_input_token_count: u.cached_content_token_count,
286 ..Default::default()
287 });
288
289 let finish_reason =
290 resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
291 adk_gemini::FinishReason::Stop => FinishReason::Stop,
292 adk_gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
293 adk_gemini::FinishReason::Safety => FinishReason::Safety,
294 adk_gemini::FinishReason::Recitation => FinishReason::Recitation,
295 _ => FinishReason::Other,
296 });
297
298 let citation_metadata =
299 resp.candidates.first().and_then(|c| c.citation_metadata.as_ref()).map(|meta| {
300 CitationMetadata {
301 citation_sources: meta
302 .citation_sources
303 .iter()
304 .map(|source| CitationSource {
305 uri: source.uri.clone(),
306 title: source.title.clone(),
307 start_index: source.start_index,
308 end_index: source.end_index,
309 license: source.license.clone(),
310 publication_date: source.publication_date.map(|d| d.to_string()),
311 })
312 .collect(),
313 }
314 });
315
316 let provider_metadata = resp
319 .candidates
320 .first()
321 .and_then(|c| c.grounding_metadata.as_ref())
322 .and_then(|g| serde_json::to_value(g).ok());
323
324 Ok(LlmResponse {
325 content,
326 usage_metadata,
327 finish_reason,
328 citation_metadata,
329 partial: false,
330 turn_complete: true,
331 interrupted: false,
332 error_code: None,
333 error_message: None,
334 provider_metadata,
335 })
336 }
337
338 fn gemini_function_response_payload(response: serde_json::Value) -> serde_json::Value {
339 match response {
340 serde_json::Value::Object(_) => response,
342 other => serde_json::json!({ "result": other }),
343 }
344 }
345
346 fn merge_object_value(
347 target: &mut serde_json::Map<String, serde_json::Value>,
348 value: serde_json::Value,
349 ) {
350 if let serde_json::Value::Object(object) = value {
351 for (key, value) in object {
352 target.insert(key, value);
353 }
354 }
355 }
356
357 fn build_gemini_tools(
358 tools: &std::collections::HashMap<String, serde_json::Value>,
359 ) -> Result<(Vec<adk_gemini::Tool>, adk_gemini::ToolConfig)> {
360 let mut gemini_tools = Vec::new();
361 let mut function_declarations = Vec::new();
362 let mut has_provider_native_tools = false;
363 let mut tool_config_json = serde_json::Map::new();
364
365 for (name, tool_decl) in tools {
366 if let Some(provider_tool) = tool_decl.get("x-adk-gemini-tool") {
367 let tool = serde_json::from_value::<adk_gemini::Tool>(provider_tool.clone())
368 .map_err(|error| {
369 adk_core::AdkError::model(format!(
370 "failed to deserialize Gemini native tool '{name}': {error}"
371 ))
372 })?;
373 has_provider_native_tools = true;
374 gemini_tools.push(tool);
375 } else if let Ok(func_decl) =
376 serde_json::from_value::<adk_gemini::FunctionDeclaration>(tool_decl.clone())
377 {
378 function_declarations.push(func_decl);
379 } else {
380 return Err(adk_core::AdkError::model(format!(
381 "failed to deserialize Gemini tool '{name}' as a function declaration"
382 )));
383 }
384
385 if let Some(tool_config) = tool_decl.get("x-adk-gemini-tool-config") {
386 Self::merge_object_value(&mut tool_config_json, tool_config.clone());
387 }
388 }
389
390 let has_function_declarations = !function_declarations.is_empty();
391 if has_function_declarations {
392 gemini_tools.push(adk_gemini::Tool::with_functions(function_declarations));
393 }
394
395 if has_provider_native_tools && has_function_declarations {
396 tool_config_json.insert(
397 "includeServerSideToolInvocations".to_string(),
398 serde_json::Value::Bool(true),
399 );
400 }
401
402 let tool_config = if tool_config_json.is_empty() {
403 adk_gemini::ToolConfig::default()
404 } else {
405 serde_json::from_value::<adk_gemini::ToolConfig>(serde_json::Value::Object(
406 tool_config_json,
407 ))
408 .map_err(|error| {
409 adk_core::AdkError::model(format!(
410 "failed to deserialize Gemini tool configuration: {error}"
411 ))
412 })?
413 };
414
415 Ok((gemini_tools, tool_config))
416 }
417
418 fn stream_chunks_from_response(
419 mut response: LlmResponse,
420 saw_partial_chunk: bool,
421 ) -> (Vec<LlmResponse>, bool) {
422 let is_final = response.finish_reason.is_some();
423
424 if !is_final {
425 response.partial = true;
426 response.turn_complete = false;
427 return (vec![response], true);
428 }
429
430 response.partial = false;
431 response.turn_complete = true;
432
433 if saw_partial_chunk {
434 return (vec![response], true);
435 }
436
437 let synthetic_partial = LlmResponse {
438 content: None,
439 usage_metadata: None,
440 finish_reason: None,
441 citation_metadata: None,
442 partial: true,
443 turn_complete: false,
444 interrupted: false,
445 error_code: None,
446 error_message: None,
447 provider_metadata: None,
448 };
449
450 (vec![synthetic_partial, response], true)
451 }
452
453 async fn generate_content_internal(
454 &self,
455 req: LlmRequest,
456 stream: bool,
457 ) -> Result<LlmResponseStream> {
458 let mut builder = self.client.generate_content();
459
460 let mut fn_call_signatures: std::collections::HashMap<String, String> =
465 std::collections::HashMap::new();
466 for content in &req.contents {
467 if content.role == "model" {
468 for part in &content.parts {
469 if let Part::FunctionCall { name, thought_signature: Some(sig), .. } = part {
470 fn_call_signatures.insert(name.clone(), sig.clone());
471 }
472 }
473 }
474 }
475
476 for content in &req.contents {
478 match content.role.as_str() {
479 "user" => {
480 let mut gemini_parts = Vec::new();
482 for part in &content.parts {
483 match part {
484 Part::Text { text } => {
485 gemini_parts.push(adk_gemini::Part::Text {
486 text: text.clone(),
487 thought: None,
488 thought_signature: None,
489 });
490 }
491 Part::Thinking { thinking, signature } => {
492 gemini_parts.push(adk_gemini::Part::Text {
493 text: thinking.clone(),
494 thought: Some(true),
495 thought_signature: signature.clone(),
496 });
497 }
498 Part::InlineData { data, mime_type } => {
499 let encoded = attachment::encode_base64(data);
500 gemini_parts.push(adk_gemini::Part::InlineData {
501 inline_data: adk_gemini::Blob {
502 mime_type: mime_type.clone(),
503 data: encoded,
504 },
505 });
506 }
507 Part::FileData { mime_type, file_uri } => {
508 gemini_parts.push(adk_gemini::Part::Text {
509 text: attachment::file_attachment_to_text(mime_type, file_uri),
510 thought: None,
511 thought_signature: None,
512 });
513 }
514 _ => {}
515 }
516 }
517 if !gemini_parts.is_empty() {
518 let user_content = adk_gemini::Content {
519 role: Some(adk_gemini::Role::User),
520 parts: Some(gemini_parts),
521 };
522 builder = builder.with_message(adk_gemini::Message {
523 content: user_content,
524 role: adk_gemini::Role::User,
525 });
526 }
527 }
528 "model" => {
529 let mut gemini_parts = Vec::new();
531 for part in &content.parts {
532 match part {
533 Part::Text { text } => {
534 gemini_parts.push(adk_gemini::Part::Text {
535 text: text.clone(),
536 thought: None,
537 thought_signature: None,
538 });
539 }
540 Part::Thinking { thinking, signature } => {
541 gemini_parts.push(adk_gemini::Part::Text {
542 text: thinking.clone(),
543 thought: Some(true),
544 thought_signature: signature.clone(),
545 });
546 }
547 Part::FunctionCall { name, args, thought_signature, .. } => {
548 gemini_parts.push(adk_gemini::Part::FunctionCall {
549 function_call: adk_gemini::FunctionCall {
550 name: name.clone(),
551 args: args.clone(),
552 thought_signature: None,
553 },
554 thought_signature: thought_signature.clone(),
555 });
556 }
557 Part::ServerToolCall { server_tool_call } => {
558 if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
559 server_tool_call.clone(),
560 ) {
561 match native_part {
562 adk_gemini::Part::ToolCall { .. }
563 | adk_gemini::Part::ExecutableCode { .. } => {
564 gemini_parts.push(native_part);
565 continue;
566 }
567 _ => {}
568 }
569 }
570
571 gemini_parts.push(adk_gemini::Part::ToolCall {
572 tool_call: server_tool_call.clone(),
573 thought_signature: Self::gemini_part_thought_signature(
574 server_tool_call,
575 ),
576 });
577 }
578 Part::ServerToolResponse { server_tool_response } => {
579 if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
580 server_tool_response.clone(),
581 ) {
582 match native_part {
583 adk_gemini::Part::ToolResponse { .. }
584 | adk_gemini::Part::CodeExecutionResult { .. } => {
585 gemini_parts.push(native_part);
586 continue;
587 }
588 _ => {}
589 }
590 }
591
592 gemini_parts.push(adk_gemini::Part::ToolResponse {
593 tool_response: server_tool_response.clone(),
594 thought_signature: Self::gemini_part_thought_signature(
595 server_tool_response,
596 ),
597 });
598 }
599 _ => {}
600 }
601 }
602 if !gemini_parts.is_empty() {
603 let model_content = adk_gemini::Content {
604 role: Some(adk_gemini::Role::Model),
605 parts: Some(gemini_parts),
606 };
607 builder = builder.with_message(adk_gemini::Message {
608 content: model_content,
609 role: adk_gemini::Role::Model,
610 });
611 }
612 }
613 "function" => {
614 let mut gemini_parts = Vec::new();
617 for part in &content.parts {
618 if let Part::FunctionResponse { function_response, .. } = part {
619 let sig = fn_call_signatures.get(&function_response.name).cloned();
620 gemini_parts.push(adk_gemini::Part::FunctionResponse {
621 function_response: adk_gemini::tools::FunctionResponse::new(
622 &function_response.name,
623 Self::gemini_function_response_payload(
624 function_response.response.clone(),
625 ),
626 ),
627 thought_signature: sig,
628 });
629 }
630 }
631 if !gemini_parts.is_empty() {
632 let fn_content = adk_gemini::Content {
633 role: Some(adk_gemini::Role::User),
634 parts: Some(gemini_parts),
635 };
636 builder = builder.with_message(adk_gemini::Message {
637 content: fn_content,
638 role: adk_gemini::Role::User,
639 });
640 }
641 }
642 _ => {}
643 }
644 }
645
646 if let Some(config) = req.config {
648 let has_schema = config.response_schema.is_some();
649 let gen_config = adk_gemini::GenerationConfig {
650 temperature: config.temperature,
651 top_p: config.top_p,
652 top_k: config.top_k,
653 max_output_tokens: config.max_output_tokens,
654 response_schema: config.response_schema,
655 response_mime_type: if has_schema {
656 Some("application/json".to_string())
657 } else {
658 None
659 },
660 ..Default::default()
661 };
662 builder = builder.with_generation_config(gen_config);
663
664 if let Some(ref name) = config.cached_content {
666 let handle = self.client.get_cached_content(name);
667 builder = builder.with_cached_content(&handle);
668 }
669 }
670
671 if !req.tools.is_empty() {
673 let (gemini_tools, tool_config) = Self::build_gemini_tools(&req.tools)?;
674 for tool in gemini_tools {
675 builder = builder.with_tool(tool);
676 }
677 if tool_config != adk_gemini::ToolConfig::default() {
678 builder = builder.with_tool_config(tool_config);
679 }
680 }
681
682 if stream {
683 adk_telemetry::debug!("Executing streaming request");
684 let response_stream = builder.execute_stream().await.map_err(|e| {
685 adk_telemetry::error!(error = %e, "Model request failed");
686 gemini_error_to_adk(&e)
687 })?;
688
689 let mapped_stream = async_stream::stream! {
690 let mut stream = response_stream;
691 let mut saw_partial_chunk = false;
692 while let Some(result) = stream.try_next().await.transpose() {
693 match result {
694 Ok(resp) => {
695 match Self::convert_response(&resp) {
696 Ok(llm_resp) => {
697 let (chunks, next_saw_partial) =
698 Self::stream_chunks_from_response(llm_resp, saw_partial_chunk);
699 saw_partial_chunk = next_saw_partial;
700 for chunk in chunks {
701 yield Ok(chunk);
702 }
703 }
704 Err(e) => {
705 adk_telemetry::error!(error = %e, "Failed to convert response");
706 yield Err(e);
707 }
708 }
709 }
710 Err(e) => {
711 adk_telemetry::error!(error = %e, "Stream error");
712 yield Err(gemini_error_to_adk(&e));
713 }
714 }
715 }
716 };
717
718 Ok(Box::pin(mapped_stream))
719 } else {
720 adk_telemetry::debug!("Executing blocking request");
721 let response = builder.execute().await.map_err(|e| {
722 adk_telemetry::error!(error = %e, "Model request failed");
723 gemini_error_to_adk(&e)
724 })?;
725
726 let llm_response = Self::convert_response(&response)?;
727
728 let stream = async_stream::stream! {
729 yield Ok(llm_response);
730 };
731
732 Ok(Box::pin(stream))
733 }
734 }
735
736 pub async fn create_cached_content(
741 &self,
742 system_instruction: &str,
743 tools: &std::collections::HashMap<String, serde_json::Value>,
744 ttl_seconds: u32,
745 ) -> Result<String> {
746 let mut cache_builder = self
747 .client
748 .create_cache()
749 .with_system_instruction(system_instruction)
750 .with_ttl(std::time::Duration::from_secs(u64::from(ttl_seconds)));
751
752 let (gemini_tools, tool_config) = Self::build_gemini_tools(tools)?;
753 if !gemini_tools.is_empty() {
754 cache_builder = cache_builder.with_tools(gemini_tools);
755 }
756 if tool_config != adk_gemini::ToolConfig::default() {
757 cache_builder = cache_builder.with_tool_config(tool_config);
758 }
759
760 let handle = cache_builder
761 .execute()
762 .await
763 .map_err(|e| adk_core::AdkError::model(format!("cache creation failed: {e}")))?;
764
765 Ok(handle.name().to_string())
766 }
767
768 pub async fn delete_cached_content(&self, name: &str) -> Result<()> {
770 let handle = self.client.get_cached_content(name);
771 handle
772 .delete()
773 .await
774 .map_err(|(_, e)| adk_core::AdkError::model(format!("cache deletion failed: {e}")))?;
775 Ok(())
776 }
777}
778
779#[async_trait]
780impl Llm for GeminiModel {
781 fn name(&self) -> &str {
782 &self.model_name
783 }
784
785 #[adk_telemetry::instrument(
786 name = "call_llm",
787 skip(self, req),
788 fields(
789 model.name = %self.model_name,
790 stream = %stream,
791 request.contents_count = %req.contents.len(),
792 request.tools_count = %req.tools.len()
793 )
794 )]
795 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
796 adk_telemetry::info!("Generating content");
797 let usage_span = adk_telemetry::llm_generate_span("gemini", &self.model_name, stream);
798 let result = execute_with_retry(&self.retry_config, is_retryable_model_error, || {
801 self.generate_content_internal(req.clone(), stream)
802 })
803 .await?;
804 Ok(crate::usage_tracking::with_usage_tracking(result, usage_span))
805 }
806}
807
808#[cfg(test)]
809mod native_tool_tests {
810 use super::*;
811
812 #[test]
813 fn test_build_gemini_tools_supports_native_tool_metadata() {
814 let mut tools = std::collections::HashMap::new();
815 tools.insert(
816 "google_search".to_string(),
817 serde_json::json!({
818 "x-adk-gemini-tool": {
819 "google_search": {}
820 }
821 }),
822 );
823 tools.insert(
824 "lookup_weather".to_string(),
825 serde_json::json!({
826 "name": "lookup_weather",
827 "description": "lookup weather",
828 "parameters": {
829 "type": "object",
830 "properties": {
831 "city": { "type": "string" }
832 }
833 }
834 }),
835 );
836
837 let (gemini_tools, tool_config) =
838 GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
839
840 assert_eq!(gemini_tools.len(), 2);
841 assert_eq!(tool_config.include_server_side_tool_invocations, Some(true));
842 }
843
844 #[test]
845 fn test_build_gemini_tools_merges_native_tool_config() {
846 let mut tools = std::collections::HashMap::new();
847 tools.insert(
848 "google_maps".to_string(),
849 serde_json::json!({
850 "x-adk-gemini-tool": {
851 "google_maps": {
852 "enable_widget": true
853 }
854 },
855 "x-adk-gemini-tool-config": {
856 "retrievalConfig": {
857 "latLng": {
858 "latitude": 1.23,
859 "longitude": 4.56
860 }
861 }
862 }
863 }),
864 );
865
866 let (_gemini_tools, tool_config) =
867 GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
868
869 assert_eq!(
870 tool_config.retrieval_config,
871 Some(serde_json::json!({
872 "latLng": {
873 "latitude": 1.23,
874 "longitude": 4.56
875 }
876 }))
877 );
878 }
879}
880
881#[async_trait]
882impl CacheCapable for GeminiModel {
883 async fn create_cache(
884 &self,
885 system_instruction: &str,
886 tools: &std::collections::HashMap<String, serde_json::Value>,
887 ttl_seconds: u32,
888 ) -> Result<String> {
889 self.create_cached_content(system_instruction, tools, ttl_seconds).await
890 }
891
892 async fn delete_cache(&self, name: &str) -> Result<()> {
893 self.delete_cached_content(name).await
894 }
895}
896
897#[cfg(test)]
898mod tests {
899 use super::*;
900 use adk_core::AdkError;
901 use std::{
902 sync::{
903 Arc,
904 atomic::{AtomicU32, Ordering},
905 },
906 time::Duration,
907 };
908
909 #[test]
910 fn constructor_is_backward_compatible_and_sync() {
911 fn accepts_sync_constructor<F>(_f: F)
912 where
913 F: Fn(&str, &str) -> Result<GeminiModel>,
914 {
915 }
916
917 accepts_sync_constructor(|api_key, model| GeminiModel::new(api_key, model));
918 }
919
920 #[test]
921 fn stream_chunks_from_response_injects_partial_before_lone_final_chunk() {
922 let response = LlmResponse {
923 content: Some(Content::new("model").with_text("hello")),
924 usage_metadata: None,
925 finish_reason: Some(FinishReason::Stop),
926 citation_metadata: None,
927 partial: false,
928 turn_complete: true,
929 interrupted: false,
930 error_code: None,
931 error_message: None,
932 provider_metadata: None,
933 };
934
935 let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, false);
936 assert!(saw_partial);
937 assert_eq!(chunks.len(), 2);
938 assert!(chunks[0].partial);
939 assert!(!chunks[0].turn_complete);
940 assert!(chunks[0].content.is_none());
941 assert!(!chunks[1].partial);
942 assert!(chunks[1].turn_complete);
943 }
944
945 #[test]
946 fn stream_chunks_from_response_keeps_final_only_when_partial_already_seen() {
947 let response = LlmResponse {
948 content: Some(Content::new("model").with_text("done")),
949 usage_metadata: None,
950 finish_reason: Some(FinishReason::Stop),
951 citation_metadata: None,
952 partial: false,
953 turn_complete: true,
954 interrupted: false,
955 error_code: None,
956 error_message: None,
957 provider_metadata: None,
958 };
959
960 let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, true);
961 assert!(saw_partial);
962 assert_eq!(chunks.len(), 1);
963 assert!(!chunks[0].partial);
964 assert!(chunks[0].turn_complete);
965 }
966
967 #[tokio::test]
968 async fn execute_with_retry_retries_retryable_errors() {
969 let retry_config = RetryConfig::default()
970 .with_max_retries(2)
971 .with_initial_delay(Duration::from_millis(0))
972 .with_max_delay(Duration::from_millis(0));
973 let attempts = Arc::new(AtomicU32::new(0));
974
975 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
976 let attempts = Arc::clone(&attempts);
977 async move {
978 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
979 if attempt < 2 {
980 return Err(AdkError::model("code 429 RESOURCE_EXHAUSTED"));
981 }
982 Ok("ok")
983 }
984 })
985 .await
986 .expect("retry should eventually succeed");
987
988 assert_eq!(result, "ok");
989 assert_eq!(attempts.load(Ordering::SeqCst), 3);
990 }
991
992 #[tokio::test]
993 async fn execute_with_retry_does_not_retry_non_retryable_errors() {
994 let retry_config = RetryConfig::default()
995 .with_max_retries(3)
996 .with_initial_delay(Duration::from_millis(0))
997 .with_max_delay(Duration::from_millis(0));
998 let attempts = Arc::new(AtomicU32::new(0));
999
1000 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1001 let attempts = Arc::clone(&attempts);
1002 async move {
1003 attempts.fetch_add(1, Ordering::SeqCst);
1004 Err::<(), _>(AdkError::model("code 400 invalid request"))
1005 }
1006 })
1007 .await
1008 .expect_err("non-retryable error should return immediately");
1009
1010 assert!(error.is_model());
1011 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1012 }
1013
1014 #[tokio::test]
1015 async fn execute_with_retry_respects_disabled_config() {
1016 let retry_config = RetryConfig::disabled().with_max_retries(10);
1017 let attempts = Arc::new(AtomicU32::new(0));
1018
1019 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1020 let attempts = Arc::clone(&attempts);
1021 async move {
1022 attempts.fetch_add(1, Ordering::SeqCst);
1023 Err::<(), _>(AdkError::model("code 429 RESOURCE_EXHAUSTED"))
1024 }
1025 })
1026 .await
1027 .expect_err("disabled retries should return first error");
1028
1029 assert!(error.is_model());
1030 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1031 }
1032
1033 #[test]
1034 fn convert_response_preserves_citation_metadata() {
1035 let response = adk_gemini::GenerationResponse {
1036 candidates: vec![adk_gemini::Candidate {
1037 content: adk_gemini::Content {
1038 role: Some(adk_gemini::Role::Model),
1039 parts: Some(vec![adk_gemini::Part::Text {
1040 text: "hello world".to_string(),
1041 thought: None,
1042 thought_signature: None,
1043 }]),
1044 },
1045 safety_ratings: None,
1046 citation_metadata: Some(adk_gemini::CitationMetadata {
1047 citation_sources: vec![adk_gemini::CitationSource {
1048 uri: Some("https://example.com".to_string()),
1049 title: Some("Example".to_string()),
1050 start_index: Some(0),
1051 end_index: Some(5),
1052 license: Some("CC-BY".to_string()),
1053 publication_date: None,
1054 }],
1055 }),
1056 grounding_metadata: None,
1057 finish_reason: Some(adk_gemini::FinishReason::Stop),
1058 index: Some(0),
1059 }],
1060 prompt_feedback: None,
1061 usage_metadata: None,
1062 model_version: None,
1063 response_id: None,
1064 };
1065
1066 let converted =
1067 GeminiModel::convert_response(&response).expect("conversion should succeed");
1068 let metadata = converted.citation_metadata.expect("citation metadata should be mapped");
1069 assert_eq!(metadata.citation_sources.len(), 1);
1070 assert_eq!(metadata.citation_sources[0].uri.as_deref(), Some("https://example.com"));
1071 assert_eq!(metadata.citation_sources[0].start_index, Some(0));
1072 assert_eq!(metadata.citation_sources[0].end_index, Some(5));
1073 }
1074
1075 #[test]
1076 fn convert_response_handles_inline_data_from_model() {
1077 let image_bytes = vec![0x89, 0x50, 0x4E, 0x47];
1078 let encoded = crate::attachment::encode_base64(&image_bytes);
1079
1080 let response = adk_gemini::GenerationResponse {
1081 candidates: vec![adk_gemini::Candidate {
1082 content: adk_gemini::Content {
1083 role: Some(adk_gemini::Role::Model),
1084 parts: Some(vec![
1085 adk_gemini::Part::Text {
1086 text: "Here is the image".to_string(),
1087 thought: None,
1088 thought_signature: None,
1089 },
1090 adk_gemini::Part::InlineData {
1091 inline_data: adk_gemini::Blob {
1092 mime_type: "image/png".to_string(),
1093 data: encoded,
1094 },
1095 },
1096 ]),
1097 },
1098 safety_ratings: None,
1099 citation_metadata: None,
1100 grounding_metadata: None,
1101 finish_reason: Some(adk_gemini::FinishReason::Stop),
1102 index: Some(0),
1103 }],
1104 prompt_feedback: None,
1105 usage_metadata: None,
1106 model_version: None,
1107 response_id: None,
1108 };
1109
1110 let converted =
1111 GeminiModel::convert_response(&response).expect("conversion should succeed");
1112 let content = converted.content.expect("should have content");
1113 assert!(
1114 content
1115 .parts
1116 .iter()
1117 .any(|part| matches!(part, Part::Text { text } if text == "Here is the image"))
1118 );
1119 assert!(content.parts.iter().any(|part| {
1120 matches!(
1121 part,
1122 Part::InlineData { mime_type, data }
1123 if mime_type == "image/png" && data.as_slice() == image_bytes.as_slice()
1124 )
1125 }));
1126 }
1127
1128 #[test]
1129 fn gemini_function_response_payload_preserves_objects() {
1130 let value = serde_json::json!({
1131 "documents": [
1132 { "id": "pricing", "score": 0.91 }
1133 ]
1134 });
1135
1136 let payload = GeminiModel::gemini_function_response_payload(value.clone());
1137
1138 assert_eq!(payload, value);
1139 }
1140
1141 #[test]
1142 fn gemini_function_response_payload_wraps_arrays() {
1143 let payload =
1144 GeminiModel::gemini_function_response_payload(serde_json::json!([{ "id": "pricing" }]));
1145
1146 assert_eq!(payload, serde_json::json!({ "result": [{ "id": "pricing" }] }));
1147 }
1148}