Skip to main content

rig/providers/gemini/
client.rs

1use crate::client::{
2    self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3    ProviderClient, Transport,
4};
5use crate::http_client;
6use serde::Deserialize;
7use std::fmt::Debug;
8
9// ================================================================
10// Google Gemini Client
11// ================================================================
12const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
13
14#[derive(Debug, Default, Clone)]
15pub struct GeminiExt {
16    api_key: String,
17}
18
19#[derive(Debug, Default, Clone)]
20pub struct GeminiBuilder;
21
22pub struct GeminiApiKey(String);
23
24impl<S> From<S> for GeminiApiKey
25where
26    S: Into<String>,
27{
28    fn from(value: S) -> Self {
29        Self(value.into())
30    }
31}
32
33pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
34pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
35
36impl ApiKey for GeminiApiKey {}
37
38impl DebugExt for GeminiExt {
39    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
40        std::iter::once(("api_key", (&"******") as &dyn Debug))
41    }
42}
43
44impl Provider for GeminiExt {
45    type Builder = GeminiBuilder;
46
47    const VERIFY_PATH: &'static str = "/v1beta/models";
48
49    fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
50        match transport {
51            Transport::Sse => {
52                format!(
53                    "{}/{}?alt=sse&key={}",
54                    base_url,
55                    path.trim_start_matches('/'),
56                    self.api_key
57                )
58            }
59            _ => {
60                format!(
61                    "{}/{}?key={}",
62                    base_url,
63                    path.trim_start_matches('/'),
64                    self.api_key
65                )
66            }
67        }
68    }
69}
70
71impl<H> Capabilities<H> for GeminiExt {
72    type Completion = Capable<super::completion::CompletionModel>;
73    type Embeddings = Capable<super::embedding::EmbeddingModel>;
74    type Transcription = Capable<super::transcription::TranscriptionModel>;
75    type ModelListing = Nothing;
76
77    #[cfg(feature = "image")]
78    type ImageGeneration = Nothing;
79    #[cfg(feature = "audio")]
80    type AudioGeneration = Nothing;
81}
82
83impl ProviderBuilder for GeminiBuilder {
84    type Extension<H>
85        = GeminiExt
86    where
87        H: http_client::HttpClientExt;
88    type ApiKey = GeminiApiKey;
89
90    const BASE_URL: &'static str = GEMINI_API_BASE_URL;
91
92    fn build<H>(
93        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
94    ) -> http_client::Result<Self::Extension<H>>
95    where
96        H: http_client::HttpClientExt,
97    {
98        Ok(GeminiExt {
99            api_key: builder.get_api_key().0.clone(),
100        })
101    }
102}
103
104impl ProviderClient for Client {
105    type Input = GeminiApiKey;
106
107    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
108    /// Panics if the environment variable is not set.
109    fn from_env() -> Self {
110        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
111        Self::new(api_key).unwrap()
112    }
113
114    fn from_val(input: Self::Input) -> Self {
115        Self::new(input).unwrap()
116    }
117}
118
119#[derive(Debug, Deserialize)]
120pub struct ApiErrorResponse {
121    pub message: String,
122}
123
124#[derive(Debug, Deserialize)]
125#[serde(untagged)]
126pub enum ApiResponse<T> {
127    Ok(T),
128    Err(ApiErrorResponse),
129}
130
131// ================================================================
132// Tests
133// ================================================================
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    #[test]
139    fn test_client_initialization() {
140        let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
141        let _client_from_builder: Client = Client::builder()
142            .api_key("dummy-key")
143            .build()
144            .expect("Client::builder() failed");
145    }
146}