1use crate::attachment;
2use crate::retry::{RetryConfig, execute_with_retry, is_retryable_model_error};
3use adk_core::{
4 CacheCapable, CitationMetadata, CitationSource, Content, ErrorCategory, ErrorComponent,
5 FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result, SchemaAdapter,
6 SchemaCache, UsageMetadata,
7};
8use adk_gemini::Gemini;
9use adk_gemini::schema_adapter::GeminiSchemaAdapter;
10use async_trait::async_trait;
11use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
12use futures::TryStreamExt;
13
14pub struct GeminiModel {
15 client: Gemini,
16 model_name: String,
17 retry_config: RetryConfig,
18 thinking_config: Option<adk_gemini::ThinkingConfig>,
24}
25
26fn gemini_error_to_adk(e: &adk_gemini::ClientError) -> adk_core::AdkError {
28 fn format_error_chain(e: &dyn std::error::Error) -> String {
29 let mut msg = e.to_string();
30 let mut source = e.source();
31 while let Some(s) = source {
32 msg.push_str(": ");
33 msg.push_str(&s.to_string());
34 source = s.source();
35 }
36 msg
37 }
38
39 let message = format_error_chain(e);
40
41 let (category, code, status_code) = if message.contains("code 429")
44 || message.contains("RESOURCE_EXHAUSTED")
45 || message.contains("rate limit")
46 {
47 (ErrorCategory::RateLimited, "model.gemini.rate_limited", Some(429u16))
48 } else if message.contains("code 503") || message.contains("UNAVAILABLE") {
49 (ErrorCategory::Unavailable, "model.gemini.unavailable", Some(503))
50 } else if message.contains("code 529") || message.contains("OVERLOADED") {
51 (ErrorCategory::Unavailable, "model.gemini.overloaded", Some(529))
52 } else if message.contains("code 408")
53 || message.contains("DEADLINE_EXCEEDED")
54 || message.contains("TIMEOUT")
55 {
56 (ErrorCategory::Timeout, "model.gemini.timeout", Some(408))
57 } else if message.contains("code 401") || message.contains("Invalid API key") {
58 (ErrorCategory::Unauthorized, "model.gemini.unauthorized", Some(401))
59 } else if message.contains("code 400") {
60 (ErrorCategory::InvalidInput, "model.gemini.bad_request", Some(400))
61 } else if message.contains("code 404") {
62 (ErrorCategory::NotFound, "model.gemini.not_found", Some(404))
63 } else if message.contains("invalid generation config") {
64 (ErrorCategory::InvalidInput, "model.gemini.invalid_config", None)
65 } else {
66 (ErrorCategory::Internal, "model.gemini.internal", None)
67 };
68
69 let mut err = adk_core::AdkError::new(ErrorComponent::Model, category, code, message)
70 .with_provider("gemini");
71 if let Some(sc) = status_code {
72 err = err.with_upstream_status(sc);
73 }
74 err
75}
76
77impl GeminiModel {
78 fn gemini_part_thought_signature(value: &serde_json::Value) -> Option<String> {
79 value.get("thoughtSignature").and_then(serde_json::Value::as_str).map(str::to_string)
80 }
81
82 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
83 let model_name = model.into();
84 let client = Gemini::with_model(api_key.into(), model_name.clone())
85 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
86
87 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
88 }
89
90 #[cfg(feature = "gemini-vertex")]
94 pub fn new_google_cloud(
95 api_key: impl Into<String>,
96 project_id: impl AsRef<str>,
97 location: impl AsRef<str>,
98 model: impl Into<String>,
99 ) -> Result<Self> {
100 let model_name = model.into();
101 let client = Gemini::with_google_cloud_model(
102 api_key.into(),
103 project_id,
104 location,
105 model_name.clone(),
106 )
107 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
108
109 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
110 }
111
112 #[cfg(feature = "gemini-vertex")]
116 pub fn new_google_cloud_service_account(
117 service_account_json: &str,
118 project_id: impl AsRef<str>,
119 location: impl AsRef<str>,
120 model: impl Into<String>,
121 ) -> Result<Self> {
122 let model_name = model.into();
123 let client = Gemini::with_google_cloud_service_account_json(
124 service_account_json,
125 project_id.as_ref(),
126 location.as_ref(),
127 model_name.clone(),
128 )
129 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
130
131 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
132 }
133
134 #[cfg(feature = "gemini-vertex")]
138 pub fn new_google_cloud_adc(
139 project_id: impl AsRef<str>,
140 location: impl AsRef<str>,
141 model: impl Into<String>,
142 ) -> Result<Self> {
143 let model_name = model.into();
144 let client = Gemini::with_google_cloud_adc_model(
145 project_id.as_ref(),
146 location.as_ref(),
147 model_name.clone(),
148 )
149 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
150
151 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
152 }
153
154 #[cfg(feature = "gemini-vertex")]
158 pub fn new_google_cloud_wif(
159 wif_json: &str,
160 project_id: impl AsRef<str>,
161 location: impl AsRef<str>,
162 model: impl Into<String>,
163 ) -> Result<Self> {
164 let model_name = model.into();
165 let client = Gemini::with_google_cloud_wif_json(
166 wif_json,
167 project_id.as_ref(),
168 location.as_ref(),
169 model_name.clone(),
170 )
171 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
172
173 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
174 }
175
176 #[must_use]
177 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
178 self.retry_config = retry_config;
179 self
180 }
181
182 pub fn set_retry_config(&mut self, retry_config: RetryConfig) {
183 self.retry_config = retry_config;
184 }
185
186 pub fn retry_config(&self) -> &RetryConfig {
187 &self.retry_config
188 }
189
190 #[must_use]
214 pub fn with_thinking_config(mut self, thinking_config: adk_gemini::ThinkingConfig) -> Self {
215 self.thinking_config = Some(thinking_config);
216 self
217 }
218
219 pub fn set_thinking_config(&mut self, thinking_config: adk_gemini::ThinkingConfig) {
221 self.thinking_config = Some(thinking_config);
222 }
223
224 pub fn thinking_config(&self) -> Option<&adk_gemini::ThinkingConfig> {
226 self.thinking_config.as_ref()
227 }
228
229 fn convert_response(resp: &adk_gemini::GenerationResponse) -> Result<LlmResponse> {
230 let mut converted_parts: Vec<Part> = Vec::new();
231
232 if let Some(parts) = resp.candidates.first().and_then(|c| c.content.parts.as_ref()) {
234 for p in parts {
235 match p {
236 adk_gemini::Part::Text { text, thought, thought_signature } => {
237 if thought == &Some(true) {
238 converted_parts.push(Part::Thinking {
239 thinking: text.clone(),
240 signature: thought_signature.clone(),
241 });
242 } else {
243 converted_parts.push(Part::Text { text: text.clone() });
244 }
245 }
246 adk_gemini::Part::InlineData { inline_data } => {
247 let decoded =
248 BASE64_STANDARD.decode(&inline_data.data).map_err(|error| {
249 adk_core::AdkError::model(format!(
250 "failed to decode inline data from gemini response: {error}"
251 ))
252 })?;
253 converted_parts.push(Part::InlineData {
254 mime_type: inline_data.mime_type.clone(),
255 data: decoded,
256 });
257 }
258 adk_gemini::Part::FunctionCall { function_call, thought_signature } => {
259 converted_parts.push(Part::FunctionCall {
260 name: function_call.name.clone(),
261 args: function_call.args.clone(),
262 id: function_call.id.clone(),
263 thought_signature: thought_signature.clone(),
264 });
265 }
266 adk_gemini::Part::FunctionResponse { function_response, .. } => {
267 converted_parts.push(Part::FunctionResponse {
268 function_response: adk_core::FunctionResponseData::new(
269 function_response.name.clone(),
270 function_response
271 .response
272 .clone()
273 .unwrap_or(serde_json::Value::Null),
274 ),
275 id: None,
276 });
277 }
278 adk_gemini::Part::ToolCall { .. } | adk_gemini::Part::ExecutableCode { .. } => {
279 if let Ok(value) = serde_json::to_value(p) {
280 converted_parts.push(Part::ServerToolCall { server_tool_call: value });
281 }
282 }
283 adk_gemini::Part::ToolResponse { .. }
284 | adk_gemini::Part::CodeExecutionResult { .. } => {
285 let value = serde_json::to_value(p).unwrap_or(serde_json::Value::Null);
286 converted_parts
287 .push(Part::ServerToolResponse { server_tool_response: value });
288 }
289 adk_gemini::Part::FileData { file_data } => {
290 converted_parts.push(Part::FileData {
291 mime_type: file_data.mime_type.clone(),
292 file_uri: file_data.file_uri.clone(),
293 });
294 }
295 }
296 }
297 }
298
299 if let Some(grounding) = resp.candidates.first().and_then(|c| c.grounding_metadata.as_ref())
301 {
302 if let Some(queries) = &grounding.web_search_queries {
303 if !queries.is_empty() {
304 let search_info = format!("\n\nš **Searched:** {}", queries.join(", "));
305 converted_parts.push(Part::Text { text: search_info });
306 }
307 }
308 if let Some(chunks) = &grounding.grounding_chunks {
309 let sources: Vec<String> = chunks
310 .iter()
311 .filter_map(|c| {
312 c.web.as_ref().and_then(|w| match (&w.title, &w.uri) {
313 (Some(title), Some(uri)) => Some(format!("[{}]({})", title, uri)),
314 (Some(title), None) => Some(title.clone()),
315 (None, Some(uri)) => Some(uri.to_string()),
316 (None, None) => None,
317 })
318 })
319 .collect();
320 if !sources.is_empty() {
321 let sources_info = format!("\nš **Sources:** {}", sources.join(" | "));
322 converted_parts.push(Part::Text { text: sources_info });
323 }
324 }
325 }
326
327 let content = if converted_parts.is_empty() {
328 None
329 } else {
330 Some(Content { role: "model".to_string(), parts: converted_parts })
331 };
332
333 let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
334 prompt_token_count: u.prompt_token_count.unwrap_or(0),
335 candidates_token_count: u.candidates_token_count.unwrap_or(0),
336 total_token_count: u.total_token_count.unwrap_or(0),
337 thinking_token_count: u.thoughts_token_count,
338 cache_read_input_token_count: u.cached_content_token_count,
339 ..Default::default()
340 });
341
342 let finish_reason =
343 resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
344 adk_gemini::FinishReason::Stop => FinishReason::Stop,
345 adk_gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
346 adk_gemini::FinishReason::Safety => FinishReason::Safety,
347 adk_gemini::FinishReason::Recitation => FinishReason::Recitation,
348 _ => FinishReason::Other,
349 });
350
351 let citation_metadata =
352 resp.candidates.first().and_then(|c| c.citation_metadata.as_ref()).map(|meta| {
353 CitationMetadata {
354 citation_sources: meta
355 .citation_sources
356 .iter()
357 .map(|source| CitationSource {
358 uri: source.uri.clone(),
359 title: source.title.clone(),
360 start_index: source.start_index,
361 end_index: source.end_index,
362 license: source.license.clone(),
363 publication_date: source.publication_date.map(|d| d.to_string()),
364 })
365 .collect(),
366 }
367 });
368
369 let provider_metadata = resp
372 .candidates
373 .first()
374 .and_then(|c| c.grounding_metadata.as_ref())
375 .and_then(|g| serde_json::to_value(g).ok());
376
377 Ok(LlmResponse {
378 content,
379 usage_metadata,
380 finish_reason,
381 citation_metadata,
382 partial: false,
383 turn_complete: true,
384 interrupted: false,
385 error_code: None,
386 error_message: None,
387 provider_metadata,
388 })
389 }
390
391 fn gemini_function_response_payload(response: serde_json::Value) -> serde_json::Value {
392 match response {
393 serde_json::Value::Object(_) => response,
395 other => serde_json::json!({ "result": other }),
396 }
397 }
398
399 fn merge_object_value(
400 target: &mut serde_json::Map<String, serde_json::Value>,
401 value: serde_json::Value,
402 ) {
403 if let serde_json::Value::Object(object) = value {
404 for (key, value) in object {
405 target.insert(key, value);
406 }
407 }
408 }
409
410 fn build_gemini_tools(
411 tools: &std::collections::HashMap<String, serde_json::Value>,
412 adapter: &dyn SchemaAdapter,
413 cache: &SchemaCache,
414 ) -> Result<(Vec<adk_gemini::Tool>, adk_gemini::ToolConfig)> {
415 let mut gemini_tools = Vec::new();
416 let mut function_declarations = Vec::new();
417 let mut has_provider_native_tools = false;
418 let mut tool_config_json = serde_json::Map::new();
419
420 for (name, tool_decl) in tools {
421 if let Some(provider_tool) = tool_decl.get("x-adk-gemini-tool") {
422 let tool = serde_json::from_value::<adk_gemini::Tool>(provider_tool.clone())
423 .map_err(|error| {
424 adk_core::AdkError::model(format!(
425 "failed to deserialize Gemini native tool '{name}': {error}"
426 ))
427 })?;
428 has_provider_native_tools = true;
429 gemini_tools.push(tool);
430 } else {
431 let normalized_name = adapter.normalize_tool_name(name);
433
434 let schema =
437 tool_decl.get("parameters").cloned().unwrap_or_else(|| adapter.empty_schema());
438 let normalized_schema = cache.get_or_normalize(&schema, adapter);
439
440 let description =
442 tool_decl.get("description").and_then(|v| v.as_str()).unwrap_or("").to_string();
443
444 let mut func_decl_json = serde_json::json!({
445 "name": normalized_name.as_ref(),
446 "description": description,
447 "parameters": normalized_schema,
448 });
449
450 if let Some(response) = tool_decl.get("response") {
452 func_decl_json["response"] = response.clone();
453 }
454
455 if let Some(behavior) = tool_decl.get("behavior") {
457 func_decl_json["behavior"] = behavior.clone();
458 }
459
460 let func_decl =
461 serde_json::from_value::<adk_gemini::FunctionDeclaration>(func_decl_json)
462 .map_err(|error| {
463 adk_core::AdkError::model(format!(
464 "failed to build Gemini function declaration for '{name}': {error}"
465 ))
466 })?;
467 function_declarations.push(func_decl);
468 }
469
470 if let Some(tool_config) = tool_decl.get("x-adk-gemini-tool-config") {
471 Self::merge_object_value(&mut tool_config_json, tool_config.clone());
472 }
473 }
474
475 let has_function_declarations = !function_declarations.is_empty();
476 if has_function_declarations {
477 gemini_tools.push(adk_gemini::Tool::with_functions(function_declarations));
478 }
479
480 if has_provider_native_tools {
481 tool_config_json.insert(
482 "includeServerSideToolInvocations".to_string(),
483 serde_json::Value::Bool(true),
484 );
485 }
486
487 let tool_config = if tool_config_json.is_empty() {
488 adk_gemini::ToolConfig::default()
489 } else {
490 serde_json::from_value::<adk_gemini::ToolConfig>(serde_json::Value::Object(
491 tool_config_json,
492 ))
493 .map_err(|error| {
494 adk_core::AdkError::model(format!(
495 "failed to deserialize Gemini tool configuration: {error}"
496 ))
497 })?
498 };
499
500 Ok((gemini_tools, tool_config))
501 }
502
503 fn stream_chunks_from_response(
504 mut response: LlmResponse,
505 saw_partial_chunk: bool,
506 ) -> (Vec<LlmResponse>, bool) {
507 let is_final = response.finish_reason.is_some();
508
509 if !is_final {
510 response.partial = true;
511 response.turn_complete = false;
512 return (vec![response], true);
513 }
514
515 response.partial = false;
516 response.turn_complete = true;
517
518 if saw_partial_chunk {
519 return (vec![response], true);
520 }
521
522 let synthetic_partial = LlmResponse {
523 content: None,
524 usage_metadata: None,
525 finish_reason: None,
526 citation_metadata: None,
527 partial: true,
528 turn_complete: false,
529 interrupted: false,
530 error_code: None,
531 error_message: None,
532 provider_metadata: None,
533 };
534
535 (vec![synthetic_partial, response], true)
536 }
537
538 async fn generate_content_internal(
539 &self,
540 req: LlmRequest,
541 stream: bool,
542 ) -> Result<LlmResponseStream> {
543 let mut builder = self.client.generate_content();
544
545 let mut fn_call_signatures: std::collections::HashMap<String, String> =
550 std::collections::HashMap::new();
551 for content in &req.contents {
552 if content.role == "model" {
553 for part in &content.parts {
554 if let Part::FunctionCall { name, thought_signature: Some(sig), .. } = part {
555 fn_call_signatures.insert(name.clone(), sig.clone());
556 }
557 }
558 }
559 }
560
561 for content in &req.contents {
563 match content.role.as_str() {
564 "user" => {
565 let mut gemini_parts = Vec::new();
567 for part in &content.parts {
568 match part {
569 Part::Text { text } => {
570 gemini_parts.push(adk_gemini::Part::Text {
571 text: text.clone(),
572 thought: None,
573 thought_signature: None,
574 });
575 }
576 Part::Thinking { thinking, signature } => {
577 gemini_parts.push(adk_gemini::Part::Text {
578 text: thinking.clone(),
579 thought: Some(true),
580 thought_signature: signature.clone(),
581 });
582 }
583 Part::InlineData { data, mime_type } => {
584 let encoded = attachment::encode_base64(data);
585 gemini_parts.push(adk_gemini::Part::InlineData {
586 inline_data: adk_gemini::Blob {
587 mime_type: mime_type.clone(),
588 data: encoded,
589 },
590 });
591 }
592 Part::FileData { mime_type, file_uri } => {
593 gemini_parts.push(adk_gemini::Part::Text {
594 text: attachment::file_attachment_to_text(mime_type, file_uri),
595 thought: None,
596 thought_signature: None,
597 });
598 }
599 _ => {}
600 }
601 }
602 if !gemini_parts.is_empty() {
603 let user_content = adk_gemini::Content {
604 role: Some(adk_gemini::Role::User),
605 parts: Some(gemini_parts),
606 };
607 builder = builder.with_message(adk_gemini::Message {
608 content: user_content,
609 role: adk_gemini::Role::User,
610 });
611 }
612 }
613 "model" => {
614 let mut gemini_parts = Vec::new();
616 for part in &content.parts {
617 match part {
618 Part::Text { text } => {
619 gemini_parts.push(adk_gemini::Part::Text {
620 text: text.clone(),
621 thought: None,
622 thought_signature: None,
623 });
624 }
625 Part::Thinking { thinking, signature } => {
626 gemini_parts.push(adk_gemini::Part::Text {
627 text: thinking.clone(),
628 thought: Some(true),
629 thought_signature: signature.clone(),
630 });
631 }
632 Part::FunctionCall { name, args, thought_signature, id } => {
633 gemini_parts.push(adk_gemini::Part::FunctionCall {
634 function_call: adk_gemini::FunctionCall {
635 name: name.clone(),
636 args: args.clone(),
637 id: id.clone(),
638 thought_signature: None,
639 },
640 thought_signature: thought_signature.clone(),
641 });
642 }
643 Part::ServerToolCall { server_tool_call } => {
644 if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
645 server_tool_call.clone(),
646 ) {
647 match native_part {
648 adk_gemini::Part::ToolCall { .. }
649 | adk_gemini::Part::ExecutableCode { .. } => {
650 gemini_parts.push(native_part);
651 continue;
652 }
653 _ => {}
654 }
655 }
656
657 gemini_parts.push(adk_gemini::Part::ToolCall {
658 tool_call: server_tool_call.clone(),
659 thought_signature: Self::gemini_part_thought_signature(
660 server_tool_call,
661 ),
662 });
663 }
664 Part::ServerToolResponse { server_tool_response } => {
665 if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
666 server_tool_response.clone(),
667 ) {
668 match native_part {
669 adk_gemini::Part::ToolResponse { .. }
670 | adk_gemini::Part::CodeExecutionResult { .. } => {
671 gemini_parts.push(native_part);
672 continue;
673 }
674 _ => {}
675 }
676 }
677
678 gemini_parts.push(adk_gemini::Part::ToolResponse {
679 tool_response: server_tool_response.clone(),
680 thought_signature: Self::gemini_part_thought_signature(
681 server_tool_response,
682 ),
683 });
684 }
685 _ => {}
686 }
687 }
688 if !gemini_parts.is_empty() {
689 let model_content = adk_gemini::Content {
690 role: Some(adk_gemini::Role::Model),
691 parts: Some(gemini_parts),
692 };
693 builder = builder.with_message(adk_gemini::Message {
694 content: model_content,
695 role: adk_gemini::Role::Model,
696 });
697 }
698 }
699 "function" => {
700 let mut gemini_parts = Vec::new();
703 for part in &content.parts {
704 if let Part::FunctionResponse { function_response, .. } = part {
705 let sig = fn_call_signatures.get(&function_response.name).cloned();
706
707 let mut fr_parts = Vec::new();
709 for inline in &function_response.inline_data {
710 let encoded = attachment::encode_base64(&inline.data);
711 fr_parts.push(adk_gemini::FunctionResponsePart::InlineData {
712 inline_data: adk_gemini::Blob {
713 mime_type: inline.mime_type.clone(),
714 data: encoded,
715 },
716 });
717 }
718 for file in &function_response.file_data {
719 fr_parts.push(adk_gemini::FunctionResponsePart::FileData {
720 file_data: adk_gemini::FileDataRef {
721 mime_type: file.mime_type.clone(),
722 file_uri: file.file_uri.clone(),
723 },
724 });
725 }
726
727 let mut gemini_fr = adk_gemini::tools::FunctionResponse::new(
728 &function_response.name,
729 Self::gemini_function_response_payload(
730 function_response.response.clone(),
731 ),
732 );
733 gemini_fr.parts = fr_parts;
734
735 gemini_parts.push(adk_gemini::Part::FunctionResponse {
736 function_response: gemini_fr,
737 thought_signature: sig,
738 });
739 }
740 }
741 if !gemini_parts.is_empty() {
742 let fn_content = adk_gemini::Content {
743 role: Some(adk_gemini::Role::User),
744 parts: Some(gemini_parts),
745 };
746 builder = builder.with_message(adk_gemini::Message {
747 content: fn_content,
748 role: adk_gemini::Role::User,
749 });
750 }
751 }
752 _ => {}
753 }
754 }
755
756 if let Some(config) = req.config {
758 let has_schema = config.response_schema.is_some();
759 let gen_config = adk_gemini::GenerationConfig {
760 temperature: config.temperature,
761 top_p: config.top_p,
762 top_k: config.top_k,
763 max_output_tokens: config.max_output_tokens,
764 response_schema: config.response_schema,
765 response_mime_type: if has_schema {
766 Some("application/json".to_string())
767 } else {
768 None
769 },
770 thinking_config: self.thinking_config.clone(),
771 ..Default::default()
772 };
773 builder = builder.with_generation_config(gen_config);
774
775 if let Some(ref name) = config.cached_content {
777 let handle = self.client.get_cached_content(name);
778 builder = builder.with_cached_content(&handle);
779 }
780 } else if self.thinking_config.is_some() {
781 let gen_config = adk_gemini::GenerationConfig {
784 thinking_config: self.thinking_config.clone(),
785 ..Default::default()
786 };
787 builder = builder.with_generation_config(gen_config);
788 }
789
790 if !req.tools.is_empty() {
792 let adapter = self.schema_adapter();
793 use std::sync::LazyLock;
794 static SCHEMA_CACHE: LazyLock<SchemaCache> = LazyLock::new(SchemaCache::new);
795 let (gemini_tools, tool_config) =
796 Self::build_gemini_tools(&req.tools, adapter, &SCHEMA_CACHE)?;
797 for tool in gemini_tools {
798 builder = builder.with_tool(tool);
799 }
800 if tool_config != adk_gemini::ToolConfig::default() {
801 builder = builder.with_tool_config(tool_config);
802 }
803 }
804
805 if stream {
806 adk_telemetry::debug!("Executing streaming request");
807 let response_stream = builder.execute_stream().await.map_err(|e| {
808 adk_telemetry::error!(error = %e, "Model request failed");
809 gemini_error_to_adk(&e)
810 })?;
811
812 let mapped_stream = async_stream::stream! {
813 let mut stream = response_stream;
814 let mut saw_partial_chunk = false;
815 while let Some(result) = stream.try_next().await.transpose() {
816 match result {
817 Ok(resp) => {
818 match Self::convert_response(&resp) {
819 Ok(llm_resp) => {
820 let (chunks, next_saw_partial) =
821 Self::stream_chunks_from_response(llm_resp, saw_partial_chunk);
822 saw_partial_chunk = next_saw_partial;
823 for chunk in chunks {
824 yield Ok(chunk);
825 }
826 }
827 Err(e) => {
828 adk_telemetry::error!(error = %e, "Failed to convert response");
829 yield Err(e);
830 }
831 }
832 }
833 Err(e) => {
834 adk_telemetry::error!(error = %e, "Stream error");
835 yield Err(gemini_error_to_adk(&e));
836 }
837 }
838 }
839 };
840
841 Ok(Box::pin(mapped_stream))
842 } else {
843 adk_telemetry::debug!("Executing blocking request");
844 let response = builder.execute().await.map_err(|e| {
845 adk_telemetry::error!(error = %e, "Model request failed");
846 gemini_error_to_adk(&e)
847 })?;
848
849 let llm_response = Self::convert_response(&response)?;
850
851 let stream = async_stream::stream! {
852 yield Ok(llm_response);
853 };
854
855 Ok(Box::pin(stream))
856 }
857 }
858
859 pub async fn create_cached_content(
864 &self,
865 system_instruction: &str,
866 tools: &std::collections::HashMap<String, serde_json::Value>,
867 ttl_seconds: u32,
868 ) -> Result<String> {
869 let mut cache_builder = self
870 .client
871 .create_cache()
872 .with_system_instruction(system_instruction)
873 .with_ttl(std::time::Duration::from_secs(u64::from(ttl_seconds)));
874
875 let adapter = self.schema_adapter();
876 use std::sync::LazyLock;
877 static SCHEMA_CACHE: LazyLock<SchemaCache> = LazyLock::new(SchemaCache::new);
878 let (gemini_tools, tool_config) = Self::build_gemini_tools(tools, adapter, &SCHEMA_CACHE)?;
879 if !gemini_tools.is_empty() {
880 cache_builder = cache_builder.with_tools(gemini_tools);
881 }
882 if tool_config != adk_gemini::ToolConfig::default() {
883 cache_builder = cache_builder.with_tool_config(tool_config);
884 }
885
886 let handle = cache_builder
887 .execute()
888 .await
889 .map_err(|e| adk_core::AdkError::model(format!("cache creation failed: {e}")))?;
890
891 Ok(handle.name().to_string())
892 }
893
894 pub async fn delete_cached_content(&self, name: &str) -> Result<()> {
896 let handle = self.client.get_cached_content(name);
897 handle
898 .delete()
899 .await
900 .map_err(|(_, e)| adk_core::AdkError::model(format!("cache deletion failed: {e}")))?;
901 Ok(())
902 }
903}
904
905#[async_trait]
906impl Llm for GeminiModel {
907 fn name(&self) -> &str {
908 &self.model_name
909 }
910
911 fn schema_adapter(&self) -> &dyn SchemaAdapter {
912 use std::sync::LazyLock;
913 static ADAPTER: LazyLock<GeminiSchemaAdapter> = LazyLock::new(GeminiSchemaAdapter::new);
914 &*ADAPTER
915 }
916
917 #[adk_telemetry::instrument(
918 name = "call_llm",
919 skip(self, req),
920 fields(
921 model.name = %self.model_name,
922 stream = %stream,
923 request.contents_count = %req.contents.len(),
924 request.tools_count = %req.tools.len()
925 )
926 )]
927 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
928 adk_telemetry::info!("Generating content");
929 let usage_span = adk_telemetry::llm_generate_span("gemini", &self.model_name, stream);
930 let result = execute_with_retry(&self.retry_config, is_retryable_model_error, || {
933 self.generate_content_internal(req.clone(), stream)
934 })
935 .await?;
936 Ok(crate::usage_tracking::with_usage_tracking(result, usage_span))
937 }
938}
939
940#[cfg(test)]
941mod native_tool_tests {
942 use super::*;
943
944 fn test_adapter() -> GeminiSchemaAdapter {
945 GeminiSchemaAdapter::new()
946 }
947
948 fn test_cache() -> SchemaCache {
949 SchemaCache::new()
950 }
951
952 #[test]
953 fn test_build_gemini_tools_supports_native_tool_metadata() {
954 let mut tools = std::collections::HashMap::new();
955 tools.insert(
956 "google_search".to_string(),
957 serde_json::json!({
958 "x-adk-gemini-tool": {
959 "google_search": {}
960 }
961 }),
962 );
963 tools.insert(
964 "lookup_weather".to_string(),
965 serde_json::json!({
966 "name": "lookup_weather",
967 "description": "lookup weather",
968 "parameters": {
969 "type": "object",
970 "properties": {
971 "city": { "type": "string" }
972 }
973 }
974 }),
975 );
976
977 let adapter = test_adapter();
978 let cache = test_cache();
979 let (gemini_tools, tool_config) = GeminiModel::build_gemini_tools(&tools, &adapter, &cache)
980 .expect("tool conversion should succeed");
981
982 assert_eq!(gemini_tools.len(), 2);
983 assert_eq!(tool_config.include_server_side_tool_invocations, Some(true));
984 }
985
986 #[test]
987 fn test_build_gemini_tools_sets_flag_for_builtin_only() {
988 let mut tools = std::collections::HashMap::new();
989 tools.insert(
990 "google_search".to_string(),
991 serde_json::json!({
992 "x-adk-gemini-tool": {
993 "google_search": {}
994 }
995 }),
996 );
997
998 let adapter = test_adapter();
999 let cache = test_cache();
1000 let (_gemini_tools, tool_config) =
1001 GeminiModel::build_gemini_tools(&tools, &adapter, &cache)
1002 .expect("tool conversion should succeed");
1003
1004 assert_eq!(
1005 tool_config.include_server_side_tool_invocations,
1006 Some(true),
1007 "includeServerSideToolInvocations should be set even with only built-in tools"
1008 );
1009 }
1010
1011 #[test]
1012 fn test_build_gemini_tools_no_flag_for_function_only() {
1013 let mut tools = std::collections::HashMap::new();
1014 tools.insert(
1015 "lookup_weather".to_string(),
1016 serde_json::json!({
1017 "name": "lookup_weather",
1018 "description": "lookup weather",
1019 "parameters": {
1020 "type": "object",
1021 "properties": {
1022 "city": { "type": "string" }
1023 }
1024 }
1025 }),
1026 );
1027
1028 let adapter = test_adapter();
1029 let cache = test_cache();
1030 let (_gemini_tools, tool_config) =
1031 GeminiModel::build_gemini_tools(&tools, &adapter, &cache)
1032 .expect("tool conversion should succeed");
1033
1034 assert_eq!(
1035 tool_config.include_server_side_tool_invocations, None,
1036 "includeServerSideToolInvocations should NOT be set for function-only tools"
1037 );
1038 }
1039
1040 #[test]
1041 fn test_build_gemini_tools_merges_native_tool_config() {
1042 let mut tools = std::collections::HashMap::new();
1043 tools.insert(
1044 "google_maps".to_string(),
1045 serde_json::json!({
1046 "x-adk-gemini-tool": {
1047 "google_maps": {
1048 "enable_widget": true
1049 }
1050 },
1051 "x-adk-gemini-tool-config": {
1052 "retrievalConfig": {
1053 "latLng": {
1054 "latitude": 1.23,
1055 "longitude": 4.56
1056 }
1057 }
1058 }
1059 }),
1060 );
1061
1062 let adapter = test_adapter();
1063 let cache = test_cache();
1064 let (_gemini_tools, tool_config) =
1065 GeminiModel::build_gemini_tools(&tools, &adapter, &cache)
1066 .expect("tool conversion should succeed");
1067
1068 assert_eq!(
1069 tool_config.retrieval_config,
1070 Some(serde_json::json!({
1071 "latLng": {
1072 "latitude": 1.23,
1073 "longitude": 4.56
1074 }
1075 }))
1076 );
1077 }
1078}
1079
1080#[async_trait]
1081impl CacheCapable for GeminiModel {
1082 async fn create_cache(
1083 &self,
1084 system_instruction: &str,
1085 tools: &std::collections::HashMap<String, serde_json::Value>,
1086 ttl_seconds: u32,
1087 ) -> Result<String> {
1088 self.create_cached_content(system_instruction, tools, ttl_seconds).await
1089 }
1090
1091 async fn delete_cache(&self, name: &str) -> Result<()> {
1092 self.delete_cached_content(name).await
1093 }
1094}
1095
1096#[cfg(test)]
1097mod tests {
1098 use super::*;
1099 use adk_core::AdkError;
1100 use std::{
1101 sync::{
1102 Arc,
1103 atomic::{AtomicU32, Ordering},
1104 },
1105 time::Duration,
1106 };
1107
1108 #[test]
1109 fn constructor_is_backward_compatible_and_sync() {
1110 fn accepts_sync_constructor<F>(_f: F)
1111 where
1112 F: Fn(&str, &str) -> Result<GeminiModel>,
1113 {
1114 }
1115
1116 accepts_sync_constructor(|api_key, model| GeminiModel::new(api_key, model));
1117 }
1118
1119 #[test]
1120 fn stream_chunks_from_response_injects_partial_before_lone_final_chunk() {
1121 let response = LlmResponse {
1122 content: Some(Content::new("model").with_text("hello")),
1123 usage_metadata: None,
1124 finish_reason: Some(FinishReason::Stop),
1125 citation_metadata: None,
1126 partial: false,
1127 turn_complete: true,
1128 interrupted: false,
1129 error_code: None,
1130 error_message: None,
1131 provider_metadata: None,
1132 };
1133
1134 let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, false);
1135 assert!(saw_partial);
1136 assert_eq!(chunks.len(), 2);
1137 assert!(chunks[0].partial);
1138 assert!(!chunks[0].turn_complete);
1139 assert!(chunks[0].content.is_none());
1140 assert!(!chunks[1].partial);
1141 assert!(chunks[1].turn_complete);
1142 }
1143
1144 #[test]
1145 fn stream_chunks_from_response_keeps_final_only_when_partial_already_seen() {
1146 let response = LlmResponse {
1147 content: Some(Content::new("model").with_text("done")),
1148 usage_metadata: None,
1149 finish_reason: Some(FinishReason::Stop),
1150 citation_metadata: None,
1151 partial: false,
1152 turn_complete: true,
1153 interrupted: false,
1154 error_code: None,
1155 error_message: None,
1156 provider_metadata: None,
1157 };
1158
1159 let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, true);
1160 assert!(saw_partial);
1161 assert_eq!(chunks.len(), 1);
1162 assert!(!chunks[0].partial);
1163 assert!(chunks[0].turn_complete);
1164 }
1165
1166 #[tokio::test]
1167 async fn execute_with_retry_retries_retryable_errors() {
1168 let retry_config = RetryConfig::default()
1169 .with_max_retries(2)
1170 .with_initial_delay(Duration::from_millis(0))
1171 .with_max_delay(Duration::from_millis(0));
1172 let attempts = Arc::new(AtomicU32::new(0));
1173
1174 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
1175 let attempts = Arc::clone(&attempts);
1176 async move {
1177 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
1178 if attempt < 2 {
1179 return Err(AdkError::model("code 429 RESOURCE_EXHAUSTED"));
1180 }
1181 Ok("ok")
1182 }
1183 })
1184 .await
1185 .expect("retry should eventually succeed");
1186
1187 assert_eq!(result, "ok");
1188 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1189 }
1190
1191 #[tokio::test]
1192 async fn execute_with_retry_does_not_retry_non_retryable_errors() {
1193 let retry_config = RetryConfig::default()
1194 .with_max_retries(3)
1195 .with_initial_delay(Duration::from_millis(0))
1196 .with_max_delay(Duration::from_millis(0));
1197 let attempts = Arc::new(AtomicU32::new(0));
1198
1199 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1200 let attempts = Arc::clone(&attempts);
1201 async move {
1202 attempts.fetch_add(1, Ordering::SeqCst);
1203 Err::<(), _>(AdkError::model("code 400 invalid request"))
1204 }
1205 })
1206 .await
1207 .expect_err("non-retryable error should return immediately");
1208
1209 assert!(error.is_model());
1210 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1211 }
1212
1213 #[tokio::test]
1214 async fn execute_with_retry_respects_disabled_config() {
1215 let retry_config = RetryConfig::disabled().with_max_retries(10);
1216 let attempts = Arc::new(AtomicU32::new(0));
1217
1218 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1219 let attempts = Arc::clone(&attempts);
1220 async move {
1221 attempts.fetch_add(1, Ordering::SeqCst);
1222 Err::<(), _>(AdkError::model("code 429 RESOURCE_EXHAUSTED"))
1223 }
1224 })
1225 .await
1226 .expect_err("disabled retries should return first error");
1227
1228 assert!(error.is_model());
1229 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1230 }
1231
1232 #[test]
1233 fn convert_response_preserves_citation_metadata() {
1234 let response = adk_gemini::GenerationResponse {
1235 candidates: vec![adk_gemini::Candidate {
1236 content: adk_gemini::Content {
1237 role: Some(adk_gemini::Role::Model),
1238 parts: Some(vec![adk_gemini::Part::Text {
1239 text: "hello world".to_string(),
1240 thought: None,
1241 thought_signature: None,
1242 }]),
1243 },
1244 safety_ratings: None,
1245 citation_metadata: Some(adk_gemini::CitationMetadata {
1246 citation_sources: vec![adk_gemini::CitationSource {
1247 uri: Some("https://example.com".to_string()),
1248 title: Some("Example".to_string()),
1249 start_index: Some(0),
1250 end_index: Some(5),
1251 license: Some("CC-BY".to_string()),
1252 publication_date: None,
1253 }],
1254 }),
1255 grounding_metadata: None,
1256 finish_reason: Some(adk_gemini::FinishReason::Stop),
1257 index: Some(0),
1258 }],
1259 prompt_feedback: None,
1260 usage_metadata: None,
1261 model_version: None,
1262 response_id: None,
1263 };
1264
1265 let converted =
1266 GeminiModel::convert_response(&response).expect("conversion should succeed");
1267 let metadata = converted.citation_metadata.expect("citation metadata should be mapped");
1268 assert_eq!(metadata.citation_sources.len(), 1);
1269 assert_eq!(metadata.citation_sources[0].uri.as_deref(), Some("https://example.com"));
1270 assert_eq!(metadata.citation_sources[0].start_index, Some(0));
1271 assert_eq!(metadata.citation_sources[0].end_index, Some(5));
1272 }
1273
1274 #[test]
1275 fn convert_response_handles_inline_data_from_model() {
1276 let image_bytes = vec![0x89, 0x50, 0x4E, 0x47];
1277 let encoded = crate::attachment::encode_base64(&image_bytes);
1278
1279 let response = adk_gemini::GenerationResponse {
1280 candidates: vec![adk_gemini::Candidate {
1281 content: adk_gemini::Content {
1282 role: Some(adk_gemini::Role::Model),
1283 parts: Some(vec![
1284 adk_gemini::Part::Text {
1285 text: "Here is the image".to_string(),
1286 thought: None,
1287 thought_signature: None,
1288 },
1289 adk_gemini::Part::InlineData {
1290 inline_data: adk_gemini::Blob {
1291 mime_type: "image/png".to_string(),
1292 data: encoded,
1293 },
1294 },
1295 ]),
1296 },
1297 safety_ratings: None,
1298 citation_metadata: None,
1299 grounding_metadata: None,
1300 finish_reason: Some(adk_gemini::FinishReason::Stop),
1301 index: Some(0),
1302 }],
1303 prompt_feedback: None,
1304 usage_metadata: None,
1305 model_version: None,
1306 response_id: None,
1307 };
1308
1309 let converted =
1310 GeminiModel::convert_response(&response).expect("conversion should succeed");
1311 let content = converted.content.expect("should have content");
1312 assert!(
1313 content
1314 .parts
1315 .iter()
1316 .any(|part| matches!(part, Part::Text { text } if text == "Here is the image"))
1317 );
1318 assert!(content.parts.iter().any(|part| {
1319 matches!(
1320 part,
1321 Part::InlineData { mime_type, data }
1322 if mime_type == "image/png" && data.as_slice() == image_bytes.as_slice()
1323 )
1324 }));
1325 }
1326
1327 #[test]
1328 fn gemini_function_response_payload_preserves_objects() {
1329 let value = serde_json::json!({
1330 "documents": [
1331 { "id": "pricing", "score": 0.91 }
1332 ]
1333 });
1334
1335 let payload = GeminiModel::gemini_function_response_payload(value.clone());
1336
1337 assert_eq!(payload, value);
1338 }
1339
1340 #[test]
1341 fn gemini_function_response_payload_wraps_arrays() {
1342 let payload =
1343 GeminiModel::gemini_function_response_payload(serde_json::json!([{ "id": "pricing" }]));
1344
1345 assert_eq!(payload, serde_json::json!({ "result": [{ "id": "pricing" }] }));
1346 }
1347
1348 fn convert_function_response_to_gemini_fr(
1353 frd: &adk_core::FunctionResponseData,
1354 ) -> adk_gemini::tools::FunctionResponse {
1355 let mut fr_parts = Vec::new();
1356
1357 for inline in &frd.inline_data {
1358 let encoded = crate::attachment::encode_base64(&inline.data);
1359 fr_parts.push(adk_gemini::FunctionResponsePart::InlineData {
1360 inline_data: adk_gemini::Blob {
1361 mime_type: inline.mime_type.clone(),
1362 data: encoded,
1363 },
1364 });
1365 }
1366
1367 for file in &frd.file_data {
1368 fr_parts.push(adk_gemini::FunctionResponsePart::FileData {
1369 file_data: adk_gemini::FileDataRef {
1370 mime_type: file.mime_type.clone(),
1371 file_uri: file.file_uri.clone(),
1372 },
1373 });
1374 }
1375
1376 let mut gemini_fr = adk_gemini::tools::FunctionResponse::new(
1377 &frd.name,
1378 GeminiModel::gemini_function_response_payload(frd.response.clone()),
1379 );
1380 gemini_fr.parts = fr_parts;
1381 gemini_fr
1382 }
1383
1384 #[test]
1385 fn json_only_function_response_has_no_nested_parts() {
1386 let frd = adk_core::FunctionResponseData::new("tool", serde_json::json!({"ok": true}));
1387 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1388 assert!(gemini_fr.parts.is_empty());
1389 let json = serde_json::to_string(&gemini_fr).unwrap();
1391 assert!(!json.contains("\"parts\""));
1392 }
1393
1394 #[test]
1395 fn function_response_with_inline_data_has_nested_parts() {
1396 let frd = adk_core::FunctionResponseData::with_inline_data(
1397 "chart",
1398 serde_json::json!({"status": "ok"}),
1399 vec![adk_core::InlineDataPart {
1400 mime_type: "image/png".to_string(),
1401 data: vec![0x89, 0x50, 0x4E, 0x47],
1402 }],
1403 );
1404 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1405 assert_eq!(gemini_fr.parts.len(), 1);
1406 match &gemini_fr.parts[0] {
1407 adk_gemini::FunctionResponsePart::InlineData { inline_data } => {
1408 assert_eq!(inline_data.mime_type, "image/png");
1409 let decoded = BASE64_STANDARD.decode(&inline_data.data).unwrap();
1410 assert_eq!(decoded, vec![0x89, 0x50, 0x4E, 0x47]);
1411 }
1412 other => panic!("expected InlineData, got {other:?}"),
1413 }
1414 }
1415
1416 #[test]
1417 fn function_response_with_file_data_has_nested_parts() {
1418 let frd = adk_core::FunctionResponseData::with_file_data(
1419 "doc",
1420 serde_json::json!({"ok": true}),
1421 vec![adk_core::FileDataPart {
1422 mime_type: "application/pdf".to_string(),
1423 file_uri: "gs://bucket/report.pdf".to_string(),
1424 }],
1425 );
1426 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1427 assert_eq!(gemini_fr.parts.len(), 1);
1428 match &gemini_fr.parts[0] {
1429 adk_gemini::FunctionResponsePart::FileData { file_data } => {
1430 assert_eq!(file_data.mime_type, "application/pdf");
1431 assert_eq!(file_data.file_uri, "gs://bucket/report.pdf");
1432 }
1433 other => panic!("expected FileData, got {other:?}"),
1434 }
1435 }
1436
1437 #[test]
1438 fn function_response_with_both_inline_and_file_data_ordering() {
1439 let frd = adk_core::FunctionResponseData::with_multimodal(
1440 "multi",
1441 serde_json::json!({}),
1442 vec![
1443 adk_core::InlineDataPart { mime_type: "image/png".to_string(), data: vec![1, 2] },
1444 adk_core::InlineDataPart { mime_type: "image/jpeg".to_string(), data: vec![3, 4] },
1445 ],
1446 vec![adk_core::FileDataPart {
1447 mime_type: "application/pdf".to_string(),
1448 file_uri: "gs://b/f.pdf".to_string(),
1449 }],
1450 );
1451 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1452 assert_eq!(gemini_fr.parts.len(), 3);
1454 assert!(matches!(&gemini_fr.parts[0], adk_gemini::FunctionResponsePart::InlineData { .. }));
1455 assert!(matches!(&gemini_fr.parts[1], adk_gemini::FunctionResponsePart::InlineData { .. }));
1456 assert!(matches!(&gemini_fr.parts[2], adk_gemini::FunctionResponsePart::FileData { .. }));
1457 }
1458}