Skip to main content

st/proxy/
google.rs

1//! 🤖 Google Gemini Provider Implementation
2//!
3//! "Expanding our horizons with Google's Gemini!" - The Cheet 😺
4
5use crate::proxy::{LlmMessage, LlmProvider, LlmRequest, LlmResponse, LlmRole, LlmUsage};
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use reqwest::Client;
9use serde::{Deserialize, Serialize};
10
11pub struct GoogleProvider {
12    client: Client,
13    api_key: String,
14    base_url: String,
15}
16
17impl GoogleProvider {
18    pub fn new(api_key: String) -> Self {
19        Self {
20            client: Client::new(),
21            api_key,
22            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
23        }
24    }
25}
26
27impl Default for GoogleProvider {
28    fn default() -> Self {
29        let api_key = std::env::var("GOOGLE_API_KEY").unwrap_or_default();
30        Self::new(api_key)
31    }
32}
33
34#[async_trait]
35impl LlmProvider for GoogleProvider {
36    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
37        let url = format!(
38            "{}/models/{}:generateContent?key={}",
39            self.base_url, request.model, self.api_key
40        );
41
42        let google_request = GoogleChatRequest {
43            contents: request.messages.into_iter().map(Into::into).collect(),
44            generation_config: Some(GoogleGenerationConfig {
45                temperature: request.temperature,
46                max_output_tokens: request.max_tokens,
47            }),
48        };
49
50        let response = self
51            .client
52            .post(&url)
53            .json(&google_request)
54            .send()
55            .await
56            .context("Failed to send request to Google Gemini")?;
57
58        if !response.status().is_success() {
59            let error_text = response.text().await?;
60            return Err(anyhow::anyhow!("Google Gemini API error: {}", error_text));
61        }
62
63        let google_response: GoogleChatResponse = response.json().await?;
64
65        let content = google_response
66            .candidates
67            .first()
68            .and_then(|c| c.content.parts.first())
69            .map(|p| p.text.clone())
70            .unwrap_or_default();
71
72        Ok(LlmResponse {
73            content,
74            model: request.model,
75            usage: google_response.usage_metadata.map(Into::into),
76        })
77    }
78
79    fn name(&self) -> &'static str {
80        "Google"
81    }
82}
83
84#[derive(Debug, Serialize)]
85struct GoogleChatRequest {
86    contents: Vec<GoogleContent>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    generation_config: Option<GoogleGenerationConfig>,
89}
90
91#[derive(Debug, Serialize, Deserialize)]
92struct GoogleContent {
93    role: String,
94    parts: Vec<GooglePart>,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98struct GooglePart {
99    text: String,
100}
101
102impl From<LlmMessage> for GoogleContent {
103    fn from(msg: LlmMessage) -> Self {
104        Self {
105            role: match msg.role {
106                LlmRole::System => "user".to_string(), // Gemini uses systemInstruction separately or just user
107                LlmRole::User => "user".to_string(),
108                LlmRole::Assistant => "model".to_string(),
109            },
110            parts: vec![GooglePart { text: msg.content }],
111        }
112    }
113}
114
115#[derive(Debug, Serialize)]
116struct GoogleGenerationConfig {
117    #[serde(skip_serializing_if = "Option::is_none")]
118    temperature: Option<f32>,
119    #[serde(skip_serializing_if = "Option::is_none")]
120    max_output_tokens: Option<usize>,
121}
122
123#[derive(Debug, Deserialize)]
124struct GoogleChatResponse {
125    candidates: Vec<GoogleCandidate>,
126    #[serde(rename = "usageMetadata")]
127    usage_metadata: Option<GoogleUsageMetadata>,
128}
129
130#[derive(Debug, Deserialize)]
131struct GoogleCandidate {
132    content: GoogleContent,
133}
134
135#[derive(Debug, Deserialize)]
136struct GoogleUsageMetadata {
137    #[serde(rename = "promptTokenCount")]
138    prompt_token_count: usize,
139    #[serde(rename = "candidatesTokenCount")]
140    candidates_token_count: usize,
141    #[serde(rename = "totalTokenCount")]
142    total_token_count: usize,
143}
144
145impl From<GoogleUsageMetadata> for LlmUsage {
146    fn from(usage: GoogleUsageMetadata) -> Self {
147        Self {
148            prompt_tokens: usage.prompt_token_count,
149            completion_tokens: usage.candidates_token_count,
150            total_tokens: usage.total_token_count,
151        }
152    }
153}