Skip to main content

rig_core/providers/gemini/
client.rs

1use crate::client::{
2    self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
3    Transport,
4};
5use crate::http_client::{self};
6use crate::providers::gemini::model_listing::{GeminiInteractionsModelLister, GeminiModelLister};
7use serde::Deserialize;
8use std::fmt::Debug;
9
10#[cfg(any(feature = "image", feature = "audio"))]
11use crate::client::Nothing;
12
13// ================================================================
14// Google Gemini Client
15// ================================================================
16const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
17
18/// Provider extension for the Gemini GenerateContent API.
19#[derive(Debug, Default, Clone)]
20pub struct GeminiExt {
21    api_key: String,
22}
23
24/// Builder marker for the Gemini GenerateContent client.
25#[derive(Debug, Default, Clone)]
26pub struct GeminiBuilder;
27
28/// Provider extension for the Gemini Interactions API.
29#[derive(Debug, Default, Clone)]
30pub struct GeminiInteractionsExt {
31    api_key: String,
32}
33
34/// Builder marker for the Gemini Interactions client.
35#[derive(Debug, Default, Clone)]
36pub struct GeminiInteractionsBuilder;
37
38/// Wrapper type for Gemini API keys.
39pub struct GeminiApiKey(String);
40
41impl<S> From<S> for GeminiApiKey
42where
43    S: Into<String>,
44{
45    fn from(value: S) -> Self {
46        Self(value.into())
47    }
48}
49
50/// Gemini GenerateContent client.
51pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
52/// Builder for the Gemini GenerateContent client.
53pub type ClientBuilder<H = crate::markers::Missing> =
54    client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
55/// Gemini Interactions API client.
56pub type InteractionsClient<H = reqwest::Client> = client::Client<GeminiInteractionsExt, H>;
57
58impl ApiKey for GeminiApiKey {}
59
60impl DebugExt for GeminiExt {
61    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
62        std::iter::once(("api_key", (&"******") as &dyn Debug))
63    }
64}
65
66impl DebugExt for GeminiInteractionsExt {
67    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
68        std::iter::once(("api_key", (&"******") as &dyn Debug))
69    }
70}
71
72impl Provider for GeminiExt {
73    type Builder = GeminiBuilder;
74
75    const VERIFY_PATH: &'static str = "/v1beta/models";
76
77    fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
78        let trimmed = path.trim_start_matches('/');
79        let separator = if trimmed.contains('?') { "&" } else { "?" };
80
81        match transport {
82            Transport::Sse => format!(
83                "{base_url}/{trimmed}{separator}alt=sse&key={}",
84                self.api_key
85            ),
86            _ => format!("{base_url}/{trimmed}{separator}key={}", self.api_key),
87        }
88    }
89}
90
91impl Provider for GeminiInteractionsExt {
92    type Builder = GeminiInteractionsBuilder;
93
94    const VERIFY_PATH: &'static str = "/v1beta/models";
95
96    fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
97        let trimmed = path.trim_start_matches('/');
98        match transport {
99            Transport::Sse => {
100                if trimmed.contains('?') {
101                    format!("{}/{}&alt=sse", base_url, trimmed)
102                } else {
103                    format!("{}/{}?alt=sse", base_url, trimmed)
104                }
105            }
106            _ => format!("{}/{}", base_url, trimmed),
107        }
108    }
109
110    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
111        Ok(req.header("x-goog-api-key", self.api_key.clone()))
112    }
113}
114
115impl<H> Capabilities<H> for GeminiExt {
116    type Completion = Capable<super::completion::CompletionModel<H>>;
117    type Embeddings = Capable<super::embedding::EmbeddingModel<H>>;
118    type Transcription = Capable<super::transcription::TranscriptionModel<H>>;
119    type ModelListing = Capable<GeminiModelLister<H>>;
120
121    #[cfg(feature = "image")]
122    type ImageGeneration = Nothing;
123    #[cfg(feature = "audio")]
124    type AudioGeneration = Nothing;
125}
126
127impl<H> Capabilities<H> for GeminiInteractionsExt {
128    type Completion = Capable<super::interactions_api::InteractionsCompletionModel<H>>;
129    type Embeddings = Capable<super::embedding::EmbeddingModel<H>>;
130    type Transcription = Capable<super::transcription::TranscriptionModel<H>>;
131    type ModelListing = Capable<GeminiInteractionsModelLister<H>>;
132
133    #[cfg(feature = "image")]
134    type ImageGeneration = Nothing;
135    #[cfg(feature = "audio")]
136    type AudioGeneration = Nothing;
137}
138
139impl ProviderBuilder for GeminiBuilder {
140    type Extension<H>
141        = GeminiExt
142    where
143        H: http_client::HttpClientExt;
144    type ApiKey = GeminiApiKey;
145
146    const BASE_URL: &'static str = GEMINI_API_BASE_URL;
147
148    fn build<H>(
149        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
150    ) -> http_client::Result<Self::Extension<H>>
151    where
152        H: http_client::HttpClientExt,
153    {
154        Ok(GeminiExt {
155            api_key: builder.get_api_key().0.clone(),
156        })
157    }
158}
159
160impl ProviderBuilder for GeminiInteractionsBuilder {
161    type Extension<H>
162        = GeminiInteractionsExt
163    where
164        H: http_client::HttpClientExt;
165    type ApiKey = GeminiApiKey;
166
167    const BASE_URL: &'static str = GEMINI_API_BASE_URL;
168
169    fn build<H>(
170        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
171    ) -> http_client::Result<Self::Extension<H>>
172    where
173        H: http_client::HttpClientExt,
174    {
175        Ok(GeminiInteractionsExt {
176            api_key: builder.get_api_key().0.clone(),
177        })
178    }
179}
180
181impl ProviderClient for Client {
182    type Input = GeminiApiKey;
183    type Error = crate::client::ProviderClientError;
184
185    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
186    fn from_env() -> Result<Self, Self::Error> {
187        let api_key = crate::client::required_env_var("GEMINI_API_KEY")?;
188        Self::new(api_key).map_err(Into::into)
189    }
190
191    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
192        Self::new(input).map_err(Into::into)
193    }
194}
195
196impl ProviderClient for InteractionsClient {
197    type Input = GeminiApiKey;
198    type Error = crate::client::ProviderClientError;
199
200    /// Create a new Google Gemini interactions client from the `GEMINI_API_KEY` environment variable.
201    fn from_env() -> Result<Self, Self::Error> {
202        let api_key = crate::client::required_env_var("GEMINI_API_KEY")?;
203        Self::new(api_key).map_err(Into::into)
204    }
205
206    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
207        Self::new(input).map_err(Into::into)
208    }
209}
210
211impl<H> Client<H> {
212    /// Create an Interactions API client from this GenerateContent client.
213    pub fn interactions_api(self) -> InteractionsClient<H> {
214        let api_key = self.ext().api_key.clone();
215        self.with_ext(GeminiInteractionsExt { api_key })
216    }
217}
218
219impl<H> InteractionsClient<H> {
220    /// Create a GenerateContent API client from this Interactions client.
221    pub fn generate_content_api(self) -> Client<H> {
222        let api_key = self.ext().api_key.clone();
223        self.with_ext(GeminiExt { api_key })
224    }
225}
226
227/// Error response payload returned by Gemini.
228#[derive(Debug, Deserialize)]
229pub struct ApiErrorResponse {
230    pub message: String,
231}
232
233/// Wrapper for successful or error Gemini API responses.
234#[derive(Debug, Deserialize)]
235#[serde(untagged)]
236pub enum ApiResponse<T> {
237    Ok(T),
238    Err(ApiErrorResponse),
239}
240
241// ================================================================
242// Tests
243// ================================================================
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_client_initialization() {
251        let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
252        let _client_from_builder: Client = Client::builder()
253            .api_key("dummy-key")
254            .build()
255            .expect("Client::builder() failed");
256    }
257}