Skip to main content

rig/providers/gemini/
client.rs

1use crate::client::{
2    self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3    ProviderClient, Transport,
4};
5use crate::http_client;
6use serde::Deserialize;
7use std::fmt::Debug;
8
9// ================================================================
10// Google Gemini Client
11// ================================================================
12const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
13
14/// Provider extension for the Gemini GenerateContent API.
15#[derive(Debug, Default, Clone)]
16pub struct GeminiExt {
17    api_key: String,
18}
19
20/// Builder marker for the Gemini GenerateContent client.
21#[derive(Debug, Default, Clone)]
22pub struct GeminiBuilder;
23
24/// Provider extension for the Gemini Interactions API.
25#[derive(Debug, Default, Clone)]
26pub struct GeminiInteractionsExt {
27    api_key: String,
28}
29
30/// Builder marker for the Gemini Interactions client.
31#[derive(Debug, Default, Clone)]
32pub struct GeminiInteractionsBuilder;
33
34/// Wrapper type for Gemini API keys.
35pub struct GeminiApiKey(String);
36
37impl<S> From<S> for GeminiApiKey
38where
39    S: Into<String>,
40{
41    fn from(value: S) -> Self {
42        Self(value.into())
43    }
44}
45
46/// Gemini GenerateContent client.
47pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
48/// Builder for the Gemini GenerateContent client.
49pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
50/// Gemini Interactions API client.
51pub type InteractionsClient<H = reqwest::Client> = client::Client<GeminiInteractionsExt, H>;
52
53impl ApiKey for GeminiApiKey {}
54
55impl DebugExt for GeminiExt {
56    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
57        std::iter::once(("api_key", (&"******") as &dyn Debug))
58    }
59}
60
61impl DebugExt for GeminiInteractionsExt {
62    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
63        std::iter::once(("api_key", (&"******") as &dyn Debug))
64    }
65}
66
67impl Provider for GeminiExt {
68    type Builder = GeminiBuilder;
69
70    const VERIFY_PATH: &'static str = "/v1beta/models";
71
72    fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
73        match transport {
74            Transport::Sse => {
75                format!(
76                    "{}/{}?alt=sse&key={}",
77                    base_url,
78                    path.trim_start_matches('/'),
79                    self.api_key
80                )
81            }
82            _ => {
83                format!(
84                    "{}/{}?key={}",
85                    base_url,
86                    path.trim_start_matches('/'),
87                    self.api_key
88                )
89            }
90        }
91    }
92}
93
94impl Provider for GeminiInteractionsExt {
95    type Builder = GeminiInteractionsBuilder;
96
97    const VERIFY_PATH: &'static str = "/v1beta/models";
98
99    fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
100        let trimmed = path.trim_start_matches('/');
101        match transport {
102            Transport::Sse => {
103                if trimmed.contains('?') {
104                    format!("{}/{}&alt=sse", base_url, trimmed)
105                } else {
106                    format!("{}/{}?alt=sse", base_url, trimmed)
107                }
108            }
109            _ => format!("{}/{}", base_url, trimmed),
110        }
111    }
112
113    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
114        Ok(req.header("x-goog-api-key", self.api_key.clone()))
115    }
116}
117
118impl<H> Capabilities<H> for GeminiExt {
119    type Completion = Capable<super::completion::CompletionModel>;
120    type Embeddings = Capable<super::embedding::EmbeddingModel>;
121    type Transcription = Capable<super::transcription::TranscriptionModel>;
122    type ModelListing = Nothing;
123
124    #[cfg(feature = "image")]
125    type ImageGeneration = Nothing;
126    #[cfg(feature = "audio")]
127    type AudioGeneration = Nothing;
128}
129
130impl<H> Capabilities<H> for GeminiInteractionsExt {
131    type Completion = Capable<super::interactions_api::InteractionsCompletionModel<H>>;
132    type Embeddings = Capable<super::embedding::EmbeddingModel>;
133    type Transcription = Capable<super::transcription::TranscriptionModel>;
134    type ModelListing = Nothing;
135
136    #[cfg(feature = "image")]
137    type ImageGeneration = Nothing;
138    #[cfg(feature = "audio")]
139    type AudioGeneration = Nothing;
140}
141
142impl ProviderBuilder for GeminiBuilder {
143    type Extension<H>
144        = GeminiExt
145    where
146        H: http_client::HttpClientExt;
147    type ApiKey = GeminiApiKey;
148
149    const BASE_URL: &'static str = GEMINI_API_BASE_URL;
150
151    fn build<H>(
152        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
153    ) -> http_client::Result<Self::Extension<H>>
154    where
155        H: http_client::HttpClientExt,
156    {
157        Ok(GeminiExt {
158            api_key: builder.get_api_key().0.clone(),
159        })
160    }
161}
162
163impl ProviderBuilder for GeminiInteractionsBuilder {
164    type Extension<H>
165        = GeminiInteractionsExt
166    where
167        H: http_client::HttpClientExt;
168    type ApiKey = GeminiApiKey;
169
170    const BASE_URL: &'static str = GEMINI_API_BASE_URL;
171
172    fn build<H>(
173        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
174    ) -> http_client::Result<Self::Extension<H>>
175    where
176        H: http_client::HttpClientExt,
177    {
178        Ok(GeminiInteractionsExt {
179            api_key: builder.get_api_key().0.clone(),
180        })
181    }
182}
183
184impl ProviderClient for Client {
185    type Input = GeminiApiKey;
186
187    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
188    /// Panics if the environment variable is not set.
189    fn from_env() -> Self {
190        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
191        Self::new(api_key).unwrap()
192    }
193
194    fn from_val(input: Self::Input) -> Self {
195        Self::new(input).unwrap()
196    }
197}
198
199impl ProviderClient for InteractionsClient {
200    type Input = GeminiApiKey;
201
202    /// Create a new Google Gemini interactions client from the `GEMINI_API_KEY` environment variable.
203    /// Panics if the environment variable is not set.
204    fn from_env() -> Self {
205        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
206        Self::new(api_key).unwrap()
207    }
208
209    fn from_val(input: Self::Input) -> Self {
210        Self::new(input).unwrap()
211    }
212}
213
214impl<H> Client<H> {
215    /// Create an Interactions API client from this GenerateContent client.
216    pub fn interactions_api(self) -> InteractionsClient<H> {
217        let api_key = self.ext().api_key.clone();
218        self.with_ext(GeminiInteractionsExt { api_key })
219    }
220}
221
222impl<H> InteractionsClient<H> {
223    /// Create a GenerateContent API client from this Interactions client.
224    pub fn generate_content_api(self) -> Client<H> {
225        let api_key = self.ext().api_key.clone();
226        self.with_ext(GeminiExt { api_key })
227    }
228}
229
230/// Error response payload returned by Gemini.
231#[derive(Debug, Deserialize)]
232pub struct ApiErrorResponse {
233    pub message: String,
234}
235
236/// Wrapper for successful or error Gemini API responses.
237#[derive(Debug, Deserialize)]
238#[serde(untagged)]
239pub enum ApiResponse<T> {
240    Ok(T),
241    Err(ApiErrorResponse),
242}
243
244// ================================================================
245// Tests
246// ================================================================
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    #[test]
252    fn test_client_initialization() {
253        let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
254        let _client_from_builder: Client = Client::builder()
255            .api_key("dummy-key")
256            .build()
257            .expect("Client::builder() failed");
258    }
259}