Skip to main content

rustic_ai/providers/
gemini.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use base64::Engine;
5use reqwest::{Client, Url};
6use serde::Deserialize;
7use serde_json::{Map, Value, json};
8use uuid::Uuid;
9
10use crate::messages::{
11    ModelMessage, ModelRequestPart, ModelResponse, ModelResponsePart, TextPart, ToolCallPart,
12    UserContent,
13};
14use crate::model::{Model, ModelError, ModelRequestParameters, ModelSettings, OutputMode};
15use crate::providers::{Provider, ProviderError};
16use crate::usage::RequestUsage;
17
18fn map_reqwest_error(label: &str, error: reqwest::Error) -> ModelError {
19    if error.is_timeout() {
20        return ModelError::Timeout;
21    }
22    if error.is_connect() {
23        return ModelError::Transport(format!("{label} connect error: {error}"));
24    }
25    ModelError::Transport(format!("{label} request failed: {error}"))
26}
27
28fn truncate_error_body(body: &str) -> String {
29    const LIMIT: usize = 512;
30    if body.len() <= LIMIT {
31        body.to_string()
32    } else {
33        format!("{}... ({} bytes)", &body[..LIMIT], body.len())
34    }
35}
36
37fn normalize_tool_call_id(id: Option<String>) -> String {
38    match id {
39        Some(value) if !value.trim().is_empty() => value,
40        _ => format!("call_{}", Uuid::new_v4().simple()),
41    }
42}
43
44fn gemini_response_object(value: &Value) -> Value {
45    match value {
46        Value::Object(_) => value.clone(),
47        _ => {
48            let mut wrapped = Map::new();
49            wrapped.insert("return_value".to_string(), value.clone());
50            Value::Object(wrapped)
51        }
52    }
53}
54
55fn is_null_schema(value: &Value) -> bool {
56    matches!(
57        value,
58        Value::Object(map) if matches!(map.get("type"), Some(Value::String(t)) if t == "null")
59    )
60}
61
62fn sanitize_gemini_schema(value: &Value) -> Value {
63    match value {
64        Value::Object(map) => {
65            if let Some(variants) = map.get("anyOf").and_then(|val| val.as_array()) {
66                let mut cleaned = variants
67                    .iter()
68                    .filter(|variant| !is_null_schema(variant))
69                    .map(sanitize_gemini_schema)
70                    .collect::<Vec<_>>();
71                if cleaned.len() == 1 {
72                    return cleaned.pop().unwrap_or(Value::Null);
73                }
74            }
75            if let Some(variants) = map.get("oneOf").and_then(|val| val.as_array()) {
76                let mut cleaned = variants
77                    .iter()
78                    .filter(|variant| !is_null_schema(variant))
79                    .map(sanitize_gemini_schema)
80                    .collect::<Vec<_>>();
81                if cleaned.len() == 1 {
82                    return cleaned.pop().unwrap_or(Value::Null);
83                }
84            }
85
86            let mut out = Map::new();
87            for (key, val) in map {
88                if matches!(
89                    key.as_str(),
90                    "additionalProperties" | "$schema" | "$id" | "title"
91                ) {
92                    continue;
93                }
94                if key == "type"
95                    && let Value::Array(types) = val
96                {
97                    if let Some(first) = types
98                        .iter()
99                        .find(|item| !matches!(item, Value::String(t) if t == "null"))
100                    {
101                        out.insert(key.clone(), first.clone());
102                    }
103                    continue;
104                }
105                out.insert(key.clone(), sanitize_gemini_schema(val));
106            }
107            Value::Object(out)
108        }
109        Value::Array(items) => Value::Array(items.iter().map(sanitize_gemini_schema).collect()),
110        _ => value.clone(),
111    }
112}
113
114fn infer_media_type_from_url(url: &str) -> Option<String> {
115    let path = url.split('?').next()?;
116    let ext = path.rsplit('.').next()?.to_lowercase();
117    let media_type = match ext.as_str() {
118        "png" => "image/png",
119        "jpg" | "jpeg" => "image/jpeg",
120        "gif" => "image/gif",
121        "webp" => "image/webp",
122        "pdf" => "application/pdf",
123        "txt" => "text/plain",
124        "md" | "markdown" => "text/markdown",
125        "csv" => "text/csv",
126        "json" => "application/json",
127        "mp3" => "audio/mpeg",
128        "wav" => "audio/wav",
129        "ogg" | "oga" => "audio/ogg",
130        "flac" => "audio/flac",
131        "m4a" | "aac" => "audio/aac",
132        "mp4" => "video/mp4",
133        "mov" => "video/quicktime",
134        "webm" => "video/webm",
135        "mkv" => "video/x-matroska",
136        _ => return None,
137    };
138    Some(media_type.to_string())
139}
140
141fn file_data_part(url: &str, media_type: &Option<String>) -> Value {
142    let mut file_data = Map::new();
143    file_data.insert("fileUri".to_string(), Value::String(url.to_string()));
144    let inferred = media_type
145        .clone()
146        .or_else(|| infer_media_type_from_url(url));
147    if let Some(media_type) = inferred {
148        file_data.insert("mimeType".to_string(), Value::String(media_type.clone()));
149    }
150    let mut wrapper = Map::new();
151    wrapper.insert("fileData".to_string(), Value::Object(file_data));
152    Value::Object(wrapper)
153}
154
155#[derive(Clone, Debug)]
156pub struct GeminiProvider {
157    api_key: String,
158    base_url: Url,
159}
160
161impl GeminiProvider {
162    pub fn new(
163        api_key: impl Into<String>,
164        base_url: impl AsRef<str>,
165    ) -> Result<Self, ProviderError> {
166        let url = Url::parse(base_url.as_ref())
167            .map_err(|_| ProviderError::InvalidModel(base_url.as_ref().to_string()))?;
168        Ok(Self {
169            api_key: api_key.into(),
170            base_url: url,
171        })
172    }
173
174    pub fn from_env() -> Result<Self, ProviderError> {
175        let api_key = std::env::var("GEMINI_API_KEY")
176            .or_else(|_| std::env::var("GOOGLE_API_KEY"))
177            .map_err(|_| ProviderError::MissingApiKey("gemini".to_string()))?;
178        Self::new(api_key, "https://generativelanguage.googleapis.com")
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::messages::{
186        BinaryContent, ImageUrl, ModelMessage, ModelRequest, ModelRequestPart, ModelResponse,
187        ModelResponsePart, ToolCallPart, ToolReturnPart,
188    };
189    use base64::engine::general_purpose::STANDARD;
190    use serde_json::{Value, json};
191    use std::path::PathBuf;
192
193    fn fixture_bytes(name: &str) -> Vec<u8> {
194        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
195            .join("tests")
196            .join("fixtures")
197            .join(name);
198        std::fs::read(path).expect("fixture read")
199    }
200
201    #[test]
202    fn convert_user_content_handles_inline_and_file_data() {
203        let pdf_bytes = fixture_bytes("fixture.pdf");
204        let audio_bytes = fixture_bytes("fixture.m4a");
205
206        let content = vec![
207            UserContent::Binary(BinaryContent {
208                data: pdf_bytes.clone(),
209                media_type: "application/pdf".to_string(),
210            }),
211            UserContent::Binary(BinaryContent {
212                data: audio_bytes.clone(),
213                media_type: "audio/aac".to_string(),
214            }),
215            UserContent::Image(ImageUrl {
216                url: "https://example.com/fixture.jpg".to_string(),
217                media_type: None,
218            }),
219        ];
220
221        let parts = convert_user_content(&content);
222        assert_eq!(parts.len(), 3);
223
224        let pdf = &parts[0];
225        let pdf_inline = pdf.get("inlineData").expect("pdf inline");
226        assert_eq!(
227            pdf_inline.get("mimeType"),
228            Some(&Value::String("application/pdf".to_string()))
229        );
230        assert_eq!(
231            pdf_inline.get("data"),
232            Some(&Value::String(STANDARD.encode(&pdf_bytes)))
233        );
234
235        let audio = &parts[1];
236        let audio_inline = audio.get("inlineData").expect("audio inline");
237        assert_eq!(
238            audio_inline.get("mimeType"),
239            Some(&Value::String("audio/aac".to_string()))
240        );
241        assert_eq!(
242            audio_inline.get("data"),
243            Some(&Value::String(STANDARD.encode(&audio_bytes)))
244        );
245
246        let image = &parts[2];
247        let file_data = image.get("fileData").expect("file data");
248        assert_eq!(
249            file_data.get("fileUri"),
250            Some(&Value::String(
251                "https://example.com/fixture.jpg".to_string()
252            ))
253        );
254        assert_eq!(
255            file_data.get("mimeType"),
256            Some(&Value::String("image/jpeg".to_string()))
257        );
258    }
259
260    #[test]
261    fn split_system_replays_tool_calls() {
262        let messages = vec![
263            ModelMessage::Response(ModelResponse {
264                parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
265                    id: "call-1".to_string(),
266                    name: "get_data".to_string(),
267                    arguments: json!({"a": 1}),
268                })],
269                usage: None,
270                model_name: None,
271                finish_reason: None,
272            }),
273            ModelMessage::Request(ModelRequest {
274                parts: vec![ModelRequestPart::ToolReturn(ToolReturnPart {
275                    tool_name: "get_data".to_string(),
276                    tool_call_id: "call-1".to_string(),
277                    content: json!({"ok": true}),
278                })],
279                instructions: None,
280            }),
281        ];
282
283        let (_system, contents) = GeminiModel::split_system(&messages);
284        assert_eq!(contents.len(), 2);
285
286        let model_msg = contents[0].as_object().expect("model message");
287        assert_eq!(
288            model_msg.get("role"),
289            Some(&Value::String("model".to_string()))
290        );
291        let model_parts = model_msg
292            .get("parts")
293            .and_then(|value| value.as_array())
294            .expect("model parts");
295        let function_call = model_parts
296            .iter()
297            .find_map(|part| part.get("functionCall"))
298            .expect("functionCall");
299        assert_eq!(
300            function_call.get("name"),
301            Some(&Value::String("get_data".to_string()))
302        );
303        assert_eq!(function_call.get("args"), Some(&json!({"a": 1})));
304
305        let user_msg = contents[1].as_object().expect("user message");
306        assert_eq!(
307            user_msg.get("role"),
308            Some(&Value::String("user".to_string()))
309        );
310        let user_parts = user_msg
311            .get("parts")
312            .and_then(|value| value.as_array())
313            .expect("user parts");
314        let function_response = user_parts
315            .iter()
316            .find_map(|part| part.get("functionResponse"))
317            .expect("functionResponse");
318        assert_eq!(
319            function_response.get("name"),
320            Some(&Value::String("get_data".to_string()))
321        );
322        assert_eq!(
323            function_response.get("response"),
324            Some(&json!({"ok": true}))
325        );
326    }
327}
328
329impl Provider for GeminiProvider {
330    fn name(&self) -> &str {
331        "gemini"
332    }
333
334    fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model> {
335        Arc::new(GeminiModel::new(
336            model,
337            self.api_key.clone(),
338            self.base_url.clone(),
339            settings,
340        ))
341    }
342}
343
344#[derive(Clone, Debug)]
345pub struct GeminiModel {
346    model: String,
347    api_key: String,
348    base_url: Url,
349    client: Client,
350    default_settings: Option<ModelSettings>,
351}
352
353impl GeminiModel {
354    pub fn new(
355        model: impl Into<String>,
356        api_key: String,
357        base_url: Url,
358        settings: Option<ModelSettings>,
359    ) -> Self {
360        let mut model = model.into();
361        if !model.starts_with("models/") {
362            model = format!("models/{model}");
363        }
364        Self {
365            model,
366            api_key,
367            base_url,
368            client: Client::new(),
369            default_settings: settings,
370        }
371    }
372
373    fn endpoint(&self) -> Result<Url, ModelError> {
374        let path = format!("v1beta/{}:generateContent", self.model);
375        let mut url = self
376            .base_url
377            .join(&path)
378            .map_err(|e| ModelError::Provider(format!("invalid base url: {e}")))?;
379        url.query_pairs_mut().append_pair("key", &self.api_key);
380        Ok(url)
381    }
382
383    fn split_system(messages: &[ModelMessage]) -> (Option<String>, Vec<Value>) {
384        let mut system_parts = Vec::new();
385        let mut contents = Vec::new();
386
387        for message in messages {
388            match message {
389                ModelMessage::Request(req) => {
390                    if let Some(instructions) = req
391                        .instructions
392                        .as_ref()
393                        .filter(|value| !value.trim().is_empty())
394                    {
395                        system_parts.push(instructions.to_string());
396                    }
397                    for part in &req.parts {
398                        match part {
399                            ModelRequestPart::SystemPrompt(prompt) => {
400                                system_parts.push(prompt.content.clone());
401                            }
402                            ModelRequestPart::UserPrompt(prompt) => contents.push(json!({
403                                "role": "user",
404                                "parts": convert_user_content(&prompt.content)
405                            })),
406                            ModelRequestPart::ToolReturn(tool_return) => contents.push(json!({
407                                "role": "user",
408                                "parts": [{
409                                    "functionResponse": {
410                                        "name": tool_return.tool_name,
411                                        "response": gemini_response_object(&tool_return.content),
412                                    }
413                                }]
414                            })),
415                            ModelRequestPart::RetryPrompt(retry) => {
416                                let parts = if let Some(tool_name) = &retry.tool_name {
417                                    vec![json!({
418                                        "functionResponse": {
419                                            "name": tool_name,
420                                            "response": {"call_error": retry.content}
421                                        }
422                                    })]
423                                } else {
424                                    vec![json!({"text": retry.content})]
425                                };
426                                contents.push(json!({
427                                    "role": "user",
428                                    "parts": parts
429                                }));
430                            }
431                        }
432                    }
433                }
434                ModelMessage::Response(res) => {
435                    let mut parts = Vec::new();
436                    if let Some(text) = res.text() {
437                        parts.push(json!({"text": text}));
438                    }
439                    for call in res.tool_calls() {
440                        parts.push(json!({
441                            "functionCall": {
442                                "name": call.name,
443                                "args": call.arguments,
444                            }
445                        }));
446                    }
447
448                    if !parts.is_empty() {
449                        contents.push(json!({
450                            "role": "model",
451                            "parts": parts
452                        }));
453                    }
454                }
455            }
456        }
457
458        let system = if system_parts.is_empty() {
459            None
460        } else {
461            Some(system_parts.join("\n\n"))
462        };
463
464        (system, contents)
465    }
466}
467
468fn convert_user_content(content: &[UserContent]) -> Vec<Value> {
469    let mut parts = Vec::new();
470    for item in content {
471        match item {
472            UserContent::Text(text) => parts.push(json!({"text": text})),
473            UserContent::Image(image) => parts.push(file_data_part(&image.url, &image.media_type)),
474            UserContent::Video(video) => parts.push(file_data_part(&video.url, &video.media_type)),
475            UserContent::Audio(audio) => parts.push(file_data_part(&audio.url, &audio.media_type)),
476            UserContent::Document(doc) => parts.push(file_data_part(&doc.url, &doc.media_type)),
477            UserContent::Binary(binary) => parts.push(json!({
478                "inlineData": {
479                    "mimeType": binary.media_type,
480                    "data": base64::engine::general_purpose::STANDARD.encode(&binary.data)
481                }
482            })),
483        }
484    }
485    parts
486}
487
488#[async_trait]
489impl Model for GeminiModel {
490    fn name(&self) -> &str {
491        &self.model
492    }
493
494    async fn request(
495        &self,
496        messages: &[ModelMessage],
497        settings: Option<&ModelSettings>,
498        params: &ModelRequestParameters,
499    ) -> Result<ModelResponse, ModelError> {
500        tracing::debug!(
501            model = %self.model,
502            tool_count = params.function_tools.len(),
503            output_schema = params.output_schema.is_some(),
504            "Gemini request"
505        );
506        let (system, contents) = Self::split_system(messages);
507        let mut body = Map::new();
508        body.insert("contents".to_string(), Value::Array(contents));
509        if let Some(system) = system {
510            body.insert(
511                "systemInstruction".to_string(),
512                json!({"parts": [{"text": system}]}),
513            );
514        }
515
516        if !params.function_tools.is_empty() {
517            let tools = params
518                .function_tools
519                .iter()
520                .map(|tool| {
521                    let schema = sanitize_gemini_schema(&tool.parameters_json_schema);
522                    json!({
523                        "name": tool.name,
524                        "description": tool.description,
525                        "parameters": schema,
526                    })
527                })
528                .collect::<Vec<_>>();
529            body.insert(
530                "tools".to_string(),
531                json!([{ "functionDeclarations": tools }]),
532            );
533            body.insert(
534                "toolConfig".to_string(),
535                json!({"functionCallingConfig": {"mode": "AUTO"}}),
536            );
537        }
538
539        if params.output_mode == OutputMode::JsonSchema
540            && let Some(schema) = params.output_schema.clone()
541        {
542            let schema = sanitize_gemini_schema(&schema);
543            body.insert(
544                "generationConfig".to_string(),
545                json!({
546                    "responseMimeType": "application/json",
547                    "responseSchema": schema
548                }),
549            );
550        }
551
552        if let Some(settings) = &self.default_settings {
553            for (key, value) in settings {
554                body.insert(key.clone(), value.clone());
555            }
556        }
557
558        if let Some(settings) = settings {
559            for (key, value) in settings {
560                body.insert(key.clone(), value.clone());
561            }
562        }
563
564        let response = self
565            .client
566            .post(self.endpoint()?)
567            .json(&Value::Object(body))
568            .send()
569            .await
570            .map_err(|e| map_reqwest_error("Gemini", e))?;
571
572        let status = response.status();
573        if !status.is_success() {
574            let body = response.text().await.unwrap_or_default();
575            tracing::error!(
576                status = status.as_u16(),
577                model = %self.model,
578                body = %truncate_error_body(&body),
579                "Gemini request failed"
580            );
581            return Err(ModelError::HttpStatus {
582                status: status.as_u16(),
583            });
584        }
585
586        let body: GeminiResponse = response.json().await.map_err(|e| {
587            tracing::error!(
588                error = %e,
589                model = %self.model,
590                "Gemini response parse failed"
591            );
592            ModelError::Provider(format!("Gemini response parse failed: {e}"))
593        })?;
594
595        let candidate = body.candidates.into_iter().next().ok_or_else(|| {
596            tracing::error!(model = %self.model, "Gemini response missing candidates");
597            ModelError::Provider("Gemini response missing candidates".to_string())
598        })?;
599
600        let mut parts = Vec::new();
601        if let Some(content) = candidate.content {
602            for part in content.parts {
603                if let Some(text) = part.text {
604                    parts.push(ModelResponsePart::Text(TextPart { content: text }));
605                }
606                if let Some(call) = part.function_call {
607                    parts.push(ModelResponsePart::ToolCall(ToolCallPart {
608                        id: normalize_tool_call_id(call.id),
609                        name: call.name.unwrap_or_else(|| "tool".to_string()),
610                        arguments: call.args.unwrap_or_else(|| Value::Object(Map::new())),
611                    }));
612                }
613            }
614        }
615
616        let usage = body.usage_metadata.map(|usage| RequestUsage {
617            input_tokens: usage.prompt_token_count.unwrap_or(0),
618            output_tokens: usage.candidates_token_count.unwrap_or(0),
619            ..Default::default()
620        });
621
622        Ok(ModelResponse {
623            parts,
624            usage,
625            model_name: Some(self.model.clone()),
626            finish_reason: candidate.finish_reason,
627        })
628    }
629}
630
631#[derive(Debug, Deserialize)]
632struct GeminiResponse {
633    candidates: Vec<GeminiCandidate>,
634    #[serde(rename = "usageMetadata")]
635    usage_metadata: Option<GeminiUsage>,
636}
637
638#[derive(Debug, Deserialize)]
639struct GeminiCandidate {
640    content: Option<GeminiContent>,
641    #[serde(rename = "finishReason")]
642    finish_reason: Option<String>,
643}
644
645#[derive(Debug, Deserialize)]
646struct GeminiContent {
647    parts: Vec<GeminiPart>,
648}
649
650#[derive(Debug, Deserialize)]
651struct GeminiPart {
652    text: Option<String>,
653    #[serde(rename = "functionCall")]
654    function_call: Option<GeminiFunctionCall>,
655}
656
657#[derive(Debug, Deserialize)]
658struct GeminiFunctionCall {
659    id: Option<String>,
660    name: Option<String>,
661    args: Option<Value>,
662}
663
664#[derive(Debug, Deserialize)]
665struct GeminiUsage {
666    #[serde(rename = "promptTokenCount")]
667    prompt_token_count: Option<u64>,
668    #[serde(rename = "candidatesTokenCount")]
669    candidates_token_count: Option<u64>,
670}