rig/providers/gemini/
client.rs

1use crate::{
2    agent::AgentBuilder,
3    embeddings::{self},
4    extractor::ExtractorBuilder,
5    Embed,
6};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9
10use super::{
11    completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
12};
13
14// ================================================================
15// Google Gemini Client
16// ================================================================
17const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
18
19#[derive(Clone)]
20pub struct Client {
21    base_url: String,
22    api_key: String,
23    http_client: reqwest::Client,
24}
25
26impl Client {
27    pub fn new(api_key: &str) -> Self {
28        Self::from_url(api_key, GEMINI_API_BASE_URL)
29    }
30    pub fn from_url(api_key: &str, base_url: &str) -> Self {
31        Self {
32            base_url: base_url.to_string(),
33            api_key: api_key.to_string(),
34            http_client: reqwest::Client::builder()
35                .default_headers({
36                    let mut headers = reqwest::header::HeaderMap::new();
37                    headers.insert(
38                        reqwest::header::CONTENT_TYPE,
39                        "application/json".parse().unwrap(),
40                    );
41                    headers
42                })
43                .build()
44                .expect("Gemini reqwest client should build"),
45        }
46    }
47
48    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
49    /// Panics if the environment variable is not set.
50    pub fn from_env() -> Self {
51        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
52        Self::new(&api_key)
53    }
54
55    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
56        let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
57
58        tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
59        self.http_client.post(url)
60    }
61
62    pub fn post_sse(&self, path: &str) -> reqwest::RequestBuilder {
63        let url =
64            format!("{}/{}?alt=sse&key={}", self.base_url, path, self.api_key).replace("//", "/");
65
66        tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
67        self.http_client.post(url)
68    }
69
70    /// Create an embedding model with the given name.
71    /// Note: default embedding dimension of 0 will be used if model is not known.
72    /// If this is the case, it's better to use function `embedding_model_with_ndims`
73    ///
74    /// # Example
75    /// ```
76    /// use rig::providers::gemini::{Client, self};
77    ///
78    /// // Initialize the Google Gemini client
79    /// let gemini = Client::new("your-google-gemini-api-key");
80    ///
81    /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
82    /// ```
83    pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
84        EmbeddingModel::new(self.clone(), model, None)
85    }
86
87    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
88    ///
89    /// # Example
90    /// ```
91    /// use rig::providers::gemini::{Client, self};
92    ///
93    /// // Initialize the Google Gemini client
94    /// let gemini = Client::new("your-google-gemini-api-key");
95    ///
96    /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
97    /// ```
98    pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
99        EmbeddingModel::new(self.clone(), model, Some(ndims))
100    }
101
102    /// Create an embedding builder with the given embedding model.
103    ///
104    /// # Example
105    /// ```
106    /// use rig::providers::gemini::{Client, self};
107    ///
108    /// // Initialize the Google Gemini client
109    /// let gemini = Client::new("your-google-gemini-api-key");
110    ///
111    /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
112    ///     .simple_document("doc0", "Hello, world!")
113    ///     .simple_document("doc1", "Goodbye, world!")
114    ///     .build()
115    ///     .await
116    ///     .expect("Failed to embed documents");
117    /// ```
118    pub fn embeddings<D: Embed>(
119        &self,
120        model: &str,
121    ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
122        embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
123    }
124
125    /// Create a completion model with the given name.
126    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
127    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
128    pub fn completion_model(&self, model: &str) -> CompletionModel {
129        CompletionModel::new(self.clone(), model)
130    }
131
132    /// Create a transcription model with the given name.
133    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
134    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
135    pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
136        TranscriptionModel::new(self.clone(), model)
137    }
138
139    /// Create an agent builder with the given completion model.
140    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
141    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
142    /// # Example
143    /// ```
144    /// use rig::providers::gemini::{Client, self};
145    ///
146    /// // Initialize the Google Gemini client
147    /// let gemini = Client::new("your-google-gemini-api-key");
148    ///
149    /// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO)
150    ///    .preamble("You are comedian AI with a mission to make people laugh.")
151    ///    .temperature(0.0)
152    ///    .build();
153    /// ```
154    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
155        AgentBuilder::new(self.completion_model(model))
156    }
157
158    /// Create an extractor builder with the given completion model.
159    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
160        &self,
161        model: &str,
162    ) -> ExtractorBuilder<T, CompletionModel> {
163        ExtractorBuilder::new(self.completion_model(model))
164    }
165}
166
167#[derive(Debug, Deserialize)]
168pub struct ApiErrorResponse {
169    pub message: String,
170}
171
172#[derive(Debug, Deserialize)]
173#[serde(untagged)]
174pub enum ApiResponse<T> {
175    Ok(T),
176    Err(ApiErrorResponse),
177}