Skip to main content

qai_sdk/google/
mod.rs

1//! # QAI Google
2//!
3//! Google Gemini provider for the QAI SDK. Supports chat, streaming,
4//! tool calling, vision, embeddings, and image generation via the
5//! Generative Language API.
6//!
7//! ## Usage
8//!
9//! ```rust,no_run
10//! use qai_sdk::google::create_google;
11//! use qai_sdk::core::types::ProviderSettings;
12//!
13//! let provider = create_google(ProviderSettings {
14//!     api_key: Some("AIza...".to_string()),
15//!     ..Default::default()
16//! });
17//!
18//! let model = provider.chat("gemini-2.0-flash");
19//! ```
20
21pub mod embedding;
22pub mod error;
23pub mod image;
24#[cfg(test)]
25mod tests;
26pub mod tools;
27pub mod types;
28
29use crate::core::types::{
30    Content, FileSource, GenerateOptions, GenerateResult, ImageSource, Prompt, Role, StreamPart,
31    Usage,
32};
33use crate::google::types::{
34    GoogleContent, GoogleFunctionDeclaration, GoogleGenerationConfig, GooglePart, GoogleRequest,
35    GoogleResponse, GoogleTool,
36};
37use anyhow::anyhow;
38use async_trait::async_trait;
39use eventsource_stream::Eventsource;
40use futures::stream::BoxStream;
41use futures_util::StreamExt;
42use reqwest::Client;
43
44pub struct GoogleModel {
45    pub api_key: String,
46    pub base_url: String,
47    pub client: Client,
48}
49
50impl GoogleModel {
51    #[must_use]
52    pub fn new(api_key: String) -> Self {
53        Self {
54            api_key,
55            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
56            client: Client::new(),
57        }
58    }
59}
60
61#[async_trait]
62impl crate::core::LanguageModel for GoogleModel {
63    #[tracing::instrument(skip(self, prompt), fields(model = options.model_id))]
64    async fn generate(
65        &self,
66        prompt: Prompt,
67        options: GenerateOptions,
68    ) -> crate::core::Result<GenerateResult> {
69        let request = self.prepare_request(prompt, &options)?;
70
71        let url = format!(
72            "{}/models/{}:generateContent?key={}",
73            self.base_url, options.model_id, self.api_key
74        );
75
76        let response = self.client.post(&url).json(&request).send().await?;
77
78        if !response.status().is_success() {
79            let error_text = response.text().await?;
80            return Err(anyhow!("Google API error: {error_text}").into());
81        }
82
83        let headers = response.headers().clone();
84        let google_response: GoogleResponse = response.json().await?;
85
86        let mut usage = Usage {
87            prompt_tokens: google_response.usage_metadata.prompt_token_count,
88            completion_tokens: google_response.usage_metadata.candidates_token_count,
89        };
90
91        // Header extraction as fallback/supplement
92        if let Some(header_usage) = Usage::from_headers(&headers) {
93            usage = header_usage;
94        }
95
96        let candidate =
97            google_response
98                .candidates
99                .first()
100                .ok_or_else(|| -> crate::core::ProviderError {
101                    crate::core::ProviderError::Other(anyhow::anyhow!(
102                        "No candidates returned from Google"
103                    ))
104                })?;
105
106        let mut text_parts = Vec::new();
107        let mut tool_calls = Vec::new();
108
109        for part in &candidate.content.parts {
110            match part {
111                GooglePart::Text { text } => {
112                    text_parts.push(text.clone());
113                }
114                GooglePart::FunctionCall { name, args } => {
115                    tool_calls.push(crate::core::types::ToolCallResult {
116                        name: name.clone(),
117                        arguments: args.clone(),
118                    });
119                }
120                _ => {}
121            }
122        }
123
124        let text = text_parts.join("");
125
126        Ok(GenerateResult {
127            text,
128            usage,
129            finish_reason: candidate
130                .finish_reason
131                .clone()
132                .unwrap_or_else(|| "stop".to_string()),
133            tool_calls,
134        })
135    }
136
137    async fn generate_stream(
138        &self,
139        prompt: Prompt,
140        options: GenerateOptions,
141    ) -> crate::core::Result<BoxStream<'static, StreamPart>> {
142        let request = self.prepare_request(prompt, &options)?;
143        let url = format!(
144            "{}/models/{}:streamGenerateContent?alt=sse&key={}",
145            self.base_url, options.model_id, self.api_key
146        );
147
148        let response = self.client.post(&url).json(&request).send().await?;
149
150        if !response.status().is_success() {
151            let error_text = response.text().await?;
152            return Err(anyhow!("Google API error: {error_text}").into());
153        }
154
155        let mut event_stream = response.bytes_stream().eventsource();
156
157        let stream = async_stream::stream! {
158            while let Some(event) = event_stream.next().await {
159                match event {
160                    Ok(event) => {
161                        let parsed: Result<GoogleResponse, _> = serde_json::from_str(&event.data);
162                        match parsed {
163                            Ok(google_response) => {
164                                // Gemini sends usage in the last chunk or sometimes in every chunk
165                                yield StreamPart::Usage {
166                                    usage: Usage {
167                                        prompt_tokens: google_response.usage_metadata.prompt_token_count,
168                                        completion_tokens: google_response.usage_metadata.candidates_token_count
169                                    }
170                                };
171
172                                if let Some(candidate) = google_response.candidates.first() {
173                                    for part in &candidate.content.parts {
174                                        match part {
175                                            GooglePart::Text { text } => {
176                                                yield StreamPart::TextDelta { delta: text.clone() };
177                                            }
178                                            GooglePart::FunctionCall { name, args } => {
179                                                yield StreamPart::ToolCallDelta {
180                                                    index: 0,
181                                                    id: None,
182                                                    name: Some(name.clone()),
183                                                    arguments_delta: Some(args.to_string()),
184                                                };
185                                            }
186                                            _ => {}
187                                        }
188                                    }
189
190                                    if let Some(reason) = &candidate.finish_reason {
191                                        yield StreamPart::Finish { finish_reason: reason.clone() };
192                                    }
193                                }
194                            }
195                            Err(e) => {
196                                yield StreamPart::Error { message: e.to_string() };
197                            }
198                        }
199                    }
200                    Err(e) => {
201                        yield StreamPart::Error { message: e.to_string() };
202                    }
203                }
204            }
205        };
206
207        Ok(Box::pin(stream))
208    }
209}
210
211impl GoogleModel {
212    fn prepare_request(
213        &self,
214        prompt: Prompt,
215        options: &GenerateOptions,
216    ) -> crate::core::Result<GoogleRequest> {
217        let mut contents = Vec::new();
218        let mut system_instruction = None;
219
220        for msg in prompt.messages {
221            let role = match msg.role {
222                Role::System => {
223                    let mut parts = Vec::new();
224                    for content in msg.content {
225                        if let Content::Text { text } = content {
226                            parts.push(GooglePart::Text { text });
227                        }
228                    }
229                    system_instruction = Some(GoogleContent {
230                        role: "system".to_string(),
231                        parts,
232                    });
233                    continue;
234                }
235                Role::User => "user",
236                Role::Assistant => "model",
237                Role::Tool => "user",
238            };
239
240            let mut parts = Vec::new();
241            for content in msg.content {
242                match content {
243                    Content::Text { text } => {
244                        parts.push(GooglePart::Text { text });
245                    }
246                    Content::Image { source } => {
247                        let (mime_type, data) = match source {
248                            ImageSource::Base64 { media_type, data } => (media_type, data),
249                            _ => return Err(anyhow!("Unsupported image source for Google").into()),
250                        };
251                        parts.push(GooglePart::InlineData { mime_type, data });
252                    }
253                    Content::File { source } => {
254                        let FileSource::Base64 { media_type, data } = source;
255                        parts.push(GooglePart::InlineData {
256                            mime_type: media_type,
257                            data,
258                        });
259                    }
260                    Content::ToolCall {
261                        name, arguments, ..
262                    } => {
263                        parts.push(GooglePart::FunctionCall {
264                            name,
265                            args: arguments,
266                        });
267                    }
268                    Content::ToolResult { id, result } => {
269                        parts.push(GooglePart::FunctionResponse {
270                            name: id,
271                            response: result,
272                        });
273                    }
274                }
275            }
276
277            contents.push(GoogleContent {
278                role: role.to_string(),
279                parts,
280            });
281        }
282
283        let google_tools = if options.tools.as_ref().is_some_and(|t| !t.is_empty()) {
284            Some(vec![GoogleTool {
285                function_declarations: options
286                    .tools
287                    .as_ref()
288                    .unwrap()
289                    .iter()
290                    .map(|t| GoogleFunctionDeclaration {
291                        name: t.name.clone(),
292                        description: t.description.clone(),
293                        parameters: t.parameters.clone(),
294                    })
295                    .collect(),
296            }])
297        } else {
298            None
299        };
300
301        let mut response_mime_type = None;
302        let mut response_schema = None;
303        if let Some(format) = &options.response_format {
304            if format.get("type").and_then(|t| t.as_str()) == Some("json_schema") {
305                response_mime_type = Some("application/json".to_string());
306                if let Some(schema) = format.get("json_schema").and_then(|s| s.get("schema")) {
307                    response_schema = Some(schema.clone());
308                }
309            } else if format.get("type").and_then(|t| t.as_str()) == Some("json_object") {
310                response_mime_type = Some("application/json".to_string());
311            }
312        }
313
314        Ok(GoogleRequest {
315            contents,
316            system_instruction,
317            generation_config: Some(GoogleGenerationConfig {
318                max_output_tokens: options.max_tokens,
319                temperature: options.temperature,
320                top_p: options.top_p,
321                top_k: None,
322                stop_sequences: options.stop_sequences.clone(),
323                response_mime_type,
324                response_schema,
325            }),
326            tools: google_tools,
327        })
328    }
329}
330
331// --- Provider Factory ---
332
333use crate::core::types::ProviderSettings;
334
335/// Google provider with configurable settings.
336pub struct GoogleProvider {
337    settings: ProviderSettings,
338}
339
340impl GoogleProvider {
341    /// Creates a chat language model.
342    #[must_use]
343    pub fn chat(&self, _model_id: &str) -> GoogleModel {
344        let api_key = self
345            .settings
346            .api_key
347            .clone()
348            .or_else(|| std::env::var("GOOGLE_GENERATIVE_AI_API_KEY").ok())
349            .unwrap_or_default();
350        let mut model = GoogleModel::new(api_key);
351        if let Some(ref base_url) = self.settings.base_url {
352            model.base_url = base_url.clone();
353        }
354        model
355    }
356
357    /// Alias for `chat`.
358    #[must_use]
359    pub fn language_model(&self, model_id: &str) -> GoogleModel {
360        self.chat(model_id)
361    }
362
363    /// Creates an embedding model.
364    #[must_use]
365    pub fn embedding(&self, _model_id: &str) -> embedding::GoogleEmbeddingModel {
366        let api_key = self
367            .settings
368            .api_key
369            .clone()
370            .or_else(|| std::env::var("GOOGLE_GENERATIVE_AI_API_KEY").ok())
371            .unwrap_or_default();
372        let mut model = embedding::GoogleEmbeddingModel::new(api_key);
373        if let Some(ref base_url) = self.settings.base_url {
374            model.base_url = base_url.clone();
375        }
376        model
377    }
378
379    /// Creates an image generation model.
380    #[must_use]
381    pub fn image(&self, _model_id: &str) -> image::GoogleImageModel {
382        let api_key = self
383            .settings
384            .api_key
385            .clone()
386            .or_else(|| std::env::var("GOOGLE_GENERATIVE_AI_API_KEY").ok())
387            .unwrap_or_default();
388        let mut model = image::GoogleImageModel::new(api_key);
389        if let Some(ref base_url) = self.settings.base_url {
390            model.base_url = base_url.clone();
391        }
392        model
393    }
394}
395
396/// Create a Google provider instance with the given settings.
397#[must_use]
398pub fn create_google(settings: ProviderSettings) -> GoogleProvider {
399    GoogleProvider { settings }
400}
401
402impl crate::core::registry::Provider for GoogleProvider {
403    fn language_model(&self, model_id: &str) -> Option<Box<dyn crate::core::LanguageModel>> {
404        Some(Box::new(self.chat(model_id)))
405    }
406
407    fn embedding_model(&self, model_id: &str) -> Option<Box<dyn crate::core::EmbeddingModel>> {
408        Some(Box::new(self.embedding(model_id)))
409    }
410
411    fn image_model(&self, model_id: &str) -> Option<Box<dyn crate::core::ImageModel>> {
412        Some(Box::new(self.image(model_id)))
413    }
414}