Skip to main content

rig_core/providers/gemini/
client.rs

1use crate::client::{
2    self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3    ProviderClient, Transport,
4};
5use crate::http_client::{self};
6use crate::providers::gemini::model_listing::{GeminiInteractionsModelLister, GeminiModelLister};
7use serde::Deserialize;
8use std::fmt::Debug;
9
10// ================================================================
11// Google Gemini Client
12// ================================================================
13const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
14
15/// Provider extension for the Gemini GenerateContent API.
16#[derive(Debug, Default, Clone)]
17pub struct GeminiExt {
18    api_key: String,
19}
20
21/// Builder marker for the Gemini GenerateContent client.
22#[derive(Debug, Default, Clone)]
23pub struct GeminiBuilder;
24
25/// Provider extension for the Gemini Interactions API.
26#[derive(Debug, Default, Clone)]
27pub struct GeminiInteractionsExt {
28    api_key: String,
29}
30
31/// Builder marker for the Gemini Interactions client.
32#[derive(Debug, Default, Clone)]
33pub struct GeminiInteractionsBuilder;
34
35/// Wrapper type for Gemini API keys.
36pub struct GeminiApiKey(String);
37
38impl<S> From<S> for GeminiApiKey
39where
40    S: Into<String>,
41{
42    fn from(value: S) -> Self {
43        Self(value.into())
44    }
45}
46
47/// Gemini GenerateContent client.
48pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
49/// Builder for the Gemini GenerateContent client.
50pub type ClientBuilder<H = crate::markers::Missing> =
51    client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
52/// Gemini Interactions API client.
53pub type InteractionsClient<H = reqwest::Client> = client::Client<GeminiInteractionsExt, H>;
54
55impl ApiKey for GeminiApiKey {}
56
57impl DebugExt for GeminiExt {
58    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
59        std::iter::once(("api_key", (&"******") as &dyn Debug))
60    }
61}
62
63impl DebugExt for GeminiInteractionsExt {
64    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
65        std::iter::once(("api_key", (&"******") as &dyn Debug))
66    }
67}
68
69impl Provider for GeminiExt {
70    type Builder = GeminiBuilder;
71
72    const VERIFY_PATH: &'static str = "/v1beta/models";
73
74    fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
75        let trimmed = path.trim_start_matches('/');
76        let separator = if trimmed.contains('?') { "&" } else { "?" };
77
78        match transport {
79            Transport::Sse => format!(
80                "{base_url}/{trimmed}{separator}alt=sse&key={}",
81                self.api_key
82            ),
83            _ => format!("{base_url}/{trimmed}{separator}key={}", self.api_key),
84        }
85    }
86}
87
88impl Provider for GeminiInteractionsExt {
89    type Builder = GeminiInteractionsBuilder;
90
91    const VERIFY_PATH: &'static str = "/v1beta/models";
92
93    fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
94        let trimmed = path.trim_start_matches('/');
95        match transport {
96            Transport::Sse => {
97                if trimmed.contains('?') {
98                    format!("{}/{}&alt=sse", base_url, trimmed)
99                } else {
100                    format!("{}/{}?alt=sse", base_url, trimmed)
101                }
102            }
103            _ => format!("{}/{}", base_url, trimmed),
104        }
105    }
106
107    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
108        Ok(req.header("x-goog-api-key", self.api_key.clone()))
109    }
110}
111
112impl<H> Capabilities<H> for GeminiExt {
113    type Completion = Capable<super::completion::CompletionModel<H>>;
114    type Embeddings = Capable<super::embedding::EmbeddingModel<H>>;
115    type Transcription = Capable<super::transcription::TranscriptionModel<H>>;
116    type ModelListing = Capable<GeminiModelLister<H>>;
117
118    #[cfg(feature = "image")]
119    type ImageGeneration = Capable<super::image_generation::ImageGenerationModel<H>>;
120    #[cfg(feature = "audio")]
121    type AudioGeneration = Nothing;
122    type Rerank = Nothing;
123}
124
125impl<H> Capabilities<H> for GeminiInteractionsExt {
126    type Completion = Capable<super::interactions_api::InteractionsCompletionModel<H>>;
127    type Embeddings = Capable<super::embedding::EmbeddingModel<H>>;
128    type Transcription = Capable<super::transcription::TranscriptionModel<H>>;
129    type ModelListing = Capable<GeminiInteractionsModelLister<H>>;
130
131    #[cfg(feature = "image")]
132    type ImageGeneration = Nothing;
133    #[cfg(feature = "audio")]
134    type AudioGeneration = Nothing;
135    type Rerank = Nothing;
136}
137
138impl ProviderBuilder for GeminiBuilder {
139    type Extension<H>
140        = GeminiExt
141    where
142        H: http_client::HttpClientExt;
143    type ApiKey = GeminiApiKey;
144
145    const BASE_URL: &'static str = GEMINI_API_BASE_URL;
146
147    fn build<H>(
148        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
149    ) -> http_client::Result<Self::Extension<H>>
150    where
151        H: http_client::HttpClientExt,
152    {
153        Ok(GeminiExt {
154            api_key: builder.get_api_key().0.clone(),
155        })
156    }
157}
158
159impl ProviderBuilder for GeminiInteractionsBuilder {
160    type Extension<H>
161        = GeminiInteractionsExt
162    where
163        H: http_client::HttpClientExt;
164    type ApiKey = GeminiApiKey;
165
166    const BASE_URL: &'static str = GEMINI_API_BASE_URL;
167
168    fn build<H>(
169        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
170    ) -> http_client::Result<Self::Extension<H>>
171    where
172        H: http_client::HttpClientExt,
173    {
174        Ok(GeminiInteractionsExt {
175            api_key: builder.get_api_key().0.clone(),
176        })
177    }
178}
179
180impl ProviderClient for Client {
181    type Input = GeminiApiKey;
182    type Error = crate::client::ProviderClientError;
183
184    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
185    fn from_env() -> Result<Self, Self::Error> {
186        let api_key = crate::client::required_env_var("GEMINI_API_KEY")?;
187        Self::new(api_key).map_err(Into::into)
188    }
189
190    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
191        Self::new(input).map_err(Into::into)
192    }
193}
194
195impl ProviderClient for InteractionsClient {
196    type Input = GeminiApiKey;
197    type Error = crate::client::ProviderClientError;
198
199    /// Create a new Google Gemini interactions client from the `GEMINI_API_KEY` environment variable.
200    fn from_env() -> Result<Self, Self::Error> {
201        let api_key = crate::client::required_env_var("GEMINI_API_KEY")?;
202        Self::new(api_key).map_err(Into::into)
203    }
204
205    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
206        Self::new(input).map_err(Into::into)
207    }
208}
209
210impl<H> Client<H> {
211    /// Create an Interactions API client from this GenerateContent client.
212    pub fn interactions_api(self) -> InteractionsClient<H> {
213        let api_key = self.ext().api_key.clone();
214        self.with_ext(GeminiInteractionsExt { api_key })
215    }
216}
217
218impl<H> InteractionsClient<H> {
219    /// Create a GenerateContent API client from this Interactions client.
220    pub fn generate_content_api(self) -> Client<H> {
221        let api_key = self.ext().api_key.clone();
222        self.with_ext(GeminiExt { api_key })
223    }
224}
225
226/// Error response payload returned by Gemini.
227#[derive(Debug, Deserialize)]
228pub struct ApiErrorResponse {
229    pub message: String,
230}
231
232/// Wrapper for successful or error Gemini API responses.
233#[derive(Debug, Deserialize)]
234#[serde(untagged)]
235pub enum ApiResponse<T> {
236    Ok(T),
237    Err(ApiErrorResponse),
238}
239
240// ================================================================
241// Tests
242// ================================================================
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn test_client_initialization() {
250        let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
251        let _client_from_builder: Client = Client::builder()
252            .api_key("dummy-key")
253            .build()
254            .expect("Client::builder() failed");
255    }
256}