1use async_trait::async_trait;
7use base64::{Engine as _, engine::general_purpose::STANDARD};
8use rain_engine_core::{
9 AgentAction, AttachmentContent, AttachmentRef, LlmProvider, PlannedSkillCall,
10 ProviderCacheRecord, ProviderContentPart, ProviderDecision, ProviderError, ProviderErrorKind,
11 ProviderRequest, ProviderUsageRecord, SessionRecord,
12};
13use reqwest::{Client, RequestBuilder, StatusCode};
14use serde::{Deserialize, Serialize};
15use serde_json::{Value, json};
16use thiserror::Error;
17
18#[derive(Debug, Clone)]
19pub enum GeminiAuth {
20 ApiKey(String),
21 BearerToken(String),
22}
23
24#[derive(Debug, Clone)]
25pub struct GeminiConfig {
26 pub base_url: String,
27 pub auth: GeminiAuth,
28 pub default_request: rain_engine_core::ProviderRequestConfig,
29 pub system_instruction: String,
30 pub provider_name: String,
31 pub embedding_model: String,
32}
33
34impl GeminiConfig {
35 pub fn validated(&self) -> Result<(), GeminiConfigError> {
36 if self.base_url.trim().is_empty() {
37 return Err(GeminiConfigError::Invalid(
38 "base_url must not be empty".to_string(),
39 ));
40 }
41 match &self.auth {
42 GeminiAuth::ApiKey(value) | GeminiAuth::BearerToken(value)
43 if value.trim().is_empty() =>
44 {
45 return Err(GeminiConfigError::Invalid(
46 "auth credential must not be empty".to_string(),
47 ));
48 }
49 _ => {}
50 }
51 Ok(())
52 }
53}
54
55#[derive(Debug, Error)]
56pub enum GeminiConfigError {
57 #[error("{0}")]
58 Invalid(String),
59}
60
61#[derive(Clone)]
62pub struct GeminiProvider {
63 client: Client,
64 config: GeminiConfig,
65}
66
67impl GeminiProvider {
68 pub fn new(config: GeminiConfig) -> Result<Self, GeminiConfigError> {
69 config.validated()?;
70 Ok(Self {
71 client: Client::new(),
72 config,
73 })
74 }
75
76 fn latest_cached_content_id(&self, request: &ProviderRequest) -> Option<String> {
77 request
78 .context
79 .history
80 .iter()
81 .rev()
82 .find_map(|record| match record {
83 SessionRecord::ProviderCache(cache)
84 if cache.provider_name == self.config.provider_name =>
85 {
86 Some(cache.cached_content_id.clone())
87 }
88 _ => None,
89 })
90 }
91
92 async fn count_tokens(
93 &self,
94 model: &str,
95 contents: &[GeminiContent],
96 ) -> Result<usize, ProviderError> {
97 let response = self
98 .authorized(self.client.post(format!(
99 "{}/models/{}:countTokens",
100 self.config.base_url.trim_end_matches('/'),
101 model
102 )))
103 .json(&json!({
104 "contents": contents,
105 }))
106 .send()
107 .await
108 .map_err(|err| {
109 ProviderError::new(ProviderErrorKind::Transport, err.to_string(), true)
110 })?;
111 if !response.status().is_success() {
112 let status = response.status();
113 let body = response.text().await.unwrap_or_default();
114 return Err(classify_status(status, body));
115 }
116 let body: CountTokensResponse = response.json().await.map_err(|err| {
117 ProviderError::new(ProviderErrorKind::InvalidResponse, err.to_string(), false)
118 })?;
119 Ok(body.total_tokens)
120 }
121
122 async fn create_cached_content(
123 &self,
124 model: &str,
125 tools: &[GeminiToolEnvelope],
126 stable_contents: &[GeminiContent],
127 token_count: usize,
128 ) -> Result<ProviderCacheRecord, ProviderError> {
129 let response = self
130 .authorized(self.client.post(format!(
131 "{}/cachedContents",
132 self.config.base_url.trim_end_matches('/')
133 )))
134 .json(&json!({
135 "model": format!("models/{model}"),
136 "systemInstruction": {
137 "parts": [{ "text": self.config.system_instruction }]
138 },
139 "tools": tools,
140 "contents": stable_contents,
141 }))
142 .send()
143 .await
144 .map_err(|err| {
145 ProviderError::new(ProviderErrorKind::Transport, err.to_string(), true)
146 })?;
147 if !response.status().is_success() {
148 let status = response.status();
149 let body = response.text().await.unwrap_or_default();
150 return Err(classify_status(status, body));
151 }
152 let body: CreateCachedContentResponse = response.json().await.map_err(|err| {
153 ProviderError::new(ProviderErrorKind::InvalidResponse, err.to_string(), false)
154 })?;
155 Ok(ProviderCacheRecord {
156 provider_name: self.config.provider_name.clone(),
157 cached_content_id: body.name,
158 token_count,
159 cached_at: std::time::SystemTime::now(),
160 })
161 }
162
163 fn authorized(&self, builder: RequestBuilder) -> RequestBuilder {
164 match &self.config.auth {
165 GeminiAuth::ApiKey(key) => builder.query(&[("key", key)]),
166 GeminiAuth::BearerToken(token) => builder.bearer_auth(token),
167 }
168 }
169}
170
171#[async_trait]
172impl LlmProvider for GeminiProvider {
173 async fn generate_action(
174 &self,
175 input: ProviderRequest,
176 ) -> Result<ProviderDecision, ProviderError> {
177 let model = input
178 .config
179 .model
180 .clone()
181 .or_else(|| self.config.default_request.model.clone())
182 .ok_or_else(|| {
183 ProviderError::new(
184 ProviderErrorKind::Configuration,
185 "no model configured for Gemini provider",
186 false,
187 )
188 })?;
189
190 let tools = vec![GeminiToolEnvelope {
191 function_declarations: input
192 .available_skills
193 .iter()
194 .map(|skill| GeminiToolDefinition {
195 name: skill.manifest.name.clone(),
196 description: skill.manifest.description.clone(),
197 parameters: skill.manifest.input_schema.clone(),
198 })
199 .collect(),
200 }];
201 let active_contents = map_provider_contents(&input.contents);
202 let mut cache_record = None;
203 let cached_content = if let Some(existing) = self.latest_cached_content_id(&input) {
204 Some(existing)
205 } else {
206 let token_count = self.count_tokens(&model, &active_contents).await?;
207 if token_count > input.policy.cache_threshold_tokens {
208 let stable_contents = if active_contents.len() > 1 {
210 active_contents[..active_contents.len() - 1].to_vec()
211 } else {
212 active_contents.clone()
213 };
214
215 let created = self
216 .create_cached_content(&model, &tools, &stable_contents, token_count)
217 .await?;
218 let id = created.cached_content_id.clone();
219 cache_record = Some(created);
220 Some(id)
221 } else {
222 None
223 }
224 };
225
226 let request_body = if let Some(cached_content) = &cached_content {
227 let suffix = if !active_contents.is_empty() {
230 vec![active_contents.last().unwrap().clone()]
231 } else {
232 vec![]
233 };
234 json!({
235 "cachedContent": cached_content,
236 "contents": suffix,
237 })
238 } else {
239 json!({
240 "systemInstruction": {
241 "parts": [{ "text": self.config.system_instruction }]
242 },
243 "contents": active_contents,
244 "tools": tools,
245 })
246 };
247
248 let response = self
249 .authorized(self.client.post(format!(
250 "{}/models/{}:generateContent",
251 self.config.base_url.trim_end_matches('/'),
252 model
253 )))
254 .json(&request_body)
255 .send()
256 .await
257 .map_err(|err| {
258 ProviderError::new(ProviderErrorKind::Transport, err.to_string(), true)
259 })?;
260 if !response.status().is_success() {
261 let status = response.status();
262 let body = response.text().await.unwrap_or_default();
263 return Err(classify_status(status, body));
264 }
265
266 let raw_text = response.text().await.map_err(|err| {
267 ProviderError::new(ProviderErrorKind::Transport, err.to_string(), true)
268 })?;
269 let body: GenerateContentResponse = serde_json::from_str(&raw_text).map_err(|err| {
270 tracing::error!("Gemini response deserialization failed: {err}\nRaw body: {raw_text}");
271 ProviderError::new(
272 ProviderErrorKind::InvalidResponse,
273 format!("error decoding response body: {err}"),
274 false,
275 )
276 })?;
277 let candidate = body.candidates.into_iter().next().ok_or_else(|| {
278 ProviderError::new(
279 ProviderErrorKind::InvalidResponse,
280 "provider returned no candidates",
281 false,
282 )
283 })?;
284 let content = candidate.content.ok_or_else(|| {
285 let reason = candidate.finish_reason.unwrap_or_else(|| "unknown".into());
286 ProviderError::new(
287 ProviderErrorKind::InvalidResponse,
288 format!("candidate blocked by provider (reason: {reason})"),
289 false,
290 )
291 })?;
292
293 let mut calls = Vec::new();
294 let mut text_parts = Vec::new();
295 for (index, part) in content.parts.into_iter().enumerate() {
296 if part.thought == Some(true) {
298 continue;
299 }
300 if let Some(function_call) = part.function_call {
301 calls.push(PlannedSkillCall {
302 call_id: function_call
303 .id
304 .unwrap_or_else(|| format!("gemini-call-{index}")),
305 name: function_call.name,
306 args: function_call.args.unwrap_or_else(|| json!({})),
307 priority: 0,
308 depends_on: Vec::new(),
309 retry_policy: Default::default(),
310 dry_run: false,
311 });
312 } else if let Some(text) = part.text {
313 text_parts.push(text);
314 }
315 }
316
317 let usage = body.usage_metadata.map(|usage| ProviderUsageRecord {
318 provider_name: self.config.provider_name.clone(),
319 recorded_at: std::time::SystemTime::now(),
320 input_tokens: usage.prompt_token_count,
321 output_tokens: usage.candidates_token_count,
322 estimated_cost_usd: ((usage.prompt_token_count + usage.candidates_token_count) as f64)
323 / 1_000_000.0,
324 cached_content_id: cached_content,
325 });
326
327 let action = if !calls.is_empty() {
328 AgentAction::CallSkills(calls)
329 } else {
330 let joined = text_parts.join("\n");
331 if joined.trim().is_empty() {
332 AgentAction::Yield { reason: None }
333 } else if let Ok(structured) = serde_json::from_str::<StructuredAction>(&joined) {
334 match structured.kind.as_str() {
335 "yield" => AgentAction::Yield {
336 reason: structured.content,
337 },
338 _ => AgentAction::Respond {
339 content: structured.content.unwrap_or_default(),
340 },
341 }
342 } else {
343 AgentAction::Respond { content: joined }
344 }
345 };
346
347 Ok(ProviderDecision {
348 action,
349 usage,
350 cache: cache_record,
351 })
352 }
353}
354
355fn map_provider_contents(contents: &[rain_engine_core::ProviderMessage]) -> Vec<GeminiContent> {
356 contents
357 .iter()
358 .map(|message| GeminiContent {
359 role: match message.role {
360 rain_engine_core::ProviderRole::System => "user".to_string(),
361 rain_engine_core::ProviderRole::User => "user".to_string(),
362 rain_engine_core::ProviderRole::Assistant => "model".to_string(),
363 rain_engine_core::ProviderRole::Tool => "user".to_string(),
364 },
365 parts: message
366 .parts
367 .iter()
368 .flat_map(|part| map_provider_part_with_role(part, &message.role))
369 .collect::<Vec<_>>(),
370 })
371 .collect()
372}
373
374fn map_provider_part_with_role(
375 part: &ProviderContentPart,
376 role: &rain_engine_core::ProviderRole,
377) -> Vec<GeminiPart> {
378 match part {
379 ProviderContentPart::Json(value) if *role == rain_engine_core::ProviderRole::Assistant => {
380 if let Ok(calls) = serde_json::from_value::<Vec<PlannedSkillCall>>(value.clone()) {
381 return calls
382 .into_iter()
383 .map(|c| GeminiPart {
384 text: None,
385 inline_data: None,
386 file_data: None,
387 function_call: Some(FunctionCall {
388 id: Some(c.call_id),
389 name: c.name,
390 args: Some(c.args),
391 }),
392 function_response: None,
393 })
394 .collect();
395 }
396 vec![GeminiPart {
397 text: Some(value.to_string()),
398 inline_data: None,
399 file_data: None,
400 function_call: None,
401 function_response: None,
402 }]
403 }
404 ProviderContentPart::Text(text) => vec![GeminiPart {
405 text: Some(text.clone()),
406 inline_data: None,
407 file_data: None,
408 function_call: None,
409 function_response: None,
410 }],
411 ProviderContentPart::Json(value) => vec![GeminiPart {
412 text: Some(value.to_string()),
413 inline_data: None,
414 file_data: None,
415 function_call: None,
416 function_response: None,
417 }],
418 ProviderContentPart::InlineData(payload) => vec![GeminiPart {
419 text: None,
420 inline_data: Some(InlineData {
421 mime_type: payload.mime_type.clone(),
422 data: STANDARD.encode(&payload.data),
423 }),
424 file_data: None,
425 function_call: None,
426 function_response: None,
427 }],
428 ProviderContentPart::Attachment(attachment) => vec![map_attachment_part(attachment)],
429 ProviderContentPart::ToolResult(result) => vec![GeminiPart {
430 text: None,
431 inline_data: None,
432 file_data: None,
433 function_call: None,
434 function_response: Some(FunctionResponse {
435 name: result.skill_name.clone(),
436 response: json!({
437 "call_id": result.call_id,
438 "output": result.output.as_ref().map_or_else(
439 |err| json!({ "error": err.message }),
440 truncate_tool_output,
441 ),
442 }),
443 }),
444 }],
445 }
446}
447
448fn map_attachment_part(attachment: &AttachmentRef) -> GeminiPart {
449 match &attachment.content {
450 AttachmentContent::Inline { data } => GeminiPart {
451 text: None,
452 inline_data: Some(InlineData {
453 mime_type: attachment.mime_type.clone(),
454 data: STANDARD.encode(data),
455 }),
456 file_data: None,
457 function_call: None,
458 function_response: None,
459 },
460 AttachmentContent::Blob { descriptor } => GeminiPart {
461 text: None,
462 inline_data: None,
463 file_data: Some(FileData {
464 mime_type: attachment.mime_type.clone(),
465 file_uri: descriptor.uri.clone(),
466 }),
467 function_call: None,
468 function_response: None,
469 },
470 }
471}
472
473fn truncate_tool_output(value: &Value) -> Value {
474 match value {
475 Value::String(s) if s.len() > 65536 => {
476 json!(format!(
477 "{}... [TRUNCATED {} bytes for token efficiency]",
478 &s[..65536],
479 s.len() - 65536
480 ))
481 }
482 Value::Object(map) => {
483 let mut new_map = serde_json::Map::new();
484 for (k, v) in map {
485 new_map.insert(k.clone(), truncate_tool_output(v));
486 }
487 Value::Object(new_map)
488 }
489 Value::Array(arr) => Value::Array(arr.iter().map(truncate_tool_output).collect()),
490 _ => value.clone(),
491 }
492}
493
494fn classify_status(status: StatusCode, body: String) -> ProviderError {
495 match status {
496 StatusCode::TOO_MANY_REQUESTS => {
497 ProviderError::new(ProviderErrorKind::RateLimited, body, true)
498 }
499 StatusCode::BAD_REQUEST => {
500 ProviderError::new(ProviderErrorKind::InvalidResponse, body, false)
501 }
502 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
503 ProviderError::new(ProviderErrorKind::Configuration, body, false)
504 }
505 _ if status.is_server_error() => {
506 ProviderError::new(ProviderErrorKind::Transport, body, true)
507 }
508 _ => ProviderError::new(ProviderErrorKind::Internal, body, false),
509 }
510}
511
512#[derive(Debug, Serialize)]
513struct GeminiEmbedRequest {
514 model: String,
515 content: GeminiContent,
516}
517
518#[derive(Debug, Serialize)]
519struct GeminiBatchEmbedRequest {
520 requests: Vec<GeminiEmbedRequest>,
521}
522
523#[derive(Debug, Deserialize)]
524struct GeminiBatchEmbedResponse {
525 embeddings: Vec<GeminiEmbedding>,
526}
527
528#[derive(Debug, Deserialize)]
529struct GeminiEmbedding {
530 values: Vec<f32>,
531}
532
533#[async_trait]
534impl rain_engine_core::EmbeddingProvider for GeminiProvider {
535 async fn generate_embeddings(
536 &self,
537 texts: Vec<String>,
538 ) -> Result<Vec<Vec<f32>>, rain_engine_core::ProviderError> {
539 let model = format!("models/{}", self.config.embedding_model);
540 let requests = texts
541 .into_iter()
542 .map(|text| GeminiEmbedRequest {
543 model: model.clone(),
544 content: GeminiContent {
545 role: "user".to_string(),
546 parts: vec![GeminiPart {
547 text: Some(text),
548 inline_data: None,
549 file_data: None,
550 function_call: None,
551 function_response: None,
552 }],
553 },
554 })
555 .collect::<Vec<_>>();
556
557 let response = self
558 .authorized(self.client.post(format!(
559 "{}/{}:batchEmbedContents",
560 self.config.base_url.trim_end_matches('/'),
561 model
562 )))
563 .json(&GeminiBatchEmbedRequest { requests })
564 .send()
565 .await
566 .map_err(|err| {
567 rain_engine_core::ProviderError::new(
568 rain_engine_core::ProviderErrorKind::Transport,
569 err.to_string(),
570 true,
571 )
572 })?;
573
574 if !response.status().is_success() {
575 let status = response.status();
576 let body = response.text().await.unwrap_or_default();
577 return Err(classify_status(status, body));
578 }
579
580 let body: GeminiBatchEmbedResponse = response.json().await.map_err(|err| {
581 rain_engine_core::ProviderError::new(
582 rain_engine_core::ProviderErrorKind::InvalidResponse,
583 err.to_string(),
584 false,
585 )
586 })?;
587
588 Ok(body.embeddings.into_iter().map(|e| e.values).collect())
589 }
590}
591
592#[derive(Debug, Serialize, Clone)]
593struct GeminiContent {
594 role: String,
595 parts: Vec<GeminiPart>,
596}
597
598#[derive(Debug, Serialize, Clone)]
599struct GeminiPart {
600 #[serde(skip_serializing_if = "Option::is_none")]
601 text: Option<String>,
602 #[serde(rename = "inlineData", skip_serializing_if = "Option::is_none")]
603 inline_data: Option<InlineData>,
604 #[serde(rename = "fileData", skip_serializing_if = "Option::is_none")]
605 file_data: Option<FileData>,
606 #[serde(rename = "functionCall", skip_serializing_if = "Option::is_none")]
607 function_call: Option<FunctionCall>,
608 #[serde(rename = "functionResponse", skip_serializing_if = "Option::is_none")]
609 function_response: Option<FunctionResponse>,
610}
611
612#[derive(Debug, Serialize, Clone)]
613struct InlineData {
614 #[serde(rename = "mimeType")]
615 mime_type: String,
616 data: String,
617}
618
619#[derive(Debug, Serialize, Clone)]
620struct FileData {
621 #[serde(rename = "mimeType")]
622 mime_type: String,
623 #[serde(rename = "fileUri")]
624 file_uri: String,
625}
626
627#[derive(Debug, Serialize, Clone)]
628struct FunctionResponse {
629 name: String,
630 response: Value,
631}
632
633#[derive(Debug, Serialize, Clone)]
634struct GeminiToolEnvelope {
635 #[serde(rename = "functionDeclarations")]
636 function_declarations: Vec<GeminiToolDefinition>,
637}
638
639#[derive(Debug, Serialize, Clone)]
640struct GeminiToolDefinition {
641 name: String,
642 description: String,
643 parameters: Value,
644}
645
646#[derive(Debug, Deserialize)]
647struct CountTokensResponse {
648 #[serde(rename = "totalTokens")]
649 total_tokens: usize,
650}
651
652#[derive(Debug, Deserialize)]
653struct CreateCachedContentResponse {
654 name: String,
655}
656
657#[derive(Debug, Deserialize)]
660#[serde(rename_all = "camelCase")]
661struct GenerateContentResponse {
662 #[serde(default)]
663 candidates: Vec<GenerateCandidate>,
664 #[serde(default)]
665 usage_metadata: Option<UsageMetadata>,
666}
667
668#[derive(Debug, Deserialize)]
669#[serde(rename_all = "camelCase")]
670struct GenerateCandidate {
671 content: Option<GenerateContent>,
673 #[serde(default)]
674 finish_reason: Option<String>,
675}
676
677#[derive(Debug, Deserialize)]
678struct GenerateContent {
679 #[serde(default)]
680 parts: Vec<GeneratePart>,
681}
682
683#[derive(Debug, Deserialize)]
684#[serde(rename_all = "camelCase")]
685struct GeneratePart {
686 text: Option<String>,
687 function_call: Option<FunctionCall>,
688 #[serde(default)]
691 thought: Option<bool>,
692}
693
694#[derive(Debug, Serialize, Deserialize, Clone)]
695struct FunctionCall {
696 id: Option<String>,
697 name: String,
698 args: Option<Value>,
699}
700
701#[allow(dead_code)]
704#[derive(Debug, Deserialize)]
705#[serde(rename_all = "camelCase")]
706struct UsageMetadata {
707 #[serde(default)]
708 prompt_token_count: u64,
709 #[serde(default)]
710 candidates_token_count: u64,
711 #[serde(default)]
712 total_token_count: Option<u64>,
713 #[serde(default)]
715 thoughts_token_count: Option<u64>,
716}
717
718#[derive(Debug, Deserialize)]
719struct StructuredAction {
720 #[serde(rename = "type")]
721 kind: String,
722 content: Option<String>,
723}
724
725#[cfg(test)]
726mod tests {
727 use super::*;
728 use axum::{Json, Router, extract::State, routing::post};
729 use rain_engine_core::{
730 AgentContextSnapshot, AgentId, AgentStateSnapshot, AgentTrigger, AttachmentRef,
731 EnginePolicy, ProviderMessage, ProviderRequestConfig, ProviderRole, ResourcePolicy,
732 SkillDefinition, SkillManifest,
733 };
734 use serde_json::json;
735 use std::sync::{Arc, Mutex};
736
737 #[derive(Clone, Default)]
738 struct TestState {
739 requests: Arc<Mutex<Vec<Value>>>,
740 }
741
742 fn provider_request(with_attachment: bool) -> ProviderRequest {
743 let contents = vec![ProviderMessage {
744 role: ProviderRole::User,
745 parts: if with_attachment {
746 vec![ProviderContentPart::Attachment(AttachmentRef::inline(
747 "a1",
748 "image/png",
749 Some("diagram.png".to_string()),
750 vec![1, 2, 3, 4],
751 ))]
752 } else {
753 vec![ProviderContentPart::Text("hello".to_string())]
754 },
755 }];
756 ProviderRequest {
757 trigger: AgentTrigger::Message {
758 user_id: "u".to_string(),
759 content: "hello".to_string(),
760 attachments: Vec::new(),
761 },
762 context: AgentContextSnapshot {
763 session_id: "s".to_string(),
764 granted_scopes: vec!["tool:run".to_string()],
765 trigger_id: "t".to_string(),
766 idempotency_key: None,
767 current_step: 0,
768 max_steps: 8,
769 history: Vec::new(),
770 prior_tool_results: Vec::new(),
771 session_cost_usd: 0.0,
772 state: AgentStateSnapshot {
773 agent_id: AgentId("s".to_string()),
774 profile: None,
775 goals: Vec::new(),
776 tasks: Vec::new(),
777 observations: Vec::new(),
778 artifacts: Vec::new(),
779 resources: Vec::new(),
780 relationships: Vec::new(),
781 pending_wake: None,
782 },
783 policy: EnginePolicy::default(),
784 active_execution_plan: None,
785 },
786 available_skills: vec![SkillDefinition {
787 manifest: SkillManifest {
788 name: "db_fix".to_string(),
789 description: "Fix DB".to_string(),
790 input_schema: json!({"type":"object"}),
791 required_scopes: vec!["tool:run".to_string()],
792 capability_grants: vec![],
793 resource_policy: ResourcePolicy::default_for_tools(),
794 approval_required: false,
795 circuit_breaker_threshold: 0.5,
796 },
797 executor_kind: "native".to_string(),
798 }],
799 config: ProviderRequestConfig {
800 model: Some("gemini-1.5-pro".to_string()),
801 temperature: None,
802 max_tokens: None,
803 },
804 policy: EnginePolicy {
805 cache_threshold_tokens: 10,
806 ..EnginePolicy::default()
807 },
808 contents,
809 }
810 }
811
812 async fn spawn_test_server() -> (String, TestState) {
813 let state = TestState::default();
814 let app = Router::new()
815 .route(
816 "/models/gemini-1.5-pro:countTokens",
817 post(
818 |State(state): State<TestState>, Json(body): Json<Value>| async move {
819 state.requests.lock().expect("requests lock").push(body);
820 Json(json!({"totalTokens": 50}))
821 },
822 ),
823 )
824 .route(
825 "/cachedContents",
826 post(
827 |State(state): State<TestState>, Json(body): Json<Value>| async move {
828 state.requests.lock().expect("requests lock").push(body);
829 Json(json!({"name": "cachedContents/cache-1"}))
830 },
831 ),
832 )
833 .route(
834 "/models/gemini-1.5-pro:generateContent",
835 post(
836 |State(state): State<TestState>, Json(body): Json<Value>| async move {
837 state.requests.lock().expect("requests lock").push(body);
838 Json(json!({
839 "candidates": [{
840 "content": {
841 "parts": [{
842 "functionCall": {
843 "id": "fc-1",
844 "name": "db_fix",
845 "args": {"apply": true}
846 }
847 }, {
848 "functionCall": {
849 "id": "fc-2",
850 "name": "db_fix",
851 "args": {"apply": false}
852 }
853 }]
854 }
855 }],
856 "usageMetadata": {
857 "promptTokenCount": 123,
858 "candidatesTokenCount": 45
859 }
860 }))
861 },
862 ),
863 )
864 .with_state(state.clone());
865
866 let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
867 .await
868 .expect("bind");
869 let addr = listener.local_addr().expect("addr");
870 tokio::spawn(async move {
871 axum::serve(listener, app).await.expect("server");
872 });
873 (format!("http://{}", addr), state)
874 }
875
876 #[tokio::test]
877 async fn maps_inline_attachment_and_parallel_calls() {
878 let (base_url, state) = spawn_test_server().await;
879 let provider = GeminiProvider::new(GeminiConfig {
880 base_url,
881 auth: GeminiAuth::ApiKey("token".to_string()),
882 default_request: ProviderRequestConfig::default(),
883 system_instruction: "You are helpful".to_string(),
884 provider_name: "gemini".to_string(),
885 embedding_model: "text-embedding-004".to_string(),
886 })
887 .expect("provider");
888
889 let decision = provider
890 .generate_action(provider_request(true))
891 .await
892 .expect("decision");
893 match decision.action {
894 AgentAction::CallSkills(calls) => assert_eq!(calls.len(), 2),
895 other => panic!("expected parallel calls, got {other:?}"),
896 }
897 assert!(decision.cache.is_some());
898 assert!(decision.usage.is_some());
899
900 let requests = state.requests.lock().expect("requests");
901 let generate = requests.last().expect("generate request");
902 let body = generate.to_string();
903 assert!(body.contains("inlineData"));
904 assert!(body.contains("cachedContents/cache-1"));
905 }
906
907 #[tokio::test]
908 async fn reuses_existing_cache_without_recreating() {
909 let (base_url, state) = spawn_test_server().await;
910 let provider = GeminiProvider::new(GeminiConfig {
911 base_url,
912 auth: GeminiAuth::ApiKey("token".to_string()),
913 default_request: ProviderRequestConfig::default(),
914 system_instruction: "You are helpful".to_string(),
915 provider_name: "gemini".to_string(),
916 embedding_model: "text-embedding-004".to_string(),
917 })
918 .expect("provider");
919
920 let mut request = provider_request(false);
921 request
922 .context
923 .history
924 .push(SessionRecord::ProviderCache(ProviderCacheRecord {
925 provider_name: "gemini".to_string(),
926 cached_content_id: "cachedContents/existing".to_string(),
927 token_count: 99_999,
928 cached_at: std::time::SystemTime::now(),
929 }));
930 let _ = provider.generate_action(request).await.expect("decision");
931 let requests = state.requests.lock().expect("requests");
932 let body = requests.last().expect("generate request").to_string();
933 assert!(body.contains("cachedContents/existing"));
934 }
935}