1use crate::attachment;
2use crate::retry::{RetryConfig, execute_with_retry, is_retryable_model_error};
3use adk_core::{
4 CacheCapable, CitationMetadata, CitationSource, Content, ErrorCategory, ErrorComponent,
5 FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result, SchemaAdapter,
6 SchemaCache, UsageMetadata,
7};
8use adk_gemini::Gemini;
9use adk_gemini::schema_adapter::GeminiSchemaAdapter;
10use async_trait::async_trait;
11use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
12use futures::TryStreamExt;
13
14pub struct GeminiModel {
16 client: Gemini,
17 model_name: String,
18 retry_config: RetryConfig,
19 thinking_config: Option<adk_gemini::ThinkingConfig>,
25}
26
27fn gemini_error_to_adk(e: &adk_gemini::ClientError) -> adk_core::AdkError {
29 fn format_error_chain(e: &dyn std::error::Error) -> String {
30 let mut msg = e.to_string();
31 let mut source = e.source();
32 while let Some(s) = source {
33 msg.push_str(": ");
34 msg.push_str(&s.to_string());
35 source = s.source();
36 }
37 msg
38 }
39
40 let message = format_error_chain(e);
41
42 let (category, code, status_code) = if message.contains("code 429")
45 || message.contains("RESOURCE_EXHAUSTED")
46 || message.contains("rate limit")
47 {
48 (ErrorCategory::RateLimited, "model.gemini.rate_limited", Some(429u16))
49 } else if message.contains("code 503") || message.contains("UNAVAILABLE") {
50 (ErrorCategory::Unavailable, "model.gemini.unavailable", Some(503))
51 } else if message.contains("code 529") || message.contains("OVERLOADED") {
52 (ErrorCategory::Unavailable, "model.gemini.overloaded", Some(529))
53 } else if message.contains("code 408")
54 || message.contains("DEADLINE_EXCEEDED")
55 || message.contains("TIMEOUT")
56 {
57 (ErrorCategory::Timeout, "model.gemini.timeout", Some(408))
58 } else if message.contains("code 401") || message.contains("Invalid API key") {
59 (ErrorCategory::Unauthorized, "model.gemini.unauthorized", Some(401))
60 } else if message.contains("code 400") {
61 (ErrorCategory::InvalidInput, "model.gemini.bad_request", Some(400))
62 } else if message.contains("code 404") {
63 (ErrorCategory::NotFound, "model.gemini.not_found", Some(404))
64 } else if message.contains("invalid generation config") {
65 (ErrorCategory::InvalidInput, "model.gemini.invalid_config", None)
66 } else {
67 (ErrorCategory::Internal, "model.gemini.internal", None)
68 };
69
70 let mut err = adk_core::AdkError::new(ErrorComponent::Model, category, code, message)
71 .with_provider("gemini");
72 if let Some(sc) = status_code {
73 err = err.with_upstream_status(sc);
74 }
75 err
76}
77
78impl GeminiModel {
79 fn gemini_part_thought_signature(value: &serde_json::Value) -> Option<String> {
80 value.get("thoughtSignature").and_then(serde_json::Value::as_str).map(str::to_string)
81 }
82
83 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
85 let model_name = model.into();
86 let client = Gemini::with_model(api_key.into(), model_name.clone())
87 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
88
89 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
90 }
91
92 #[cfg(feature = "gemini-vertex")]
96 pub fn new_google_cloud(
97 api_key: impl Into<String>,
98 project_id: impl AsRef<str>,
99 location: impl AsRef<str>,
100 model: impl Into<String>,
101 ) -> Result<Self> {
102 let model_name = model.into();
103 let client = Gemini::with_google_cloud_model(
104 api_key.into(),
105 project_id,
106 location,
107 model_name.clone(),
108 )
109 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
110
111 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
112 }
113
114 #[cfg(feature = "gemini-vertex")]
118 pub fn new_google_cloud_service_account(
119 service_account_json: &str,
120 project_id: impl AsRef<str>,
121 location: impl AsRef<str>,
122 model: impl Into<String>,
123 ) -> Result<Self> {
124 let model_name = model.into();
125 let client = Gemini::with_google_cloud_service_account_json(
126 service_account_json,
127 project_id.as_ref(),
128 location.as_ref(),
129 model_name.clone(),
130 )
131 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
132
133 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
134 }
135
136 #[cfg(feature = "gemini-vertex")]
140 pub fn new_google_cloud_adc(
141 project_id: impl AsRef<str>,
142 location: impl AsRef<str>,
143 model: impl Into<String>,
144 ) -> Result<Self> {
145 let model_name = model.into();
146 let client = Gemini::with_google_cloud_adc_model(
147 project_id.as_ref(),
148 location.as_ref(),
149 model_name.clone(),
150 )
151 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
152
153 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
154 }
155
156 #[cfg(feature = "gemini-vertex")]
160 pub fn new_google_cloud_wif(
161 wif_json: &str,
162 project_id: impl AsRef<str>,
163 location: impl AsRef<str>,
164 model: impl Into<String>,
165 ) -> Result<Self> {
166 let model_name = model.into();
167 let client = Gemini::with_google_cloud_wif_json(
168 wif_json,
169 project_id.as_ref(),
170 location.as_ref(),
171 model_name.clone(),
172 )
173 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
174
175 Ok(Self { client, model_name, retry_config: RetryConfig::default(), thinking_config: None })
176 }
177
178 #[must_use]
180 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
181 self.retry_config = retry_config;
182 self
183 }
184
185 pub fn set_retry_config(&mut self, retry_config: RetryConfig) {
187 self.retry_config = retry_config;
188 }
189
190 pub fn retry_config(&self) -> &RetryConfig {
192 &self.retry_config
193 }
194
195 #[must_use]
219 pub fn with_thinking_config(mut self, thinking_config: adk_gemini::ThinkingConfig) -> Self {
220 self.thinking_config = Some(thinking_config);
221 self
222 }
223
224 pub fn set_thinking_config(&mut self, thinking_config: adk_gemini::ThinkingConfig) {
226 self.thinking_config = Some(thinking_config);
227 }
228
229 pub fn thinking_config(&self) -> Option<&adk_gemini::ThinkingConfig> {
231 self.thinking_config.as_ref()
232 }
233
234 fn convert_response(resp: &adk_gemini::GenerationResponse) -> Result<LlmResponse> {
235 let mut converted_parts: Vec<Part> = Vec::new();
236
237 if let Some(parts) = resp.candidates.first().and_then(|c| c.content.parts.as_ref()) {
239 for p in parts {
240 match p {
241 adk_gemini::Part::Text { text, thought, thought_signature } => {
242 if thought == &Some(true) {
243 converted_parts.push(Part::Thinking {
244 thinking: text.clone(),
245 signature: thought_signature.clone(),
246 });
247 } else {
248 converted_parts.push(Part::Text { text: text.clone() });
249 }
250 }
251 adk_gemini::Part::InlineData { inline_data } => {
252 let decoded =
253 BASE64_STANDARD.decode(&inline_data.data).map_err(|error| {
254 adk_core::AdkError::model(format!(
255 "failed to decode inline data from gemini response: {error}"
256 ))
257 })?;
258 converted_parts.push(Part::InlineData {
259 mime_type: inline_data.mime_type.clone(),
260 data: decoded,
261 });
262 }
263 adk_gemini::Part::FunctionCall { function_call, thought_signature } => {
264 converted_parts.push(Part::FunctionCall {
265 name: function_call.name.clone(),
266 args: function_call.args.clone(),
267 id: function_call.id.clone(),
268 thought_signature: thought_signature.clone(),
269 });
270 }
271 adk_gemini::Part::FunctionResponse { function_response, .. } => {
272 converted_parts.push(Part::FunctionResponse {
273 function_response: adk_core::FunctionResponseData::new(
274 function_response.name.clone(),
275 function_response
276 .response
277 .clone()
278 .unwrap_or(serde_json::Value::Null),
279 ),
280 id: None,
281 });
282 }
283 adk_gemini::Part::ToolCall { .. } | adk_gemini::Part::ExecutableCode { .. } => {
284 if let Ok(value) = serde_json::to_value(p) {
285 converted_parts.push(Part::ServerToolCall { server_tool_call: value });
286 }
287 }
288 adk_gemini::Part::ToolResponse { .. }
289 | adk_gemini::Part::CodeExecutionResult { .. } => {
290 let value = serde_json::to_value(p).unwrap_or(serde_json::Value::Null);
291 converted_parts
292 .push(Part::ServerToolResponse { server_tool_response: value });
293 }
294 adk_gemini::Part::FileData { file_data } => {
295 converted_parts.push(Part::FileData {
296 mime_type: file_data.mime_type.clone(),
297 file_uri: file_data.file_uri.clone(),
298 });
299 }
300 }
301 }
302 }
303
304 if let Some(grounding) = resp.candidates.first().and_then(|c| c.grounding_metadata.as_ref())
306 {
307 if let Some(queries) = &grounding.web_search_queries {
308 if !queries.is_empty() {
309 let search_info = format!("\n\nš **Searched:** {}", queries.join(", "));
310 converted_parts.push(Part::Text { text: search_info });
311 }
312 }
313 if let Some(chunks) = &grounding.grounding_chunks {
314 let sources: Vec<String> = chunks
315 .iter()
316 .filter_map(|c| {
317 c.web.as_ref().and_then(|w| match (&w.title, &w.uri) {
318 (Some(title), Some(uri)) => Some(format!("[{}]({})", title, uri)),
319 (Some(title), None) => Some(title.clone()),
320 (None, Some(uri)) => Some(uri.to_string()),
321 (None, None) => None,
322 })
323 })
324 .collect();
325 if !sources.is_empty() {
326 let sources_info = format!("\nš **Sources:** {}", sources.join(" | "));
327 converted_parts.push(Part::Text { text: sources_info });
328 }
329 }
330 }
331
332 let content = if converted_parts.is_empty() {
333 None
334 } else {
335 Some(Content { role: "model".to_string(), parts: converted_parts })
336 };
337
338 let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
339 prompt_token_count: u.prompt_token_count.unwrap_or(0),
340 candidates_token_count: u.candidates_token_count.unwrap_or(0),
341 total_token_count: u.total_token_count.unwrap_or(0),
342 thinking_token_count: u.thoughts_token_count,
343 cache_read_input_token_count: u.cached_content_token_count,
344 ..Default::default()
345 });
346
347 let finish_reason =
348 resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
349 adk_gemini::FinishReason::Stop => FinishReason::Stop,
350 adk_gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
351 adk_gemini::FinishReason::Safety => FinishReason::Safety,
352 adk_gemini::FinishReason::Recitation => FinishReason::Recitation,
353 _ => FinishReason::Other,
354 });
355
356 let citation_metadata =
357 resp.candidates.first().and_then(|c| c.citation_metadata.as_ref()).map(|meta| {
358 CitationMetadata {
359 citation_sources: meta
360 .citation_sources
361 .iter()
362 .map(|source| CitationSource {
363 uri: source.uri.clone(),
364 title: source.title.clone(),
365 start_index: source.start_index,
366 end_index: source.end_index,
367 license: source.license.clone(),
368 publication_date: source.publication_date.map(|d| d.to_string()),
369 })
370 .collect(),
371 }
372 });
373
374 let provider_metadata = resp
377 .candidates
378 .first()
379 .and_then(|c| c.grounding_metadata.as_ref())
380 .and_then(|g| serde_json::to_value(g).ok());
381
382 Ok(LlmResponse {
383 content,
384 usage_metadata,
385 finish_reason,
386 citation_metadata,
387 partial: false,
388 turn_complete: true,
389 interrupted: false,
390 error_code: None,
391 error_message: None,
392 provider_metadata,
393 })
394 }
395
396 fn gemini_function_response_payload(response: serde_json::Value) -> serde_json::Value {
397 match response {
398 serde_json::Value::Object(_) => response,
400 other => serde_json::json!({ "result": other }),
401 }
402 }
403
404 fn merge_object_value(
405 target: &mut serde_json::Map<String, serde_json::Value>,
406 value: serde_json::Value,
407 ) {
408 if let serde_json::Value::Object(object) = value {
409 for (key, value) in object {
410 target.insert(key, value);
411 }
412 }
413 }
414
415 fn build_gemini_tools(
416 tools: &std::collections::HashMap<String, serde_json::Value>,
417 adapter: &dyn SchemaAdapter,
418 cache: &SchemaCache,
419 ) -> Result<(Vec<adk_gemini::Tool>, adk_gemini::ToolConfig)> {
420 let mut gemini_tools = Vec::new();
421 let mut function_declarations = Vec::new();
422 let mut has_provider_native_tools = false;
423 let mut tool_config_json = serde_json::Map::new();
424
425 for (name, tool_decl) in tools {
426 if let Some(provider_tool) = tool_decl.get("x-adk-gemini-tool") {
427 let tool = serde_json::from_value::<adk_gemini::Tool>(provider_tool.clone())
428 .map_err(|error| {
429 adk_core::AdkError::model(format!(
430 "failed to deserialize Gemini native tool '{name}': {error}"
431 ))
432 })?;
433 has_provider_native_tools = true;
434 gemini_tools.push(tool);
435 } else {
436 let normalized_name = adapter.normalize_tool_name(name);
438
439 let schema =
442 tool_decl.get("parameters").cloned().unwrap_or_else(|| adapter.empty_schema());
443 let normalized_schema = cache.get_or_normalize(&schema, adapter);
444
445 let description =
447 tool_decl.get("description").and_then(|v| v.as_str()).unwrap_or("").to_string();
448
449 let mut func_decl_json = serde_json::json!({
450 "name": normalized_name.as_ref(),
451 "description": description,
452 "parameters": normalized_schema,
453 });
454
455 if let Some(response) = tool_decl.get("response") {
457 func_decl_json["response"] = response.clone();
458 }
459
460 if let Some(behavior) = tool_decl.get("behavior") {
462 func_decl_json["behavior"] = behavior.clone();
463 }
464
465 let func_decl =
466 serde_json::from_value::<adk_gemini::FunctionDeclaration>(func_decl_json)
467 .map_err(|error| {
468 adk_core::AdkError::model(format!(
469 "failed to build Gemini function declaration for '{name}': {error}"
470 ))
471 })?;
472 function_declarations.push(func_decl);
473 }
474
475 if let Some(tool_config) = tool_decl.get("x-adk-gemini-tool-config") {
476 Self::merge_object_value(&mut tool_config_json, tool_config.clone());
477 }
478 }
479
480 let has_function_declarations = !function_declarations.is_empty();
481 if has_function_declarations {
482 gemini_tools.push(adk_gemini::Tool::with_functions(function_declarations));
483 }
484
485 if has_provider_native_tools {
486 tool_config_json.insert(
487 "includeServerSideToolInvocations".to_string(),
488 serde_json::Value::Bool(true),
489 );
490 }
491
492 let tool_config = if tool_config_json.is_empty() {
493 adk_gemini::ToolConfig::default()
494 } else {
495 serde_json::from_value::<adk_gemini::ToolConfig>(serde_json::Value::Object(
496 tool_config_json,
497 ))
498 .map_err(|error| {
499 adk_core::AdkError::model(format!(
500 "failed to deserialize Gemini tool configuration: {error}"
501 ))
502 })?
503 };
504
505 Ok((gemini_tools, tool_config))
506 }
507
508 fn stream_chunks_from_response(
509 mut response: LlmResponse,
510 saw_partial_chunk: bool,
511 ) -> (Vec<LlmResponse>, bool) {
512 let is_final = response.finish_reason.is_some();
513
514 if !is_final {
515 response.partial = true;
516 response.turn_complete = false;
517 return (vec![response], true);
518 }
519
520 response.partial = false;
521 response.turn_complete = true;
522
523 if saw_partial_chunk {
524 return (vec![response], true);
525 }
526
527 let synthetic_partial = LlmResponse {
528 content: None,
529 usage_metadata: None,
530 finish_reason: None,
531 citation_metadata: None,
532 partial: true,
533 turn_complete: false,
534 interrupted: false,
535 error_code: None,
536 error_message: None,
537 provider_metadata: None,
538 };
539
540 (vec![synthetic_partial, response], true)
541 }
542
543 async fn generate_content_internal(
544 &self,
545 req: LlmRequest,
546 stream: bool,
547 ) -> Result<LlmResponseStream> {
548 let mut builder = self.client.generate_content();
549
550 let mut fn_call_signatures: std::collections::HashMap<String, String> =
555 std::collections::HashMap::new();
556 for content in &req.contents {
557 if content.role == "model" {
558 for part in &content.parts {
559 if let Part::FunctionCall { name, thought_signature: Some(sig), .. } = part {
560 fn_call_signatures.insert(name.clone(), sig.clone());
561 }
562 }
563 }
564 }
565
566 for content in &req.contents {
568 match content.role.as_str() {
569 "user" => {
570 let mut gemini_parts = Vec::new();
572 for part in &content.parts {
573 match part {
574 Part::Text { text } => {
575 gemini_parts.push(adk_gemini::Part::Text {
576 text: text.clone(),
577 thought: None,
578 thought_signature: None,
579 });
580 }
581 Part::Thinking { thinking, signature } => {
582 gemini_parts.push(adk_gemini::Part::Text {
583 text: thinking.clone(),
584 thought: Some(true),
585 thought_signature: signature.clone(),
586 });
587 }
588 Part::InlineData { data, mime_type } => {
589 let encoded = attachment::encode_base64(data);
590 gemini_parts.push(adk_gemini::Part::InlineData {
591 inline_data: adk_gemini::Blob {
592 mime_type: mime_type.clone(),
593 data: encoded,
594 },
595 });
596 }
597 Part::FileData { mime_type, file_uri } => {
598 gemini_parts.push(adk_gemini::Part::Text {
599 text: attachment::file_attachment_to_text(mime_type, file_uri),
600 thought: None,
601 thought_signature: None,
602 });
603 }
604 _ => {}
605 }
606 }
607 if !gemini_parts.is_empty() {
608 let user_content = adk_gemini::Content {
609 role: Some(adk_gemini::Role::User),
610 parts: Some(gemini_parts),
611 };
612 builder = builder.with_message(adk_gemini::Message {
613 content: user_content,
614 role: adk_gemini::Role::User,
615 });
616 }
617 }
618 "model" => {
619 let mut gemini_parts = Vec::new();
621 for part in &content.parts {
622 match part {
623 Part::Text { text } => {
624 gemini_parts.push(adk_gemini::Part::Text {
625 text: text.clone(),
626 thought: None,
627 thought_signature: None,
628 });
629 }
630 Part::Thinking { thinking, signature } => {
631 gemini_parts.push(adk_gemini::Part::Text {
632 text: thinking.clone(),
633 thought: Some(true),
634 thought_signature: signature.clone(),
635 });
636 }
637 Part::FunctionCall { name, args, thought_signature, id } => {
638 gemini_parts.push(adk_gemini::Part::FunctionCall {
639 function_call: adk_gemini::FunctionCall {
640 name: name.clone(),
641 args: args.clone(),
642 id: id.clone(),
643 thought_signature: None,
644 },
645 thought_signature: thought_signature.clone(),
646 });
647 }
648 Part::ServerToolCall { server_tool_call } => {
649 if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
650 server_tool_call.clone(),
651 ) {
652 match native_part {
653 adk_gemini::Part::ToolCall { .. }
654 | adk_gemini::Part::ExecutableCode { .. } => {
655 gemini_parts.push(native_part);
656 continue;
657 }
658 _ => {}
659 }
660 }
661
662 gemini_parts.push(adk_gemini::Part::ToolCall {
663 tool_call: server_tool_call.clone(),
664 thought_signature: Self::gemini_part_thought_signature(
665 server_tool_call,
666 ),
667 });
668 }
669 Part::ServerToolResponse { server_tool_response } => {
670 if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
671 server_tool_response.clone(),
672 ) {
673 match native_part {
674 adk_gemini::Part::ToolResponse { .. }
675 | adk_gemini::Part::CodeExecutionResult { .. } => {
676 gemini_parts.push(native_part);
677 continue;
678 }
679 _ => {}
680 }
681 }
682
683 gemini_parts.push(adk_gemini::Part::ToolResponse {
684 tool_response: server_tool_response.clone(),
685 thought_signature: Self::gemini_part_thought_signature(
686 server_tool_response,
687 ),
688 });
689 }
690 _ => {}
691 }
692 }
693 if !gemini_parts.is_empty() {
694 let model_content = adk_gemini::Content {
695 role: Some(adk_gemini::Role::Model),
696 parts: Some(gemini_parts),
697 };
698 builder = builder.with_message(adk_gemini::Message {
699 content: model_content,
700 role: adk_gemini::Role::Model,
701 });
702 }
703 }
704 "function" => {
705 let mut gemini_parts = Vec::new();
708 for part in &content.parts {
709 if let Part::FunctionResponse { function_response, .. } = part {
710 let sig = fn_call_signatures.get(&function_response.name).cloned();
711
712 let mut fr_parts = Vec::new();
714 for inline in &function_response.inline_data {
715 let encoded = attachment::encode_base64(&inline.data);
716 fr_parts.push(adk_gemini::FunctionResponsePart::InlineData {
717 inline_data: adk_gemini::Blob {
718 mime_type: inline.mime_type.clone(),
719 data: encoded,
720 },
721 });
722 }
723 for file in &function_response.file_data {
724 fr_parts.push(adk_gemini::FunctionResponsePart::FileData {
725 file_data: adk_gemini::FileDataRef {
726 mime_type: file.mime_type.clone(),
727 file_uri: file.file_uri.clone(),
728 },
729 });
730 }
731
732 let mut gemini_fr = adk_gemini::tools::FunctionResponse::new(
733 &function_response.name,
734 Self::gemini_function_response_payload(
735 function_response.response.clone(),
736 ),
737 );
738 gemini_fr.parts = fr_parts;
739
740 gemini_parts.push(adk_gemini::Part::FunctionResponse {
741 function_response: gemini_fr,
742 thought_signature: sig,
743 });
744 }
745 }
746 if !gemini_parts.is_empty() {
747 let fn_content = adk_gemini::Content {
748 role: Some(adk_gemini::Role::User),
749 parts: Some(gemini_parts),
750 };
751 builder = builder.with_message(adk_gemini::Message {
752 content: fn_content,
753 role: adk_gemini::Role::User,
754 });
755 }
756 }
757 _ => {}
758 }
759 }
760
761 if let Some(config) = req.config {
763 let has_schema = config.response_schema.is_some();
764 let gen_config = adk_gemini::GenerationConfig {
765 temperature: config.temperature,
766 top_p: config.top_p,
767 top_k: config.top_k,
768 max_output_tokens: config.max_output_tokens,
769 response_schema: config.response_schema,
770 response_mime_type: if has_schema {
771 Some("application/json".to_string())
772 } else {
773 None
774 },
775 thinking_config: self.thinking_config.clone(),
776 ..Default::default()
777 };
778 builder = builder.with_generation_config(gen_config);
779
780 if let Some(ref name) = config.cached_content {
782 let handle = self.client.get_cached_content(name);
783 builder = builder.with_cached_content(&handle);
784 }
785 } else if self.thinking_config.is_some() {
786 let gen_config = adk_gemini::GenerationConfig {
789 thinking_config: self.thinking_config.clone(),
790 ..Default::default()
791 };
792 builder = builder.with_generation_config(gen_config);
793 }
794
795 if !req.tools.is_empty() {
797 let adapter = self.schema_adapter();
798 use std::sync::LazyLock;
799 static SCHEMA_CACHE: LazyLock<SchemaCache> = LazyLock::new(SchemaCache::new);
800 let (gemini_tools, tool_config) =
801 Self::build_gemini_tools(&req.tools, adapter, &SCHEMA_CACHE)?;
802 for tool in gemini_tools {
803 builder = builder.with_tool(tool);
804 }
805 if tool_config != adk_gemini::ToolConfig::default() {
806 builder = builder.with_tool_config(tool_config);
807 }
808 }
809
810 if stream {
811 adk_telemetry::debug!("Executing streaming request");
812 let response_stream = builder.execute_stream().await.map_err(|e| {
813 adk_telemetry::error!(error = %e, "Model request failed");
814 gemini_error_to_adk(&e)
815 })?;
816
817 let mapped_stream = async_stream::stream! {
818 let mut stream = response_stream;
819 let mut saw_partial_chunk = false;
820 while let Some(result) = stream.try_next().await.transpose() {
821 match result {
822 Ok(resp) => {
823 match Self::convert_response(&resp) {
824 Ok(llm_resp) => {
825 let (chunks, next_saw_partial) =
826 Self::stream_chunks_from_response(llm_resp, saw_partial_chunk);
827 saw_partial_chunk = next_saw_partial;
828 for chunk in chunks {
829 yield Ok(chunk);
830 }
831 }
832 Err(e) => {
833 adk_telemetry::error!(error = %e, "Failed to convert response");
834 yield Err(e);
835 }
836 }
837 }
838 Err(e) => {
839 adk_telemetry::error!(error = %e, "Stream error");
840 yield Err(gemini_error_to_adk(&e));
841 }
842 }
843 }
844 };
845
846 Ok(Box::pin(mapped_stream))
847 } else {
848 adk_telemetry::debug!("Executing blocking request");
849 let response = builder.execute().await.map_err(|e| {
850 adk_telemetry::error!(error = %e, "Model request failed");
851 gemini_error_to_adk(&e)
852 })?;
853
854 let llm_response = Self::convert_response(&response)?;
855
856 let stream = async_stream::stream! {
857 yield Ok(llm_response);
858 };
859
860 Ok(Box::pin(stream))
861 }
862 }
863
864 pub async fn create_cached_content(
869 &self,
870 system_instruction: &str,
871 tools: &std::collections::HashMap<String, serde_json::Value>,
872 ttl_seconds: u32,
873 ) -> Result<String> {
874 let mut cache_builder = self
875 .client
876 .create_cache()
877 .with_system_instruction(system_instruction)
878 .with_ttl(std::time::Duration::from_secs(u64::from(ttl_seconds)));
879
880 let adapter = self.schema_adapter();
881 use std::sync::LazyLock;
882 static SCHEMA_CACHE: LazyLock<SchemaCache> = LazyLock::new(SchemaCache::new);
883 let (gemini_tools, tool_config) = Self::build_gemini_tools(tools, adapter, &SCHEMA_CACHE)?;
884 if !gemini_tools.is_empty() {
885 cache_builder = cache_builder.with_tools(gemini_tools);
886 }
887 if tool_config != adk_gemini::ToolConfig::default() {
888 cache_builder = cache_builder.with_tool_config(tool_config);
889 }
890
891 let handle = cache_builder
892 .execute()
893 .await
894 .map_err(|e| adk_core::AdkError::model(format!("cache creation failed: {e}")))?;
895
896 Ok(handle.name().to_string())
897 }
898
899 pub async fn delete_cached_content(&self, name: &str) -> Result<()> {
901 let handle = self.client.get_cached_content(name);
902 handle
903 .delete()
904 .await
905 .map_err(|(_, e)| adk_core::AdkError::model(format!("cache deletion failed: {e}")))?;
906 Ok(())
907 }
908}
909
910#[async_trait]
911impl Llm for GeminiModel {
912 fn name(&self) -> &str {
913 &self.model_name
914 }
915
916 fn schema_adapter(&self) -> &dyn SchemaAdapter {
917 use std::sync::LazyLock;
918 static ADAPTER: LazyLock<GeminiSchemaAdapter> = LazyLock::new(GeminiSchemaAdapter::new);
919 &*ADAPTER
920 }
921
922 #[adk_telemetry::instrument(
923 name = "call_llm",
924 skip(self, req),
925 fields(
926 model.name = %self.model_name,
927 stream = %stream,
928 request.contents_count = %req.contents.len(),
929 request.tools_count = %req.tools.len()
930 )
931 )]
932 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
933 adk_telemetry::info!("Generating content");
934 let usage_span = adk_telemetry::llm_generate_span("gemini", &self.model_name, stream);
935 let result = execute_with_retry(&self.retry_config, is_retryable_model_error, || {
938 self.generate_content_internal(req.clone(), stream)
939 })
940 .await?;
941 Ok(crate::usage_tracking::with_usage_tracking(result, usage_span))
942 }
943}
944
945#[cfg(test)]
946mod native_tool_tests {
947 use super::*;
948
949 fn test_adapter() -> GeminiSchemaAdapter {
950 GeminiSchemaAdapter::new()
951 }
952
953 fn test_cache() -> SchemaCache {
954 SchemaCache::new()
955 }
956
957 #[test]
958 fn test_build_gemini_tools_supports_native_tool_metadata() {
959 let mut tools = std::collections::HashMap::new();
960 tools.insert(
961 "google_search".to_string(),
962 serde_json::json!({
963 "x-adk-gemini-tool": {
964 "google_search": {}
965 }
966 }),
967 );
968 tools.insert(
969 "lookup_weather".to_string(),
970 serde_json::json!({
971 "name": "lookup_weather",
972 "description": "lookup weather",
973 "parameters": {
974 "type": "object",
975 "properties": {
976 "city": { "type": "string" }
977 }
978 }
979 }),
980 );
981
982 let adapter = test_adapter();
983 let cache = test_cache();
984 let (gemini_tools, tool_config) = GeminiModel::build_gemini_tools(&tools, &adapter, &cache)
985 .expect("tool conversion should succeed");
986
987 assert_eq!(gemini_tools.len(), 2);
988 assert_eq!(tool_config.include_server_side_tool_invocations, Some(true));
989 }
990
991 #[test]
992 fn test_build_gemini_tools_sets_flag_for_builtin_only() {
993 let mut tools = std::collections::HashMap::new();
994 tools.insert(
995 "google_search".to_string(),
996 serde_json::json!({
997 "x-adk-gemini-tool": {
998 "google_search": {}
999 }
1000 }),
1001 );
1002
1003 let adapter = test_adapter();
1004 let cache = test_cache();
1005 let (_gemini_tools, tool_config) =
1006 GeminiModel::build_gemini_tools(&tools, &adapter, &cache)
1007 .expect("tool conversion should succeed");
1008
1009 assert_eq!(
1010 tool_config.include_server_side_tool_invocations,
1011 Some(true),
1012 "includeServerSideToolInvocations should be set even with only built-in tools"
1013 );
1014 }
1015
1016 #[test]
1017 fn test_build_gemini_tools_no_flag_for_function_only() {
1018 let mut tools = std::collections::HashMap::new();
1019 tools.insert(
1020 "lookup_weather".to_string(),
1021 serde_json::json!({
1022 "name": "lookup_weather",
1023 "description": "lookup weather",
1024 "parameters": {
1025 "type": "object",
1026 "properties": {
1027 "city": { "type": "string" }
1028 }
1029 }
1030 }),
1031 );
1032
1033 let adapter = test_adapter();
1034 let cache = test_cache();
1035 let (_gemini_tools, tool_config) =
1036 GeminiModel::build_gemini_tools(&tools, &adapter, &cache)
1037 .expect("tool conversion should succeed");
1038
1039 assert_eq!(
1040 tool_config.include_server_side_tool_invocations, None,
1041 "includeServerSideToolInvocations should NOT be set for function-only tools"
1042 );
1043 }
1044
1045 #[test]
1046 fn test_build_gemini_tools_merges_native_tool_config() {
1047 let mut tools = std::collections::HashMap::new();
1048 tools.insert(
1049 "google_maps".to_string(),
1050 serde_json::json!({
1051 "x-adk-gemini-tool": {
1052 "google_maps": {
1053 "enable_widget": true
1054 }
1055 },
1056 "x-adk-gemini-tool-config": {
1057 "retrievalConfig": {
1058 "latLng": {
1059 "latitude": 1.23,
1060 "longitude": 4.56
1061 }
1062 }
1063 }
1064 }),
1065 );
1066
1067 let adapter = test_adapter();
1068 let cache = test_cache();
1069 let (_gemini_tools, tool_config) =
1070 GeminiModel::build_gemini_tools(&tools, &adapter, &cache)
1071 .expect("tool conversion should succeed");
1072
1073 assert_eq!(
1074 tool_config.retrieval_config,
1075 Some(serde_json::json!({
1076 "latLng": {
1077 "latitude": 1.23,
1078 "longitude": 4.56
1079 }
1080 }))
1081 );
1082 }
1083}
1084
1085#[async_trait]
1086impl CacheCapable for GeminiModel {
1087 async fn create_cache(
1088 &self,
1089 system_instruction: &str,
1090 tools: &std::collections::HashMap<String, serde_json::Value>,
1091 ttl_seconds: u32,
1092 ) -> Result<String> {
1093 self.create_cached_content(system_instruction, tools, ttl_seconds).await
1094 }
1095
1096 async fn delete_cache(&self, name: &str) -> Result<()> {
1097 self.delete_cached_content(name).await
1098 }
1099}
1100
1101#[cfg(test)]
1102mod tests {
1103 use super::*;
1104 use adk_core::AdkError;
1105 use std::{
1106 sync::{
1107 Arc,
1108 atomic::{AtomicU32, Ordering},
1109 },
1110 time::Duration,
1111 };
1112
1113 #[test]
1114 fn constructor_is_backward_compatible_and_sync() {
1115 fn accepts_sync_constructor<F>(_f: F)
1116 where
1117 F: Fn(&str, &str) -> Result<GeminiModel>,
1118 {
1119 }
1120
1121 accepts_sync_constructor(|api_key, model| GeminiModel::new(api_key, model));
1122 }
1123
1124 #[test]
1125 fn stream_chunks_from_response_injects_partial_before_lone_final_chunk() {
1126 let response = LlmResponse {
1127 content: Some(Content::new("model").with_text("hello")),
1128 usage_metadata: None,
1129 finish_reason: Some(FinishReason::Stop),
1130 citation_metadata: None,
1131 partial: false,
1132 turn_complete: true,
1133 interrupted: false,
1134 error_code: None,
1135 error_message: None,
1136 provider_metadata: None,
1137 };
1138
1139 let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, false);
1140 assert!(saw_partial);
1141 assert_eq!(chunks.len(), 2);
1142 assert!(chunks[0].partial);
1143 assert!(!chunks[0].turn_complete);
1144 assert!(chunks[0].content.is_none());
1145 assert!(!chunks[1].partial);
1146 assert!(chunks[1].turn_complete);
1147 }
1148
1149 #[test]
1150 fn stream_chunks_from_response_keeps_final_only_when_partial_already_seen() {
1151 let response = LlmResponse {
1152 content: Some(Content::new("model").with_text("done")),
1153 usage_metadata: None,
1154 finish_reason: Some(FinishReason::Stop),
1155 citation_metadata: None,
1156 partial: false,
1157 turn_complete: true,
1158 interrupted: false,
1159 error_code: None,
1160 error_message: None,
1161 provider_metadata: None,
1162 };
1163
1164 let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, true);
1165 assert!(saw_partial);
1166 assert_eq!(chunks.len(), 1);
1167 assert!(!chunks[0].partial);
1168 assert!(chunks[0].turn_complete);
1169 }
1170
1171 #[tokio::test]
1172 async fn execute_with_retry_retries_retryable_errors() {
1173 let retry_config = RetryConfig::default()
1174 .with_max_retries(2)
1175 .with_initial_delay(Duration::from_millis(0))
1176 .with_max_delay(Duration::from_millis(0));
1177 let attempts = Arc::new(AtomicU32::new(0));
1178
1179 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
1180 let attempts = Arc::clone(&attempts);
1181 async move {
1182 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
1183 if attempt < 2 {
1184 return Err(AdkError::model("code 429 RESOURCE_EXHAUSTED"));
1185 }
1186 Ok("ok")
1187 }
1188 })
1189 .await
1190 .expect("retry should eventually succeed");
1191
1192 assert_eq!(result, "ok");
1193 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1194 }
1195
1196 #[tokio::test]
1197 async fn execute_with_retry_does_not_retry_non_retryable_errors() {
1198 let retry_config = RetryConfig::default()
1199 .with_max_retries(3)
1200 .with_initial_delay(Duration::from_millis(0))
1201 .with_max_delay(Duration::from_millis(0));
1202 let attempts = Arc::new(AtomicU32::new(0));
1203
1204 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1205 let attempts = Arc::clone(&attempts);
1206 async move {
1207 attempts.fetch_add(1, Ordering::SeqCst);
1208 Err::<(), _>(AdkError::model("code 400 invalid request"))
1209 }
1210 })
1211 .await
1212 .expect_err("non-retryable error should return immediately");
1213
1214 assert!(error.is_model());
1215 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1216 }
1217
1218 #[tokio::test]
1219 async fn execute_with_retry_respects_disabled_config() {
1220 let retry_config = RetryConfig::disabled().with_max_retries(10);
1221 let attempts = Arc::new(AtomicU32::new(0));
1222
1223 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1224 let attempts = Arc::clone(&attempts);
1225 async move {
1226 attempts.fetch_add(1, Ordering::SeqCst);
1227 Err::<(), _>(AdkError::model("code 429 RESOURCE_EXHAUSTED"))
1228 }
1229 })
1230 .await
1231 .expect_err("disabled retries should return first error");
1232
1233 assert!(error.is_model());
1234 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1235 }
1236
1237 #[test]
1238 fn convert_response_preserves_citation_metadata() {
1239 let response = adk_gemini::GenerationResponse {
1240 candidates: vec![adk_gemini::Candidate {
1241 content: adk_gemini::Content {
1242 role: Some(adk_gemini::Role::Model),
1243 parts: Some(vec![adk_gemini::Part::Text {
1244 text: "hello world".to_string(),
1245 thought: None,
1246 thought_signature: None,
1247 }]),
1248 },
1249 safety_ratings: None,
1250 citation_metadata: Some(adk_gemini::CitationMetadata {
1251 citation_sources: vec![adk_gemini::CitationSource {
1252 uri: Some("https://example.com".to_string()),
1253 title: Some("Example".to_string()),
1254 start_index: Some(0),
1255 end_index: Some(5),
1256 license: Some("CC-BY".to_string()),
1257 publication_date: None,
1258 }],
1259 }),
1260 grounding_metadata: None,
1261 finish_reason: Some(adk_gemini::FinishReason::Stop),
1262 index: Some(0),
1263 }],
1264 prompt_feedback: None,
1265 usage_metadata: None,
1266 model_version: None,
1267 response_id: None,
1268 };
1269
1270 let converted =
1271 GeminiModel::convert_response(&response).expect("conversion should succeed");
1272 let metadata = converted.citation_metadata.expect("citation metadata should be mapped");
1273 assert_eq!(metadata.citation_sources.len(), 1);
1274 assert_eq!(metadata.citation_sources[0].uri.as_deref(), Some("https://example.com"));
1275 assert_eq!(metadata.citation_sources[0].start_index, Some(0));
1276 assert_eq!(metadata.citation_sources[0].end_index, Some(5));
1277 }
1278
1279 #[test]
1280 fn convert_response_handles_inline_data_from_model() {
1281 let image_bytes = vec![0x89, 0x50, 0x4E, 0x47];
1282 let encoded = crate::attachment::encode_base64(&image_bytes);
1283
1284 let response = adk_gemini::GenerationResponse {
1285 candidates: vec![adk_gemini::Candidate {
1286 content: adk_gemini::Content {
1287 role: Some(adk_gemini::Role::Model),
1288 parts: Some(vec![
1289 adk_gemini::Part::Text {
1290 text: "Here is the image".to_string(),
1291 thought: None,
1292 thought_signature: None,
1293 },
1294 adk_gemini::Part::InlineData {
1295 inline_data: adk_gemini::Blob {
1296 mime_type: "image/png".to_string(),
1297 data: encoded,
1298 },
1299 },
1300 ]),
1301 },
1302 safety_ratings: None,
1303 citation_metadata: None,
1304 grounding_metadata: None,
1305 finish_reason: Some(adk_gemini::FinishReason::Stop),
1306 index: Some(0),
1307 }],
1308 prompt_feedback: None,
1309 usage_metadata: None,
1310 model_version: None,
1311 response_id: None,
1312 };
1313
1314 let converted =
1315 GeminiModel::convert_response(&response).expect("conversion should succeed");
1316 let content = converted.content.expect("should have content");
1317 assert!(
1318 content
1319 .parts
1320 .iter()
1321 .any(|part| matches!(part, Part::Text { text } if text == "Here is the image"))
1322 );
1323 assert!(content.parts.iter().any(|part| {
1324 matches!(
1325 part,
1326 Part::InlineData { mime_type, data }
1327 if mime_type == "image/png" && data.as_slice() == image_bytes.as_slice()
1328 )
1329 }));
1330 }
1331
1332 #[test]
1333 fn gemini_function_response_payload_preserves_objects() {
1334 let value = serde_json::json!({
1335 "documents": [
1336 { "id": "pricing", "score": 0.91 }
1337 ]
1338 });
1339
1340 let payload = GeminiModel::gemini_function_response_payload(value.clone());
1341
1342 assert_eq!(payload, value);
1343 }
1344
1345 #[test]
1346 fn gemini_function_response_payload_wraps_arrays() {
1347 let payload =
1348 GeminiModel::gemini_function_response_payload(serde_json::json!([{ "id": "pricing" }]));
1349
1350 assert_eq!(payload, serde_json::json!({ "result": [{ "id": "pricing" }] }));
1351 }
1352
1353 fn convert_function_response_to_gemini_fr(
1358 frd: &adk_core::FunctionResponseData,
1359 ) -> adk_gemini::tools::FunctionResponse {
1360 let mut fr_parts = Vec::new();
1361
1362 for inline in &frd.inline_data {
1363 let encoded = crate::attachment::encode_base64(&inline.data);
1364 fr_parts.push(adk_gemini::FunctionResponsePart::InlineData {
1365 inline_data: adk_gemini::Blob {
1366 mime_type: inline.mime_type.clone(),
1367 data: encoded,
1368 },
1369 });
1370 }
1371
1372 for file in &frd.file_data {
1373 fr_parts.push(adk_gemini::FunctionResponsePart::FileData {
1374 file_data: adk_gemini::FileDataRef {
1375 mime_type: file.mime_type.clone(),
1376 file_uri: file.file_uri.clone(),
1377 },
1378 });
1379 }
1380
1381 let mut gemini_fr = adk_gemini::tools::FunctionResponse::new(
1382 &frd.name,
1383 GeminiModel::gemini_function_response_payload(frd.response.clone()),
1384 );
1385 gemini_fr.parts = fr_parts;
1386 gemini_fr
1387 }
1388
1389 #[test]
1390 fn json_only_function_response_has_no_nested_parts() {
1391 let frd = adk_core::FunctionResponseData::new("tool", serde_json::json!({"ok": true}));
1392 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1393 assert!(gemini_fr.parts.is_empty());
1394 let json = serde_json::to_string(&gemini_fr).unwrap();
1396 assert!(!json.contains("\"parts\""));
1397 }
1398
1399 #[test]
1400 fn function_response_with_inline_data_has_nested_parts() {
1401 let frd = adk_core::FunctionResponseData::with_inline_data(
1402 "chart",
1403 serde_json::json!({"status": "ok"}),
1404 vec![adk_core::InlineDataPart {
1405 mime_type: "image/png".to_string(),
1406 data: vec![0x89, 0x50, 0x4E, 0x47],
1407 }],
1408 );
1409 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1410 assert_eq!(gemini_fr.parts.len(), 1);
1411 match &gemini_fr.parts[0] {
1412 adk_gemini::FunctionResponsePart::InlineData { inline_data } => {
1413 assert_eq!(inline_data.mime_type, "image/png");
1414 let decoded = BASE64_STANDARD.decode(&inline_data.data).unwrap();
1415 assert_eq!(decoded, vec![0x89, 0x50, 0x4E, 0x47]);
1416 }
1417 other => panic!("expected InlineData, got {other:?}"),
1418 }
1419 }
1420
1421 #[test]
1422 fn function_response_with_file_data_has_nested_parts() {
1423 let frd = adk_core::FunctionResponseData::with_file_data(
1424 "doc",
1425 serde_json::json!({"ok": true}),
1426 vec![adk_core::FileDataPart {
1427 mime_type: "application/pdf".to_string(),
1428 file_uri: "gs://bucket/report.pdf".to_string(),
1429 }],
1430 );
1431 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1432 assert_eq!(gemini_fr.parts.len(), 1);
1433 match &gemini_fr.parts[0] {
1434 adk_gemini::FunctionResponsePart::FileData { file_data } => {
1435 assert_eq!(file_data.mime_type, "application/pdf");
1436 assert_eq!(file_data.file_uri, "gs://bucket/report.pdf");
1437 }
1438 other => panic!("expected FileData, got {other:?}"),
1439 }
1440 }
1441
1442 #[test]
1443 fn function_response_with_both_inline_and_file_data_ordering() {
1444 let frd = adk_core::FunctionResponseData::with_multimodal(
1445 "multi",
1446 serde_json::json!({}),
1447 vec![
1448 adk_core::InlineDataPart { mime_type: "image/png".to_string(), data: vec![1, 2] },
1449 adk_core::InlineDataPart { mime_type: "image/jpeg".to_string(), data: vec![3, 4] },
1450 ],
1451 vec![adk_core::FileDataPart {
1452 mime_type: "application/pdf".to_string(),
1453 file_uri: "gs://b/f.pdf".to_string(),
1454 }],
1455 );
1456 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1457 assert_eq!(gemini_fr.parts.len(), 3);
1459 assert!(matches!(&gemini_fr.parts[0], adk_gemini::FunctionResponsePart::InlineData { .. }));
1460 assert!(matches!(&gemini_fr.parts[1], adk_gemini::FunctionResponsePart::InlineData { .. }));
1461 assert!(matches!(&gemini_fr.parts[2], adk_gemini::FunctionResponsePart::FileData { .. }));
1462 }
1463}