1use super::{
2 completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
3};
4use crate::client::{
5 ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
6 VerifyClient, VerifyError, impl_conversion_traits,
7};
8use crate::http_client::{self, HttpClientExt};
9use crate::wasm_compat::*;
10use crate::{
11 Embed,
12 embeddings::{self},
13};
14use bytes::Bytes;
15use serde::Deserialize;
16use std::fmt::Debug;
17
18const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
22
23pub struct ClientBuilder<'a, T = reqwest::Client> {
24 api_key: &'a str,
25 base_url: &'a str,
26 http_client: T,
27}
28
29impl<'a, T> ClientBuilder<'a, T>
30where
31 T: HttpClientExt + Default,
32{
33 pub fn new(api_key: &'a str) -> ClientBuilder<'a, T> {
34 ClientBuilder {
35 api_key,
36 base_url: GEMINI_API_BASE_URL,
37 http_client: Default::default(),
38 }
39 }
40
41 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
42 Self {
43 api_key,
44 base_url: GEMINI_API_BASE_URL,
45 http_client,
46 }
47 }
48
49 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U>
50 where
51 U: HttpClientExt,
52 {
53 ClientBuilder {
54 api_key: self.api_key,
55 base_url: self.base_url,
56 http_client,
57 }
58 }
59
60 pub fn base_url(mut self, base_url: &'a str) -> Self {
61 self.base_url = base_url;
62 self
63 }
64
65 pub fn build(self) -> Result<Client<T>, ClientBuilderError> {
66 let mut default_headers = reqwest::header::HeaderMap::new();
67 default_headers.insert(
68 reqwest::header::CONTENT_TYPE,
69 "application/json".parse().unwrap(),
70 );
71
72 Ok(Client {
73 base_url: self.base_url.to_string(),
74 api_key: self.api_key.to_string(),
75 default_headers,
76 http_client: self.http_client,
77 })
78 }
79}
80#[derive(Clone)]
81pub struct Client<T = reqwest::Client> {
82 base_url: String,
83 api_key: String,
84 default_headers: reqwest::header::HeaderMap,
85 pub http_client: T,
86}
87
88impl<T> Debug for Client<T>
89where
90 T: Debug,
91{
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("Client")
94 .field("base_url", &self.base_url)
95 .field("http_client", &self.http_client)
96 .field("default_headers", &self.default_headers)
97 .field("api_key", &"<REDACTED>")
98 .finish()
99 }
100}
101
102impl<T> Client<T>
103where
104 T: HttpClientExt,
105{
106 pub(crate) fn post(&self, path: &str) -> http_client::Builder {
107 let url = format!(
109 "{}/{}?key={}",
110 self.base_url,
111 path.trim_start_matches('/'),
112 self.api_key
113 );
114
115 tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
116 let mut req = http_client::Request::post(url);
117
118 if let Some(hs) = req.headers_mut() {
119 *hs = self.default_headers.clone();
120 }
121
122 req
123 }
124
125 pub(crate) fn post_sse(&self, path: &str) -> http_client::Builder {
126 let url = format!(
127 "{}/{}?alt=sse&key={}",
128 self.base_url,
129 path.trim_start_matches('/'),
130 self.api_key
131 );
132
133 tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
134
135 let mut req = http_client::Request::post(url);
136
137 if let Some(hs) = req.headers_mut() {
138 *hs = self.default_headers.clone();
139 }
140
141 req
142 }
143
144 pub(crate) fn get(&self, path: &str) -> http_client::Builder {
145 let url = format!(
147 "{}/{}?key={}",
148 self.base_url,
149 path.trim_start_matches('/'),
150 self.api_key
151 );
152
153 tracing::debug!("GET {}/{}?key={}", self.base_url, path, "****");
154
155 let mut req = http_client::Request::get(url);
156
157 if let Some(hs) = req.headers_mut() {
158 *hs = self.default_headers.clone();
159 }
160
161 req
162 }
163
164 pub(crate) async fn send<U, R>(
165 &self,
166 req: http_client::Request<U>,
167 ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
168 where
169 U: Into<Bytes> + Send,
170 R: From<Bytes> + Send + 'static,
171 {
172 self.http_client.send(req).await
173 }
174}
175
176impl Client<reqwest::Client> {
177 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
178 ClientBuilder::<reqwest::Client>::new(api_key)
179 }
180
181 pub fn new(api_key: &str) -> Self {
186 ClientBuilder::<reqwest::Client>::new(api_key)
187 .build()
188 .unwrap()
189 }
190
191 pub fn from_env() -> Self {
192 <Self as ProviderClient>::from_env()
193 }
194}
195
196impl<T> ProviderClient for Client<T>
197where
198 T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
199{
200 fn from_env() -> Self {
203 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
204 ClientBuilder::<T>::new(&api_key).build().unwrap()
205 }
206
207 fn from_val(input: crate::client::ProviderValue) -> Self {
208 let crate::client::ProviderValue::Simple(api_key) = input else {
209 panic!("Incorrect provider value type")
210 };
211 ClientBuilder::<T>::new(&api_key).build().unwrap()
212 }
213}
214
215impl<T> CompletionClient for Client<T>
216where
217 T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
218{
219 type CompletionModel = CompletionModel<T>;
220
221 fn completion_model(&self, model: &str) -> Self::CompletionModel {
225 CompletionModel::new(self.clone(), model)
226 }
227}
228
229impl<T> EmbeddingsClient for Client<T>
230where
231 T: HttpClientExt + Clone + Debug + Default + 'static,
232 Client<T>: CompletionClient,
233{
234 type EmbeddingModel = EmbeddingModel<T>;
235
236 fn embedding_model(&self, model: &str) -> EmbeddingModel<T> {
250 EmbeddingModel::new(self.clone(), model, None)
251 }
252
253 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel<T> {
265 EmbeddingModel::new(self.clone(), model, Some(ndims))
266 }
267
268 fn embeddings<D: Embed>(
285 &self,
286 model: &str,
287 ) -> embeddings::EmbeddingsBuilder<EmbeddingModel<T>, D> {
288 embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
289 }
290}
291
292impl<T> TranscriptionClient for Client<T>
293where
294 T: HttpClientExt + Clone + Debug + Default + WasmCompatSend + 'static,
295 Client<T>: CompletionClient,
296{
297 type TranscriptionModel = TranscriptionModel<T>;
298
299 fn transcription_model(&self, model: &str) -> TranscriptionModel<T> {
303 TranscriptionModel::new(self.clone(), model)
304 }
305}
306
307impl<T> VerifyClient for Client<T>
308where
309 T: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
310 Client<T>: CompletionClient,
311{
312 #[cfg_attr(feature = "worker", worker::send)]
313 async fn verify(&self) -> Result<(), VerifyError> {
314 let req = self
315 .get("/v1beta/models")
316 .body(http_client::NoBody)
317 .map_err(|e| VerifyError::HttpError(e.into()))?;
318 let response = self.http_client.send::<_, Vec<u8>>(req).await?;
319
320 match response.status() {
321 reqwest::StatusCode::OK => Ok(()),
322 reqwest::StatusCode::FORBIDDEN => Err(VerifyError::InvalidAuthentication),
323 reqwest::StatusCode::INTERNAL_SERVER_ERROR
324 | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
325 let text = http_client::text(response).await?;
326 Err(VerifyError::ProviderError(text))
327 }
328 _ => {
329 Ok(())
334 }
335 }
336 }
337}
338
339impl_conversion_traits!(
340 AsImageGeneration,
341 AsAudioGeneration for Client<T>
342);
343
344#[derive(Debug, Deserialize)]
345pub struct ApiErrorResponse {
346 pub message: String,
347}
348
349#[derive(Debug, Deserialize)]
350#[serde(untagged)]
351pub enum ApiResponse<T> {
352 Ok(T),
353 Err(ApiErrorResponse),
354}