Skip to main content

aster/providers/
google.rs

1use super::api_client::{ApiClient, AuthMethod};
2use super::base::MessageStream;
3use super::errors::ProviderError;
4use super::retry::ProviderRetry;
5use super::utils::{
6    handle_response_google_compat, handle_status_openai_compat, unescape_json_values, RequestLog,
7};
8use crate::conversation::message::Message;
9
10use crate::model::ModelConfig;
11use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
12use crate::providers::formats::google::{
13    create_request, get_usage, response_to_message, response_to_streaming_message,
14};
15use anyhow::Result;
16use async_stream::try_stream;
17use async_trait::async_trait;
18use futures::TryStreamExt;
19use rmcp::model::Tool;
20use serde_json::Value;
21use std::io;
22use tokio::pin;
23use tokio_stream::StreamExt;
24use tokio_util::codec::{FramedRead, LinesCodec};
25use tokio_util::io::StreamReader;
26
27pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com";
28pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-2.5-pro";
29pub const GOOGLE_DEFAULT_FAST_MODEL: &str = "gemini-2.5-flash";
30pub const GOOGLE_KNOWN_MODELS: &[&str] = &[
31    // Gemini 3 models
32    "gemini-3-pro-preview",
33    "gemini-3-pro-image-preview",
34    // Gemini 2.5 Pro models
35    "gemini-2.5-pro",
36    "gemini-2.5-pro-preview-tts",
37    // Gemini 2.5 Flash models
38    "gemini-2.5-flash",
39    "gemini-2.5-flash-preview-09-2025",
40    "gemini-2.5-flash-image",
41    "gemini-2.5-flash-image-preview",
42    "gemini-2.5-flash-native-audio-preview-09-2025",
43    "gemini-2.5-flash-preview-tts",
44    // Gemini 2.5 Flash-Lite models
45    "gemini-2.5-flash-lite",
46    "gemini-2.5-flash-lite-preview-09-2025",
47    // Gemini 2.0 Flash models
48    "gemini-2.0-flash",
49    "gemini-2.0-flash-001",
50    "gemini-2.0-flash-exp",
51    "gemini-2.0-flash-preview-image-generation",
52    "gemini-2.0-flash-live-001",
53    // Gemini 2.0 Flash-Lite models
54    "gemini-2.0-flash-lite",
55    "gemini-2.0-flash-lite-001",
56];
57
58pub const GOOGLE_DOC_URL: &str = "https://ai.google.dev/gemini-api/docs/models";
59
60#[derive(Debug, serde::Serialize)]
61pub struct GoogleProvider {
62    #[serde(skip)]
63    api_client: ApiClient,
64    model: ModelConfig,
65    #[serde(skip)]
66    name: String,
67}
68
69impl GoogleProvider {
70    pub async fn from_env(model: ModelConfig) -> Result<Self> {
71        let model = model.with_fast(GOOGLE_DEFAULT_FAST_MODEL.to_string());
72
73        let config = crate::config::Config::global();
74        let api_key: String = config.get_secret("GOOGLE_API_KEY")?;
75        let host: String = config
76            .get_param("GOOGLE_HOST")
77            .unwrap_or_else(|_| GOOGLE_API_HOST.to_string());
78
79        let auth = AuthMethod::ApiKey {
80            header_name: "x-goog-api-key".to_string(),
81            key: api_key,
82        };
83
84        let api_client =
85            ApiClient::new(host, auth)?.with_header("Content-Type", "application/json")?;
86
87        Ok(Self {
88            api_client,
89            model,
90            name: Self::metadata().name,
91        })
92    }
93
94    async fn post(&self, model_name: &str, payload: &Value) -> Result<Value, ProviderError> {
95        let path = format!("v1beta/models/{}:generateContent", model_name);
96        let response = self.api_client.response_post(&path, payload).await?;
97        handle_response_google_compat(response).await
98    }
99
100    async fn post_stream(
101        &self,
102        model_name: &str,
103        payload: &Value,
104    ) -> Result<reqwest::Response, ProviderError> {
105        let path = format!("v1beta/models/{}:streamGenerateContent?alt=sse", model_name);
106        let response = self.api_client.response_post(&path, payload).await?;
107        handle_status_openai_compat(response).await
108    }
109}
110
111#[async_trait]
112impl Provider for GoogleProvider {
113    fn metadata() -> ProviderMetadata {
114        ProviderMetadata::new(
115            "google",
116            "Google Gemini",
117            "Gemini models from Google AI",
118            GOOGLE_DEFAULT_MODEL,
119            GOOGLE_KNOWN_MODELS.to_vec(),
120            GOOGLE_DOC_URL,
121            vec![
122                ConfigKey::new("GOOGLE_API_KEY", true, true, None),
123                ConfigKey::new("GOOGLE_HOST", false, false, Some(GOOGLE_API_HOST)),
124            ],
125        )
126    }
127
128    fn get_name(&self) -> &str {
129        &self.name
130    }
131
132    fn get_model_config(&self) -> ModelConfig {
133        self.model.clone()
134    }
135
136    #[tracing::instrument(
137        skip(self, model_config, system, messages, tools),
138        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
139    )]
140    async fn complete_with_model(
141        &self,
142        model_config: &ModelConfig,
143        system: &str,
144        messages: &[Message],
145        tools: &[Tool],
146    ) -> Result<(Message, ProviderUsage), ProviderError> {
147        let payload = create_request(model_config, system, messages, tools)?;
148        let mut log = RequestLog::start(model_config, &payload)?;
149
150        let response = self
151            .with_retry(|| async { self.post(&model_config.model_name, &payload).await })
152            .await?;
153
154        let message = response_to_message(unescape_json_values(&response))?;
155        let usage = get_usage(&response)?;
156        let response_model = match response.get("modelVersion") {
157            Some(model_version) => model_version.as_str().unwrap_or_default().to_string(),
158            None => model_config.model_name.clone(),
159        };
160        log.write(&response, Some(&usage))?;
161        let provider_usage = ProviderUsage::new(response_model, usage);
162        Ok((message, provider_usage))
163    }
164
165    async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
166        let response = self.api_client.response_get("v1beta/models").await?;
167        let json: serde_json::Value = response.json().await?;
168        let arr = match json.get("models").and_then(|v| v.as_array()) {
169            Some(arr) => arr,
170            None => return Ok(None),
171        };
172        let mut models: Vec<String> = arr
173            .iter()
174            .filter_map(|m| m.get("name").and_then(|v| v.as_str()))
175            .map(|name| name.split('/').next_back().unwrap_or(name).to_string())
176            .collect();
177        models.sort();
178        Ok(Some(models))
179    }
180
181    fn supports_streaming(&self) -> bool {
182        true
183    }
184
185    async fn stream(
186        &self,
187        system: &str,
188        messages: &[Message],
189        tools: &[Tool],
190    ) -> Result<MessageStream, ProviderError> {
191        let payload = create_request(&self.model, system, messages, tools)?;
192        let mut log = RequestLog::start(&self.model, &payload)?;
193
194        let response = self
195            .with_retry(|| async { self.post_stream(&self.model.model_name, &payload).await })
196            .await
197            .inspect_err(|e| {
198                let _ = log.error(e);
199            })?;
200
201        let stream = response.bytes_stream().map_err(io::Error::other);
202
203        Ok(Box::pin(try_stream! {
204            let stream_reader = StreamReader::new(stream);
205            let framed = FramedRead::new(stream_reader, LinesCodec::new())
206                .map_err(anyhow::Error::from);
207
208            let message_stream = response_to_streaming_message(framed);
209            pin!(message_stream);
210            while let Some(message) = message_stream.next().await {
211                let (message, usage) = message.map_err(|e|
212                    ProviderError::RequestFailed(format!("Stream decode error: {}", e))
213                )?;
214                if message.is_some() || usage.is_some() {
215                    log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?;
216                }
217                yield (message, usage);
218            }
219        }))
220    }
221}