1use super::{
2    completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
3};
4use crate::client::{
5    ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
6    VerifyClient, VerifyError, impl_conversion_traits,
7};
8use crate::http_client::{self, HttpClientExt};
9use crate::wasm_compat::*;
10use crate::{
11    Embed,
12    embeddings::{self},
13};
14use bytes::Bytes;
15use serde::Deserialize;
16use std::fmt::Debug;
17
18const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
22
23pub struct ClientBuilder<'a, T = reqwest::Client> {
24    api_key: &'a str,
25    base_url: &'a str,
26    http_client: T,
27}
28
29impl<'a, T> ClientBuilder<'a, T>
30where
31    T: HttpClientExt,
32{
33    pub fn new(api_key: &'a str) -> ClientBuilder<'a, reqwest::Client> {
34        ClientBuilder {
35            api_key,
36            base_url: GEMINI_API_BASE_URL,
37            http_client: Default::default(),
38        }
39    }
40
41    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
42        Self {
43            api_key,
44            base_url: GEMINI_API_BASE_URL,
45            http_client,
46        }
47    }
48
49    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U>
50    where
51        U: HttpClientExt,
52    {
53        ClientBuilder {
54            api_key: self.api_key,
55            base_url: self.base_url,
56            http_client,
57        }
58    }
59
60    pub fn base_url(mut self, base_url: &'a str) -> Self {
61        self.base_url = base_url;
62        self
63    }
64
65    pub fn build(self) -> Result<Client<T>, ClientBuilderError> {
66        let mut default_headers = reqwest::header::HeaderMap::new();
67        default_headers.insert(
68            reqwest::header::CONTENT_TYPE,
69            "application/json".parse().unwrap(),
70        );
71
72        Ok(Client {
73            base_url: self.base_url.to_string(),
74            api_key: self.api_key.to_string(),
75            default_headers,
76            http_client: self.http_client,
77        })
78    }
79}
80#[derive(Clone)]
81pub struct Client<T = reqwest::Client> {
82    base_url: String,
83    api_key: String,
84    default_headers: reqwest::header::HeaderMap,
85    http_client: T,
86}
87
88impl<T> Debug for Client<T>
89where
90    T: Debug,
91{
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("Client")
94            .field("base_url", &self.base_url)
95            .field("http_client", &self.http_client)
96            .field("default_headers", &self.default_headers)
97            .field("api_key", &"<REDACTED>")
98            .finish()
99    }
100}
101
102impl<T> Client<T>
103where
104    T: HttpClientExt + Default,
105{
106    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
117        ClientBuilder::new_with_client(api_key, Default::default())
118    }
119
120    pub fn new(api_key: &str) -> Self {
125        Self::builder(api_key)
126            .build()
127            .expect("Gemini client should build")
128    }
129}
130
131impl Client<reqwest::Client> {
132    pub(crate) fn post_sse(&self, path: &str) -> reqwest::RequestBuilder {
133        let url = format!(
134            "{}/{}?alt=sse&key={}",
135            self.base_url,
136            path.trim_start_matches('/'),
137            self.api_key
138        );
139
140        tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
141
142        self.http_client
143            .post(url)
144            .headers(self.default_headers.clone())
145    }
146}
147
148impl<T> Client<T>
149where
150    T: HttpClientExt,
151{
152    pub(crate) fn post(&self, path: &str) -> http_client::Builder {
153        let url = format!(
155            "{}/{}?key={}",
156            self.base_url,
157            path.trim_start_matches('/'),
158            self.api_key
159        );
160
161        tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
162        let mut req = http_client::Request::post(url);
163
164        if let Some(hs) = req.headers_mut() {
165            *hs = self.default_headers.clone();
166        }
167
168        req
169    }
170
171    pub(crate) fn get(&self, path: &str) -> http_client::Builder {
172        let url = format!(
174            "{}/{}?key={}",
175            self.base_url,
176            path.trim_start_matches('/'),
177            self.api_key
178        );
179
180        tracing::debug!("GET {}/{}?key={}", self.base_url, path, "****");
181
182        let mut req = http_client::Request::get(url);
183
184        if let Some(hs) = req.headers_mut() {
185            *hs = self.default_headers.clone();
186        }
187
188        req
189    }
190
191    pub(crate) async fn send<U, R>(
192        &self,
193        req: http_client::Request<U>,
194    ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
195    where
196        U: Into<Bytes> + Send,
197        R: From<Bytes> + Send + 'static,
198    {
199        self.http_client.send(req).await
200    }
201}
202
203impl ProviderClient for Client<reqwest::Client> {
206    fn from_env() -> Self {
209        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
210        Self::new(&api_key)
211    }
212
213    fn from_val(input: crate::client::ProviderValue) -> Self {
214        let crate::client::ProviderValue::Simple(api_key) = input else {
215            panic!("Incorrect provider value type")
216        };
217        Self::new(&api_key)
218    }
219}
220
221impl CompletionClient for Client<reqwest::Client> {
222    type CompletionModel = CompletionModel<reqwest::Client>;
223
224    fn completion_model(&self, model: &str) -> Self::CompletionModel {
228        CompletionModel::new(self.clone(), model)
229    }
230}
231
232impl<T> EmbeddingsClient for Client<T>
233where
234    T: HttpClientExt + Clone + Debug + Default + 'static,
235    Client<T>: CompletionClient,
236{
237    type EmbeddingModel = EmbeddingModel<T>;
238
239    fn embedding_model(&self, model: &str) -> EmbeddingModel<T> {
253        EmbeddingModel::new(self.clone(), model, None)
254    }
255
256    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel<T> {
268        EmbeddingModel::new(self.clone(), model, Some(ndims))
269    }
270
271    fn embeddings<D: Embed>(
288        &self,
289        model: &str,
290    ) -> embeddings::EmbeddingsBuilder<EmbeddingModel<T>, D> {
291        embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
292    }
293}
294
295impl<T> TranscriptionClient for Client<T>
296where
297    T: HttpClientExt + Clone + Debug + Default + 'static,
298    Client<T>: CompletionClient,
299{
300    type TranscriptionModel = TranscriptionModel<T>;
301
302    fn transcription_model(&self, model: &str) -> TranscriptionModel<T> {
306        TranscriptionModel::new(self.clone(), model)
307    }
308}
309
310impl<T> VerifyClient for Client<T>
311where
312    T: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
313    Client<T>: CompletionClient,
314{
315    #[cfg_attr(feature = "worker", worker::send)]
316    async fn verify(&self) -> Result<(), VerifyError> {
317        let req = self
318            .get("/v1beta/models")
319            .body(http_client::NoBody)
320            .map_err(|e| VerifyError::HttpError(e.into()))?;
321        let response = self.http_client.send::<_, Vec<u8>>(req).await?;
322
323        match response.status() {
324            reqwest::StatusCode::OK => Ok(()),
325            reqwest::StatusCode::FORBIDDEN => Err(VerifyError::InvalidAuthentication),
326            reqwest::StatusCode::INTERNAL_SERVER_ERROR
327            | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
328                let text = http_client::text(response).await?;
329                Err(VerifyError::ProviderError(text))
330            }
331            _ => {
332                Ok(())
337            }
338        }
339    }
340}
341
342impl_conversion_traits!(
343    AsImageGeneration,
344    AsAudioGeneration for Client<T>
345);
346
347#[derive(Debug, Deserialize)]
348pub struct ApiErrorResponse {
349    pub message: String,
350}
351
352#[derive(Debug, Deserialize)]
353#[serde(untagged)]
354pub enum ApiResponse<T> {
355    Ok(T),
356    Err(ApiErrorResponse),
357}