1use crate::attachment;
2use crate::retry::{RetryConfig, execute_with_retry, is_retryable_model_error};
3use adk_core::{
4 CacheCapable, CitationMetadata, CitationSource, Content, ErrorCategory, ErrorComponent,
5 FinishReason, Llm, LlmRequest, LlmResponse, LlmResponseStream, Part, Result, UsageMetadata,
6};
7use adk_gemini::Gemini;
8use async_trait::async_trait;
9use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
10use futures::TryStreamExt;
11
12pub struct GeminiModel {
13 client: Gemini,
14 model_name: String,
15 retry_config: RetryConfig,
16}
17
18fn gemini_error_to_adk(e: &adk_gemini::ClientError) -> adk_core::AdkError {
20 fn format_error_chain(e: &dyn std::error::Error) -> String {
21 let mut msg = e.to_string();
22 let mut source = e.source();
23 while let Some(s) = source {
24 msg.push_str(": ");
25 msg.push_str(&s.to_string());
26 source = s.source();
27 }
28 msg
29 }
30
31 let message = format_error_chain(e);
32
33 let (category, code, status_code) = if message.contains("code 429")
36 || message.contains("RESOURCE_EXHAUSTED")
37 || message.contains("rate limit")
38 {
39 (ErrorCategory::RateLimited, "model.gemini.rate_limited", Some(429u16))
40 } else if message.contains("code 503") || message.contains("UNAVAILABLE") {
41 (ErrorCategory::Unavailable, "model.gemini.unavailable", Some(503))
42 } else if message.contains("code 529") || message.contains("OVERLOADED") {
43 (ErrorCategory::Unavailable, "model.gemini.overloaded", Some(529))
44 } else if message.contains("code 408")
45 || message.contains("DEADLINE_EXCEEDED")
46 || message.contains("TIMEOUT")
47 {
48 (ErrorCategory::Timeout, "model.gemini.timeout", Some(408))
49 } else if message.contains("code 401") || message.contains("Invalid API key") {
50 (ErrorCategory::Unauthorized, "model.gemini.unauthorized", Some(401))
51 } else if message.contains("code 400") {
52 (ErrorCategory::InvalidInput, "model.gemini.bad_request", Some(400))
53 } else if message.contains("code 404") {
54 (ErrorCategory::NotFound, "model.gemini.not_found", Some(404))
55 } else if message.contains("invalid generation config") {
56 (ErrorCategory::InvalidInput, "model.gemini.invalid_config", None)
57 } else {
58 (ErrorCategory::Internal, "model.gemini.internal", None)
59 };
60
61 let mut err = adk_core::AdkError::new(ErrorComponent::Model, category, code, message)
62 .with_provider("gemini");
63 if let Some(sc) = status_code {
64 err = err.with_upstream_status(sc);
65 }
66 err
67}
68
69impl GeminiModel {
70 fn gemini_part_thought_signature(value: &serde_json::Value) -> Option<String> {
71 value.get("thoughtSignature").and_then(serde_json::Value::as_str).map(str::to_string)
72 }
73
74 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
75 let model_name = model.into();
76 let client = Gemini::with_model(api_key.into(), model_name.clone())
77 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
78
79 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
80 }
81
82 #[cfg(feature = "gemini-vertex")]
86 pub fn new_google_cloud(
87 api_key: impl Into<String>,
88 project_id: impl AsRef<str>,
89 location: impl AsRef<str>,
90 model: impl Into<String>,
91 ) -> Result<Self> {
92 let model_name = model.into();
93 let client = Gemini::with_google_cloud_model(
94 api_key.into(),
95 project_id,
96 location,
97 model_name.clone(),
98 )
99 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
100
101 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
102 }
103
104 #[cfg(feature = "gemini-vertex")]
108 pub fn new_google_cloud_service_account(
109 service_account_json: &str,
110 project_id: impl AsRef<str>,
111 location: impl AsRef<str>,
112 model: impl Into<String>,
113 ) -> Result<Self> {
114 let model_name = model.into();
115 let client = Gemini::with_google_cloud_service_account_json(
116 service_account_json,
117 project_id.as_ref(),
118 location.as_ref(),
119 model_name.clone(),
120 )
121 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
122
123 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
124 }
125
126 #[cfg(feature = "gemini-vertex")]
130 pub fn new_google_cloud_adc(
131 project_id: impl AsRef<str>,
132 location: impl AsRef<str>,
133 model: impl Into<String>,
134 ) -> Result<Self> {
135 let model_name = model.into();
136 let client = Gemini::with_google_cloud_adc_model(
137 project_id.as_ref(),
138 location.as_ref(),
139 model_name.clone(),
140 )
141 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
142
143 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
144 }
145
146 #[cfg(feature = "gemini-vertex")]
150 pub fn new_google_cloud_wif(
151 wif_json: &str,
152 project_id: impl AsRef<str>,
153 location: impl AsRef<str>,
154 model: impl Into<String>,
155 ) -> Result<Self> {
156 let model_name = model.into();
157 let client = Gemini::with_google_cloud_wif_json(
158 wif_json,
159 project_id.as_ref(),
160 location.as_ref(),
161 model_name.clone(),
162 )
163 .map_err(|e| adk_core::AdkError::model(e.to_string()))?;
164
165 Ok(Self { client, model_name, retry_config: RetryConfig::default() })
166 }
167
168 #[must_use]
169 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
170 self.retry_config = retry_config;
171 self
172 }
173
174 pub fn set_retry_config(&mut self, retry_config: RetryConfig) {
175 self.retry_config = retry_config;
176 }
177
178 pub fn retry_config(&self) -> &RetryConfig {
179 &self.retry_config
180 }
181
182 fn convert_response(resp: &adk_gemini::GenerationResponse) -> Result<LlmResponse> {
183 let mut converted_parts: Vec<Part> = Vec::new();
184
185 if let Some(parts) = resp.candidates.first().and_then(|c| c.content.parts.as_ref()) {
187 for p in parts {
188 match p {
189 adk_gemini::Part::Text { text, thought, thought_signature } => {
190 if thought == &Some(true) {
191 converted_parts.push(Part::Thinking {
192 thinking: text.clone(),
193 signature: thought_signature.clone(),
194 });
195 } else {
196 converted_parts.push(Part::Text { text: text.clone() });
197 }
198 }
199 adk_gemini::Part::InlineData { inline_data } => {
200 let decoded =
201 BASE64_STANDARD.decode(&inline_data.data).map_err(|error| {
202 adk_core::AdkError::model(format!(
203 "failed to decode inline data from gemini response: {error}"
204 ))
205 })?;
206 converted_parts.push(Part::InlineData {
207 mime_type: inline_data.mime_type.clone(),
208 data: decoded,
209 });
210 }
211 adk_gemini::Part::FunctionCall { function_call, thought_signature } => {
212 converted_parts.push(Part::FunctionCall {
213 name: function_call.name.clone(),
214 args: function_call.args.clone(),
215 id: function_call.id.clone(),
216 thought_signature: thought_signature.clone(),
217 });
218 }
219 adk_gemini::Part::FunctionResponse { function_response, .. } => {
220 converted_parts.push(Part::FunctionResponse {
221 function_response: adk_core::FunctionResponseData::new(
222 function_response.name.clone(),
223 function_response
224 .response
225 .clone()
226 .unwrap_or(serde_json::Value::Null),
227 ),
228 id: None,
229 });
230 }
231 adk_gemini::Part::ToolCall { .. } | adk_gemini::Part::ExecutableCode { .. } => {
232 if let Ok(value) = serde_json::to_value(p) {
233 converted_parts.push(Part::ServerToolCall { server_tool_call: value });
234 }
235 }
236 adk_gemini::Part::ToolResponse { .. }
237 | adk_gemini::Part::CodeExecutionResult { .. } => {
238 let value = serde_json::to_value(p).unwrap_or(serde_json::Value::Null);
239 converted_parts
240 .push(Part::ServerToolResponse { server_tool_response: value });
241 }
242 adk_gemini::Part::FileData { file_data } => {
243 converted_parts.push(Part::FileData {
244 mime_type: file_data.mime_type.clone(),
245 file_uri: file_data.file_uri.clone(),
246 });
247 }
248 }
249 }
250 }
251
252 if let Some(grounding) = resp.candidates.first().and_then(|c| c.grounding_metadata.as_ref())
254 {
255 if let Some(queries) = &grounding.web_search_queries {
256 if !queries.is_empty() {
257 let search_info = format!("\n\nš **Searched:** {}", queries.join(", "));
258 converted_parts.push(Part::Text { text: search_info });
259 }
260 }
261 if let Some(chunks) = &grounding.grounding_chunks {
262 let sources: Vec<String> = chunks
263 .iter()
264 .filter_map(|c| {
265 c.web.as_ref().and_then(|w| match (&w.title, &w.uri) {
266 (Some(title), Some(uri)) => Some(format!("[{}]({})", title, uri)),
267 (Some(title), None) => Some(title.clone()),
268 (None, Some(uri)) => Some(uri.to_string()),
269 (None, None) => None,
270 })
271 })
272 .collect();
273 if !sources.is_empty() {
274 let sources_info = format!("\nš **Sources:** {}", sources.join(" | "));
275 converted_parts.push(Part::Text { text: sources_info });
276 }
277 }
278 }
279
280 let content = if converted_parts.is_empty() {
281 None
282 } else {
283 Some(Content { role: "model".to_string(), parts: converted_parts })
284 };
285
286 let usage_metadata = resp.usage_metadata.as_ref().map(|u| UsageMetadata {
287 prompt_token_count: u.prompt_token_count.unwrap_or(0),
288 candidates_token_count: u.candidates_token_count.unwrap_or(0),
289 total_token_count: u.total_token_count.unwrap_or(0),
290 thinking_token_count: u.thoughts_token_count,
291 cache_read_input_token_count: u.cached_content_token_count,
292 ..Default::default()
293 });
294
295 let finish_reason =
296 resp.candidates.first().and_then(|c| c.finish_reason.as_ref()).map(|fr| match fr {
297 adk_gemini::FinishReason::Stop => FinishReason::Stop,
298 adk_gemini::FinishReason::MaxTokens => FinishReason::MaxTokens,
299 adk_gemini::FinishReason::Safety => FinishReason::Safety,
300 adk_gemini::FinishReason::Recitation => FinishReason::Recitation,
301 _ => FinishReason::Other,
302 });
303
304 let citation_metadata =
305 resp.candidates.first().and_then(|c| c.citation_metadata.as_ref()).map(|meta| {
306 CitationMetadata {
307 citation_sources: meta
308 .citation_sources
309 .iter()
310 .map(|source| CitationSource {
311 uri: source.uri.clone(),
312 title: source.title.clone(),
313 start_index: source.start_index,
314 end_index: source.end_index,
315 license: source.license.clone(),
316 publication_date: source.publication_date.map(|d| d.to_string()),
317 })
318 .collect(),
319 }
320 });
321
322 let provider_metadata = resp
325 .candidates
326 .first()
327 .and_then(|c| c.grounding_metadata.as_ref())
328 .and_then(|g| serde_json::to_value(g).ok());
329
330 Ok(LlmResponse {
331 content,
332 usage_metadata,
333 finish_reason,
334 citation_metadata,
335 partial: false,
336 turn_complete: true,
337 interrupted: false,
338 error_code: None,
339 error_message: None,
340 provider_metadata,
341 })
342 }
343
344 fn gemini_function_response_payload(response: serde_json::Value) -> serde_json::Value {
345 match response {
346 serde_json::Value::Object(_) => response,
348 other => serde_json::json!({ "result": other }),
349 }
350 }
351
352 fn merge_object_value(
353 target: &mut serde_json::Map<String, serde_json::Value>,
354 value: serde_json::Value,
355 ) {
356 if let serde_json::Value::Object(object) = value {
357 for (key, value) in object {
358 target.insert(key, value);
359 }
360 }
361 }
362
363 fn build_gemini_tools(
364 tools: &std::collections::HashMap<String, serde_json::Value>,
365 ) -> Result<(Vec<adk_gemini::Tool>, adk_gemini::ToolConfig)> {
366 let mut gemini_tools = Vec::new();
367 let mut function_declarations = Vec::new();
368 let mut has_provider_native_tools = false;
369 let mut tool_config_json = serde_json::Map::new();
370
371 for (name, tool_decl) in tools {
372 if let Some(provider_tool) = tool_decl.get("x-adk-gemini-tool") {
373 let tool = serde_json::from_value::<adk_gemini::Tool>(provider_tool.clone())
374 .map_err(|error| {
375 adk_core::AdkError::model(format!(
376 "failed to deserialize Gemini native tool '{name}': {error}"
377 ))
378 })?;
379 has_provider_native_tools = true;
380 gemini_tools.push(tool);
381 } else if let Ok(func_decl) =
382 serde_json::from_value::<adk_gemini::FunctionDeclaration>(tool_decl.clone())
383 {
384 function_declarations.push(func_decl);
385 } else {
386 return Err(adk_core::AdkError::model(format!(
387 "failed to deserialize Gemini tool '{name}' as a function declaration"
388 )));
389 }
390
391 if let Some(tool_config) = tool_decl.get("x-adk-gemini-tool-config") {
392 Self::merge_object_value(&mut tool_config_json, tool_config.clone());
393 }
394 }
395
396 let has_function_declarations = !function_declarations.is_empty();
397 if has_function_declarations {
398 gemini_tools.push(adk_gemini::Tool::with_functions(function_declarations));
399 }
400
401 if has_provider_native_tools {
402 tool_config_json.insert(
403 "includeServerSideToolInvocations".to_string(),
404 serde_json::Value::Bool(true),
405 );
406 }
407
408 let tool_config = if tool_config_json.is_empty() {
409 adk_gemini::ToolConfig::default()
410 } else {
411 serde_json::from_value::<adk_gemini::ToolConfig>(serde_json::Value::Object(
412 tool_config_json,
413 ))
414 .map_err(|error| {
415 adk_core::AdkError::model(format!(
416 "failed to deserialize Gemini tool configuration: {error}"
417 ))
418 })?
419 };
420
421 Ok((gemini_tools, tool_config))
422 }
423
424 fn stream_chunks_from_response(
425 mut response: LlmResponse,
426 saw_partial_chunk: bool,
427 ) -> (Vec<LlmResponse>, bool) {
428 let is_final = response.finish_reason.is_some();
429
430 if !is_final {
431 response.partial = true;
432 response.turn_complete = false;
433 return (vec![response], true);
434 }
435
436 response.partial = false;
437 response.turn_complete = true;
438
439 if saw_partial_chunk {
440 return (vec![response], true);
441 }
442
443 let synthetic_partial = LlmResponse {
444 content: None,
445 usage_metadata: None,
446 finish_reason: None,
447 citation_metadata: None,
448 partial: true,
449 turn_complete: false,
450 interrupted: false,
451 error_code: None,
452 error_message: None,
453 provider_metadata: None,
454 };
455
456 (vec![synthetic_partial, response], true)
457 }
458
459 async fn generate_content_internal(
460 &self,
461 req: LlmRequest,
462 stream: bool,
463 ) -> Result<LlmResponseStream> {
464 let mut builder = self.client.generate_content();
465
466 let mut fn_call_signatures: std::collections::HashMap<String, String> =
471 std::collections::HashMap::new();
472 for content in &req.contents {
473 if content.role == "model" {
474 for part in &content.parts {
475 if let Part::FunctionCall { name, thought_signature: Some(sig), .. } = part {
476 fn_call_signatures.insert(name.clone(), sig.clone());
477 }
478 }
479 }
480 }
481
482 for content in &req.contents {
484 match content.role.as_str() {
485 "user" => {
486 let mut gemini_parts = Vec::new();
488 for part in &content.parts {
489 match part {
490 Part::Text { text } => {
491 gemini_parts.push(adk_gemini::Part::Text {
492 text: text.clone(),
493 thought: None,
494 thought_signature: None,
495 });
496 }
497 Part::Thinking { thinking, signature } => {
498 gemini_parts.push(adk_gemini::Part::Text {
499 text: thinking.clone(),
500 thought: Some(true),
501 thought_signature: signature.clone(),
502 });
503 }
504 Part::InlineData { data, mime_type } => {
505 let encoded = attachment::encode_base64(data);
506 gemini_parts.push(adk_gemini::Part::InlineData {
507 inline_data: adk_gemini::Blob {
508 mime_type: mime_type.clone(),
509 data: encoded,
510 },
511 });
512 }
513 Part::FileData { mime_type, file_uri } => {
514 gemini_parts.push(adk_gemini::Part::Text {
515 text: attachment::file_attachment_to_text(mime_type, file_uri),
516 thought: None,
517 thought_signature: None,
518 });
519 }
520 _ => {}
521 }
522 }
523 if !gemini_parts.is_empty() {
524 let user_content = adk_gemini::Content {
525 role: Some(adk_gemini::Role::User),
526 parts: Some(gemini_parts),
527 };
528 builder = builder.with_message(adk_gemini::Message {
529 content: user_content,
530 role: adk_gemini::Role::User,
531 });
532 }
533 }
534 "model" => {
535 let mut gemini_parts = Vec::new();
537 for part in &content.parts {
538 match part {
539 Part::Text { text } => {
540 gemini_parts.push(adk_gemini::Part::Text {
541 text: text.clone(),
542 thought: None,
543 thought_signature: None,
544 });
545 }
546 Part::Thinking { thinking, signature } => {
547 gemini_parts.push(adk_gemini::Part::Text {
548 text: thinking.clone(),
549 thought: Some(true),
550 thought_signature: signature.clone(),
551 });
552 }
553 Part::FunctionCall { name, args, thought_signature, id } => {
554 gemini_parts.push(adk_gemini::Part::FunctionCall {
555 function_call: adk_gemini::FunctionCall {
556 name: name.clone(),
557 args: args.clone(),
558 id: id.clone(),
559 thought_signature: None,
560 },
561 thought_signature: thought_signature.clone(),
562 });
563 }
564 Part::ServerToolCall { server_tool_call } => {
565 if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
566 server_tool_call.clone(),
567 ) {
568 match native_part {
569 adk_gemini::Part::ToolCall { .. }
570 | adk_gemini::Part::ExecutableCode { .. } => {
571 gemini_parts.push(native_part);
572 continue;
573 }
574 _ => {}
575 }
576 }
577
578 gemini_parts.push(adk_gemini::Part::ToolCall {
579 tool_call: server_tool_call.clone(),
580 thought_signature: Self::gemini_part_thought_signature(
581 server_tool_call,
582 ),
583 });
584 }
585 Part::ServerToolResponse { server_tool_response } => {
586 if let Ok(native_part) = serde_json::from_value::<adk_gemini::Part>(
587 server_tool_response.clone(),
588 ) {
589 match native_part {
590 adk_gemini::Part::ToolResponse { .. }
591 | adk_gemini::Part::CodeExecutionResult { .. } => {
592 gemini_parts.push(native_part);
593 continue;
594 }
595 _ => {}
596 }
597 }
598
599 gemini_parts.push(adk_gemini::Part::ToolResponse {
600 tool_response: server_tool_response.clone(),
601 thought_signature: Self::gemini_part_thought_signature(
602 server_tool_response,
603 ),
604 });
605 }
606 _ => {}
607 }
608 }
609 if !gemini_parts.is_empty() {
610 let model_content = adk_gemini::Content {
611 role: Some(adk_gemini::Role::Model),
612 parts: Some(gemini_parts),
613 };
614 builder = builder.with_message(adk_gemini::Message {
615 content: model_content,
616 role: adk_gemini::Role::Model,
617 });
618 }
619 }
620 "function" => {
621 let mut gemini_parts = Vec::new();
624 for part in &content.parts {
625 if let Part::FunctionResponse { function_response, .. } = part {
626 let sig = fn_call_signatures.get(&function_response.name).cloned();
627
628 let mut fr_parts = Vec::new();
630 for inline in &function_response.inline_data {
631 let encoded = attachment::encode_base64(&inline.data);
632 fr_parts.push(adk_gemini::FunctionResponsePart::InlineData {
633 inline_data: adk_gemini::Blob {
634 mime_type: inline.mime_type.clone(),
635 data: encoded,
636 },
637 });
638 }
639 for file in &function_response.file_data {
640 fr_parts.push(adk_gemini::FunctionResponsePart::FileData {
641 file_data: adk_gemini::FileDataRef {
642 mime_type: file.mime_type.clone(),
643 file_uri: file.file_uri.clone(),
644 },
645 });
646 }
647
648 let mut gemini_fr = adk_gemini::tools::FunctionResponse::new(
649 &function_response.name,
650 Self::gemini_function_response_payload(
651 function_response.response.clone(),
652 ),
653 );
654 gemini_fr.parts = fr_parts;
655
656 gemini_parts.push(adk_gemini::Part::FunctionResponse {
657 function_response: gemini_fr,
658 thought_signature: sig,
659 });
660 }
661 }
662 if !gemini_parts.is_empty() {
663 let fn_content = adk_gemini::Content {
664 role: Some(adk_gemini::Role::User),
665 parts: Some(gemini_parts),
666 };
667 builder = builder.with_message(adk_gemini::Message {
668 content: fn_content,
669 role: adk_gemini::Role::User,
670 });
671 }
672 }
673 _ => {}
674 }
675 }
676
677 if let Some(config) = req.config {
679 let has_schema = config.response_schema.is_some();
680 let gen_config = adk_gemini::GenerationConfig {
681 temperature: config.temperature,
682 top_p: config.top_p,
683 top_k: config.top_k,
684 max_output_tokens: config.max_output_tokens,
685 response_schema: config.response_schema,
686 response_mime_type: if has_schema {
687 Some("application/json".to_string())
688 } else {
689 None
690 },
691 ..Default::default()
692 };
693 builder = builder.with_generation_config(gen_config);
694
695 if let Some(ref name) = config.cached_content {
697 let handle = self.client.get_cached_content(name);
698 builder = builder.with_cached_content(&handle);
699 }
700 }
701
702 if !req.tools.is_empty() {
704 let (gemini_tools, tool_config) = Self::build_gemini_tools(&req.tools)?;
705 for tool in gemini_tools {
706 builder = builder.with_tool(tool);
707 }
708 if tool_config != adk_gemini::ToolConfig::default() {
709 builder = builder.with_tool_config(tool_config);
710 }
711 }
712
713 if stream {
714 adk_telemetry::debug!("Executing streaming request");
715 let response_stream = builder.execute_stream().await.map_err(|e| {
716 adk_telemetry::error!(error = %e, "Model request failed");
717 gemini_error_to_adk(&e)
718 })?;
719
720 let mapped_stream = async_stream::stream! {
721 let mut stream = response_stream;
722 let mut saw_partial_chunk = false;
723 while let Some(result) = stream.try_next().await.transpose() {
724 match result {
725 Ok(resp) => {
726 match Self::convert_response(&resp) {
727 Ok(llm_resp) => {
728 let (chunks, next_saw_partial) =
729 Self::stream_chunks_from_response(llm_resp, saw_partial_chunk);
730 saw_partial_chunk = next_saw_partial;
731 for chunk in chunks {
732 yield Ok(chunk);
733 }
734 }
735 Err(e) => {
736 adk_telemetry::error!(error = %e, "Failed to convert response");
737 yield Err(e);
738 }
739 }
740 }
741 Err(e) => {
742 adk_telemetry::error!(error = %e, "Stream error");
743 yield Err(gemini_error_to_adk(&e));
744 }
745 }
746 }
747 };
748
749 Ok(Box::pin(mapped_stream))
750 } else {
751 adk_telemetry::debug!("Executing blocking request");
752 let response = builder.execute().await.map_err(|e| {
753 adk_telemetry::error!(error = %e, "Model request failed");
754 gemini_error_to_adk(&e)
755 })?;
756
757 let llm_response = Self::convert_response(&response)?;
758
759 let stream = async_stream::stream! {
760 yield Ok(llm_response);
761 };
762
763 Ok(Box::pin(stream))
764 }
765 }
766
767 pub async fn create_cached_content(
772 &self,
773 system_instruction: &str,
774 tools: &std::collections::HashMap<String, serde_json::Value>,
775 ttl_seconds: u32,
776 ) -> Result<String> {
777 let mut cache_builder = self
778 .client
779 .create_cache()
780 .with_system_instruction(system_instruction)
781 .with_ttl(std::time::Duration::from_secs(u64::from(ttl_seconds)));
782
783 let (gemini_tools, tool_config) = Self::build_gemini_tools(tools)?;
784 if !gemini_tools.is_empty() {
785 cache_builder = cache_builder.with_tools(gemini_tools);
786 }
787 if tool_config != adk_gemini::ToolConfig::default() {
788 cache_builder = cache_builder.with_tool_config(tool_config);
789 }
790
791 let handle = cache_builder
792 .execute()
793 .await
794 .map_err(|e| adk_core::AdkError::model(format!("cache creation failed: {e}")))?;
795
796 Ok(handle.name().to_string())
797 }
798
799 pub async fn delete_cached_content(&self, name: &str) -> Result<()> {
801 let handle = self.client.get_cached_content(name);
802 handle
803 .delete()
804 .await
805 .map_err(|(_, e)| adk_core::AdkError::model(format!("cache deletion failed: {e}")))?;
806 Ok(())
807 }
808}
809
810#[async_trait]
811impl Llm for GeminiModel {
812 fn name(&self) -> &str {
813 &self.model_name
814 }
815
816 #[adk_telemetry::instrument(
817 name = "call_llm",
818 skip(self, req),
819 fields(
820 model.name = %self.model_name,
821 stream = %stream,
822 request.contents_count = %req.contents.len(),
823 request.tools_count = %req.tools.len()
824 )
825 )]
826 async fn generate_content(&self, req: LlmRequest, stream: bool) -> Result<LlmResponseStream> {
827 adk_telemetry::info!("Generating content");
828 let usage_span = adk_telemetry::llm_generate_span("gemini", &self.model_name, stream);
829 let result = execute_with_retry(&self.retry_config, is_retryable_model_error, || {
832 self.generate_content_internal(req.clone(), stream)
833 })
834 .await?;
835 Ok(crate::usage_tracking::with_usage_tracking(result, usage_span))
836 }
837}
838
839#[cfg(test)]
840mod native_tool_tests {
841 use super::*;
842
843 #[test]
844 fn test_build_gemini_tools_supports_native_tool_metadata() {
845 let mut tools = std::collections::HashMap::new();
846 tools.insert(
847 "google_search".to_string(),
848 serde_json::json!({
849 "x-adk-gemini-tool": {
850 "google_search": {}
851 }
852 }),
853 );
854 tools.insert(
855 "lookup_weather".to_string(),
856 serde_json::json!({
857 "name": "lookup_weather",
858 "description": "lookup weather",
859 "parameters": {
860 "type": "object",
861 "properties": {
862 "city": { "type": "string" }
863 }
864 }
865 }),
866 );
867
868 let (gemini_tools, tool_config) =
869 GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
870
871 assert_eq!(gemini_tools.len(), 2);
872 assert_eq!(tool_config.include_server_side_tool_invocations, Some(true));
873 }
874
875 #[test]
876 fn test_build_gemini_tools_sets_flag_for_builtin_only() {
877 let mut tools = std::collections::HashMap::new();
878 tools.insert(
879 "google_search".to_string(),
880 serde_json::json!({
881 "x-adk-gemini-tool": {
882 "google_search": {}
883 }
884 }),
885 );
886
887 let (_gemini_tools, tool_config) =
888 GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
889
890 assert_eq!(
891 tool_config.include_server_side_tool_invocations,
892 Some(true),
893 "includeServerSideToolInvocations should be set even with only built-in tools"
894 );
895 }
896
897 #[test]
898 fn test_build_gemini_tools_no_flag_for_function_only() {
899 let mut tools = std::collections::HashMap::new();
900 tools.insert(
901 "lookup_weather".to_string(),
902 serde_json::json!({
903 "name": "lookup_weather",
904 "description": "lookup weather",
905 "parameters": {
906 "type": "object",
907 "properties": {
908 "city": { "type": "string" }
909 }
910 }
911 }),
912 );
913
914 let (_gemini_tools, tool_config) =
915 GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
916
917 assert_eq!(
918 tool_config.include_server_side_tool_invocations, None,
919 "includeServerSideToolInvocations should NOT be set for function-only tools"
920 );
921 }
922
923 #[test]
924 fn test_build_gemini_tools_merges_native_tool_config() {
925 let mut tools = std::collections::HashMap::new();
926 tools.insert(
927 "google_maps".to_string(),
928 serde_json::json!({
929 "x-adk-gemini-tool": {
930 "google_maps": {
931 "enable_widget": true
932 }
933 },
934 "x-adk-gemini-tool-config": {
935 "retrievalConfig": {
936 "latLng": {
937 "latitude": 1.23,
938 "longitude": 4.56
939 }
940 }
941 }
942 }),
943 );
944
945 let (_gemini_tools, tool_config) =
946 GeminiModel::build_gemini_tools(&tools).expect("tool conversion should succeed");
947
948 assert_eq!(
949 tool_config.retrieval_config,
950 Some(serde_json::json!({
951 "latLng": {
952 "latitude": 1.23,
953 "longitude": 4.56
954 }
955 }))
956 );
957 }
958}
959
960#[async_trait]
961impl CacheCapable for GeminiModel {
962 async fn create_cache(
963 &self,
964 system_instruction: &str,
965 tools: &std::collections::HashMap<String, serde_json::Value>,
966 ttl_seconds: u32,
967 ) -> Result<String> {
968 self.create_cached_content(system_instruction, tools, ttl_seconds).await
969 }
970
971 async fn delete_cache(&self, name: &str) -> Result<()> {
972 self.delete_cached_content(name).await
973 }
974}
975
976#[cfg(test)]
977mod tests {
978 use super::*;
979 use adk_core::AdkError;
980 use std::{
981 sync::{
982 Arc,
983 atomic::{AtomicU32, Ordering},
984 },
985 time::Duration,
986 };
987
988 #[test]
989 fn constructor_is_backward_compatible_and_sync() {
990 fn accepts_sync_constructor<F>(_f: F)
991 where
992 F: Fn(&str, &str) -> Result<GeminiModel>,
993 {
994 }
995
996 accepts_sync_constructor(|api_key, model| GeminiModel::new(api_key, model));
997 }
998
999 #[test]
1000 fn stream_chunks_from_response_injects_partial_before_lone_final_chunk() {
1001 let response = LlmResponse {
1002 content: Some(Content::new("model").with_text("hello")),
1003 usage_metadata: None,
1004 finish_reason: Some(FinishReason::Stop),
1005 citation_metadata: None,
1006 partial: false,
1007 turn_complete: true,
1008 interrupted: false,
1009 error_code: None,
1010 error_message: None,
1011 provider_metadata: None,
1012 };
1013
1014 let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, false);
1015 assert!(saw_partial);
1016 assert_eq!(chunks.len(), 2);
1017 assert!(chunks[0].partial);
1018 assert!(!chunks[0].turn_complete);
1019 assert!(chunks[0].content.is_none());
1020 assert!(!chunks[1].partial);
1021 assert!(chunks[1].turn_complete);
1022 }
1023
1024 #[test]
1025 fn stream_chunks_from_response_keeps_final_only_when_partial_already_seen() {
1026 let response = LlmResponse {
1027 content: Some(Content::new("model").with_text("done")),
1028 usage_metadata: None,
1029 finish_reason: Some(FinishReason::Stop),
1030 citation_metadata: None,
1031 partial: false,
1032 turn_complete: true,
1033 interrupted: false,
1034 error_code: None,
1035 error_message: None,
1036 provider_metadata: None,
1037 };
1038
1039 let (chunks, saw_partial) = GeminiModel::stream_chunks_from_response(response, true);
1040 assert!(saw_partial);
1041 assert_eq!(chunks.len(), 1);
1042 assert!(!chunks[0].partial);
1043 assert!(chunks[0].turn_complete);
1044 }
1045
1046 #[tokio::test]
1047 async fn execute_with_retry_retries_retryable_errors() {
1048 let retry_config = RetryConfig::default()
1049 .with_max_retries(2)
1050 .with_initial_delay(Duration::from_millis(0))
1051 .with_max_delay(Duration::from_millis(0));
1052 let attempts = Arc::new(AtomicU32::new(0));
1053
1054 let result = execute_with_retry(&retry_config, is_retryable_model_error, || {
1055 let attempts = Arc::clone(&attempts);
1056 async move {
1057 let attempt = attempts.fetch_add(1, Ordering::SeqCst);
1058 if attempt < 2 {
1059 return Err(AdkError::model("code 429 RESOURCE_EXHAUSTED"));
1060 }
1061 Ok("ok")
1062 }
1063 })
1064 .await
1065 .expect("retry should eventually succeed");
1066
1067 assert_eq!(result, "ok");
1068 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1069 }
1070
1071 #[tokio::test]
1072 async fn execute_with_retry_does_not_retry_non_retryable_errors() {
1073 let retry_config = RetryConfig::default()
1074 .with_max_retries(3)
1075 .with_initial_delay(Duration::from_millis(0))
1076 .with_max_delay(Duration::from_millis(0));
1077 let attempts = Arc::new(AtomicU32::new(0));
1078
1079 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1080 let attempts = Arc::clone(&attempts);
1081 async move {
1082 attempts.fetch_add(1, Ordering::SeqCst);
1083 Err::<(), _>(AdkError::model("code 400 invalid request"))
1084 }
1085 })
1086 .await
1087 .expect_err("non-retryable error should return immediately");
1088
1089 assert!(error.is_model());
1090 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1091 }
1092
1093 #[tokio::test]
1094 async fn execute_with_retry_respects_disabled_config() {
1095 let retry_config = RetryConfig::disabled().with_max_retries(10);
1096 let attempts = Arc::new(AtomicU32::new(0));
1097
1098 let error = execute_with_retry(&retry_config, is_retryable_model_error, || {
1099 let attempts = Arc::clone(&attempts);
1100 async move {
1101 attempts.fetch_add(1, Ordering::SeqCst);
1102 Err::<(), _>(AdkError::model("code 429 RESOURCE_EXHAUSTED"))
1103 }
1104 })
1105 .await
1106 .expect_err("disabled retries should return first error");
1107
1108 assert!(error.is_model());
1109 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1110 }
1111
1112 #[test]
1113 fn convert_response_preserves_citation_metadata() {
1114 let response = adk_gemini::GenerationResponse {
1115 candidates: vec![adk_gemini::Candidate {
1116 content: adk_gemini::Content {
1117 role: Some(adk_gemini::Role::Model),
1118 parts: Some(vec![adk_gemini::Part::Text {
1119 text: "hello world".to_string(),
1120 thought: None,
1121 thought_signature: None,
1122 }]),
1123 },
1124 safety_ratings: None,
1125 citation_metadata: Some(adk_gemini::CitationMetadata {
1126 citation_sources: vec![adk_gemini::CitationSource {
1127 uri: Some("https://example.com".to_string()),
1128 title: Some("Example".to_string()),
1129 start_index: Some(0),
1130 end_index: Some(5),
1131 license: Some("CC-BY".to_string()),
1132 publication_date: None,
1133 }],
1134 }),
1135 grounding_metadata: None,
1136 finish_reason: Some(adk_gemini::FinishReason::Stop),
1137 index: Some(0),
1138 }],
1139 prompt_feedback: None,
1140 usage_metadata: None,
1141 model_version: None,
1142 response_id: None,
1143 };
1144
1145 let converted =
1146 GeminiModel::convert_response(&response).expect("conversion should succeed");
1147 let metadata = converted.citation_metadata.expect("citation metadata should be mapped");
1148 assert_eq!(metadata.citation_sources.len(), 1);
1149 assert_eq!(metadata.citation_sources[0].uri.as_deref(), Some("https://example.com"));
1150 assert_eq!(metadata.citation_sources[0].start_index, Some(0));
1151 assert_eq!(metadata.citation_sources[0].end_index, Some(5));
1152 }
1153
1154 #[test]
1155 fn convert_response_handles_inline_data_from_model() {
1156 let image_bytes = vec![0x89, 0x50, 0x4E, 0x47];
1157 let encoded = crate::attachment::encode_base64(&image_bytes);
1158
1159 let response = adk_gemini::GenerationResponse {
1160 candidates: vec![adk_gemini::Candidate {
1161 content: adk_gemini::Content {
1162 role: Some(adk_gemini::Role::Model),
1163 parts: Some(vec![
1164 adk_gemini::Part::Text {
1165 text: "Here is the image".to_string(),
1166 thought: None,
1167 thought_signature: None,
1168 },
1169 adk_gemini::Part::InlineData {
1170 inline_data: adk_gemini::Blob {
1171 mime_type: "image/png".to_string(),
1172 data: encoded,
1173 },
1174 },
1175 ]),
1176 },
1177 safety_ratings: None,
1178 citation_metadata: None,
1179 grounding_metadata: None,
1180 finish_reason: Some(adk_gemini::FinishReason::Stop),
1181 index: Some(0),
1182 }],
1183 prompt_feedback: None,
1184 usage_metadata: None,
1185 model_version: None,
1186 response_id: None,
1187 };
1188
1189 let converted =
1190 GeminiModel::convert_response(&response).expect("conversion should succeed");
1191 let content = converted.content.expect("should have content");
1192 assert!(
1193 content
1194 .parts
1195 .iter()
1196 .any(|part| matches!(part, Part::Text { text } if text == "Here is the image"))
1197 );
1198 assert!(content.parts.iter().any(|part| {
1199 matches!(
1200 part,
1201 Part::InlineData { mime_type, data }
1202 if mime_type == "image/png" && data.as_slice() == image_bytes.as_slice()
1203 )
1204 }));
1205 }
1206
1207 #[test]
1208 fn gemini_function_response_payload_preserves_objects() {
1209 let value = serde_json::json!({
1210 "documents": [
1211 { "id": "pricing", "score": 0.91 }
1212 ]
1213 });
1214
1215 let payload = GeminiModel::gemini_function_response_payload(value.clone());
1216
1217 assert_eq!(payload, value);
1218 }
1219
1220 #[test]
1221 fn gemini_function_response_payload_wraps_arrays() {
1222 let payload =
1223 GeminiModel::gemini_function_response_payload(serde_json::json!([{ "id": "pricing" }]));
1224
1225 assert_eq!(payload, serde_json::json!({ "result": [{ "id": "pricing" }] }));
1226 }
1227
1228 fn convert_function_response_to_gemini_fr(
1233 frd: &adk_core::FunctionResponseData,
1234 ) -> adk_gemini::tools::FunctionResponse {
1235 let mut fr_parts = Vec::new();
1236
1237 for inline in &frd.inline_data {
1238 let encoded = crate::attachment::encode_base64(&inline.data);
1239 fr_parts.push(adk_gemini::FunctionResponsePart::InlineData {
1240 inline_data: adk_gemini::Blob {
1241 mime_type: inline.mime_type.clone(),
1242 data: encoded,
1243 },
1244 });
1245 }
1246
1247 for file in &frd.file_data {
1248 fr_parts.push(adk_gemini::FunctionResponsePart::FileData {
1249 file_data: adk_gemini::FileDataRef {
1250 mime_type: file.mime_type.clone(),
1251 file_uri: file.file_uri.clone(),
1252 },
1253 });
1254 }
1255
1256 let mut gemini_fr = adk_gemini::tools::FunctionResponse::new(
1257 &frd.name,
1258 GeminiModel::gemini_function_response_payload(frd.response.clone()),
1259 );
1260 gemini_fr.parts = fr_parts;
1261 gemini_fr
1262 }
1263
1264 #[test]
1265 fn json_only_function_response_has_no_nested_parts() {
1266 let frd = adk_core::FunctionResponseData::new("tool", serde_json::json!({"ok": true}));
1267 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1268 assert!(gemini_fr.parts.is_empty());
1269 let json = serde_json::to_string(&gemini_fr).unwrap();
1271 assert!(!json.contains("\"parts\""));
1272 }
1273
1274 #[test]
1275 fn function_response_with_inline_data_has_nested_parts() {
1276 let frd = adk_core::FunctionResponseData::with_inline_data(
1277 "chart",
1278 serde_json::json!({"status": "ok"}),
1279 vec![adk_core::InlineDataPart {
1280 mime_type: "image/png".to_string(),
1281 data: vec![0x89, 0x50, 0x4E, 0x47],
1282 }],
1283 );
1284 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1285 assert_eq!(gemini_fr.parts.len(), 1);
1286 match &gemini_fr.parts[0] {
1287 adk_gemini::FunctionResponsePart::InlineData { inline_data } => {
1288 assert_eq!(inline_data.mime_type, "image/png");
1289 let decoded = BASE64_STANDARD.decode(&inline_data.data).unwrap();
1290 assert_eq!(decoded, vec![0x89, 0x50, 0x4E, 0x47]);
1291 }
1292 other => panic!("expected InlineData, got {other:?}"),
1293 }
1294 }
1295
1296 #[test]
1297 fn function_response_with_file_data_has_nested_parts() {
1298 let frd = adk_core::FunctionResponseData::with_file_data(
1299 "doc",
1300 serde_json::json!({"ok": true}),
1301 vec![adk_core::FileDataPart {
1302 mime_type: "application/pdf".to_string(),
1303 file_uri: "gs://bucket/report.pdf".to_string(),
1304 }],
1305 );
1306 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1307 assert_eq!(gemini_fr.parts.len(), 1);
1308 match &gemini_fr.parts[0] {
1309 adk_gemini::FunctionResponsePart::FileData { file_data } => {
1310 assert_eq!(file_data.mime_type, "application/pdf");
1311 assert_eq!(file_data.file_uri, "gs://bucket/report.pdf");
1312 }
1313 other => panic!("expected FileData, got {other:?}"),
1314 }
1315 }
1316
1317 #[test]
1318 fn function_response_with_both_inline_and_file_data_ordering() {
1319 let frd = adk_core::FunctionResponseData::with_multimodal(
1320 "multi",
1321 serde_json::json!({}),
1322 vec![
1323 adk_core::InlineDataPart { mime_type: "image/png".to_string(), data: vec![1, 2] },
1324 adk_core::InlineDataPart { mime_type: "image/jpeg".to_string(), data: vec![3, 4] },
1325 ],
1326 vec![adk_core::FileDataPart {
1327 mime_type: "application/pdf".to_string(),
1328 file_uri: "gs://b/f.pdf".to_string(),
1329 }],
1330 );
1331 let gemini_fr = convert_function_response_to_gemini_fr(&frd);
1332 assert_eq!(gemini_fr.parts.len(), 3);
1334 assert!(matches!(&gemini_fr.parts[0], adk_gemini::FunctionResponsePart::InlineData { .. }));
1335 assert!(matches!(&gemini_fr.parts[1], adk_gemini::FunctionResponsePart::InlineData { .. }));
1336 assert!(matches!(&gemini_fr.parts[2], adk_gemini::FunctionResponsePart::FileData { .. }));
1337 }
1338}