gemini_rust/
client.rs

1use crate::{
2    batch_builder::BatchBuilder,
3    content_builder::ContentBuilder,
4    embed_builder::EmbedBuilder,
5    models::{
6        BatchContentEmbeddingResponse, BatchEmbedContentsRequest, BatchGenerateContentRequest,
7        BatchGenerateContentResponse, ContentEmbeddingResponse, EmbedContentRequest,
8        GenerateContentRequest, GenerationResponse,
9    },
10    Error, Result,
11};
12use futures::stream::Stream;
13use reqwest::Client;
14use serde_json::Value;
15use std::{pin::Pin, sync::Arc};
16use url::Url;
17
18const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
19const DEFAULT_MODEL: &str = "models/gemini-2.5-flash";
20
21/// Internal client for making requests to the Gemini API
22pub(crate) struct GeminiClient {
23    http_client: Client,
24    api_key: String,
25    pub model: String,
26    base_url: String,
27}
28
29impl GeminiClient {
30    /// Create a new client
31    #[allow(dead_code)]
32    fn new(api_key: impl Into<String>, model: String) -> Self {
33        Self::with_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
34    }
35
36    /// Create a new client with custom base URL
37    fn with_base_url(api_key: impl Into<String>, model: String, base_url: String) -> Self {
38        Self {
39            http_client: Client::new(),
40            api_key: api_key.into(),
41            model,
42            base_url,
43        }
44    }
45
46    /// Generate content
47    pub(crate) async fn generate_content_raw(
48        &self,
49        request: GenerateContentRequest,
50    ) -> Result<GenerationResponse> {
51        let url = self.build_url("generateContent")?;
52
53        let response = self.http_client.post(url).json(&request).send().await?;
54
55        let status = response.status();
56        if !status.is_success() {
57            let error_text = response.text().await?;
58            return Err(Error::ApiError {
59                status_code: status.as_u16(),
60                message: error_text,
61            });
62        }
63
64        let response = response.json().await?;
65
66        Ok(response)
67    }
68
69    /// Generate content with streaming
70    pub(crate) async fn generate_content_stream(
71        &self,
72        request: GenerateContentRequest,
73    ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
74        let url = self.build_url("streamGenerateContent")?;
75
76        let response = self.http_client.post(url).json(&request).send().await?;
77
78        let status = response.status();
79        if !status.is_success() {
80            let error_text = response.text().await?;
81            return Err(Error::ApiError {
82                status_code: status.as_u16(),
83                message: error_text,
84            });
85        }
86
87        // Get the full response as bytes and parse as JSON array
88        let bytes = response.bytes().await?;
89        let text = String::from_utf8_lossy(&bytes);
90
91        // The Gemini API returns a JSON array format like: [{json1}, {json2}, {json3}]
92        let responses: Vec<Result<GenerationResponse>> =
93            match serde_json::from_str::<Vec<GenerationResponse>>(&text) {
94                Ok(json_array) => json_array.into_iter().map(Ok).collect(),
95                Err(e) => {
96                    vec![Err(Error::JsonError(e))]
97                }
98            };
99
100        let stream = futures::stream::iter(responses);
101        Ok(Box::pin(stream))
102    }
103
104    /// Embed content
105    pub(crate) async fn embed_content(
106        &self,
107        request: EmbedContentRequest,
108    ) -> Result<ContentEmbeddingResponse> {
109        let value = self.post_json(request, "embedContent").await?;
110        let response = serde_json::from_value::<ContentEmbeddingResponse>(value)?;
111
112        Ok(response)
113    }
114
115    /// Batch Embed content
116    pub(crate) async fn embed_content_batch(
117        &self,
118        request: BatchEmbedContentsRequest,
119    ) -> Result<BatchContentEmbeddingResponse> {
120        let value = self.post_json(request, "batchEmbedContents").await?;
121        let response = serde_json::from_value::<BatchContentEmbeddingResponse>(value)?;
122
123        Ok(response)
124    }
125
126    /// Synchronous Batch Generate content
127    pub(crate) async fn batch_generate_content_sync(
128        &self,
129        request: BatchGenerateContentRequest,
130    ) -> Result<BatchGenerateContentResponse> {
131        let value = self.post_json(request, "batchGenerateContent").await?;
132        let response = serde_json::from_value::<BatchGenerateContentResponse>(value)?;
133        Ok(response)
134    }
135
136    /// Post JSON to an endpoint
137    async fn post_json<T: serde::Serialize>(&self, request: T, endpoint: &str) -> Result<Value> {
138        let url = self.build_url(endpoint)?;
139
140        let response = self.http_client.post(url).json(&request).send().await?;
141
142        let status = response.status();
143        if !status.is_success() {
144            let error_text = response.text().await?;
145            return Err(Error::ApiError {
146                status_code: status.as_u16(),
147                message: error_text,
148            });
149        }
150
151        let response = response.json().await?;
152        Ok(response)
153    }
154
155    /// Build a URL for the API
156    fn build_url(&self, endpoint: &str) -> Result<Url> {
157        let url_str = format!(
158            "{}{}:{}?key={}",
159            self.base_url, self.model, endpoint, self.api_key
160        );
161        Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
162    }
163}
164
165/// Client for the Gemini API
166#[derive(Clone)]
167pub struct Gemini {
168    client: Arc<GeminiClient>,
169}
170
171impl Gemini {
172    /// Create a new client with the specified API key
173    pub fn new(api_key: impl Into<String>) -> Self {
174        Self::with_model(api_key, DEFAULT_MODEL.to_string())
175    }
176
177    /// Create a new client for the Gemini Pro model
178    pub fn pro(api_key: impl Into<String>) -> Self {
179        Self::with_model(api_key, "models/gemini-2.5-pro".to_string())
180    }
181
182    /// Create a new client with the specified API key and model
183    pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
184        Self::with_model_and_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
185    }
186
187    /// Create a new client with custom base URL
188    pub fn with_base_url(api_key: impl Into<String>, base_url: String) -> Self {
189        Self::with_model_and_base_url(api_key, DEFAULT_MODEL.to_string(), base_url)
190    }
191
192    /// Create a new client with the specified API key, model, and base URL
193    pub fn with_model_and_base_url(
194        api_key: impl Into<String>,
195        model: String,
196        base_url: String,
197    ) -> Self {
198        let client = GeminiClient::with_base_url(api_key, model, base_url);
199        Self {
200            client: Arc::new(client),
201        }
202    }
203
204    /// Start building a content generation request
205    pub fn generate_content(&self) -> ContentBuilder {
206        ContentBuilder::new(self.client.clone())
207    }
208
209    /// Start building a content generation request
210    pub fn embed_content(&self) -> EmbedBuilder {
211        EmbedBuilder::new(self.client.clone())
212    }
213
214    /// Start building a synchronous batch content generation request
215    pub fn batch_generate_content_sync(&self) -> BatchBuilder {
216        BatchBuilder::new(self.client.clone())
217    }
218}