1use crate::attachment;
2use crate::retry::{RetryConfig, execute_with_retry, is_retryable_model_error};
3use adk_core::{
4 CacheCapable, CitationMetadata, CitationSource, Content, ErrorCategory, ErrorComponent,
5 FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result, UsageMetadata,
6};
7use adk_gemini::Gemini;
8use async_trait::async_trait;
9use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
10use futures::TryStreamExt;
11
12pub struct GeminiModel {
13 client: Gemini,
14 model_name: String,
15 retry_config: RetryConfig,
16 thinking_config: Option<adk_gemini::ThinkingConfig>,
22}
23
24fn gemini_error_to_adk(e: &adk_gemini::ClientError) -> adk_core::AdkError {
26 fn format_error_chain(e: &dyn std::error::Error) -> String {
27 let mut msg = e.to_string();
28 let mut source = e.source();
29 while let Some(s) = source {
30 msg.push_str(": ");
31 msg.push_str(&s.to_string());
32 source = s.source();
33 }
34 msg
35 }
36
37 let message = format_error_chain(e);
38
39 let (category, code, status_code) = if message.contains("code 429")
42 || message.contains("RESOURCE_EXHAUSTED")
43 || message.contains("rate limit")
44 {
45 (ErrorCategory::RateLimited, "model.gemini.rate_limited", Some(429u16))
46 } else if message.contains("code 503") || message.contains("UNAVAILABLE") {
47 (ErrorCategory::Unavailable, "model.gemini.unavailable", Some(503))
48 } else if message.contains("code 529") || message.contains("OVERLOADED") {
49 (ErrorCategory::Unavailable, "model.gemini.overloaded", Some(529))
50 } else if message.contains("code 408")
51 || message.contains("DEADLINE_EXCEEDED")
52 || message.contains("TIMEOUT")
53 {
54 (ErrorCategory::Timeout, "model.gemini.timeout", Some(408))
55 } else if message.contains("code 401") || message.contains("Invalid API key") {
56 (ErrorCategory::Unauthorized, "model.gemini.unauthorized", Some(401))
57 } else if message.contains("code 400") {
58 (ErrorCategory::InvalidInput, "model.gemini.bad_request", Some(400))
59 } else if message.contains("code 404") {
60 (ErrorCategory::NotFound, "model.gemini.not_found", Some(404))
61 } else if message.contains("invalid generation config") {
62 (ErrorCategory::InvalidInput, "model.gemini.invalid_config", None)
63 } else {
64 (ErrorCategory::Internal, "model.gemini.internal", None)
65 };
66
67 let mut err = adk_core::AdkError::new(ErrorComponent::Model, category, code, message)
68 .with_provider("gemini");
69 if let Some(sc) = status_code {
70 err = err.with_upstream_status(sc);
71 }
72 err
73}
74
75impl GeminiModel {
76 fn gemini_part_thought_signature(value: &serde_json::Value) -> Option<String> {
77 value.get("thoughtSignature").and_then(serde_json::Value::as_str).map(str::to_string)
78 }
79
80 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
81 let model_name = model.into();
82 let client = Gemini::with_model(api_key.into(), model_name.clone())
83 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
84
85 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
86 }
87
88 #[cfg(feature = "gemini-vertex")]
92 pub fn new_google_cloud(
93 api_key: impl Into<String>,
94 project_id: impl AsRef<str>,
95 location: impl AsRef<str>,
96 model: impl Into<String>,
97 ) -> Result<Self> {
98 let model_name = model.into();
99 let client = Gemini::with_google_cloud_model(
100 api_key.into(),
101 project_id,
102 location,
103 model_name.clone(),
104 )
105 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
106
107 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
108 }
109
110 #[cfg(feature = "gemini-vertex")]
114 pub fn new_google_cloud_service_account(
115 service_account_json: &str,
116 project_id: impl AsRef<str>,
117 location: impl AsRef<str>,
118 model: impl Into<String>,
119 ) -> Result<Self> {
120 let model_name = model.into();
121 let client = Gemini::with_google_cloud_service_account_json(
122 service_account_json,
123 project_id.as_ref(),
124 location.as_ref(),
125 model_name.clone(),
126 )
127 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
128
129 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
130 }
131
132 #[cfg(feature = "gemini-vertex")]
136 pub fn new_google_cloud_adc(
137 project_id: impl AsRef<str>,
138 location: impl AsRef<str>,
139 model: impl Into<String>,
140 ) -> Result<Self> {
141 let model_name = model.into();
142 let client = Gemini::with_google_cloud_adc_model(
143 project_id.as_ref(),
144 location.as_ref(),
145 model_name.clone(),
146 )
147 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
148
149 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
150 }
151
152 #[cfg(feature = "gemini-vertex")]
156 pub fn new_google_cloud_wif(
157 wif_json: &str,
158 project_id: impl AsRef<str>,
159 location: impl AsRef<str>,
160 model: impl Into<String>,
161 ) -> Result<Self> {
162 let model_name = model.into();
163 let client = Gemini::with_google_cloud_wif_json(
164 wif_json,
165 project_id.as_ref(),
166 location.as_ref(),
167 model_name.clone(),
168 )
169 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
170
171 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
172 }
173
174 #[must_use]
175 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
176 self.retry_config = retry_config;
177 self
178 }
179
180 pub fn set_retry_config(&mut self, retry_config: RetryConfig) {
181 self.retry_config = retry_config;
182 }
183
184 pub fn retry_config(&self) -> &RetryConfig {
185 &self.retry_config
186 }
187
188 #[must_use]
212 pub fn with_thinking_config(mut self, thinking_config: adk_gemini::ThinkingConfig) -> Self {
213 self.thinking_config = Some(thinking_config);
214 self
215 }
216
217 pub fn set_thinking_config(&mut self, thinking_config: adk_gemini::ThinkingConfig) {
219 self.thinking_config = Some(thinking_config);
220 }
221
222 pub fn thinking_config(&self) -> Option<&adk_gemini::ThinkingConfig> {
224 self.thinking_config.as_ref()
225 }
226
227 fn convert_response(resp: &adk_gemini::GenerationResponse) -> Result<LlmResponse> {
228 let mut converted_parts: Vec<Part> = Vec::new();
229
230 if let Some(parts) = resp.candidates.first().and_then(|c| c.content.parts.as_ref()) {
232 for p in parts {
233 match p {
234 adk_gemini::Part::Text { text, thought, thought_signature } => {
235 if thought == &Some(true) {
236 converted_parts.push(Part::Thinking {
237 thinking: text.clone(),
238 signature: thought_signature.clone(),
239 });
240 } else {
241 converted_parts.push(Part::Text { text: text.clone() });
242 }
243 }
244 adk_gemini::Part::InlineData { inline_data } => {
245 let decoded =
246 BASE64_STANDARD.decode(&inline_data.data).map_err(|error| {
247 adk_core::AdkError::model(format!(
248 "failed to decode inline data from gemini response: {error}"
249 ))
250 })?;
251 converted_parts.push(Part::InlineData {
252 mime_type: inline_data.mime_type.clone(),
253 data: decoded,
254 });
255 }
256 adk_gemini::Part::FunctionCall { function_call, thought_signature } => {
257 converted_parts.push(Part::FunctionCall {
258 name: function_call.name.clone(),
259 args: function_call.args.clone(),
260 id: function_call.id.clone(),
261 thought_signature: thought_signature.clone(),
262 });
263 }
264 adk_gemini::Part::FunctionResponse { function_response, .. } => {
265 converted_parts.push(Part::FunctionResponse {
266 function_response: adk_core::FunctionResponseData::new(
267 function_response.name.clone(),
268 function_response
269 .response
270 .clone()
271 .unwrap_or(serde_json::Value::Null),
272 ),
273 id: None,
274 });
275 }
276 adk_gemini::Part::ToolCall { .. } | adk_gemini::Part::ExecutableCode { .. } => {
277 if let Ok(value) = serde_json::to_value(p) {
278 converted_parts.push(Part::ServerToolCall { server_tool_call: value });
279 }
280 }
281 adk_gemini::Part::ToolResponse { .. }
282 | adk_gemini::Part::CodeExecutionResult { .. } => {
283 let value = serde_json::to_value(p).unwrap_or(serde_json::Value::Null);
284 converted_parts
285 .push(Part::ServerToolResponse { server_tool_response: value });
286 }
287 adk_gemini::Part::FileData { file_data } => {
288 converted_parts.push(Part::FileData {
289 mime_type: file_data.mime_type.clone(),
290 file_uri: file_data.file_uri.clone(),
291 });
292 }
293 }
294 }
295 }
296
297 if let Some(grounding) = resp.candidates.first().and_then(|c| c.grounding_metadata.as_ref())
299 {
300 if let Some(queries) = &grounding.web_search_queries {
301 if !queries.is_empty() {
302 let search_info = format!("\n\nš **Searched:** {}", queries.join(", "));
303 converted_parts.push(Part::Text { text: search_info });
304 }
305 }
306 if let Some(chunks) = &grounding.grounding_chunks {
307 let sources: Vec<String> = chunks
308 .iter()
309 .filter_map(|c| {
310 c.web.as_ref().and_then(|w| match (&w.title, &w.uri) {
311 (Some(title), Some(uri)) => Some(format!("[{}]({})", title, uri)),
312 (Some(title), None) => Some(title.clone()),
313 (None, Some(uri)) => Some(uri.to_string()),
314 (None, None) => None,
315 })
316 })
317 .collect();
318 if !sources.is_empty() {
319 let sources_info = format!("\nš **Sources:** {}", sources.join(" | "));
320 converted_parts.push(Part::Text { text: sources_info });
321 }
322 }
323 }
324
325 let content = if converted_parts.is_empty() {
326 None
327 } else {
328 Some(Content { role: "model".to_string(), parts: converted_parts })
329 };
330
331 let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
332 prompt_token_count: u.prompt_token_count.unwrap_or(0),
333 candidates_token_count: u.candidates_token_count.unwrap_or(0),
334 total_token_count: u.total_token_count.unwrap_or(0),
335 thinking_token_count: u.thoughts_token_count,
336 cache_read_input_token_count: u.cached_content_token_count,
337 ..Default::default()
338 });
339
340 let finish_reason =
341 resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
342 adk_gemini::FinishReason::Stop => FinishReason::Stop,
343 adk_gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
344 adk_gemini::FinishReason::Safety => FinishReason::Safety,
345 adk_gemini::FinishReason::Recitation => FinishReason::Recitation,
346 _ => FinishReason::Other,
347 });
348
349 let citation_metadata =
350 resp.candidates.first().and_then(|c| c.citation_metadata.as_ref()).map(|meta| {
351 CitationMetadata {
352 citation_sources: meta
353 .citation_sources
354 .iter()
355 .map(|source| CitationSource {
356 uri: source.uri.clone(),
357 title: source.title.clone(),
358 start_index: source.start_index,
359 end_index: source.end_index,
360 license: source.license.clone(),
361 publication_date: source.publication_date.map(|d| d.to_string()),
362 })
363 .collect(),
364 }
365 });
366
367 let provider_metadata = resp
370 .candidates
371 .first()
372 .and_then(|c| c.grounding_metadata.as_ref())
373 .and_then(|g| serde_json::to_value(g).ok());
374
375 Ok(LlmResponse {
376 content,
377 usage_metadata,
378 finish_reason,
379 citation_metadata,
380 partial: false,
381 turn_complete: true,
382 interrupted: false,
383 error_code: None,
384 error_message: None,
385 provider_metadata,
386 })
387 }
388
389 fn gemini_function_response_payload(response: serde_json::Value) -> serde_json::Value {
390 match response {
391 serde_json::Value::Object(_) => response,
393 other => serde_json::json!({ "result": other }),
394 }
395 }
396
397 fn merge_object_value(
398 target: &mut serde_json::Map<String, serde_json::Value>,
399 value: serde_json::Value,
400 ) {
401 if let serde_json::Value::Object(object) = value {
402 for (key, value) in object {
403 target.insert(key, value);
404 }
405 }
406 }
407
408 fn build_gemini_tools(
409 tools: &std::collections::HashMap<String, serde_json::Value>,
410 ) -> Result<(Vec<adk_gemini::Tool>, adk_gemini::ToolConfig)> {
411 let mut gemini_tools = Vec::new();
412 let mut function_declarations = Vec::new();
413 let mut has_provider_native_tools = false;
414 let mut tool_config_json = serde_json::Map::new();
415
416 for (name, tool_decl) in tools {
417 if let Some(provider_tool) = tool_decl.get("x-adk-gemini-tool") {
418 let tool = serde_json::from_value::<adk_gemini::Tool>(provider_tool.clone())
419 .map_err(|error| {
420 adk_core::AdkError::model(format!(
421 "failed to deserialize Gemini native tool '{name}': {error}"
422 ))
423 })?;
424 has_provider_native_tools = true;
425 gemini_tools.push(tool);
426 } else if let Ok(func_decl) =
427 serde_json::from_value::<adk_gemini::FunctionDeclaration>(tool_decl.clone())
428 {
429 function_declarations.push(func_decl);
430 } else {
431 return Err(adk_core::AdkError::model(format!(
432 "failed to deserialize Gemini tool '{name}' as a function declaration"
433 )));
434 }
435
436 if let Some(tool_config) = tool_decl.get("x-adk-gemini-tool-config") {
437 Self::merge_object_value(&mut tool_config_json, tool_config.clone());
438 }
439 }
440
441 let has_function_declarations = !function_declarations.is_empty();
442 if has_function_declarations {
443 gemini_tools.push(adk_gemini::Tool::with_functions(function_declarations));
444 }
445
446 if has_provider_native_tools {
447 tool_config_json.insert(
448 "includeServerSideToolInvocations".to_string(),
449 serde_json::Value::Bool(true),
450 );
451 }
452
453 let tool_config = if tool_config_json.is_empty() {
454 adk_gemini::ToolConfig::default()
455 } else {
456 serde_json::from_value::<adk_gemini::ToolConfig>(serde_json::Value::Object(
457 tool_config_json,
458 ))
459 .map_err(|error| {
460 adk_core::AdkError::model(format!(
461 "failed to deserialize Gemini tool configuration: {error}"
462 ))
463 })?
464 };
465
466 Ok((gemini_tools, tool_config))
467 }
468
469 fn stream_chunks_from_response(
470 mut response: LlmResponse,
471 saw_partial_chunk: bool,
472 ) -> (Vec<LlmResponse>, bool) {
473 let is_final = response.finish_reason.is_some();
474
475 if !is_final {
476 response.partial = true;
477 response.turn_complete = false;
478 return (vec![response], true);
479 }
480
481 response.partial = false;
482 response.turn_complete = true;
483
484 if saw_partial_chunk {
485 return (vec![response], true);
486 }
487
488 let synthetic_partial = LlmResponse {
489 content: None,
490 usage_metadata: None,
491 finish_reason: None,
492 citation_metadata: None,
493 partial: true,
494 turn_complete: false,
495 interrupted: false,
496 error_code: None,
497 error_message: None,
498 provider_metadata: None,
499 };
500
501 (vec![synthetic_partial, response], true)
502 }
503
504 async fn generate_content_internal(
505 &self,
506 req: LlmRequest,
507 stream: bool,
508 ) -> Result<LlmResponseStream> {
509 let mut builder = self.client.generate_content();
510
511 let mut fn_call_signatures: std::collections::HashMap<String, String> =
516 std::collections::HashMap::new();
517 for content in &req.contents {
518 if content.role == "model" {
519 for part in &content.parts {
520 if let Part::FunctionCall { name, thought_signature: Some(sig), .. } = part {
521 fn_call_signatures.insert(name.clone(), sig.clone());
522 }
523 }
524 }
525 }
526
527 for content in &req.contents {
529 match content.role.as_str() {
530 "user" => {
531 let mut gemini_parts = Vec::new();
533 for part in &content.parts {
534 match part {
535 Part::Text { text } => {
536 gemini_parts.push(adk_gemini::Part::Text {
537 text: text.clone(),
538 thought: None,
539 thought_signature: None,
540 });
541 }
542 Part::Thinking { thinking, signature } => {
543 gemini_parts.push(adk_gemini::Part::Text {
544 text: thinking.clone(),
545 thought: Some(true),
546 thought_signature: signature.clone(),
547 });
548 }
549 Part::InlineData { data, mime_type } => {
550 let encoded = attachment::encode_base64(data);
551 gemini_parts.push(adk_gemini::Part::InlineData {
552 inline_data: adk_gemini::Blob {
553 mime_type: mime_type.clone(),
554 data: encoded,
555 },
556 });
557 }
558 Part::FileData { mime_type, file_uri } => {
559 gemini_parts.push(adk_gemini::Part::Text {
560 text: attachment::file_attachment_to_text(mime_type, file_uri),
561 thought: None,
562 thought_signature: None,
563 });
564 }
565 _ => {}
566 }
567 }
568 if !gemini_parts.is_empty() {
569 let user_content = adk_gemini::Content {
570 role: Some(adk_gemini::Role::User),
571 parts: Some(gemini_parts),
572 };
573 builder = builder.with_message(adk_gemini::Message {
574 content: user_content,
575 role: adk_gemini::Role::User,
576 });
577 }
578 }
579 "model" => {
580 let mut gemini_parts = Vec::new();
582 for part in &content.parts {
583 match part {
584 Part::Text { text } => {
585 gemini_parts.push(adk_gemini::Part::Text {
586 text: text.clone(),
587 thought: None,
588 thought_signature: None,
589 });
590 }
591 Part::Thinking { thinking, signature } => {
592 gemini_parts.push(adk_gemini::Part::Text {
593 text: thinking.clone(),
594 thought: Some(true),
595 thought_signature: signature.clone(),
596 });
597 }
598 Part::FunctionCall { name, args, thought_signature, id } => {
599 gemini_parts.push(adk_gemini::Part::FunctionCall {
600 function_call: adk_gemini::FunctionCall {
601 name: name.clone(),
602 args: args.clone(),
603 id: id.clone(),
604 thought_signature: None,
605 },
606 thought_signature: thought_signature.clone(),
607 });
608 }
609 Part::ServerToolCall { server_tool_call } => {
610 if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
611 server_tool_call.clone(),
612 ) {
613 match native_part {
614 adk_gemini::Part::ToolCall { .. }
615 | adk_gemini::Part::ExecutableCode { .. } => {
616 gemini_parts.push(native_part);
617 continue;
618 }
619 _ => {}
620 }
621 }
622
623 gemini_parts.push(adk_gemini::Part::ToolCall {
624 tool_call: server_tool_call.clone(),
625 thought_signature: Self::gemini_part_thought_signature(
626 server_tool_call,
627 ),
628 });
629 }
630 Part::ServerToolResponse { server_tool_response } => {
631 if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
632 server_tool_response.clone(),
633 ) {
634 match native_part {
635 adk_gemini::Part::ToolResponse { .. }
636 | adk_gemini::Part::CodeExecutionResult { .. } => {
637 gemini_parts.push(native_part);
638 continue;
639 }
640 _ => {}
641 }
642 }
643
644 gemini_parts.push(adk_gemini::Part::ToolResponse {
645 tool_response: server_tool_response.clone(),
646 thought_signature: Self::gemini_part_thought_signature(
647 server_tool_response,
648 ),
649 });
650 }
651 _ => {}
652 }
653 }
654 if !gemini_parts.is_empty() {
655 let model_content = adk_gemini::Content {
656 role: Some(adk_gemini::Role::Model),
657 parts: Some(gemini_parts),
658 };
659 builder = builder.with_message(adk_gemini::Message {
660 content: model_content,
661 role: adk_gemini::Role::Model,
662 });
663 }
664 }
665 "function" => {
666 let mut gemini_parts = Vec::new();
669 for part in &content.parts {
670 if let Part::FunctionResponse { function_response, .. } = part {
671 let sig = fn_call_signatures.get(&function_response.name).cloned();
672
673 let mut fr_parts = Vec::new();
675 for inline in &function_response.inline_data {
676 let encoded = attachment::encode_base64(&inline.data);
677 fr_parts.push(adk_gemini::FunctionResponsePart::InlineData {
678 inline_data: adk_gemini::Blob {
679 mime_type: inline.mime_type.clone(),
680 data: encoded,
681 },
682 });
683 }
684 for file in &function_response.file_data {
685 fr_parts.push(adk_gemini::FunctionResponsePart::FileData {
686 file_data: adk_gemini::FileDataRef {
687 mime_type: file.mime_type.clone(),
688 file_uri: file.file_uri.clone(),
689 },
690 });
691 }
692
693 let mut gemini_fr = adk_gemini::tools::FunctionResponse::new(
694 &function_response.name,
695 Self::gemini_function_response_payload(
696 function_response.response.clone(),
697 ),
698 );
699 gemini_fr.parts = fr_parts;
700
701 gemini_parts.push(adk_gemini::Part::FunctionResponse {
702 function_response: gemini_fr,
703 thought_signature: sig,
704 });
705 }
706 }
707 if !gemini_parts.is_empty() {
708 let fn_content = adk_gemini::Content {
709 role: Some(adk_gemini::Role::User),
710 parts: Some(gemini_parts),
711 };
712 builder = builder.with_message(adk_gemini::Message {
713 content: fn_content,
714 role: adk_gemini::Role::User,
715 });
716 }
717 }
718 _ => {}
719 }
720 }
721
722 if let Some(config) = req.config {
724 let has_schema = config.response_schema.is_some();
725 let gen_config = adk_gemini::GenerationConfig {
726 temperature: config.temperature,
727 top_p: config.top_p,
728 top_k: config.top_k,
729 max_output_tokens: config.max_output_tokens,
730 response_schema: config.response_schema,
731 response_mime_type: if has_schema {
732 Some("application/json".to_string())
733 } else {
734 None
735 },
736 thinking_config: self.thinking_config.clone(),
737 ..Default::default()
738 };
739 builder = builder.with_generation_config(gen_config);
740
741 if let Some(ref name) = config.cached_content {
743 let handle = self.client.get_cached_content(name);
744 builder = builder.with_cached_content(&handle);
745 }
746 } else if self.thinking_config.is_some() {
747 let gen_config = adk_gemini::GenerationConfig {
750 thinking_config: self.thinking_config.clone(),
751 ..Default::default()
752 };
753 builder = builder.with_generation_config(gen_config);
754 }
755
756 if !req.tools.is_empty() {
758 let (gemini_tools, tool_config) = Self::build_gemini_tools(&req.tools)?;
759 for tool in gemini_tools {
760 builder = builder.with_tool(tool);
761 }
762 if tool_config != adk_gemini::ToolConfig::default() {
763 builder = builder.with_tool_config(tool_config);
764 }
765 }
766
767 if stream {
768 adk_telemetry::debug!("Executing streaming request");
769 let response_stream = builder.execute_stream().await.map_err(|e| {
770 adk_telemetry::error!(error = %e, "Model request failed");
771 gemini_error_to_adk(&e)
772 })?;
773
774 let mapped_stream = async_stream::stream! {
775 let mut stream = response_stream;
776 let mut saw_partial_chunk = false;
777 while let Some(result) = stream.try_next().await.transpose() {
778 match result {
779 Ok(resp) => {
780 match Self::convert_response(&resp) {
781 Ok(llm_resp) => {
782 let (chunks, next_saw_partial) =
783 Self::stream_chunks_from_response(llm_resp, saw_partial_chunk);
784 saw_partial_chunk = next_saw_partial;
785 for chunk in chunks {
786 yield Ok(chunk);
787 }
788 }
789 Err(e) => {
790 adk_telemetry::error!(error = %e, "Failed to convert response");
791 yield Err(e);
792 }
793 }
794 }
795 Err(e) => {
796 adk_telemetry::error!(error = %e, "Stream error");
797 yield Err(gemini_error_to_adk(&e));
798 }
799 }
800 }
801 };
802
803 Ok(Box::pin(mapped_stream))
804 } else {
805 adk_telemetry::debug!("Executing blocking request");
806 let response = builder.execute().await.map_err(|e| {
807 adk_telemetry::error!(error = %e, "Model request failed");
808 gemini_error_to_adk(&e)
809 })?;
810
811 let llm_response = Self::convert_response(&response)?;
812
813 let stream = async_stream::stream! {
814 yield Ok(llm_response);
815 };
816
817 Ok(Box::pin(stream))
818 }
819 }
820
821 pub async fn create_cached_content(
826 &self,
827 system_instruction: &str,
828 tools: &std::collections::HashMap<String, serde_json::Value>,
829 ttl_seconds: u32,
830 ) -> Result<String> {
831 let mut cache_builder = self
832 .client
833 .create_cache()
834 .with_system_instruction(system_instruction)
835 .with_ttl(std::time::Duration::from_secs(u64::from(ttl_seconds)));
836
837 let (gemini_tools, tool_config) = Self::build_gemini_tools(tools)?;
838 if !gemini_tools.is_empty() {
839 cache_builder = cache_builder.with_tools(gemini_tools);
840 }
841 if tool_config != adk_gemini::ToolConfig::default() {
842 cache_builder = cache_builder.with_tool_config(tool_config);
843 }
844
845 let handle = cache_builder
846 .execute()
847 .await
848 .map_err(|e| adk_core::AdkError::model(format!("cache creation failed: {e}")))?;
849
850 Ok(handle.name().to_string())
851 }
852
853 pub async fn delete_cached_content(&self, name: &str) -> Result<()> {
855 let handle = self.client.get_cached_content(name);
856 handle
857 .delete()
858 .await
859 .map_err(|(_, e)| adk_core::AdkError::model(format!("cache deletion failed: {e}")))?;
860 Ok(())
861 }
862}
863
864#[async_trait]
865impl Llm for GeminiModel {
866 fn name(&self) -> &str {
867 &self.model_name
868 }
869
870 #[adk_telemetry::instrument(
871 name = "call_llm",
872 skip(self, req),
873 fields(
874 model.name = %self.model_name,
875 stream = %stream,
876 request.contents_count = %req.contents.len(),
877 request.tools_count = %req.tools.len()
878 )
879 )]
880 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
881 adk_telemetry::info!("Generating content");
882 let usage_span = adk_telemetry::llm_generate_span("gemini", &self.model_name, stream);
883 let result = execute_with_retry(&self.retry_config, is_retryable_model_error, || {
886 self.generate_content_internal(req.clone(), stream)
887 })
888 .await?;
889 Ok(crate::usage_tracking::with_usage_tracking(result, usage_span))
890 }
891}
892
893#[cfg(test)]
894mod native_tool_tests {
895 use super::*;
896
897 #[test]
898 fn test_build_gemini_tools_supports_native_tool_metadata() {
899 let mut tools = std::collections::HashMap::new();
900 tools.insert(
901 "google_search".to_string(),
902 serde_json::json!({
903 "x-adk-gemini-tool": {
904 "google_search": {}
905 }
906 }),
907 );
908 tools.insert(
909 "lookup_weather".to_string(),
910 serde_json::json!({
911 "name": "lookup_weather",
912 "description": "lookup weather",
913 "parameters": {
914 "type": "object",
915 "properties": {
916 "city": { "type": "string" }
917 }
918 }
919 }),
920 );
921
922 let (gemini_tools, tool_config) =
923 GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
924
925 assert_eq!(gemini_tools.len(), 2);
926 assert_eq!(tool_config.include_server_side_tool_invocations, Some(true));
927 }
928
929 #[test]
930 fn test_build_gemini_tools_sets_flag_for_builtin_only() {
931 let mut tools = std::collections::HashMap::new();
932 tools.insert(
933 "google_search".to_string(),
934 serde_json::json!({
935 "x-adk-gemini-tool": {
936 "google_search": {}
937 }
938 }),
939 );
940
941 let (_gemini_tools, tool_config) =
942 GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
943
944 assert_eq!(
945 tool_config.include_server_side_tool_invocations,
946 Some(true),
947 "includeServerSideToolInvocations should be set even with only built-in tools"
948 );
949 }
950
951 #[test]
952 fn test_build_gemini_tools_no_flag_for_function_only() {
953 let mut tools = std::collections::HashMap::new();
954 tools.insert(
955 "lookup_weather".to_string(),
956 serde_json::json!({
957 "name": "lookup_weather",
958 "description": "lookup weather",
959 "parameters": {
960 "type": "object",
961 "properties": {
962 "city": { "type": "string" }
963 }
964 }
965 }),
966 );
967
968 let (_gemini_tools, tool_config) =
969 GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
970
971 assert_eq!(
972 tool_config.include_server_side_tool_invocations, None,
973 "includeServerSideToolInvocations should NOT be set for function-only tools"
974 );
975 }
976
977 #[test]
978 fn test_build_gemini_tools_merges_native_tool_config() {
979 let mut tools = std::collections::HashMap::new();
980 tools.insert(
981 "google_maps".to_string(),
982 serde_json::json!({
983 "x-adk-gemini-tool": {
984 "google_maps": {
985 "enable_widget": true
986 }
987 },
988 "x-adk-gemini-tool-config": {
989 "retrievalConfig": {
990 "latLng": {
991 "latitude": 1.23,
992 "longitude": 4.56
993 }
994 }
995 }
996 }),
997 );
998
999 let (_gemini_tools, tool_config) =
1000 GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
1001
1002 assert_eq!(
1003 tool_config.retrieval_config,
1004 Some(serde_json::json!({
1005 "latLng": {
1006 "latitude": 1.23,
1007 "longitude": 4.56
1008 }
1009 }))
1010 );
1011 }
1012}
1013
1014#[async_trait]
1015impl CacheCapable for GeminiModel {
1016 async fn create_cache(
1017 &self,
1018 system_instruction: &str,
1019 tools: &std::collections::HashMap<String, serde_json::Value>,
1020 ttl_seconds: u32,
1021 ) -> Result<String> {
1022 self.create_cached_content(system_instruction, tools, ttl_seconds).await
1023 }
1024
1025 async fn delete_cache(&self, name: &str) -> Result<()> {
1026 self.delete_cached_content(name).await
1027 }
1028}
1029
1030#[cfg(test)]
1031mod tests {
1032 use super::*;
1033 use adk_core::AdkError;
1034 use std::{
1035 sync::{
1036 Arc,
1037 atomic::{AtomicU32, Ordering},
1038 },
1039 time::Duration,
1040 };
1041
1042 #[test]
1043 fn constructor_is_backward_compatible_and_sync() {
1044 fn accepts_sync_constructor<F>(_f: F)
1045 where
1046 F: Fn(&str, &str) -> Result<GeminiModel>,
1047 {
1048 }
1049
1050 accepts_sync_constructor(|api_key, model| GeminiModel::new(api_key, model));
1051 }
1052
1053 #[test]
1054 fn stream_chunks_from_response_injects_partial_before_lone_final_chunk() {
1055 let response = LlmResponse {
1056 content: Some(Content::new("model").with_text("hello")),
1057 usage_metadata: None,
1058 finish_reason: Some(FinishReason::Stop),
1059 citation_metadata: None,
1060 partial: false,
1061 turn_complete: true,
1062 interrupted: false,
1063 error_code: None,
1064 error_message: None,
1065 provider_metadata: None,
1066 };
1067
1068 let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, false);
1069 assert!(saw_partial);
1070 assert_eq!(chunks.len(), 2);
1071 assert!(chunks[0].partial);
1072 assert!(!chunks[0].turn_complete);
1073 assert!(chunks[0].content.is_none());
1074 assert!(!chunks[1].partial);
1075 assert!(chunks[1].turn_complete);
1076 }
1077
1078 #[test]
1079 fn stream_chunks_from_response_keeps_final_only_when_partial_already_seen() {
1080 let response = LlmResponse {
1081 content: Some(Content::new("model").with_text("done")),
1082 usage_metadata: None,
1083 finish_reason: Some(FinishReason::Stop),
1084 citation_metadata: None,
1085 partial: false,
1086 turn_complete: true,
1087 interrupted: false,
1088 error_code: None,
1089 error_message: None,
1090 provider_metadata: None,
1091 };
1092
1093 let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, true);
1094 assert!(saw_partial);
1095 assert_eq!(chunks.len(), 1);
1096 assert!(!chunks[0].partial);
1097 assert!(chunks[0].turn_complete);
1098 }
1099
1100 #[tokio::test]
1101 async fn execute_with_retry_retries_retryable_errors() {
1102 let retry_config = RetryConfig::default()
1103 .with_max_retries(2)
1104 .with_initial_delay(Duration::from_millis(0))
1105 .with_max_delay(Duration::from_millis(0));
1106 let attempts = Arc::new(AtomicU32::new(0));
1107
1108 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
1109 let attempts = Arc::clone(&attempts);
1110 async move {
1111 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
1112 if attempt < 2 {
1113 return Err(AdkError::model("code 429 RESOURCE_EXHAUSTED"));
1114 }
1115 Ok("ok")
1116 }
1117 })
1118 .await
1119 .expect("retry should eventually succeed");
1120
1121 assert_eq!(result, "ok");
1122 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1123 }
1124
1125 #[tokio::test]
1126 async fn execute_with_retry_does_not_retry_non_retryable_errors() {
1127 let retry_config = RetryConfig::default()
1128 .with_max_retries(3)
1129 .with_initial_delay(Duration::from_millis(0))
1130 .with_max_delay(Duration::from_millis(0));
1131 let attempts = Arc::new(AtomicU32::new(0));
1132
1133 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1134 let attempts = Arc::clone(&attempts);
1135 async move {
1136 attempts.fetch_add(1, Ordering::SeqCst);
1137 Err::<(), _>(AdkError::model("code 400 invalid request"))
1138 }
1139 })
1140 .await
1141 .expect_err("non-retryable error should return immediately");
1142
1143 assert!(error.is_model());
1144 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1145 }
1146
1147 #[tokio::test]
1148 async fn execute_with_retry_respects_disabled_config() {
1149 let retry_config = RetryConfig::disabled().with_max_retries(10);
1150 let attempts = Arc::new(AtomicU32::new(0));
1151
1152 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1153 let attempts = Arc::clone(&attempts);
1154 async move {
1155 attempts.fetch_add(1, Ordering::SeqCst);
1156 Err::<(), _>(AdkError::model("code 429 RESOURCE_EXHAUSTED"))
1157 }
1158 })
1159 .await
1160 .expect_err("disabled retries should return first error");
1161
1162 assert!(error.is_model());
1163 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1164 }
1165
1166 #[test]
1167 fn convert_response_preserves_citation_metadata() {
1168 let response = adk_gemini::GenerationResponse {
1169 candidates: vec![adk_gemini::Candidate {
1170 content: adk_gemini::Content {
1171 role: Some(adk_gemini::Role::Model),
1172 parts: Some(vec![adk_gemini::Part::Text {
1173 text: "hello world".to_string(),
1174 thought: None,
1175 thought_signature: None,
1176 }]),
1177 },
1178 safety_ratings: None,
1179 citation_metadata: Some(adk_gemini::CitationMetadata {
1180 citation_sources: vec![adk_gemini::CitationSource {
1181 uri: Some("https://example.com".to_string()),
1182 title: Some("Example".to_string()),
1183 start_index: Some(0),
1184 end_index: Some(5),
1185 license: Some("CC-BY".to_string()),
1186 publication_date: None,
1187 }],
1188 }),
1189 grounding_metadata: None,
1190 finish_reason: Some(adk_gemini::FinishReason::Stop),
1191 index: Some(0),
1192 }],
1193 prompt_feedback: None,
1194 usage_metadata: None,
1195 model_version: None,
1196 response_id: None,
1197 };
1198
1199 let converted =
1200 GeminiModel::convert_response(&response).expect("conversion should succeed");
1201 let metadata = converted.citation_metadata.expect("citation metadata should be mapped");
1202 assert_eq!(metadata.citation_sources.len(), 1);
1203 assert_eq!(metadata.citation_sources[0].uri.as_deref(), Some("https://example.com"));
1204 assert_eq!(metadata.citation_sources[0].start_index, Some(0));
1205 assert_eq!(metadata.citation_sources[0].end_index, Some(5));
1206 }
1207
1208 #[test]
1209 fn convert_response_handles_inline_data_from_model() {
1210 let image_bytes = vec![0x89, 0x50, 0x4E, 0x47];
1211 let encoded = crate::attachment::encode_base64(&image_bytes);
1212
1213 let response = adk_gemini::GenerationResponse {
1214 candidates: vec![adk_gemini::Candidate {
1215 content: adk_gemini::Content {
1216 role: Some(adk_gemini::Role::Model),
1217 parts: Some(vec![
1218 adk_gemini::Part::Text {
1219 text: "Here is the image".to_string(),
1220 thought: None,
1221 thought_signature: None,
1222 },
1223 adk_gemini::Part::InlineData {
1224 inline_data: adk_gemini::Blob {
1225 mime_type: "image/png".to_string(),
1226 data: encoded,
1227 },
1228 },
1229 ]),
1230 },
1231 safety_ratings: None,
1232 citation_metadata: None,
1233 grounding_metadata: None,
1234 finish_reason: Some(adk_gemini::FinishReason::Stop),
1235 index: Some(0),
1236 }],
1237 prompt_feedback: None,
1238 usage_metadata: None,
1239 model_version: None,
1240 response_id: None,
1241 };
1242
1243 let converted =
1244 GeminiModel::convert_response(&response).expect("conversion should succeed");
1245 let content = converted.content.expect("should have content");
1246 assert!(
1247 content
1248 .parts
1249 .iter()
1250 .any(|part| matches!(part, Part::Text { text } if text == "Here is the image"))
1251 );
1252 assert!(content.parts.iter().any(|part| {
1253 matches!(
1254 part,
1255 Part::InlineData { mime_type, data }
1256 if mime_type == "image/png" && data.as_slice() == image_bytes.as_slice()
1257 )
1258 }));
1259 }
1260
1261 #[test]
1262 fn gemini_function_response_payload_preserves_objects() {
1263 let value = serde_json::json!({
1264 "documents": [
1265 { "id": "pricing", "score": 0.91 }
1266 ]
1267 });
1268
1269 let payload = GeminiModel::gemini_function_response_payload(value.clone());
1270
1271 assert_eq!(payload, value);
1272 }
1273
1274 #[test]
1275 fn gemini_function_response_payload_wraps_arrays() {
1276 let payload =
1277 GeminiModel::gemini_function_response_payload(serde_json::json!([{ "id": "pricing" }]));
1278
1279 assert_eq!(payload, serde_json::json!({ "result": [{ "id": "pricing" }] }));
1280 }
1281
1282 fn convert_function_response_to_gemini_fr(
1287 frd: &adk_core::FunctionResponseData,
1288 ) -> adk_gemini::tools::FunctionResponse {
1289 let mut fr_parts = Vec::new();
1290
1291 for inline in &frd.inline_data {
1292 let encoded = crate::attachment::encode_base64(&inline.data);
1293 fr_parts.push(adk_gemini::FunctionResponsePart::InlineData {
1294 inline_data: adk_gemini::Blob {
1295 mime_type: inline.mime_type.clone(),
1296 data: encoded,
1297 },
1298 });
1299 }
1300
1301 for file in &frd.file_data {
1302 fr_parts.push(adk_gemini::FunctionResponsePart::FileData {
1303 file_data: adk_gemini::FileDataRef {
1304 mime_type: file.mime_type.clone(),
1305 file_uri: file.file_uri.clone(),
1306 },
1307 });
1308 }
1309
1310 let mut gemini_fr = adk_gemini::tools::FunctionResponse::new(
1311 &frd.name,
1312 GeminiModel::gemini_function_response_payload(frd.response.clone()),
1313 );
1314 gemini_fr.parts = fr_parts;
1315 gemini_fr
1316 }
1317
1318 #[test]
1319 fn json_only_function_response_has_no_nested_parts() {
1320 let frd = adk_core::FunctionResponseData::new("tool", serde_json::json!({"ok": true}));
1321 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1322 assert!(gemini_fr.parts.is_empty());
1323 let json = serde_json::to_string(&gemini_fr).unwrap();
1325 assert!(!json.contains("\"parts\""));
1326 }
1327
1328 #[test]
1329 fn function_response_with_inline_data_has_nested_parts() {
1330 let frd = adk_core::FunctionResponseData::with_inline_data(
1331 "chart",
1332 serde_json::json!({"status": "ok"}),
1333 vec![adk_core::InlineDataPart {
1334 mime_type: "image/png".to_string(),
1335 data: vec![0x89, 0x50, 0x4E, 0x47],
1336 }],
1337 );
1338 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1339 assert_eq!(gemini_fr.parts.len(), 1);
1340 match &gemini_fr.parts[0] {
1341 adk_gemini::FunctionResponsePart::InlineData { inline_data } => {
1342 assert_eq!(inline_data.mime_type, "image/png");
1343 let decoded = BASE64_STANDARD.decode(&inline_data.data).unwrap();
1344 assert_eq!(decoded, vec![0x89, 0x50, 0x4E, 0x47]);
1345 }
1346 other => panic!("expected InlineData, got {other:?}"),
1347 }
1348 }
1349
1350 #[test]
1351 fn function_response_with_file_data_has_nested_parts() {
1352 let frd = adk_core::FunctionResponseData::with_file_data(
1353 "doc",
1354 serde_json::json!({"ok": true}),
1355 vec![adk_core::FileDataPart {
1356 mime_type: "application/pdf".to_string(),
1357 file_uri: "gs://bucket/report.pdf".to_string(),
1358 }],
1359 );
1360 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1361 assert_eq!(gemini_fr.parts.len(), 1);
1362 match &gemini_fr.parts[0] {
1363 adk_gemini::FunctionResponsePart::FileData { file_data } => {
1364 assert_eq!(file_data.mime_type, "application/pdf");
1365 assert_eq!(file_data.file_uri, "gs://bucket/report.pdf");
1366 }
1367 other => panic!("expected FileData, got {other:?}"),
1368 }
1369 }
1370
1371 #[test]
1372 fn function_response_with_both_inline_and_file_data_ordering() {
1373 let frd = adk_core::FunctionResponseData::with_multimodal(
1374 "multi",
1375 serde_json::json!({}),
1376 vec![
1377 adk_core::InlineDataPart { mime_type: "image/png".to_string(), data: vec![1, 2] },
1378 adk_core::InlineDataPart { mime_type: "image/jpeg".to_string(), data: vec![3, 4] },
1379 ],
1380 vec![adk_core::FileDataPart {
1381 mime_type: "application/pdf".to_string(),
1382 file_uri: "gs://b/f.pdf".to_string(),
1383 }],
1384 );
1385 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1386 assert_eq!(gemini_fr.parts.len(), 3);
1388 assert!(matches!(&gemini_fr.parts[0], adk_gemini::FunctionResponsePart::InlineData { .. }));
1389 assert!(matches!(&gemini_fr.parts[1], adk_gemini::FunctionResponsePart::InlineData { .. }));
1390 assert!(matches!(&gemini_fr.parts[2], adk_gemini::FunctionResponsePart::FileData { .. }));
1391 }
1392}