1pub mod embedding;
22pub mod error;
23pub mod image;
24#[cfg(test)]
25mod tests;
26pub mod tools;
27pub mod types;
28
29use crate::core::types::{
30 Content, FileSource, GenerateOptions, GenerateResult, ImageSource, Prompt, Role, StreamPart,
31 Usage,
32};
33use crate::google::types::{
34 GoogleContent, GoogleFunctionDeclaration, GoogleGenerationConfig, GooglePart, GoogleRequest,
35 GoogleResponse, GoogleTool,
36};
37use anyhow::anyhow;
38use async_trait::async_trait;
39use eventsource_stream::Eventsource;
40use futures::stream::BoxStream;
41use futures_util::StreamExt;
42use reqwest::Client;
43
44pub struct GoogleModel {
45 pub api_key: String,
46 pub base_url: String,
47 pub client: Client,
48}
49
50impl GoogleModel {
51 #[must_use]
52 pub fn new(api_key: String) -> Self {
53 Self {
54 api_key,
55 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
56 client: Client::new(),
57 }
58 }
59}
60
61#[async_trait]
62impl crate::core::LanguageModel for GoogleModel {
63 #[tracing::instrument(skip(self, prompt), fields(model = options.model_id))]
64 async fn generate(
65 &self,
66 prompt: Prompt,
67 options: GenerateOptions,
68 ) -> crate::core::Result<GenerateResult> {
69 let request = self.prepare_request(prompt, &options)?;
70
71 let url = format!(
72 "{}/models/{}:generateContent?key={}",
73 self.base_url, options.model_id, self.api_key
74 );
75
76 let response = self.client.post(&url).json(&request).send().await?;
77
78 if !response.status().is_success() {
79 let error_text = response.text().await?;
80 return Err(anyhow!("Google API error: {error_text}").into());
81 }
82
83 let headers = response.headers().clone();
84 let google_response: GoogleResponse = response.json().await?;
85
86 let mut usage = Usage {
87 prompt_tokens: google_response.usage_metadata.prompt_token_count,
88 completion_tokens: google_response.usage_metadata.candidates_token_count,
89 };
90
91 if let Some(header_usage) = Usage::from_headers(&headers) {
93 usage = header_usage;
94 }
95
96 let candidate =
97 google_response
98 .candidates
99 .first()
100 .ok_or_else(|| -> crate::core::ProviderError {
101 crate::core::ProviderError::Other(anyhow::anyhow!(
102 "No candidates returned from Google"
103 ))
104 })?;
105
106 let mut text_parts = Vec::new();
107 let mut tool_calls = Vec::new();
108
109 for part in &candidate.content.parts {
110 match part {
111 GooglePart::Text { text } => {
112 text_parts.push(text.clone());
113 }
114 GooglePart::FunctionCall { name, args } => {
115 tool_calls.push(crate::core::types::ToolCallResult {
116 name: name.clone(),
117 arguments: args.clone(),
118 });
119 }
120 _ => {}
121 }
122 }
123
124 let text = text_parts.join("");
125
126 Ok(GenerateResult {
127 text,
128 usage,
129 finish_reason: candidate
130 .finish_reason
131 .clone()
132 .unwrap_or_else(|| "stop".to_string()),
133 tool_calls,
134 })
135 }
136
137 async fn generate_stream(
138 &self,
139 prompt: Prompt,
140 options: GenerateOptions,
141 ) -> crate::core::Result<BoxStream<'static, StreamPart>> {
142 let request = self.prepare_request(prompt, &options)?;
143 let url = format!(
144 "{}/models/{}:streamGenerateContent?alt=sse&key={}",
145 self.base_url, options.model_id, self.api_key
146 );
147
148 let response = self.client.post(&url).json(&request).send().await?;
149
150 if !response.status().is_success() {
151 let error_text = response.text().await?;
152 return Err(anyhow!("Google API error: {error_text}").into());
153 }
154
155 let mut event_stream = response.bytes_stream().eventsource();
156
157 let stream = async_stream::stream! {
158 while let Some(event) = event_stream.next().await {
159 match event {
160 Ok(event) => {
161 let parsed: Result<GoogleResponse, _> = serde_json::from_str(&event.data);
162 match parsed {
163 Ok(google_response) => {
164 yield StreamPart::Usage {
166 usage: Usage {
167 prompt_tokens: google_response.usage_metadata.prompt_token_count,
168 completion_tokens: google_response.usage_metadata.candidates_token_count
169 }
170 };
171
172 if let Some(candidate) = google_response.candidates.first() {
173 for part in &candidate.content.parts {
174 match part {
175 GooglePart::Text { text } => {
176 yield StreamPart::TextDelta { delta: text.clone() };
177 }
178 GooglePart::FunctionCall { name, args } => {
179 yield StreamPart::ToolCallDelta {
180 index: 0,
181 id: None,
182 name: Some(name.clone()),
183 arguments_delta: Some(args.to_string()),
184 };
185 }
186 _ => {}
187 }
188 }
189
190 if let Some(reason) = &candidate.finish_reason {
191 yield StreamPart::Finish { finish_reason: reason.clone() };
192 }
193 }
194 }
195 Err(e) => {
196 yield StreamPart::Error { message: e.to_string() };
197 }
198 }
199 }
200 Err(e) => {
201 yield StreamPart::Error { message: e.to_string() };
202 }
203 }
204 }
205 };
206
207 Ok(Box::pin(stream))
208 }
209}
210
211impl GoogleModel {
212 fn prepare_request(
213 &self,
214 prompt: Prompt,
215 options: &GenerateOptions,
216 ) -> crate::core::Result<GoogleRequest> {
217 let mut contents = Vec::new();
218 let mut system_instruction = None;
219
220 for msg in prompt.messages {
221 let role = match msg.role {
222 Role::System => {
223 let mut parts = Vec::new();
224 for content in msg.content {
225 if let Content::Text { text } = content {
226 parts.push(GooglePart::Text { text });
227 }
228 }
229 system_instruction = Some(GoogleContent {
230 role: "system".to_string(),
231 parts,
232 });
233 continue;
234 }
235 Role::User => "user",
236 Role::Assistant => "model",
237 Role::Tool => "user",
238 };
239
240 let mut parts = Vec::new();
241 for content in msg.content {
242 match content {
243 Content::Text { text } => {
244 parts.push(GooglePart::Text { text });
245 }
246 Content::Image { source } => {
247 let (mime_type, data) = match source {
248 ImageSource::Base64 { media_type, data } => (media_type, data),
249 _ => return Err(anyhow!("Unsupported image source for Google").into()),
250 };
251 parts.push(GooglePart::InlineData { mime_type, data });
252 }
253 Content::File { source } => {
254 let FileSource::Base64 { media_type, data } = source;
255 parts.push(GooglePart::InlineData {
256 mime_type: media_type,
257 data,
258 });
259 }
260 Content::ToolCall {
261 name, arguments, ..
262 } => {
263 parts.push(GooglePart::FunctionCall {
264 name,
265 args: arguments,
266 });
267 }
268 Content::ToolResult { id, result } => {
269 parts.push(GooglePart::FunctionResponse {
270 name: id,
271 response: result,
272 });
273 }
274 }
275 }
276
277 contents.push(GoogleContent {
278 role: role.to_string(),
279 parts,
280 });
281 }
282
283 let google_tools = if options.tools.as_ref().is_some_and(|t| !t.is_empty()) {
284 Some(vec![GoogleTool {
285 function_declarations: options
286 .tools
287 .as_ref()
288 .unwrap()
289 .iter()
290 .map(|t| GoogleFunctionDeclaration {
291 name: t.name.clone(),
292 description: t.description.clone(),
293 parameters: t.parameters.clone(),
294 })
295 .collect(),
296 }])
297 } else {
298 None
299 };
300
301 let mut response_mime_type = None;
302 let mut response_schema = None;
303 if let Some(format) = &options.response_format {
304 if format.get("type").and_then(|t| t.as_str()) == Some("json_schema") {
305 response_mime_type = Some("application/json".to_string());
306 if let Some(schema) = format.get("json_schema").and_then(|s| s.get("schema")) {
307 response_schema = Some(schema.clone());
308 }
309 } else if format.get("type").and_then(|t| t.as_str()) == Some("json_object") {
310 response_mime_type = Some("application/json".to_string());
311 }
312 }
313
314 Ok(GoogleRequest {
315 contents,
316 system_instruction,
317 generation_config: Some(GoogleGenerationConfig {
318 max_output_tokens: options.max_tokens,
319 temperature: options.temperature,
320 top_p: options.top_p,
321 top_k: None,
322 stop_sequences: options.stop_sequences.clone(),
323 response_mime_type,
324 response_schema,
325 }),
326 tools: google_tools,
327 })
328 }
329}
330
331use crate::core::types::ProviderSettings;
334
335pub struct GoogleProvider {
337 settings: ProviderSettings,
338}
339
340impl GoogleProvider {
341 #[must_use]
343 pub fn chat(&self, _model_id: &str) -> GoogleModel {
344 let api_key = self
345 .settings
346 .api_key
347 .clone()
348 .or_else(|| std::env::var("GOOGLE_GENERATIVE_AI_API_KEY").ok())
349 .unwrap_or_default();
350 let mut model = GoogleModel::new(api_key);
351 if let Some(ref base_url) = self.settings.base_url {
352 model.base_url = base_url.clone();
353 }
354 model
355 }
356
357 #[must_use]
359 pub fn language_model(&self, model_id: &str) -> GoogleModel {
360 self.chat(model_id)
361 }
362
363 #[must_use]
365 pub fn embedding(&self, _model_id: &str) -> embedding::GoogleEmbeddingModel {
366 let api_key = self
367 .settings
368 .api_key
369 .clone()
370 .or_else(|| std::env::var("GOOGLE_GENERATIVE_AI_API_KEY").ok())
371 .unwrap_or_default();
372 let mut model = embedding::GoogleEmbeddingModel::new(api_key);
373 if let Some(ref base_url) = self.settings.base_url {
374 model.base_url = base_url.clone();
375 }
376 model
377 }
378
379 #[must_use]
381 pub fn image(&self, _model_id: &str) -> image::GoogleImageModel {
382 let api_key = self
383 .settings
384 .api_key
385 .clone()
386 .or_else(|| std::env::var("GOOGLE_GENERATIVE_AI_API_KEY").ok())
387 .unwrap_or_default();
388 let mut model = image::GoogleImageModel::new(api_key);
389 if let Some(ref base_url) = self.settings.base_url {
390 model.base_url = base_url.clone();
391 }
392 model
393 }
394}
395
396#[must_use]
398pub fn create_google(settings: ProviderSettings) -> GoogleProvider {
399 GoogleProvider { settings }
400}
401
402impl crate::core::registry::Provider for GoogleProvider {
403 fn language_model(&self, model_id: &str) -> Option<Box<dyn crate::core::LanguageModel>> {
404 Some(Box::new(self.chat(model_id)))
405 }
406
407 fn embedding_model(&self, model_id: &str) -> Option<Box<dyn crate::core::EmbeddingModel>> {
408 Some(Box::new(self.embedding(model_id)))
409 }
410
411 fn image_model(&self, model_id: &str) -> Option<Box<dyn crate::core::ImageModel>> {
412 Some(Box::new(self.image(model_id)))
413 }
414}