1use super::api_client::{ApiClient, AuthMethod};
2use super::base::MessageStream;
3use super::errors::ProviderError;
4use super::retry::ProviderRetry;
5use super::utils::{
6 handle_response_google_compat, handle_status_openai_compat, unescape_json_values, RequestLog,
7};
8use crate::conversation::message::Message;
9
10use crate::model::ModelConfig;
11use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
12use crate::providers::formats::google::{
13 create_request, get_usage, response_to_message, response_to_streaming_message,
14};
15use anyhow::Result;
16use async_stream::try_stream;
17use async_trait::async_trait;
18use futures::TryStreamExt;
19use rmcp::model::Tool;
20use serde_json::Value;
21use std::io;
22use tokio::pin;
23use tokio_stream::StreamExt;
24use tokio_util::codec::{FramedRead, LinesCodec};
25use tokio_util::io::StreamReader;
26
27pub const GOOGLE_API_HOST: &str = "https://generativelanguage.googleapis.com";
28pub const GOOGLE_DEFAULT_MODEL: &str = "gemini-2.5-pro";
29pub const GOOGLE_DEFAULT_FAST_MODEL: &str = "gemini-2.5-flash";
30pub const GOOGLE_KNOWN_MODELS: &[&str] = &[
31 "gemini-3-pro-preview",
33 "gemini-3-pro-image-preview",
34 "gemini-2.5-pro",
36 "gemini-2.5-pro-preview-tts",
37 "gemini-2.5-flash",
39 "gemini-2.5-flash-preview-09-2025",
40 "gemini-2.5-flash-image",
41 "gemini-2.5-flash-image-preview",
42 "gemini-2.5-flash-native-audio-preview-09-2025",
43 "gemini-2.5-flash-preview-tts",
44 "gemini-2.5-flash-lite",
46 "gemini-2.5-flash-lite-preview-09-2025",
47 "gemini-2.0-flash",
49 "gemini-2.0-flash-001",
50 "gemini-2.0-flash-exp",
51 "gemini-2.0-flash-preview-image-generation",
52 "gemini-2.0-flash-live-001",
53 "gemini-2.0-flash-lite",
55 "gemini-2.0-flash-lite-001",
56];
57
58pub const GOOGLE_DOC_URL: &str = "https://ai.google.dev/gemini-api/docs/models";
59
60#[derive(Debug, serde::Serialize)]
61pub struct GoogleProvider {
62 #[serde(skip)]
63 api_client: ApiClient,
64 model: ModelConfig,
65 #[serde(skip)]
66 name: String,
67}
68
69impl GoogleProvider {
70 pub async fn from_env(model: ModelConfig) -> Result<Self> {
71 let model = model.with_fast(GOOGLE_DEFAULT_FAST_MODEL.to_string());
72
73 let config = crate::config::Config::global();
74 let api_key: String = config.get_secret("GOOGLE_API_KEY")?;
75 let host: String = config
76 .get_param("GOOGLE_HOST")
77 .unwrap_or_else(|_| GOOGLE_API_HOST.to_string());
78
79 let auth = AuthMethod::ApiKey {
80 header_name: "x-goog-api-key".to_string(),
81 key: api_key,
82 };
83
84 let api_client =
85 ApiClient::new(host, auth)?.with_header("Content-Type", "application/json")?;
86
87 Ok(Self {
88 api_client,
89 model,
90 name: Self::metadata().name,
91 })
92 }
93
94 async fn post(&self, model_name: &str, payload: &Value) -> Result<Value, ProviderError> {
95 let path = format!("v1beta/models/{}:generateContent", model_name);
96 let response = self.api_client.response_post(&path, payload).await?;
97 handle_response_google_compat(response).await
98 }
99
100 async fn post_stream(
101 &self,
102 model_name: &str,
103 payload: &Value,
104 ) -> Result<reqwest::Response, ProviderError> {
105 let path = format!("v1beta/models/{}:streamGenerateContent?alt=sse", model_name);
106 let response = self.api_client.response_post(&path, payload).await?;
107 handle_status_openai_compat(response).await
108 }
109}
110
111#[async_trait]
112impl Provider for GoogleProvider {
113 fn metadata() -> ProviderMetadata {
114 ProviderMetadata::new(
115 "google",
116 "Google Gemini",
117 "Gemini models from Google AI",
118 GOOGLE_DEFAULT_MODEL,
119 GOOGLE_KNOWN_MODELS.to_vec(),
120 GOOGLE_DOC_URL,
121 vec![
122 ConfigKey::new("GOOGLE_API_KEY", true, true, None),
123 ConfigKey::new("GOOGLE_HOST", false, false, Some(GOOGLE_API_HOST)),
124 ],
125 )
126 }
127
128 fn get_name(&self) -> &str {
129 &self.name
130 }
131
132 fn get_model_config(&self) -> ModelConfig {
133 self.model.clone()
134 }
135
136 #[tracing::instrument(
137 skip(self, model_config, system, messages, tools),
138 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
139 )]
140 async fn complete_with_model(
141 &self,
142 model_config: &ModelConfig,
143 system: &str,
144 messages: &[Message],
145 tools: &[Tool],
146 ) -> Result<(Message, ProviderUsage), ProviderError> {
147 let payload = create_request(model_config, system, messages, tools)?;
148 let mut log = RequestLog::start(model_config, &payload)?;
149
150 let response = self
151 .with_retry(|| async { self.post(&model_config.model_name, &payload).await })
152 .await?;
153
154 let message = response_to_message(unescape_json_values(&response))?;
155 let usage = get_usage(&response)?;
156 let response_model = match response.get("modelVersion") {
157 Some(model_version) => model_version.as_str().unwrap_or_default().to_string(),
158 None => model_config.model_name.clone(),
159 };
160 log.write(&response, Some(&usage))?;
161 let provider_usage = ProviderUsage::new(response_model, usage);
162 Ok((message, provider_usage))
163 }
164
165 async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
166 let response = self.api_client.response_get("v1beta/models").await?;
167 let json: serde_json::Value = response.json().await?;
168 let arr = match json.get("models").and_then(|v| v.as_array()) {
169 Some(arr) => arr,
170 None => return Ok(None),
171 };
172 let mut models: Vec<String> = arr
173 .iter()
174 .filter_map(|m| m.get("name").and_then(|v| v.as_str()))
175 .map(|name| name.split('/').next_back().unwrap_or(name).to_string())
176 .collect();
177 models.sort();
178 Ok(Some(models))
179 }
180
181 fn supports_streaming(&self) -> bool {
182 true
183 }
184
185 async fn stream(
186 &self,
187 system: &str,
188 messages: &[Message],
189 tools: &[Tool],
190 ) -> Result<MessageStream, ProviderError> {
191 let payload = create_request(&self.model, system, messages, tools)?;
192 let mut log = RequestLog::start(&self.model, &payload)?;
193
194 let response = self
195 .with_retry(|| async { self.post_stream(&self.model.model_name, &payload).await })
196 .await
197 .inspect_err(|e| {
198 let _ = log.error(e);
199 })?;
200
201 let stream = response.bytes_stream().map_err(io::Error::other);
202
203 Ok(Box::pin(try_stream! {
204 let stream_reader = StreamReader::new(stream);
205 let framed = FramedRead::new(stream_reader, LinesCodec::new())
206 .map_err(anyhow::Error::from);
207
208 let message_stream = response_to_streaming_message(framed);
209 pin!(message_stream);
210 while let Some(message) = message_stream.next().await {
211 let (message, usage) = message.map_err(|e|
212 ProviderError::RequestFailed(format!("Stream decode error: {}", e))
213 )?;
214 if message.is_some() || usage.is_some() {
215 log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?;
216 }
217 yield (message, usage);
218 }
219 }))
220 }
221}