Skip to main content

agent_sdk_provider/
gemini.rs

1use std::{fmt, sync::Arc};
2
3use agent_sdk_core::{
4    AgentError, ProviderAdapter, ProviderCapabilities, ProviderMessageRole,
5    ProviderProjectionPolicy, ProviderRequest, ProviderResponse, ProviderStopReason,
6    ProviderToolCall, ProviderUsage, RetryClassification, ToolCallId,
7    tool_records::CanonicalToolName,
8};
9use serde::{Deserialize, Serialize};
10use serde_json::{Value, json};
11
12use crate::{
13    ProviderApiKey, ProviderToolArgumentSink,
14    error::{provider_failure, unsupported_response},
15    http::{CurlJsonHttpTransport, JsonHttpRequest, JsonHttpTransport},
16};
17
18#[derive(Clone, Debug, Eq, PartialEq)]
19/// Configuration for the live Gemini generateContent adapter.
20pub struct GeminiGenerateContentConfig {
21    /// Stable provider ref exposed through `ProviderCapabilities`.
22    pub provider_ref: String,
23    /// Gemini model id.
24    pub model: String,
25    /// Gemini API base URL.
26    pub endpoint_base: String,
27    /// Maximum input tokens advertised by this route.
28    pub max_input_tokens: Option<u32>,
29}
30
31impl GeminiGenerateContentConfig {
32    /// Creates a config for the hosted Gemini generateContent API.
33    pub fn new(model: impl Into<String>) -> Self {
34        Self {
35            provider_ref: "provider.gemini.generate_content".to_string(),
36            model: model.into(),
37            endpoint_base: "https://generativelanguage.googleapis.com/v1beta".to_string(),
38            max_input_tokens: None,
39        }
40    }
41
42    /// Sets the stable provider ref used in SDK capability metadata.
43    pub fn provider_ref(mut self, provider_ref: impl Into<String>) -> Self {
44        self.provider_ref = provider_ref.into();
45        self
46    }
47
48    /// Sets a custom API base URL.
49    pub fn endpoint_base(mut self, endpoint_base: impl Into<String>) -> Self {
50        self.endpoint_base = endpoint_base.into();
51        self
52    }
53
54    /// Sets the maximum input token limit advertised for this route.
55    pub fn max_input_tokens(mut self, max_input_tokens: u32) -> Self {
56        self.max_input_tokens = Some(max_input_tokens);
57        self
58    }
59
60    fn endpoint_url(&self) -> String {
61        let model = self.model.trim_start_matches("models/");
62        format!(
63            "{}/models/{model}:generateContent",
64            self.endpoint_base.trim_end_matches('/')
65        )
66    }
67}
68
69#[derive(Clone)]
70/// Live Gemini generateContent API adapter.
71pub struct GeminiGenerateContentAdapter {
72    config: GeminiGenerateContentConfig,
73    api_key: ProviderApiKey,
74    http: Arc<dyn JsonHttpTransport>,
75    argument_sink: Option<Arc<dyn ProviderToolArgumentSink>>,
76}
77
78impl GeminiGenerateContentAdapter {
79    /// Creates a live adapter using `GEMINI_API_KEY`.
80    pub fn from_env(model: impl Into<String>) -> Result<Self, AgentError> {
81        Self::new(
82            GeminiGenerateContentConfig::new(model),
83            ProviderApiKey::from_env("GEMINI_API_KEY")?,
84        )
85    }
86
87    /// Creates a live adapter with a host-resolved API key.
88    pub fn new(
89        config: GeminiGenerateContentConfig,
90        api_key: ProviderApiKey,
91    ) -> Result<Self, AgentError> {
92        Self::with_transport(config, api_key, Arc::new(CurlJsonHttpTransport::new()))
93    }
94
95    /// Creates an adapter with an injected JSON transport.
96    pub fn with_transport(
97        config: GeminiGenerateContentConfig,
98        api_key: ProviderApiKey,
99        http: Arc<dyn JsonHttpTransport>,
100    ) -> Result<Self, AgentError> {
101        Ok(Self {
102            config,
103            api_key,
104            http,
105            argument_sink: None,
106        })
107    }
108
109    /// Adds an optional host-owned sink for raw tool-call arguments.
110    pub fn with_argument_sink(mut self, sink: Arc<dyn ProviderToolArgumentSink>) -> Self {
111        self.argument_sink = Some(sink);
112        self
113    }
114
115    fn wire_request(&self, request: &ProviderRequest) -> Value {
116        let mut system = Vec::new();
117        let mut contents = Vec::new();
118        for message in &request.messages {
119            match message.role {
120                ProviderMessageRole::System | ProviderMessageRole::Developer => {
121                    system.push(message.content.clone());
122                }
123                ProviderMessageRole::Assistant => {
124                    contents.push(gemini_text_content("model", message.content.clone()));
125                }
126                ProviderMessageRole::Tool => {
127                    contents.push(gemini_text_content(
128                        "user",
129                        format!("Tool result:\n{}", message.content),
130                    ));
131                }
132                ProviderMessageRole::Context | ProviderMessageRole::User => {
133                    contents.push(gemini_text_content("user", message.content.clone()));
134                }
135            }
136        }
137
138        let mut body = json!({ "contents": contents });
139        if !system.is_empty() {
140            body["systemInstruction"] = json!({
141                "parts": [{ "text": system.join("\n\n") }]
142            });
143        }
144        if let Some(generation_config) = gemini_generation_config(request) {
145            body["generationConfig"] = generation_config;
146        }
147        body
148    }
149
150    fn map_response(
151        &self,
152        response: GeminiGenerateContentResponse,
153    ) -> Result<ProviderResponse, AgentError> {
154        let tool_calls = self.tool_calls_from_response(&response)?;
155        let usage = response.usage_metadata.clone().map(ProviderUsage::from);
156        if !tool_calls.is_empty() {
157            let mut mapped = ProviderResponse::tool_use(tool_calls);
158            mapped.usage = usage;
159            return Ok(mapped);
160        }
161        Ok(ProviderResponse {
162            schema_version: ProviderResponse::SCHEMA_VERSION,
163            output_text: response.output_text(),
164            stop_reason: response.stop_reason(),
165            tool_calls: Vec::new(),
166            usage,
167        })
168    }
169
170    fn tool_calls_from_response(
171        &self,
172        response: &GeminiGenerateContentResponse,
173    ) -> Result<Vec<ProviderToolCall>, AgentError> {
174        let mut calls = Vec::new();
175        for candidate in &response.candidates {
176            if let Some(content) = &candidate.content {
177                for part in &content.parts {
178                    let Some(function_call) = &part.function_call else {
179                        continue;
180                    };
181                    let name = function_call.name.as_deref().ok_or_else(|| {
182                        unsupported_response("Gemini generateContent", "functionCall missing name")
183                    })?;
184                    let call_id = function_call
185                        .id
186                        .clone()
187                        .unwrap_or_else(|| format!("gemini_call_{}", calls.len()));
188                    let canonical_tool_name = CanonicalToolName::new(name);
189                    let mut call = ProviderToolCall::new(
190                        ToolCallId::new(call_id.clone()),
191                        canonical_tool_name.clone(),
192                        format!(
193                            "provider requested tool {name} with arguments stored as content refs"
194                        ),
195                    );
196                    if let (Some(sink), Some(args)) =
197                        (self.argument_sink.as_ref(), function_call.args.as_ref())
198                    {
199                        let raw_arguments = serde_json::to_string(args).map_err(|error| {
200                            provider_failure(
201                                RetryClassification::RepairNeeded,
202                                format!(
203                                    "Gemini functionCall args could not be serialized: {error}"
204                                ),
205                            )
206                        })?;
207                        if let Some(args_ref) = sink.store_tool_arguments(
208                            &self.config.provider_ref,
209                            &call_id,
210                            &canonical_tool_name,
211                            &raw_arguments,
212                        )? {
213                            call = call.with_args_ref(args_ref);
214                        }
215                    }
216                    calls.push(call);
217                }
218            }
219        }
220        Ok(calls)
221    }
222}
223
224impl ProviderAdapter for GeminiGenerateContentAdapter {
225    fn capabilities(&self) -> ProviderCapabilities {
226        let mut capabilities = ProviderCapabilities::text_only(self.config.provider_ref.clone());
227        capabilities.supports_usage = true;
228        capabilities.max_input_tokens = self.config.max_input_tokens;
229        capabilities
230    }
231
232    fn project_request(
233        &self,
234        projection: &agent_sdk_core::ContextProjection,
235        policy: &ProviderProjectionPolicy,
236    ) -> Result<ProviderRequest, AgentError> {
237        agent_sdk_core::projection::project_context_projection(projection, policy)
238    }
239
240    fn complete(&self, request: &ProviderRequest) -> Result<ProviderResponse, AgentError> {
241        let http_request =
242            JsonHttpRequest::new(self.config.endpoint_url(), self.wire_request(request))
243                .header("x-goog-api-key", self.api_key.expose_secret())
244                .header("Content-Type", "application/json");
245        let response = self.http.post_json(http_request)?;
246        let message = serde_json::from_value::<GeminiGenerateContentResponse>(response.body)
247            .map_err(|error| unsupported_response("Gemini generateContent", error.to_string()))?;
248        self.map_response(message)
249    }
250}
251
252fn gemini_text_content(role: &str, text: String) -> Value {
253    json!({
254        "role": role,
255        "parts": [{ "text": text }],
256    })
257}
258
259fn gemini_generation_config(request: &ProviderRequest) -> Option<Value> {
260    let hint = request.structured_output_hint.as_ref()?;
261    let schema = hint.redacted_schema.clone()?;
262    Some(json!({
263        "responseMimeType": "application/json",
264        "responseJsonSchema": schema,
265    }))
266}
267
268#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
269/// Minimal Gemini generateContent response shape used by the adapter.
270pub struct GeminiGenerateContentResponse {
271    /// Response candidates.
272    #[serde(default)]
273    pub candidates: Vec<GeminiCandidate>,
274    /// Provider usage accounting.
275    #[serde(rename = "usageMetadata")]
276    pub usage_metadata: Option<GeminiUsage>,
277}
278
279impl GeminiGenerateContentResponse {
280    /// Creates a text response fixture.
281    pub fn text(text: impl Into<String>) -> Self {
282        Self {
283            candidates: vec![GeminiCandidate {
284                content: Some(GeminiContent {
285                    role: Some("model".to_string()),
286                    parts: vec![GeminiPart::text(text)],
287                }),
288                finish_reason: Some("STOP".to_string()),
289            }],
290            usage_metadata: None,
291        }
292    }
293
294    /// Creates a function-call response fixture.
295    pub fn function_call(id: impl Into<String>, name: impl Into<String>, args: Value) -> Self {
296        Self {
297            candidates: vec![GeminiCandidate {
298                content: Some(GeminiContent {
299                    role: Some("model".to_string()),
300                    parts: vec![GeminiPart::function_call(id, name, args)],
301                }),
302                finish_reason: Some("STOP".to_string()),
303            }],
304            usage_metadata: None,
305        }
306    }
307
308    fn output_text(&self) -> String {
309        self.candidates
310            .iter()
311            .filter_map(|candidate| candidate.content.as_ref())
312            .flat_map(|content| content.parts.iter())
313            .filter_map(|part| part.text.as_deref())
314            .collect::<Vec<_>>()
315            .join("")
316    }
317
318    fn stop_reason(&self) -> ProviderStopReason {
319        let reason = self
320            .candidates
321            .first()
322            .and_then(|candidate| candidate.finish_reason.as_deref())
323            .unwrap_or("STOP");
324        match reason {
325            "STOP" => ProviderStopReason::EndTurn,
326            "MAX_TOKENS" => ProviderStopReason::MaxTokens,
327            "SAFETY" | "RECITATION" | "MALFORMED_FUNCTION_CALL" => {
328                ProviderStopReason::ProviderError
329            }
330            _ => ProviderStopReason::Unknown,
331        }
332    }
333}
334
335impl fmt::Debug for GeminiGenerateContentResponse {
336    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
337        formatter
338            .debug_struct("GeminiGenerateContentResponse")
339            .field("candidate_count", &self.candidates.len())
340            .field("candidates", &self.candidates)
341            .field("usage_metadata", &self.usage_metadata)
342            .finish()
343    }
344}
345
346#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
347/// Gemini response candidate.
348pub struct GeminiCandidate {
349    /// Candidate content.
350    pub content: Option<GeminiContent>,
351    /// Provider finish reason.
352    #[serde(rename = "finishReason")]
353    pub finish_reason: Option<String>,
354}
355
356impl fmt::Debug for GeminiCandidate {
357    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
358        formatter
359            .debug_struct("GeminiCandidate")
360            .field("content", &self.content)
361            .field("finish_reason", &self.finish_reason)
362            .finish()
363    }
364}
365
366#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
367/// Gemini content block.
368pub struct GeminiContent {
369    /// Gemini role.
370    pub role: Option<String>,
371    /// Content parts.
372    #[serde(default)]
373    pub parts: Vec<GeminiPart>,
374}
375
376impl fmt::Debug for GeminiContent {
377    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
378        formatter
379            .debug_struct("GeminiContent")
380            .field("role", &self.role)
381            .field("part_count", &self.parts.len())
382            .field("parts", &self.parts)
383            .finish()
384    }
385}
386
387#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
388/// Gemini content part.
389pub struct GeminiPart {
390    /// Text part.
391    pub text: Option<String>,
392    /// Function-call part.
393    #[serde(rename = "functionCall")]
394    pub function_call: Option<GeminiFunctionCall>,
395}
396
397impl GeminiPart {
398    /// Creates a text part.
399    pub fn text(text: impl Into<String>) -> Self {
400        Self {
401            text: Some(text.into()),
402            function_call: None,
403        }
404    }
405
406    /// Creates a function-call part.
407    pub fn function_call(id: impl Into<String>, name: impl Into<String>, args: Value) -> Self {
408        Self {
409            text: None,
410            function_call: Some(GeminiFunctionCall {
411                id: Some(id.into()),
412                name: Some(name.into()),
413                args: Some(args),
414            }),
415        }
416    }
417}
418
419impl fmt::Debug for GeminiPart {
420    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
421        formatter
422            .debug_struct("GeminiPart")
423            .field(
424                "text_chars",
425                &self.text.as_ref().map(|value| value.chars().count()),
426            )
427            .field("function_call", &self.function_call)
428            .finish()
429    }
430}
431
432#[derive(Clone, Deserialize, Eq, PartialEq, Serialize)]
433/// Gemini function-call part.
434pub struct GeminiFunctionCall {
435    /// Provider function-call id.
436    pub id: Option<String>,
437    /// Tool/function name.
438    pub name: Option<String>,
439    /// Function-call arguments.
440    pub args: Option<Value>,
441}
442
443impl fmt::Debug for GeminiFunctionCall {
444    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
445        formatter
446            .debug_struct("GeminiFunctionCall")
447            .field("id", &self.id)
448            .field("name", &self.name)
449            .field("args", &"<redacted>")
450            .field("args_present", &self.args.is_some())
451            .finish()
452    }
453}
454
455#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
456/// Gemini usage accounting.
457pub struct GeminiUsage {
458    /// Provider input tokens.
459    #[serde(rename = "promptTokenCount")]
460    pub prompt_token_count: Option<u32>,
461    /// Provider output tokens.
462    #[serde(rename = "candidatesTokenCount")]
463    pub candidates_token_count: Option<u32>,
464    /// Provider total tokens.
465    #[serde(rename = "totalTokenCount")]
466    pub total_token_count: Option<u32>,
467}
468
469impl From<GeminiUsage> for ProviderUsage {
470    fn from(value: GeminiUsage) -> Self {
471        Self {
472            input_tokens: value.prompt_token_count,
473            output_tokens: value.candidates_token_count,
474            total_tokens: value.total_token_count,
475        }
476    }
477}