Skip to main content

rig/providers/gemini/
client.rs

1use crate::client::{
2    self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
3    Transport,
4};
5use crate::http_client;
6use crate::providers::gemini::model_listing::{GeminiInteractionsModelLister, GeminiModelLister};
7use serde::Deserialize;
8use std::fmt::Debug;
9
10#[cfg(any(feature = "image", feature = "audio"))]
11use crate::client::Nothing;
12
13// ================================================================
14// Google Gemini Client
15// ================================================================
16const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
17
18/// Provider extension for the Gemini GenerateContent API.
19#[derive(Debug, Default, Clone)]
20pub struct GeminiExt {
21    api_key: String,
22}
23
24/// Builder marker for the Gemini GenerateContent client.
25#[derive(Debug, Default, Clone)]
26pub struct GeminiBuilder;
27
28/// Provider extension for the Gemini Interactions API.
29#[derive(Debug, Default, Clone)]
30pub struct GeminiInteractionsExt {
31    api_key: String,
32}
33
34/// Builder marker for the Gemini Interactions client.
35#[derive(Debug, Default, Clone)]
36pub struct GeminiInteractionsBuilder;
37
38/// Wrapper type for Gemini API keys.
39pub struct GeminiApiKey(String);
40
41impl<S> From<S> for GeminiApiKey
42where
43    S: Into<String>,
44{
45    fn from(value: S) -> Self {
46        Self(value.into())
47    }
48}
49
50/// Gemini GenerateContent client.
51pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
52/// Builder for the Gemini GenerateContent client.
53pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
54/// Gemini Interactions API client.
55pub type InteractionsClient<H = reqwest::Client> = client::Client<GeminiInteractionsExt, H>;
56
57impl ApiKey for GeminiApiKey {}
58
59impl DebugExt for GeminiExt {
60    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
61        std::iter::once(("api_key", (&"******") as &dyn Debug))
62    }
63}
64
65impl DebugExt for GeminiInteractionsExt {
66    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
67        std::iter::once(("api_key", (&"******") as &dyn Debug))
68    }
69}
70
71impl Provider for GeminiExt {
72    type Builder = GeminiBuilder;
73
74    const VERIFY_PATH: &'static str = "/v1beta/models";
75
76    fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
77        let trimmed = path.trim_start_matches('/');
78        let separator = if trimmed.contains('?') { "&" } else { "?" };
79
80        match transport {
81            Transport::Sse => format!(
82                "{base_url}/{trimmed}{separator}alt=sse&key={}",
83                self.api_key
84            ),
85            _ => format!("{base_url}/{trimmed}{separator}key={}", self.api_key),
86        }
87    }
88}
89
90impl Provider for GeminiInteractionsExt {
91    type Builder = GeminiInteractionsBuilder;
92
93    const VERIFY_PATH: &'static str = "/v1beta/models";
94
95    fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
96        let trimmed = path.trim_start_matches('/');
97        match transport {
98            Transport::Sse => {
99                if trimmed.contains('?') {
100                    format!("{}/{}&alt=sse", base_url, trimmed)
101                } else {
102                    format!("{}/{}?alt=sse", base_url, trimmed)
103                }
104            }
105            _ => format!("{}/{}", base_url, trimmed),
106        }
107    }
108
109    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
110        Ok(req.header("x-goog-api-key", self.api_key.clone()))
111    }
112}
113
114impl<H> Capabilities<H> for GeminiExt {
115    type Completion = Capable<super::completion::CompletionModel>;
116    type Embeddings = Capable<super::embedding::EmbeddingModel>;
117    type Transcription = Capable<super::transcription::TranscriptionModel>;
118    type ModelListing = Capable<GeminiModelLister<H>>;
119
120    #[cfg(feature = "image")]
121    type ImageGeneration = Nothing;
122    #[cfg(feature = "audio")]
123    type AudioGeneration = Nothing;
124}
125
126impl<H> Capabilities<H> for GeminiInteractionsExt {
127    type Completion = Capable<super::interactions_api::InteractionsCompletionModel<H>>;
128    type Embeddings = Capable<super::embedding::EmbeddingModel>;
129    type Transcription = Capable<super::transcription::TranscriptionModel>;
130    type ModelListing = Capable<GeminiInteractionsModelLister<H>>;
131
132    #[cfg(feature = "image")]
133    type ImageGeneration = Nothing;
134    #[cfg(feature = "audio")]
135    type AudioGeneration = Nothing;
136}
137
138impl ProviderBuilder for GeminiBuilder {
139    type Extension<H>
140        = GeminiExt
141    where
142        H: http_client::HttpClientExt;
143    type ApiKey = GeminiApiKey;
144
145    const BASE_URL: &'static str = GEMINI_API_BASE_URL;
146
147    fn build<H>(
148        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
149    ) -> http_client::Result<Self::Extension<H>>
150    where
151        H: http_client::HttpClientExt,
152    {
153        Ok(GeminiExt {
154            api_key: builder.get_api_key().0.clone(),
155        })
156    }
157}
158
159impl ProviderBuilder for GeminiInteractionsBuilder {
160    type Extension<H>
161        = GeminiInteractionsExt
162    where
163        H: http_client::HttpClientExt;
164    type ApiKey = GeminiApiKey;
165
166    const BASE_URL: &'static str = GEMINI_API_BASE_URL;
167
168    fn build<H>(
169        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
170    ) -> http_client::Result<Self::Extension<H>>
171    where
172        H: http_client::HttpClientExt,
173    {
174        Ok(GeminiInteractionsExt {
175            api_key: builder.get_api_key().0.clone(),
176        })
177    }
178}
179
180impl ProviderClient for Client {
181    type Input = GeminiApiKey;
182
183    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
184    /// Panics if the environment variable is not set.
185    fn from_env() -> Self {
186        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
187        Self::new(api_key).unwrap()
188    }
189
190    fn from_val(input: Self::Input) -> Self {
191        Self::new(input).unwrap()
192    }
193}
194
195impl ProviderClient for InteractionsClient {
196    type Input = GeminiApiKey;
197
198    /// Create a new Google Gemini interactions client from the `GEMINI_API_KEY` environment variable.
199    /// Panics if the environment variable is not set.
200    fn from_env() -> Self {
201        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
202        Self::new(api_key).unwrap()
203    }
204
205    fn from_val(input: Self::Input) -> Self {
206        Self::new(input).unwrap()
207    }
208}
209
210impl<H> Client<H> {
211    /// Create an Interactions API client from this GenerateContent client.
212    pub fn interactions_api(self) -> InteractionsClient<H> {
213        let api_key = self.ext().api_key.clone();
214        self.with_ext(GeminiInteractionsExt { api_key })
215    }
216}
217
218impl<H> InteractionsClient<H> {
219    /// Create a GenerateContent API client from this Interactions client.
220    pub fn generate_content_api(self) -> Client<H> {
221        let api_key = self.ext().api_key.clone();
222        self.with_ext(GeminiExt { api_key })
223    }
224}
225
226/// Error response payload returned by Gemini.
227#[derive(Debug, Deserialize)]
228pub struct ApiErrorResponse {
229    pub message: String,
230}
231
232/// Wrapper for successful or error Gemini API responses.
233#[derive(Debug, Deserialize)]
234#[serde(untagged)]
235pub enum ApiResponse<T> {
236    Ok(T),
237    Err(ApiErrorResponse),
238}
239
240// ================================================================
241// Tests
242// ================================================================
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn test_client_initialization() {
250        let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
251        let _client_from_builder: Client = Client::builder()
252            .api_key("dummy-key")
253            .build()
254            .expect("Client::builder() failed");
255    }
256}