1use super::{
2 completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
3};
4use crate::client::{
5 ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
6 VerifyClient, VerifyError, impl_conversion_traits,
7};
8use crate::http_client::{self, HttpClientExt};
9use crate::wasm_compat::*;
10use crate::{
11 Embed,
12 embeddings::{self},
13};
14use bytes::Bytes;
15use serde::Deserialize;
16use std::fmt::Debug;
17
18const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
22
23pub struct ClientBuilder<'a, T = reqwest::Client> {
24 api_key: &'a str,
25 base_url: &'a str,
26 http_client: T,
27}
28
29impl<'a, T> ClientBuilder<'a, T>
30where
31 T: HttpClientExt,
32{
33 pub fn new(api_key: &'a str) -> ClientBuilder<'a, reqwest::Client> {
34 ClientBuilder {
35 api_key,
36 base_url: GEMINI_API_BASE_URL,
37 http_client: Default::default(),
38 }
39 }
40
41 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
42 Self {
43 api_key,
44 base_url: GEMINI_API_BASE_URL,
45 http_client,
46 }
47 }
48
49 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U>
50 where
51 U: HttpClientExt,
52 {
53 ClientBuilder {
54 api_key: self.api_key,
55 base_url: self.base_url,
56 http_client,
57 }
58 }
59
60 pub fn base_url(mut self, base_url: &'a str) -> Self {
61 self.base_url = base_url;
62 self
63 }
64
65 pub fn build(self) -> Result<Client<T>, ClientBuilderError> {
66 let mut default_headers = reqwest::header::HeaderMap::new();
67 default_headers.insert(
68 reqwest::header::CONTENT_TYPE,
69 "application/json".parse().unwrap(),
70 );
71
72 Ok(Client {
73 base_url: self.base_url.to_string(),
74 api_key: self.api_key.to_string(),
75 default_headers,
76 http_client: self.http_client,
77 })
78 }
79}
80#[derive(Clone)]
81pub struct Client<T = reqwest::Client> {
82 base_url: String,
83 api_key: String,
84 default_headers: reqwest::header::HeaderMap,
85 http_client: T,
86}
87
88impl<T> Debug for Client<T>
89where
90 T: Debug,
91{
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("Client")
94 .field("base_url", &self.base_url)
95 .field("http_client", &self.http_client)
96 .field("default_headers", &self.default_headers)
97 .field("api_key", &"<REDACTED>")
98 .finish()
99 }
100}
101
102impl<T> Client<T>
103where
104 T: HttpClientExt + Default,
105{
106 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
117 ClientBuilder::new_with_client(api_key, Default::default())
118 }
119
120 pub fn new(api_key: &str) -> Self {
125 Self::builder(api_key)
126 .build()
127 .expect("Gemini client should build")
128 }
129}
130
131impl Client<reqwest::Client> {
132 pub(crate) fn post_sse(&self, path: &str) -> reqwest::RequestBuilder {
133 let url = format!(
134 "{}/{}?alt=sse&key={}",
135 self.base_url,
136 path.trim_start_matches('/'),
137 self.api_key
138 );
139
140 tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
141
142 self.http_client
143 .post(url)
144 .headers(self.default_headers.clone())
145 }
146}
147
148impl<T> Client<T>
149where
150 T: HttpClientExt,
151{
152 pub(crate) fn post(&self, path: &str) -> http_client::Builder {
153 let url = format!(
155 "{}/{}?key={}",
156 self.base_url,
157 path.trim_start_matches('/'),
158 self.api_key
159 );
160
161 tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
162 let mut req = http_client::Request::post(url);
163
164 if let Some(hs) = req.headers_mut() {
165 *hs = self.default_headers.clone();
166 }
167
168 req
169 }
170
171 pub(crate) fn get(&self, path: &str) -> http_client::Builder {
172 let url = format!(
174 "{}/{}?key={}",
175 self.base_url,
176 path.trim_start_matches('/'),
177 self.api_key
178 );
179
180 tracing::debug!("GET {}/{}?key={}", self.base_url, path, "****");
181
182 let mut req = http_client::Request::get(url);
183
184 if let Some(hs) = req.headers_mut() {
185 *hs = self.default_headers.clone();
186 }
187
188 req
189 }
190
191 pub(crate) async fn send<U, R>(
192 &self,
193 req: http_client::Request<U>,
194 ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
195 where
196 U: Into<Bytes> + Send,
197 R: From<Bytes> + Send + 'static,
198 {
199 self.http_client.send(req).await
200 }
201}
202
203impl ProviderClient for Client<reqwest::Client> {
206 fn from_env() -> Self {
209 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
210 Self::new(&api_key)
211 }
212
213 fn from_val(input: crate::client::ProviderValue) -> Self {
214 let crate::client::ProviderValue::Simple(api_key) = input else {
215 panic!("Incorrect provider value type")
216 };
217 Self::new(&api_key)
218 }
219}
220
221impl CompletionClient for Client<reqwest::Client> {
222 type CompletionModel = CompletionModel<reqwest::Client>;
223
224 fn completion_model(&self, model: &str) -> Self::CompletionModel {
228 CompletionModel::new(self.clone(), model)
229 }
230}
231
232impl<T> EmbeddingsClient for Client<T>
233where
234 T: HttpClientExt + Clone + Debug + Default + 'static,
235 Client<T>: CompletionClient,
236{
237 type EmbeddingModel = EmbeddingModel<T>;
238
239 fn embedding_model(&self, model: &str) -> EmbeddingModel<T> {
253 EmbeddingModel::new(self.clone(), model, None)
254 }
255
256 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel<T> {
268 EmbeddingModel::new(self.clone(), model, Some(ndims))
269 }
270
271 fn embeddings<D: Embed>(
288 &self,
289 model: &str,
290 ) -> embeddings::EmbeddingsBuilder<EmbeddingModel<T>, D> {
291 embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
292 }
293}
294
295impl<T> TranscriptionClient for Client<T>
296where
297 T: HttpClientExt + Clone + Debug + Default + 'static,
298 Client<T>: CompletionClient,
299{
300 type TranscriptionModel = TranscriptionModel<T>;
301
302 fn transcription_model(&self, model: &str) -> TranscriptionModel<T> {
306 TranscriptionModel::new(self.clone(), model)
307 }
308}
309
310impl<T> VerifyClient for Client<T>
311where
312 T: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
313 Client<T>: CompletionClient,
314{
315 #[cfg_attr(feature = "worker", worker::send)]
316 async fn verify(&self) -> Result<(), VerifyError> {
317 let req = self
318 .get("/v1beta/models")
319 .body(http_client::NoBody)
320 .map_err(|e| VerifyError::HttpError(e.into()))?;
321 let response = self.http_client.send::<_, Vec<u8>>(req).await?;
322
323 match response.status() {
324 reqwest::StatusCode::OK => Ok(()),
325 reqwest::StatusCode::FORBIDDEN => Err(VerifyError::InvalidAuthentication),
326 reqwest::StatusCode::INTERNAL_SERVER_ERROR
327 | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
328 let text = http_client::text(response).await?;
329 Err(VerifyError::ProviderError(text))
330 }
331 _ => {
332 Ok(())
337 }
338 }
339 }
340}
341
342impl_conversion_traits!(
343 AsImageGeneration,
344 AsAudioGeneration for Client<T>
345);
346
347#[derive(Debug, Deserialize)]
348pub struct ApiErrorResponse {
349 pub message: String,
350}
351
352#[derive(Debug, Deserialize)]
353#[serde(untagged)]
354pub enum ApiResponse<T> {
355 Ok(T),
356 Err(ApiErrorResponse),
357}