gemini_rust/
client.rs

1use crate::{
2    content_builder::ContentBuilder,
3    embed_builder::EmbedBuilder,
4    models::{
5        BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
6        EmbedContentRequest, GenerateContentRequest, GenerationResponse,
7    },
8    Error, Result,
9};
10use futures::stream::Stream;
11use reqwest::Client;
12use serde_json::Value;
13use std::pin::Pin;
14use std::sync::Arc;
15use url::Url;
16
17const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
18const DEFAULT_MODEL: &str = "models/gemini-2.5-flash";
19
20/// Internal client for making requests to the Gemini API
21pub(crate) struct GeminiClient {
22    http_client: Client,
23    api_key: String,
24    pub model: String,
25    base_url: String,
26}
27
28impl GeminiClient {
29    /// Create a new client
30    #[allow(dead_code)]
31    fn new(api_key: impl Into<String>, model: String) -> Self {
32        Self::with_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
33    }
34
35    /// Create a new client with custom base URL
36    fn with_base_url(api_key: impl Into<String>, model: String, base_url: String) -> Self {
37        Self {
38            http_client: Client::new(),
39            api_key: api_key.into(),
40            model,
41            base_url,
42        }
43    }
44
45    /// Generate content
46    pub(crate) async fn generate_content_raw(
47        &self,
48        request: GenerateContentRequest,
49    ) -> Result<GenerationResponse> {
50        let url = self.build_url("generateContent")?;
51
52        let response = self.http_client.post(url).json(&request).send().await?;
53
54        let status = response.status();
55        if !status.is_success() {
56            let error_text = response.text().await?;
57            return Err(Error::ApiError {
58                status_code: status.as_u16(),
59                message: error_text,
60            });
61        }
62
63        let response = response.json().await?;
64
65        Ok(response)
66    }
67
68    /// Generate content with streaming
69    pub(crate) async fn generate_content_stream(
70        &self,
71        request: GenerateContentRequest,
72    ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
73        let url = self.build_url("streamGenerateContent")?;
74
75        let response = self.http_client.post(url).json(&request).send().await?;
76
77        let status = response.status();
78        if !status.is_success() {
79            let error_text = response.text().await?;
80            return Err(Error::ApiError {
81                status_code: status.as_u16(),
82                message: error_text,
83            });
84        }
85
86        // Get the full response as bytes and parse as JSON array
87        let bytes = response.bytes().await?;
88        let text = String::from_utf8_lossy(&bytes);
89
90        // The Gemini API returns a JSON array format like: [{json1}, {json2}, {json3}]
91        let responses: Vec<Result<GenerationResponse>> =
92            match serde_json::from_str::<Vec<GenerationResponse>>(&text) {
93                Ok(json_array) => json_array.into_iter().map(Ok).collect(),
94                Err(e) => {
95                    vec![Err(Error::JsonError(e))]
96                }
97            };
98
99        let stream = futures::stream::iter(responses);
100        Ok(Box::pin(stream))
101    }
102
103    /// Embed content
104    pub(crate) async fn embed_content(
105        &self,
106        request: EmbedContentRequest,
107    ) -> Result<ContentEmbeddingResponse> {
108        let value = self.embed(request, "embedContent").await?;
109        let response = serde_json::from_value::<ContentEmbeddingResponse>(value)?;
110
111        Ok(response)
112    }
113
114    /// Batch Embed content
115    pub(crate) async fn embed_content_batch(
116        &self,
117        request: BatchEmbedContentsRequest,
118    ) -> Result<BatchContentEmbeddingResponse> {
119        let value = self.embed(request, "batchEmbedContents").await?;
120        let response = serde_json::from_value::<BatchContentEmbeddingResponse>(value)?;
121
122        Ok(response)
123    }
124
125    /// Embed content base function
126    async fn embed<T: serde::Serialize>(&self, request: T, endpoint: &str) -> Result<Value> {
127        let url = self.build_url(endpoint)?;
128
129        let response = self.http_client.post(url).json(&request).send().await?;
130
131        let status = response.status();
132        if !status.is_success() {
133            let error_text = response.text().await?;
134            return Err(Error::ApiError {
135                status_code: status.as_u16(),
136                message: error_text,
137            });
138        }
139
140        let response = response.json().await?;
141        Ok(response)
142    }
143
144    /// Build a URL for the API
145    fn build_url(&self, endpoint: &str) -> Result<Url> {
146        // All Gemini API endpoints now use the format with colon:
147        // "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key=$API_KEY"
148        let url_str = format!(
149            "{}{}:{}?key={}",
150            self.base_url, self.model, endpoint, self.api_key
151        );
152        Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
153    }
154}
155
156/// Client for the Gemini API
157#[derive(Clone)]
158pub struct Gemini {
159    client: Arc<GeminiClient>,
160}
161
162impl Gemini {
163    /// Create a new client with the specified API key
164    pub fn new(api_key: impl Into<String>) -> Self {
165        Self::with_model(api_key, DEFAULT_MODEL.to_string())
166    }
167
168    /// Create a new client for the Gemini Pro model
169    pub fn pro(api_key: impl Into<String>) -> Self {
170        Self::with_model(api_key, "models/gemini-2.5-pro".to_string())
171    }
172
173    /// Create a new client with the specified API key and model
174    pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
175        Self::with_model_and_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
176    }
177
178    /// Create a new client with custom base URL
179    pub fn with_base_url(api_key: impl Into<String>, base_url: String) -> Self {
180        Self::with_model_and_base_url(api_key, DEFAULT_MODEL.to_string(), base_url)
181    }
182
183    /// Create a new client with the specified API key, model, and base URL
184    pub fn with_model_and_base_url(
185        api_key: impl Into<String>,
186        model: String,
187        base_url: String,
188    ) -> Self {
189        let client = GeminiClient::with_base_url(api_key, model, base_url);
190        Self {
191            client: Arc::new(client),
192        }
193    }
194
195    /// Start building a content generation request
196    pub fn generate_content(&self) -> ContentBuilder {
197        ContentBuilder::new(self.client.clone())
198    }
199
200    /// Start building a content generation request
201    pub fn embed_content(&self) -> EmbedBuilder {
202        EmbedBuilder::new(self.client.clone())
203    }
204}