Skip to main content

rain_engine_provider_gemini/
lib.rs

1//! Gemini provider adapter for RainEngine.
2//!
3//! This crate maps provider-neutral content, tool declarations, parallel tool
4//! calls, and cache metadata into Gemini REST requests.
5
6use async_trait::async_trait;
7use base64::{Engine as _, engine::general_purpose::STANDARD};
8use rain_engine_core::{
9    AgentAction, AttachmentContent, AttachmentRef, LlmProvider, PlannedSkillCall,
10    ProviderCacheRecord, ProviderContentPart, ProviderDecision, ProviderError, ProviderErrorKind,
11    ProviderRequest, ProviderUsageRecord, SessionRecord,
12};
13use reqwest::{Client, RequestBuilder, StatusCode};
14use serde::{Deserialize, Serialize};
15use serde_json::{Value, json};
16use thiserror::Error;
17
18#[derive(Debug, Clone)]
19pub enum GeminiAuth {
20    ApiKey(String),
21    BearerToken(String),
22}
23
24#[derive(Debug, Clone)]
25pub struct GeminiConfig {
26    pub base_url: String,
27    pub auth: GeminiAuth,
28    pub default_request: rain_engine_core::ProviderRequestConfig,
29    pub system_instruction: String,
30    pub provider_name: String,
31    pub embedding_model: String,
32}
33
34impl GeminiConfig {
35    pub fn validated(&self) -> Result<(), GeminiConfigError> {
36        if self.base_url.trim().is_empty() {
37            return Err(GeminiConfigError::Invalid(
38                "base_url must not be empty".to_string(),
39            ));
40        }
41        match &self.auth {
42            GeminiAuth::ApiKey(value) | GeminiAuth::BearerToken(value)
43                if value.trim().is_empty() =>
44            {
45                return Err(GeminiConfigError::Invalid(
46                    "auth credential must not be empty".to_string(),
47                ));
48            }
49            _ => {}
50        }
51        Ok(())
52    }
53}
54
55#[derive(Debug, Error)]
56pub enum GeminiConfigError {
57    #[error("{0}")]
58    Invalid(String),
59}
60
61#[derive(Clone)]
62pub struct GeminiProvider {
63    client: Client,
64    config: GeminiConfig,
65}
66
67impl GeminiProvider {
68    pub fn new(config: GeminiConfig) -> Result<Self, GeminiConfigError> {
69        config.validated()?;
70        Ok(Self {
71            client: Client::new(),
72            config,
73        })
74    }
75
76    fn latest_cached_content_id(&self, request: &ProviderRequest) -> Option<String> {
77        request
78            .context
79            .history
80            .iter()
81            .rev()
82            .find_map(|record| match record {
83                SessionRecord::ProviderCache(cache)
84                    if cache.provider_name == self.config.provider_name =>
85                {
86                    Some(cache.cached_content_id.clone())
87                }
88                _ => None,
89            })
90    }
91
92    async fn count_tokens(
93        &self,
94        model: &str,
95        contents: &[GeminiContent],
96    ) -> Result<usize, ProviderError> {
97        let response = self
98            .authorized(self.client.post(format!(
99                "{}/models/{}:countTokens",
100                self.config.base_url.trim_end_matches('/'),
101                model
102            )))
103            .json(&json!({
104                "contents": contents,
105            }))
106            .send()
107            .await
108            .map_err(|err| {
109                ProviderError::new(ProviderErrorKind::Transport, err.to_string(), true)
110            })?;
111        if !response.status().is_success() {
112            let status = response.status();
113            let body = response.text().await.unwrap_or_default();
114            return Err(classify_status(status, body));
115        }
116        let body: CountTokensResponse = response.json().await.map_err(|err| {
117            ProviderError::new(ProviderErrorKind::InvalidResponse, err.to_string(), false)
118        })?;
119        Ok(body.total_tokens)
120    }
121
122    async fn create_cached_content(
123        &self,
124        model: &str,
125        tools: &[GeminiToolEnvelope],
126        stable_contents: &[GeminiContent],
127        token_count: usize,
128    ) -> Result<ProviderCacheRecord, ProviderError> {
129        let response = self
130            .authorized(self.client.post(format!(
131                "{}/cachedContents",
132                self.config.base_url.trim_end_matches('/')
133            )))
134            .json(&json!({
135                "model": format!("models/{model}"),
136                "systemInstruction": {
137                    "parts": [{ "text": self.config.system_instruction }]
138                },
139                "tools": tools,
140                "contents": stable_contents,
141            }))
142            .send()
143            .await
144            .map_err(|err| {
145                ProviderError::new(ProviderErrorKind::Transport, err.to_string(), true)
146            })?;
147        if !response.status().is_success() {
148            let status = response.status();
149            let body = response.text().await.unwrap_or_default();
150            return Err(classify_status(status, body));
151        }
152        let body: CreateCachedContentResponse = response.json().await.map_err(|err| {
153            ProviderError::new(ProviderErrorKind::InvalidResponse, err.to_string(), false)
154        })?;
155        Ok(ProviderCacheRecord {
156            provider_name: self.config.provider_name.clone(),
157            cached_content_id: body.name,
158            token_count,
159            cached_at: std::time::SystemTime::now(),
160        })
161    }
162
163    fn authorized(&self, builder: RequestBuilder) -> RequestBuilder {
164        match &self.config.auth {
165            GeminiAuth::ApiKey(key) => builder.query(&[("key", key)]),
166            GeminiAuth::BearerToken(token) => builder.bearer_auth(token),
167        }
168    }
169}
170
171#[async_trait]
172impl LlmProvider for GeminiProvider {
173    async fn generate_action(
174        &self,
175        input: ProviderRequest,
176    ) -> Result<ProviderDecision, ProviderError> {
177        let model = input
178            .config
179            .model
180            .clone()
181            .or_else(|| self.config.default_request.model.clone())
182            .ok_or_else(|| {
183                ProviderError::new(
184                    ProviderErrorKind::Configuration,
185                    "no model configured for Gemini provider",
186                    false,
187                )
188            })?;
189
190        let tools = vec![GeminiToolEnvelope {
191            function_declarations: input
192                .available_skills
193                .iter()
194                .map(|skill| GeminiToolDefinition {
195                    name: skill.manifest.name.clone(),
196                    description: skill.manifest.description.clone(),
197                    parameters: skill.manifest.input_schema.clone(),
198                })
199                .collect(),
200        }];
201        let active_contents = map_provider_contents(&input.contents);
202        let mut cache_record = None;
203        let cached_content = if let Some(existing) = self.latest_cached_content_id(&input) {
204            Some(existing)
205        } else {
206            let token_count = self.count_tokens(&model, &active_contents).await?;
207            if token_count > input.policy.cache_threshold_tokens {
208                // Stable content for caching is the entire history EXCEPT the latest user prompt
209                let stable_contents = if active_contents.len() > 1 {
210                    active_contents[..active_contents.len() - 1].to_vec()
211                } else {
212                    active_contents.clone()
213                };
214
215                let created = self
216                    .create_cached_content(&model, &tools, &stable_contents, token_count)
217                    .await?;
218                let id = created.cached_content_id.clone();
219                cache_record = Some(created);
220                Some(id)
221            } else {
222                None
223            }
224        };
225
226        let request_body = if let Some(cached_content) = &cached_content {
227            // When using a cache, Gemini expects only the NEW turns in the 'contents' array.
228            // Tools and System Instruction are already in the cache and MUST NOT be sent.
229            let suffix = if !active_contents.is_empty() {
230                vec![active_contents.last().unwrap().clone()]
231            } else {
232                vec![]
233            };
234            json!({
235                "cachedContent": cached_content,
236                "contents": suffix,
237            })
238        } else {
239            json!({
240                "systemInstruction": {
241                    "parts": [{ "text": self.config.system_instruction }]
242                },
243                "contents": active_contents,
244                "tools": tools,
245            })
246        };
247
248        let response = self
249            .authorized(self.client.post(format!(
250                "{}/models/{}:generateContent",
251                self.config.base_url.trim_end_matches('/'),
252                model
253            )))
254            .json(&request_body)
255            .send()
256            .await
257            .map_err(|err| {
258                ProviderError::new(ProviderErrorKind::Transport, err.to_string(), true)
259            })?;
260        if !response.status().is_success() {
261            let status = response.status();
262            let body = response.text().await.unwrap_or_default();
263            return Err(classify_status(status, body));
264        }
265
266        let raw_text = response.text().await.map_err(|err| {
267            ProviderError::new(ProviderErrorKind::Transport, err.to_string(), true)
268        })?;
269        let body: GenerateContentResponse = serde_json::from_str(&raw_text).map_err(|err| {
270            tracing::error!("Gemini response deserialization failed: {err}\nRaw body: {raw_text}");
271            ProviderError::new(
272                ProviderErrorKind::InvalidResponse,
273                format!("error decoding response body: {err}"),
274                false,
275            )
276        })?;
277        let candidate = body.candidates.into_iter().next().ok_or_else(|| {
278            ProviderError::new(
279                ProviderErrorKind::InvalidResponse,
280                "provider returned no candidates",
281                false,
282            )
283        })?;
284        let content = candidate.content.ok_or_else(|| {
285            let reason = candidate.finish_reason.unwrap_or_else(|| "unknown".into());
286            ProviderError::new(
287                ProviderErrorKind::InvalidResponse,
288                format!("candidate blocked by provider (reason: {reason})"),
289                false,
290            )
291        })?;
292
293        let mut calls = Vec::new();
294        let mut text_parts = Vec::new();
295        for (index, part) in content.parts.into_iter().enumerate() {
296            // Skip internal thinking/reasoning parts from Gemini 2.5+ models
297            if part.thought == Some(true) {
298                continue;
299            }
300            if let Some(function_call) = part.function_call {
301                calls.push(PlannedSkillCall {
302                    call_id: function_call
303                        .id
304                        .unwrap_or_else(|| format!("gemini-call-{index}")),
305                    name: function_call.name,
306                    args: function_call.args.unwrap_or_else(|| json!({})),
307                    priority: 0,
308                    depends_on: Vec::new(),
309                    retry_policy: Default::default(),
310                    dry_run: false,
311                });
312            } else if let Some(text) = part.text {
313                text_parts.push(text);
314            }
315        }
316
317        let usage = body.usage_metadata.map(|usage| ProviderUsageRecord {
318            provider_name: self.config.provider_name.clone(),
319            recorded_at: std::time::SystemTime::now(),
320            input_tokens: usage.prompt_token_count,
321            output_tokens: usage.candidates_token_count,
322            estimated_cost_usd: ((usage.prompt_token_count + usage.candidates_token_count) as f64)
323                / 1_000_000.0,
324            cached_content_id: cached_content,
325        });
326
327        let action = if !calls.is_empty() {
328            AgentAction::CallSkills(calls)
329        } else {
330            let joined = text_parts.join("\n");
331            if joined.trim().is_empty() {
332                AgentAction::Yield { reason: None }
333            } else if let Ok(structured) = serde_json::from_str::<StructuredAction>(&joined) {
334                match structured.kind.as_str() {
335                    "yield" => AgentAction::Yield {
336                        reason: structured.content,
337                    },
338                    _ => AgentAction::Respond {
339                        content: structured.content.unwrap_or_default(),
340                    },
341                }
342            } else {
343                AgentAction::Respond { content: joined }
344            }
345        };
346
347        Ok(ProviderDecision {
348            action,
349            usage,
350            cache: cache_record,
351        })
352    }
353}
354
355fn map_provider_contents(contents: &[rain_engine_core::ProviderMessage]) -> Vec<GeminiContent> {
356    contents
357        .iter()
358        .map(|message| GeminiContent {
359            role: match message.role {
360                rain_engine_core::ProviderRole::System => "user".to_string(),
361                rain_engine_core::ProviderRole::User => "user".to_string(),
362                rain_engine_core::ProviderRole::Assistant => "model".to_string(),
363                rain_engine_core::ProviderRole::Tool => "user".to_string(),
364            },
365            parts: message
366                .parts
367                .iter()
368                .flat_map(|part| map_provider_part_with_role(part, &message.role))
369                .collect::<Vec<_>>(),
370        })
371        .collect()
372}
373
374fn map_provider_part_with_role(
375    part: &ProviderContentPart,
376    role: &rain_engine_core::ProviderRole,
377) -> Vec<GeminiPart> {
378    match part {
379        ProviderContentPart::Json(value) if *role == rain_engine_core::ProviderRole::Assistant => {
380            if let Ok(calls) = serde_json::from_value::<Vec<PlannedSkillCall>>(value.clone()) {
381                return calls
382                    .into_iter()
383                    .map(|c| GeminiPart {
384                        text: None,
385                        inline_data: None,
386                        file_data: None,
387                        function_call: Some(FunctionCall {
388                            id: Some(c.call_id),
389                            name: c.name,
390                            args: Some(c.args),
391                        }),
392                        function_response: None,
393                    })
394                    .collect();
395            }
396            vec![GeminiPart {
397                text: Some(value.to_string()),
398                inline_data: None,
399                file_data: None,
400                function_call: None,
401                function_response: None,
402            }]
403        }
404        ProviderContentPart::Text(text) => vec![GeminiPart {
405            text: Some(text.clone()),
406            inline_data: None,
407            file_data: None,
408            function_call: None,
409            function_response: None,
410        }],
411        ProviderContentPart::Json(value) => vec![GeminiPart {
412            text: Some(value.to_string()),
413            inline_data: None,
414            file_data: None,
415            function_call: None,
416            function_response: None,
417        }],
418        ProviderContentPart::InlineData(payload) => vec![GeminiPart {
419            text: None,
420            inline_data: Some(InlineData {
421                mime_type: payload.mime_type.clone(),
422                data: STANDARD.encode(&payload.data),
423            }),
424            file_data: None,
425            function_call: None,
426            function_response: None,
427        }],
428        ProviderContentPart::Attachment(attachment) => vec![map_attachment_part(attachment)],
429        ProviderContentPart::ToolResult(result) => vec![GeminiPart {
430            text: None,
431            inline_data: None,
432            file_data: None,
433            function_call: None,
434            function_response: Some(FunctionResponse {
435                name: result.skill_name.clone(),
436                response: json!({
437                    "call_id": result.call_id,
438                    "output": result.output.as_ref().map_or_else(
439                        |err| json!({ "error": err.message }),
440                        truncate_tool_output,
441                    ),
442                }),
443            }),
444        }],
445    }
446}
447
448fn map_attachment_part(attachment: &AttachmentRef) -> GeminiPart {
449    match &attachment.content {
450        AttachmentContent::Inline { data } => GeminiPart {
451            text: None,
452            inline_data: Some(InlineData {
453                mime_type: attachment.mime_type.clone(),
454                data: STANDARD.encode(data),
455            }),
456            file_data: None,
457            function_call: None,
458            function_response: None,
459        },
460        AttachmentContent::Blob { descriptor } => GeminiPart {
461            text: None,
462            inline_data: None,
463            file_data: Some(FileData {
464                mime_type: attachment.mime_type.clone(),
465                file_uri: descriptor.uri.clone(),
466            }),
467            function_call: None,
468            function_response: None,
469        },
470    }
471}
472
473fn truncate_tool_output(value: &Value) -> Value {
474    match value {
475        Value::String(s) if s.len() > 65536 => {
476            json!(format!(
477                "{}... [TRUNCATED {} bytes for token efficiency]",
478                &s[..65536],
479                s.len() - 65536
480            ))
481        }
482        Value::Object(map) => {
483            let mut new_map = serde_json::Map::new();
484            for (k, v) in map {
485                new_map.insert(k.clone(), truncate_tool_output(v));
486            }
487            Value::Object(new_map)
488        }
489        Value::Array(arr) => Value::Array(arr.iter().map(truncate_tool_output).collect()),
490        _ => value.clone(),
491    }
492}
493
494fn classify_status(status: StatusCode, body: String) -> ProviderError {
495    match status {
496        StatusCode::TOO_MANY_REQUESTS => {
497            ProviderError::new(ProviderErrorKind::RateLimited, body, true)
498        }
499        StatusCode::BAD_REQUEST => {
500            ProviderError::new(ProviderErrorKind::InvalidResponse, body, false)
501        }
502        StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
503            ProviderError::new(ProviderErrorKind::Configuration, body, false)
504        }
505        _ if status.is_server_error() => {
506            ProviderError::new(ProviderErrorKind::Transport, body, true)
507        }
508        _ => ProviderError::new(ProviderErrorKind::Internal, body, false),
509    }
510}
511
512#[derive(Debug, Serialize)]
513struct GeminiEmbedRequest {
514    model: String,
515    content: GeminiContent,
516}
517
518#[derive(Debug, Serialize)]
519struct GeminiBatchEmbedRequest {
520    requests: Vec<GeminiEmbedRequest>,
521}
522
523#[derive(Debug, Deserialize)]
524struct GeminiBatchEmbedResponse {
525    embeddings: Vec<GeminiEmbedding>,
526}
527
528#[derive(Debug, Deserialize)]
529struct GeminiEmbedding {
530    values: Vec<f32>,
531}
532
533#[async_trait]
534impl rain_engine_core::EmbeddingProvider for GeminiProvider {
535    async fn generate_embeddings(
536        &self,
537        texts: Vec<String>,
538    ) -> Result<Vec<Vec<f32>>, rain_engine_core::ProviderError> {
539        let model = format!("models/{}", self.config.embedding_model);
540        let requests = texts
541            .into_iter()
542            .map(|text| GeminiEmbedRequest {
543                model: model.clone(),
544                content: GeminiContent {
545                    role: "user".to_string(),
546                    parts: vec![GeminiPart {
547                        text: Some(text),
548                        inline_data: None,
549                        file_data: None,
550                        function_call: None,
551                        function_response: None,
552                    }],
553                },
554            })
555            .collect::<Vec<_>>();
556
557        let response = self
558            .authorized(self.client.post(format!(
559                "{}/{}:batchEmbedContents",
560                self.config.base_url.trim_end_matches('/'),
561                model
562            )))
563            .json(&GeminiBatchEmbedRequest { requests })
564            .send()
565            .await
566            .map_err(|err| {
567                rain_engine_core::ProviderError::new(
568                    rain_engine_core::ProviderErrorKind::Transport,
569                    err.to_string(),
570                    true,
571                )
572            })?;
573
574        if !response.status().is_success() {
575            let status = response.status();
576            let body = response.text().await.unwrap_or_default();
577            return Err(classify_status(status, body));
578        }
579
580        let body: GeminiBatchEmbedResponse = response.json().await.map_err(|err| {
581            rain_engine_core::ProviderError::new(
582                rain_engine_core::ProviderErrorKind::InvalidResponse,
583                err.to_string(),
584                false,
585            )
586        })?;
587
588        Ok(body.embeddings.into_iter().map(|e| e.values).collect())
589    }
590}
591
592#[derive(Debug, Serialize, Clone)]
593struct GeminiContent {
594    role: String,
595    parts: Vec<GeminiPart>,
596}
597
598#[derive(Debug, Serialize, Clone)]
599struct GeminiPart {
600    #[serde(skip_serializing_if = "Option::is_none")]
601    text: Option<String>,
602    #[serde(rename = "inlineData", skip_serializing_if = "Option::is_none")]
603    inline_data: Option<InlineData>,
604    #[serde(rename = "fileData", skip_serializing_if = "Option::is_none")]
605    file_data: Option<FileData>,
606    #[serde(rename = "functionCall", skip_serializing_if = "Option::is_none")]
607    function_call: Option<FunctionCall>,
608    #[serde(rename = "functionResponse", skip_serializing_if = "Option::is_none")]
609    function_response: Option<FunctionResponse>,
610}
611
612#[derive(Debug, Serialize, Clone)]
613struct InlineData {
614    #[serde(rename = "mimeType")]
615    mime_type: String,
616    data: String,
617}
618
619#[derive(Debug, Serialize, Clone)]
620struct FileData {
621    #[serde(rename = "mimeType")]
622    mime_type: String,
623    #[serde(rename = "fileUri")]
624    file_uri: String,
625}
626
627#[derive(Debug, Serialize, Clone)]
628struct FunctionResponse {
629    name: String,
630    response: Value,
631}
632
633#[derive(Debug, Serialize, Clone)]
634struct GeminiToolEnvelope {
635    #[serde(rename = "functionDeclarations")]
636    function_declarations: Vec<GeminiToolDefinition>,
637}
638
639#[derive(Debug, Serialize, Clone)]
640struct GeminiToolDefinition {
641    name: String,
642    description: String,
643    parameters: Value,
644}
645
646#[derive(Debug, Deserialize)]
647struct CountTokensResponse {
648    #[serde(rename = "totalTokens")]
649    total_tokens: usize,
650}
651
652#[derive(Debug, Deserialize)]
653struct CreateCachedContentResponse {
654    name: String,
655}
656
657/// Tolerant of extra fields returned by Gemini 2.5+ and 3.x models
658/// (e.g. modelVersion, responseId, promptFeedback, etc.)
659#[derive(Debug, Deserialize)]
660#[serde(rename_all = "camelCase")]
661struct GenerateContentResponse {
662    #[serde(default)]
663    candidates: Vec<GenerateCandidate>,
664    #[serde(default)]
665    usage_metadata: Option<UsageMetadata>,
666}
667
668#[derive(Debug, Deserialize)]
669#[serde(rename_all = "camelCase")]
670struct GenerateCandidate {
671    /// Content may be absent when the candidate is blocked by safety filters
672    content: Option<GenerateContent>,
673    #[serde(default)]
674    finish_reason: Option<String>,
675}
676
677#[derive(Debug, Deserialize)]
678struct GenerateContent {
679    #[serde(default)]
680    parts: Vec<GeneratePart>,
681}
682
683#[derive(Debug, Deserialize)]
684#[serde(rename_all = "camelCase")]
685struct GeneratePart {
686    text: Option<String>,
687    function_call: Option<FunctionCall>,
688    /// Gemini 2.5+ "thinking" models include a `thought` flag on internal
689    /// reasoning parts — we capture it so we can skip those parts.
690    #[serde(default)]
691    thought: Option<bool>,
692}
693
694#[derive(Debug, Serialize, Deserialize, Clone)]
695struct FunctionCall {
696    id: Option<String>,
697    name: String,
698    args: Option<Value>,
699}
700
701/// Usage metadata — all fields optional and defaulted to 0 because field
702/// names and availability vary across Gemini model generations.
703#[allow(dead_code)]
704#[derive(Debug, Deserialize)]
705#[serde(rename_all = "camelCase")]
706struct UsageMetadata {
707    #[serde(default)]
708    prompt_token_count: u64,
709    #[serde(default)]
710    candidates_token_count: u64,
711    #[serde(default)]
712    total_token_count: Option<u64>,
713    /// Gemini 2.5+ thinking models report thought tokens separately
714    #[serde(default)]
715    thoughts_token_count: Option<u64>,
716}
717
718#[derive(Debug, Deserialize)]
719struct StructuredAction {
720    #[serde(rename = "type")]
721    kind: String,
722    content: Option<String>,
723}
724
725#[cfg(test)]
726mod tests {
727    use super::*;
728    use axum::{Json, Router, extract::State, routing::post};
729    use rain_engine_core::{
730        AgentContextSnapshot, AgentId, AgentStateSnapshot, AgentTrigger, AttachmentRef,
731        EnginePolicy, ProviderMessage, ProviderRequestConfig, ProviderRole, ResourcePolicy,
732        SkillDefinition, SkillManifest,
733    };
734    use serde_json::json;
735    use std::sync::{Arc, Mutex};
736
737    #[derive(Clone, Default)]
738    struct TestState {
739        requests: Arc<Mutex<Vec<Value>>>,
740    }
741
742    fn provider_request(with_attachment: bool) -> ProviderRequest {
743        let contents = vec![ProviderMessage {
744            role: ProviderRole::User,
745            parts: if with_attachment {
746                vec![ProviderContentPart::Attachment(AttachmentRef::inline(
747                    "a1",
748                    "image/png",
749                    Some("diagram.png".to_string()),
750                    vec![1, 2, 3, 4],
751                ))]
752            } else {
753                vec![ProviderContentPart::Text("hello".to_string())]
754            },
755        }];
756        ProviderRequest {
757            trigger: AgentTrigger::Message {
758                user_id: "u".to_string(),
759                content: "hello".to_string(),
760                attachments: Vec::new(),
761            },
762            context: AgentContextSnapshot {
763                session_id: "s".to_string(),
764                granted_scopes: vec!["tool:run".to_string()],
765                trigger_id: "t".to_string(),
766                idempotency_key: None,
767                current_step: 0,
768                max_steps: 8,
769                history: Vec::new(),
770                prior_tool_results: Vec::new(),
771                session_cost_usd: 0.0,
772                state: AgentStateSnapshot {
773                    agent_id: AgentId("s".to_string()),
774                    profile: None,
775                    goals: Vec::new(),
776                    tasks: Vec::new(),
777                    observations: Vec::new(),
778                    artifacts: Vec::new(),
779                    resources: Vec::new(),
780                    relationships: Vec::new(),
781                    pending_wake: None,
782                },
783                policy: EnginePolicy::default(),
784                active_execution_plan: None,
785            },
786            available_skills: vec![SkillDefinition {
787                manifest: SkillManifest {
788                    name: "db_fix".to_string(),
789                    description: "Fix DB".to_string(),
790                    input_schema: json!({"type":"object"}),
791                    required_scopes: vec!["tool:run".to_string()],
792                    capability_grants: vec![],
793                    resource_policy: ResourcePolicy::default_for_tools(),
794                    approval_required: false,
795                    circuit_breaker_threshold: 0.5,
796                },
797                executor_kind: "native".to_string(),
798            }],
799            config: ProviderRequestConfig {
800                model: Some("gemini-1.5-pro".to_string()),
801                temperature: None,
802                max_tokens: None,
803            },
804            policy: EnginePolicy {
805                cache_threshold_tokens: 10,
806                ..EnginePolicy::default()
807            },
808            contents,
809        }
810    }
811
812    async fn spawn_test_server() -> (String, TestState) {
813        let state = TestState::default();
814        let app = Router::new()
815            .route(
816                "/models/gemini-1.5-pro:countTokens",
817                post(
818                    |State(state): State<TestState>, Json(body): Json<Value>| async move {
819                        state.requests.lock().expect("requests lock").push(body);
820                        Json(json!({"totalTokens": 50}))
821                    },
822                ),
823            )
824            .route(
825                "/cachedContents",
826                post(
827                    |State(state): State<TestState>, Json(body): Json<Value>| async move {
828                        state.requests.lock().expect("requests lock").push(body);
829                        Json(json!({"name": "cachedContents/cache-1"}))
830                    },
831                ),
832            )
833            .route(
834                "/models/gemini-1.5-pro:generateContent",
835                post(
836                    |State(state): State<TestState>, Json(body): Json<Value>| async move {
837                        state.requests.lock().expect("requests lock").push(body);
838                        Json(json!({
839                            "candidates": [{
840                                "content": {
841                                    "parts": [{
842                                        "functionCall": {
843                                            "id": "fc-1",
844                                            "name": "db_fix",
845                                            "args": {"apply": true}
846                                        }
847                                    }, {
848                                        "functionCall": {
849                                            "id": "fc-2",
850                                            "name": "db_fix",
851                                            "args": {"apply": false}
852                                        }
853                                    }]
854                                }
855                            }],
856                            "usageMetadata": {
857                                "promptTokenCount": 123,
858                                "candidatesTokenCount": 45
859                            }
860                        }))
861                    },
862                ),
863            )
864            .with_state(state.clone());
865
866        let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
867            .await
868            .expect("bind");
869        let addr = listener.local_addr().expect("addr");
870        tokio::spawn(async move {
871            axum::serve(listener, app).await.expect("server");
872        });
873        (format!("http://{}", addr), state)
874    }
875
876    #[tokio::test]
877    async fn maps_inline_attachment_and_parallel_calls() {
878        let (base_url, state) = spawn_test_server().await;
879        let provider = GeminiProvider::new(GeminiConfig {
880            base_url,
881            auth: GeminiAuth::ApiKey("token".to_string()),
882            default_request: ProviderRequestConfig::default(),
883            system_instruction: "You are helpful".to_string(),
884            provider_name: "gemini".to_string(),
885            embedding_model: "text-embedding-004".to_string(),
886        })
887        .expect("provider");
888
889        let decision = provider
890            .generate_action(provider_request(true))
891            .await
892            .expect("decision");
893        match decision.action {
894            AgentAction::CallSkills(calls) => assert_eq!(calls.len(), 2),
895            other => panic!("expected parallel calls, got {other:?}"),
896        }
897        assert!(decision.cache.is_some());
898        assert!(decision.usage.is_some());
899
900        let requests = state.requests.lock().expect("requests");
901        let generate = requests.last().expect("generate request");
902        let body = generate.to_string();
903        assert!(body.contains("inlineData"));
904        assert!(body.contains("cachedContents/cache-1"));
905    }
906
907    #[tokio::test]
908    async fn reuses_existing_cache_without_recreating() {
909        let (base_url, state) = spawn_test_server().await;
910        let provider = GeminiProvider::new(GeminiConfig {
911            base_url,
912            auth: GeminiAuth::ApiKey("token".to_string()),
913            default_request: ProviderRequestConfig::default(),
914            system_instruction: "You are helpful".to_string(),
915            provider_name: "gemini".to_string(),
916            embedding_model: "text-embedding-004".to_string(),
917        })
918        .expect("provider");
919
920        let mut request = provider_request(false);
921        request
922            .context
923            .history
924            .push(SessionRecord::ProviderCache(ProviderCacheRecord {
925                provider_name: "gemini".to_string(),
926                cached_content_id: "cachedContents/existing".to_string(),
927                token_count: 99_999,
928                cached_at: std::time::SystemTime::now(),
929            }));
930        let _ = provider.generate_action(request).await.expect("decision");
931        let requests = state.requests.lock().expect("requests");
932        let body = requests.last().expect("generate request").to_string();
933        assert!(body.contains("cachedContents/existing"));
934    }
935}