rig/providers/gemini/client.rs
1use crate::{
2 agent::AgentBuilder,
3 embeddings::{self},
4 extractor::ExtractorBuilder,
5 Embed,
6};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9
10use super::{
11 completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
12};
13
14// ================================================================
15// Google Gemini Client
16// ================================================================
17const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
18
19#[derive(Clone)]
20pub struct Client {
21 base_url: String,
22 api_key: String,
23 http_client: reqwest::Client,
24}
25
26impl Client {
27 pub fn new(api_key: &str) -> Self {
28 Self::from_url(api_key, GEMINI_API_BASE_URL)
29 }
30 pub fn from_url(api_key: &str, base_url: &str) -> Self {
31 Self {
32 base_url: base_url.to_string(),
33 api_key: api_key.to_string(),
34 http_client: reqwest::Client::builder()
35 .default_headers({
36 let mut headers = reqwest::header::HeaderMap::new();
37 headers.insert(
38 reqwest::header::CONTENT_TYPE,
39 "application/json".parse().unwrap(),
40 );
41 headers
42 })
43 .build()
44 .expect("Gemini reqwest client should build"),
45 }
46 }
47
48 /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
49 /// Panics if the environment variable is not set.
50 pub fn from_env() -> Self {
51 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
52 Self::new(&api_key)
53 }
54
55 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
56 let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
57
58 tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
59 self.http_client.post(url)
60 }
61
62 pub fn post_sse(&self, path: &str) -> reqwest::RequestBuilder {
63 let url =
64 format!("{}/{}?alt=sse&key={}", self.base_url, path, self.api_key).replace("//", "/");
65
66 tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
67 self.http_client.post(url)
68 }
69
70 /// Create an embedding model with the given name.
71 /// Note: default embedding dimension of 0 will be used if model is not known.
72 /// If this is the case, it's better to use function `embedding_model_with_ndims`
73 ///
74 /// # Example
75 /// ```
76 /// use rig::providers::gemini::{Client, self};
77 ///
78 /// // Initialize the Google Gemini client
79 /// let gemini = Client::new("your-google-gemini-api-key");
80 ///
81 /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
82 /// ```
83 pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
84 EmbeddingModel::new(self.clone(), model, None)
85 }
86
87 /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
88 ///
89 /// # Example
90 /// ```
91 /// use rig::providers::gemini::{Client, self};
92 ///
93 /// // Initialize the Google Gemini client
94 /// let gemini = Client::new("your-google-gemini-api-key");
95 ///
96 /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
97 /// ```
98 pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
99 EmbeddingModel::new(self.clone(), model, Some(ndims))
100 }
101
102 /// Create an embedding builder with the given embedding model.
103 ///
104 /// # Example
105 /// ```
106 /// use rig::providers::gemini::{Client, self};
107 ///
108 /// // Initialize the Google Gemini client
109 /// let gemini = Client::new("your-google-gemini-api-key");
110 ///
111 /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
112 /// .simple_document("doc0", "Hello, world!")
113 /// .simple_document("doc1", "Goodbye, world!")
114 /// .build()
115 /// .await
116 /// .expect("Failed to embed documents");
117 /// ```
118 pub fn embeddings<D: Embed>(
119 &self,
120 model: &str,
121 ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
122 embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
123 }
124
125 /// Create a completion model with the given name.
126 /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
127 /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
128 pub fn completion_model(&self, model: &str) -> CompletionModel {
129 CompletionModel::new(self.clone(), model)
130 }
131
132 /// Create a transcription model with the given name.
133 /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
134 /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
135 pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
136 TranscriptionModel::new(self.clone(), model)
137 }
138
139 /// Create an agent builder with the given completion model.
140 /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
141 /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
142 /// # Example
143 /// ```
144 /// use rig::providers::gemini::{Client, self};
145 ///
146 /// // Initialize the Google Gemini client
147 /// let gemini = Client::new("your-google-gemini-api-key");
148 ///
149 /// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO)
150 /// .preamble("You are comedian AI with a mission to make people laugh.")
151 /// .temperature(0.0)
152 /// .build();
153 /// ```
154 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
155 AgentBuilder::new(self.completion_model(model))
156 }
157
158 /// Create an extractor builder with the given completion model.
159 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
160 &self,
161 model: &str,
162 ) -> ExtractorBuilder<T, CompletionModel> {
163 ExtractorBuilder::new(self.completion_model(model))
164 }
165}
166
167#[derive(Debug, Deserialize)]
168pub struct ApiErrorResponse {
169 pub message: String,
170}
171
172#[derive(Debug, Deserialize)]
173#[serde(untagged)]
174pub enum ApiResponse<T> {
175 Ok(T),
176 Err(ApiErrorResponse),
177}