ai_rs/gemini/
client.rs

1use crate::gemini::types::{
2    Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, Part,
3    StreamGenerateContentResponse,
4};
5use futures_util::{Stream, StreamExt};
6use log::{debug, error, info, warn};
7use reqwest::Client;
8use serde::de::Error as SerdeError;
9use serde_json::json;
10use std::fmt;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15
16/// Custom error type to handle different error scenarios
17#[derive(Debug)]
18pub enum GeminiClientError {
19    /// Error related to the request
20    RequestError(String),
21    /// Network-related error
22    NetworkError(reqwest::Error),
23    /// Error while parsing JSON
24    ParseError(serde_json::Error),
25    /// API error from Gemini
26    ApiError(String),
27}
28
29impl fmt::Display for GeminiClientError {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            GeminiClientError::RequestError(msg) => write!(f, "Request error: {}", msg),
33            GeminiClientError::NetworkError(err) => write!(f, "Network error: {}", err),
34            GeminiClientError::ParseError(err) => write!(f, "Parse error: {}", err),
35            GeminiClientError::ApiError(msg) => write!(f, "API error: {}", msg),
36        }
37    }
38}
39
40impl std::error::Error for GeminiClientError {}
41
42impl From<reqwest::Error> for GeminiClientError {
43    fn from(err: reqwest::Error) -> Self {
44        GeminiClientError::NetworkError(err)
45    }
46}
47
48impl From<serde_json::Error> for GeminiClientError {
49    fn from(err: serde_json::Error) -> Self {
50        GeminiClientError::ParseError(err)
51    }
52}
53
54/// Client for interacting with the Gemini API
55#[derive(Debug)]
56pub struct GeminiClient {
57    api_key: String,
58    model: String,
59    base_url: String,
60    client: Client,
61}
62
63impl GeminiClient {
64    /// Creates a new instance of `GeminiClient`
65    ///
66    /// # Arguments
67    ///
68    /// * `api_key` - The API key for Google AI Studio
69    /// * `model` - The model to use (e.g., "gemini-1.5-pro")
70    ///
71    /// # Returns
72    ///
73    /// A new `GeminiClient` instance
74    pub fn new(api_key: &str, model: &str) -> Self {
75        info!("Creating new GeminiClient with model: {}", model);
76        GeminiClient {
77            api_key: api_key.to_string(),
78            model: model.to_string(),
79            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
80            client: Client::new(),
81        }
82    }
83
84    /// Legacy method for backward compatibility
85    pub fn setup(api_key: &str) -> Self {
86        Self::new(api_key, "gemini-1.5-pro")
87    }
88
89    /// Sets the model to use
90    pub fn model(mut self, model: &str) -> Self {
91        info!("Setting model to {}", model);
92        self.model = model.to_string();
93        self
94    }
95
96    /// Generates content based on a text prompt
97    ///
98    /// # Arguments
99    ///
100    /// * `prompt` - The text prompt to generate content for
101    ///
102    /// # Returns
103    ///
104    /// A `Result` containing the `GenerateContentResponse` or a `GeminiClientError`
105    pub async fn generate_content(
106        &self,
107        prompt: &str,
108    ) -> Result<GenerateContentResponse, GeminiClientError> {
109        let request = GenerateContentRequest {
110            contents: vec![Content {
111                role: "user".to_string(),
112                parts: vec![Part {
113                    text: Some(prompt.to_string()),
114                    inline_data: None,
115                }],
116            }],
117            generation_config: None,
118            safety_settings: None,
119            tools: None,
120        };
121
122        self.generate_content_with_request(request).await
123    }
124
125    /// Generates content based on a structured request
126    ///
127    /// # Arguments
128    ///
129    /// * `request` - The `GenerateContentRequest` containing the content and configuration
130    ///
131    /// # Returns
132    ///
133    /// A `Result` containing the `GenerateContentResponse` or a `GeminiClientError`
134    pub async fn generate_content_with_request(
135        &self,
136        request: GenerateContentRequest,
137    ) -> Result<GenerateContentResponse, GeminiClientError> {
138        let url = format!("{}/models/{}:generateContent", self.base_url, self.model);
139        info!("Generating content with URL: {}", url);
140        debug!("GenerateContentRequest: {:?}", request);
141
142        let response = self
143            .client
144            .post(&url)
145            .header("x-goog-api-key", &self.api_key)
146            .json(&request)
147            .send()
148            .await?;
149
150        if response.status().is_success() {
151            let response_json: serde_json::Value = response.json().await?;
152            debug!("Response JSON: {:?}", response_json);
153
154            // Check for API errors in the response
155            if let Some(error) = response_json.get("error") {
156                let error_message = error.to_string();
157                error!("Gemini API error: {}", error_message);
158                return Err(GeminiClientError::ApiError(error_message));
159            }
160
161            let generate_response: GenerateContentResponse = serde_json::from_value(response_json)?;
162            info!("Successfully generated content.");
163            debug!("GenerateContentResponse: {:?}", generate_response);
164            Ok(generate_response)
165        } else {
166            let error_message = response.text().await?;
167            error!("Failed to generate content: {}", error_message);
168            Err(GeminiClientError::RequestError(error_message))
169        }
170    }
171
172    /// Streams content generation based on a text prompt
173    ///
174    /// # Arguments
175    ///
176    /// * `prompt` - The text prompt to generate content for
177    ///
178    /// # Returns
179    ///
180    /// A `Result` containing a Stream of `StreamGenerateContentResponse` chunks or a `GeminiClientError`
181    pub async fn stream_content(
182        &self,
183        prompt: &str,
184    ) -> Result<
185        impl Stream<Item = Result<StreamGenerateContentResponse, GeminiClientError>>,
186        GeminiClientError,
187    > {
188        let request = GenerateContentRequest {
189            contents: vec![Content {
190                role: "user".to_string(),
191                parts: vec![Part {
192                    text: Some(prompt.to_string()),
193                    inline_data: None,
194                }],
195            }],
196            generation_config: None,
197            safety_settings: None,
198            tools: None,
199        };
200
201        self.stream_content_with_request(request).await
202    }
203
204    /// Streams content generation based on a structured request
205    ///
206    /// # Arguments
207    ///
208    /// * `request` - The `GenerateContentRequest` containing the content and configuration
209    ///
210    /// # Returns
211    ///
212    /// A `Result` containing a Stream of `StreamGenerateContentResponse` chunks or a `GeminiClientError`
213    pub async fn stream_content_with_request(
214        &self,
215        request: GenerateContentRequest,
216    ) -> Result<
217        impl Stream<Item = Result<StreamGenerateContentResponse, GeminiClientError>>,
218        GeminiClientError,
219    > {
220        let url = format!(
221            "{}/models/{}:streamGenerateContent",
222            self.base_url, self.model
223        );
224        info!("Streaming content with URL: {}", url);
225        debug!("StreamRequest: {:?}", request);
226
227        let response = self
228            .client
229            .post(&url)
230            .header("x-goog-api-key", &self.api_key)
231            .json(&request)
232            .send()
233            .await?;
234
235        if response.status().is_success() {
236            let (tx, rx) = mpsc::channel(100);
237            let stream = response.bytes_stream();
238
239            tokio::spawn(async move {
240                let mut stream = stream;
241                while let Some(chunk) = stream.next().await {
242                    match chunk {
243                        Ok(bytes) => {
244                            let chunk_str = String::from_utf8_lossy(&bytes);
245                            debug!("Received chunk: {}", chunk_str);
246
247                            // Split by newlines and process each JSON object
248                            for line in chunk_str.lines() {
249                                if line.trim().is_empty() {
250                                    continue;
251                                }
252
253                                // Remove "data: " prefix if present
254                                let json_str = if line.starts_with("data: ") {
255                                    &line[6..]
256                                } else {
257                                    line
258                                };
259
260                                if json_str.trim() == "[DONE]" {
261                                    break;
262                                }
263
264                                match serde_json::from_str::<StreamGenerateContentResponse>(
265                                    json_str,
266                                ) {
267                                    Ok(stream_response) => {
268                                        if let Err(e) = tx.send(Ok(stream_response)).await {
269                                            error!("Failed to send stream response: {}", e);
270                                            break;
271                                        }
272                                    }
273                                    Err(e) => {
274                                        error!("Failed to parse stream response: {}", e);
275                                        if let Err(e) =
276                                            tx.send(Err(GeminiClientError::ParseError(e))).await
277                                        {
278                                            error!("Failed to send error: {}", e);
279                                            break;
280                                        }
281                                    }
282                                }
283                            }
284                        }
285                        Err(e) => {
286                            error!("Stream error: {}", e);
287                            if let Err(e) = tx.send(Err(GeminiClientError::NetworkError(e))).await {
288                                error!("Failed to send network error: {}", e);
289                            }
290                            break;
291                        }
292                    }
293                }
294            });
295
296            Ok(ReceiverStream::new(rx))
297        } else {
298            let error_message = response.text().await?;
299            error!("Failed to start streaming: {}", error_message);
300            Err(GeminiClientError::RequestError(error_message))
301        }
302    }
303
304    /// Generates content with specific generation configuration
305    ///
306    /// # Arguments
307    ///
308    /// * `prompt` - The text prompt to generate content for
309    /// * `config` - The generation configuration
310    ///
311    /// # Returns
312    ///
313    /// A `Result` containing the `GenerateContentResponse` or a `GeminiClientError`
314    pub async fn generate_content_with_config(
315        &self,
316        prompt: &str,
317        config: GenerationConfig,
318    ) -> Result<GenerateContentResponse, GeminiClientError> {
319        let request = GenerateContentRequest {
320            contents: vec![Content {
321                role: "user".to_string(),
322                parts: vec![Part {
323                    text: Some(prompt.to_string()),
324                    inline_data: None,
325                }],
326            }],
327            generation_config: Some(config),
328            safety_settings: None,
329            tools: None,
330        };
331
332        self.generate_content_with_request(request).await
333    }
334
335    /// Simple text generation method for backward compatibility
336    pub fn generate_content_sync(&self, prompt: &str) -> String {
337        // This is a blocking wrapper around the async method
338        // Note: This is not ideal for production use, but maintains backward compatibility
339        let rt = tokio::runtime::Runtime::new().unwrap();
340        match rt.block_on(self.generate_content(prompt)) {
341            Ok(response) => response
342                .get_text()
343                .unwrap_or_else(|| "No response generated".to_string()),
344            Err(e) => format!("Error: {}", e),
345        }
346    }
347}