rig/providers/gemini/
client.rs

1use super::{
2    completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
3};
4use crate::client::{
5    impl_conversion_traits, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
6};
7use crate::{
8    agent::AgentBuilder,
9    embeddings::{self},
10    extractor::ExtractorBuilder,
11    Embed,
12};
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15
16// ================================================================
17// Google Gemini Client
18// ================================================================
19const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
20
21#[derive(Debug, Clone)]
22pub struct Client {
23    base_url: String,
24    api_key: String,
25    http_client: reqwest::Client,
26}
27
28impl Client {
29    pub fn new(api_key: &str) -> Self {
30        Self::from_url(api_key, GEMINI_API_BASE_URL)
31    }
32    pub fn from_url(api_key: &str, base_url: &str) -> Self {
33        Self {
34            base_url: base_url.to_string(),
35            api_key: api_key.to_string(),
36            http_client: reqwest::Client::builder()
37                .default_headers({
38                    let mut headers = reqwest::header::HeaderMap::new();
39                    headers.insert(
40                        reqwest::header::CONTENT_TYPE,
41                        "application/json".parse().unwrap(),
42                    );
43                    headers
44                })
45                .build()
46                .expect("Gemini reqwest client should build"),
47        }
48    }
49
50    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
51        let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
52
53        tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
54        self.http_client.post(url)
55    }
56
57    pub fn post_sse(&self, path: &str) -> reqwest::RequestBuilder {
58        let url =
59            format!("{}/{}?alt=sse&key={}", self.base_url, path, self.api_key).replace("//", "/");
60
61        tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
62        self.http_client.post(url)
63    }
64
65    /// Create an agent builder with the given completion model.
66    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
67    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
68    /// # Example
69    /// ```
70    /// use rig::providers::gemini::{Client, self};
71    ///
72    /// // Initialize the Google Gemini client
73    /// let gemini = Client::new("your-google-gemini-api-key");
74    ///
75    /// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO)
76    ///    .preamble("You are comedian AI with a mission to make people laugh.")
77    ///    .temperature(0.0)
78    ///    .build();
79    /// ```
80    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
81        AgentBuilder::new(self.completion_model(model))
82    }
83
84    /// Create an extractor builder with the given completion model.
85    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
86        &self,
87        model: &str,
88    ) -> ExtractorBuilder<T, CompletionModel> {
89        ExtractorBuilder::new(self.completion_model(model))
90    }
91}
92
93impl ProviderClient for Client {
94    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
95    /// Panics if the environment variable is not set.
96    fn from_env() -> Self {
97        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
98        Self::new(&api_key)
99    }
100}
101
102impl CompletionClient for Client {
103    type CompletionModel = CompletionModel;
104
105    /// Create a completion model with the given name.
106    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
107    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
108    fn completion_model(&self, model: &str) -> CompletionModel {
109        CompletionModel::new(self.clone(), model)
110    }
111}
112
113impl EmbeddingsClient for Client {
114    type EmbeddingModel = EmbeddingModel;
115
116    /// Create an embedding model with the given name.
117    /// Note: default embedding dimension of 0 will be used if model is not known.
118    /// If this is the case, it's better to use function `embedding_model_with_ndims`
119    ///
120    /// # Example
121    /// ```
122    /// use rig::providers::gemini::{Client, self};
123    ///
124    /// // Initialize the Google Gemini client
125    /// let gemini = Client::new("your-google-gemini-api-key");
126    ///
127    /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
128    /// ```
129    fn embedding_model(&self, model: &str) -> EmbeddingModel {
130        EmbeddingModel::new(self.clone(), model, None)
131    }
132
133    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
134    ///
135    /// # Example
136    /// ```
137    /// use rig::providers::gemini::{Client, self};
138    ///
139    /// // Initialize the Google Gemini client
140    /// let gemini = Client::new("your-google-gemini-api-key");
141    ///
142    /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
143    /// ```
144    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
145        EmbeddingModel::new(self.clone(), model, Some(ndims))
146    }
147
148    /// Create an embedding builder with the given embedding model.
149    ///
150    /// # Example
151    /// ```
152    /// use rig::providers::gemini::{Client, self};
153    ///
154    /// // Initialize the Google Gemini client
155    /// let gemini = Client::new("your-google-gemini-api-key");
156    ///
157    /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
158    ///     .simple_document("doc0", "Hello, world!")
159    ///     .simple_document("doc1", "Goodbye, world!")
160    ///     .build()
161    ///     .await
162    ///     .expect("Failed to embed documents");
163    /// ```
164    fn embeddings<D: Embed>(
165        &self,
166        model: &str,
167    ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
168        embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
169    }
170}
171
172impl TranscriptionClient for Client {
173    type TranscriptionModel = TranscriptionModel;
174
175    /// Create a transcription model with the given name.
176    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
177    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
178    fn transcription_model(&self, model: &str) -> TranscriptionModel {
179        TranscriptionModel::new(self.clone(), model)
180    }
181}
182
183impl_conversion_traits!(
184    AsImageGeneration,
185    AsAudioGeneration for Client
186);
187
188#[derive(Debug, Deserialize)]
189pub struct ApiErrorResponse {
190    pub message: String,
191}
192
193#[derive(Debug, Deserialize)]
194#[serde(untagged)]
195pub enum ApiResponse<T> {
196    Ok(T),
197    Err(ApiErrorResponse),
198}