rig/providers/gemini/
client.rs1use super::{
2 completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
3};
4use crate::client::{
5 ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
6 impl_conversion_traits,
7};
8use crate::{
9 Embed,
10 agent::AgentBuilder,
11 embeddings::{self},
12 extractor::ExtractorBuilder,
13};
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16
17const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
21
22pub struct ClientBuilder<'a> {
23 api_key: &'a str,
24 base_url: &'a str,
25 http_client: Option<reqwest::Client>,
26}
27
28impl<'a> ClientBuilder<'a> {
29 pub fn new(api_key: &'a str) -> Self {
30 Self {
31 api_key,
32 base_url: GEMINI_API_BASE_URL,
33 http_client: None,
34 }
35 }
36
37 pub fn base_url(mut self, base_url: &'a str) -> Self {
38 self.base_url = base_url;
39 self
40 }
41
42 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
43 self.http_client = Some(client);
44 self
45 }
46
47 pub fn build(self) -> Result<Client, ClientBuilderError> {
48 let mut default_headers = reqwest::header::HeaderMap::new();
49 default_headers.insert(
50 reqwest::header::CONTENT_TYPE,
51 "application/json".parse().unwrap(),
52 );
53 let http_client = if let Some(http_client) = self.http_client {
54 http_client
55 } else {
56 reqwest::Client::builder().build()?
57 };
58
59 Ok(Client {
60 base_url: self.base_url.to_string(),
61 api_key: self.api_key.to_string(),
62 default_headers,
63 http_client,
64 })
65 }
66}
67#[derive(Clone)]
68pub struct Client {
69 base_url: String,
70 api_key: String,
71 default_headers: reqwest::header::HeaderMap,
72 http_client: reqwest::Client,
73}
74
75impl std::fmt::Debug for Client {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("Client")
78 .field("base_url", &self.base_url)
79 .field("http_client", &self.http_client)
80 .field("default_headers", &self.default_headers)
81 .field("api_key", &"<REDACTED>")
82 .finish()
83 }
84}
85
86impl Client {
87 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
98 ClientBuilder::new(api_key)
99 }
100
101 pub fn new(api_key: &str) -> Self {
106 Self::builder(api_key)
107 .build()
108 .expect("Gemini client should build")
109 }
110
111 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
112 let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
114
115 tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
116 self.http_client
117 .post(url)
118 .headers(self.default_headers.clone())
119 }
120
121 pub(crate) fn post_sse(&self, path: &str) -> reqwest::RequestBuilder {
122 let url =
123 format!("{}/{}?alt=sse&key={}", self.base_url, path, self.api_key).replace("//", "/");
124
125 tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
126 self.http_client
127 .post(url)
128 .headers(self.default_headers.clone())
129 }
130
131 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
147 AgentBuilder::new(self.completion_model(model))
148 }
149
150 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
152 &self,
153 model: &str,
154 ) -> ExtractorBuilder<T, CompletionModel> {
155 ExtractorBuilder::new(self.completion_model(model))
156 }
157}
158
159impl ProviderClient for Client {
160 fn from_env() -> Self {
163 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
164 Self::new(&api_key)
165 }
166
167 fn from_val(input: crate::client::ProviderValue) -> Self {
168 let crate::client::ProviderValue::Simple(api_key) = input else {
169 panic!("Incorrect provider value type")
170 };
171 Self::new(&api_key)
172 }
173}
174
175impl CompletionClient for Client {
176 type CompletionModel = CompletionModel;
177
178 fn completion_model(&self, model: &str) -> CompletionModel {
182 CompletionModel::new(self.clone(), model)
183 }
184}
185
186impl EmbeddingsClient for Client {
187 type EmbeddingModel = EmbeddingModel;
188
189 fn embedding_model(&self, model: &str) -> EmbeddingModel {
203 EmbeddingModel::new(self.clone(), model, None)
204 }
205
206 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
218 EmbeddingModel::new(self.clone(), model, Some(ndims))
219 }
220
221 fn embeddings<D: Embed>(
238 &self,
239 model: &str,
240 ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
241 embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
242 }
243}
244
245impl TranscriptionClient for Client {
246 type TranscriptionModel = TranscriptionModel;
247
248 fn transcription_model(&self, model: &str) -> TranscriptionModel {
252 TranscriptionModel::new(self.clone(), model)
253 }
254}
255
256impl_conversion_traits!(
257 AsImageGeneration,
258 AsAudioGeneration for Client
259);
260
261#[derive(Debug, Deserialize)]
262pub struct ApiErrorResponse {
263 pub message: String,
264}
265
266#[derive(Debug, Deserialize)]
267#[serde(untagged)]
268pub enum ApiResponse<T> {
269 Ok(T),
270 Err(ApiErrorResponse),
271}