gemini_rust/
client.rs

1use crate::{
2    content_builder::ContentBuilder,
3    embed_builder::EmbedBuilder,
4    models::{
5        BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
6        EmbedContentRequest, GenerateContentRequest, GenerationResponse,
7    },
8    Error, Result,
9};
10use futures::stream::Stream;
11use futures_util::StreamExt;
12use reqwest::Client;
13use serde_json::Value;
14use std::pin::Pin;
15use std::sync::Arc;
16use url::Url;
17
18const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
19const DEFAULT_MODEL: &str = "models/gemini-2.0-flash";
20
21/// Internal client for making requests to the Gemini API
22pub(crate) struct GeminiClient {
23    http_client: Client,
24    api_key: String,
25    pub model: String,
26    base_url: String,
27}
28
29impl GeminiClient {
30    /// Create a new client
31    #[allow(dead_code)]
32    fn new(api_key: impl Into<String>, model: String) -> Self {
33        Self::with_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
34    }
35
36    /// Create a new client with custom base URL
37    fn with_base_url(api_key: impl Into<String>, model: String, base_url: String) -> Self {
38        Self {
39            http_client: Client::new(),
40            api_key: api_key.into(),
41            model,
42            base_url,
43        }
44    }
45
46    /// Generate content
47    pub(crate) async fn generate_content_raw(
48        &self,
49        request: GenerateContentRequest,
50    ) -> Result<GenerationResponse> {
51        let url = self.build_url("generateContent")?;
52
53        let response = self.http_client.post(url).json(&request).send().await?;
54
55        let status = response.status();
56        if !status.is_success() {
57            let error_text = response.text().await?;
58            return Err(Error::ApiError {
59                status_code: status.as_u16(),
60                message: error_text,
61            });
62        }
63
64        let response = response.json().await?;
65
66        Ok(response)
67    }
68
69    /// Generate content with streaming
70    pub(crate) async fn generate_content_stream(
71        &self,
72        request: GenerateContentRequest,
73    ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
74        let url = self.build_url("streamGenerateContent")?;
75
76        let response = self.http_client.post(url).json(&request).send().await?;
77
78        let status = response.status();
79        if !status.is_success() {
80            let error_text = response.text().await?;
81            return Err(Error::ApiError {
82                status_code: status.as_u16(),
83                message: error_text,
84            });
85        }
86
87        let stream = response
88            .bytes_stream()
89            .map(|result| {
90                match result {
91                    Ok(bytes) => {
92                        let text = String::from_utf8_lossy(&bytes);
93                        // The stream returns each chunk as a separate JSON object
94                        // Each line that starts with "data: " contains a JSON object
95                        let mut responses = Vec::new();
96                        for line in text.lines() {
97                            if let Some(json_str) = line.strip_prefix("data: ") {
98                                if json_str == "[DONE]" {
99                                    continue;
100                                }
101                                match serde_json::from_str::<GenerationResponse>(json_str) {
102                                    Ok(response) => responses.push(Ok(response)),
103                                    Err(e) => responses.push(Err(Error::JsonError(e))),
104                                }
105                            }
106                        }
107                        futures::stream::iter(responses)
108                    }
109                    Err(e) => futures::stream::iter(vec![Err(Error::HttpError(e))]),
110                }
111            })
112            .flatten();
113
114        Ok(Box::pin(stream))
115    }
116
117    /// Embed content
118    pub(crate) async fn embed_content(
119        &self,
120        request: EmbedContentRequest,
121    ) -> Result<ContentEmbeddingResponse> {
122        let value = self.embed(request, "embedContent").await?;
123        let response = serde_json::from_value::<ContentEmbeddingResponse>(value)?;
124
125        Ok(response)
126    }
127
128    /// Batch Embed content
129    pub(crate) async fn embed_content_batch(
130        &self,
131        request: BatchEmbedContentsRequest,
132    ) -> Result<BatchContentEmbeddingResponse> {
133        let value = self.embed(request, "batchEmbedContents").await?;
134        let response = serde_json::from_value::<BatchContentEmbeddingResponse>(value)?;
135
136        Ok(response)
137    }
138
139    /// Embed content base function
140    async fn embed<T: serde::Serialize>(&self, request: T, endpoint: &str) -> Result<Value> {
141        let url = self.build_url(endpoint)?;
142
143        let response = self.http_client.post(url).json(&request).send().await?;
144
145        let status = response.status();
146        if !status.is_success() {
147            let error_text = response.text().await?;
148            return Err(Error::ApiError {
149                status_code: status.as_u16(),
150                message: error_text,
151            });
152        }
153
154        let response = response.json().await?;
155        Ok(response)
156    }
157
158    /// Build a URL for the API
159    fn build_url(&self, endpoint: &str) -> Result<Url> {
160        // All Gemini API endpoints now use the format with colon:
161        // "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key=$API_KEY"
162        let url_str = format!(
163            "{}{}:{}?key={}",
164            self.base_url, self.model, endpoint, self.api_key
165        );
166        Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
167    }
168}
169
170/// Client for the Gemini API
171#[derive(Clone)]
172pub struct Gemini {
173    client: Arc<GeminiClient>,
174}
175
176impl Gemini {
177    /// Create a new client with the specified API key
178    pub fn new(api_key: impl Into<String>) -> Self {
179        Self::with_model(api_key, DEFAULT_MODEL.to_string())
180    }
181
182    /// Create a new client for the Gemini Pro model
183    pub fn pro(api_key: impl Into<String>) -> Self {
184        Self::with_model(api_key, "models/gemini-2.0-pro-exp-02-05".to_string())
185    }
186
187    /// Create a new client with the specified API key and model
188    pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
189        Self::with_model_and_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
190    }
191
192    /// Create a new client with custom base URL
193    pub fn with_base_url(api_key: impl Into<String>, base_url: String) -> Self {
194        Self::with_model_and_base_url(api_key, DEFAULT_MODEL.to_string(), base_url)
195    }
196
197    /// Create a new client with the specified API key, model, and base URL
198    pub fn with_model_and_base_url(
199        api_key: impl Into<String>,
200        model: String,
201        base_url: String,
202    ) -> Self {
203        let client = GeminiClient::with_base_url(api_key, model, base_url);
204        Self {
205            client: Arc::new(client),
206        }
207    }
208
209    /// Start building a content generation request
210    pub fn generate_content(&self) -> ContentBuilder {
211        ContentBuilder::new(self.client.clone())
212    }
213
214    /// Start building a content generation request
215    pub fn embed_content(&self) -> EmbedBuilder {
216        EmbedBuilder::new(self.client.clone())
217    }
218}