Skip to main content

rig/providers/gemini/
client.rs

1use crate::client::{
2    self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3    ProviderClient, Transport,
4};
5use crate::http_client;
6use serde::Deserialize;
7use std::fmt::Debug;
8
9// ================================================================
10// Google Gemini Client
11// ================================================================
12const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
13
14#[derive(Debug, Default, Clone)]
15pub struct GeminiExt {
16    api_key: String,
17}
18
19#[derive(Debug, Default, Clone)]
20pub struct GeminiBuilder;
21
22pub struct GeminiApiKey(String);
23
24impl<S> From<S> for GeminiApiKey
25where
26    S: Into<String>,
27{
28    fn from(value: S) -> Self {
29        Self(value.into())
30    }
31}
32
33pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
34pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
35
36impl ApiKey for GeminiApiKey {}
37
38impl DebugExt for GeminiExt {
39    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
40        std::iter::once(("api_key", (&"******") as &dyn Debug))
41    }
42}
43
44impl Provider for GeminiExt {
45    type Builder = GeminiBuilder;
46
47    const VERIFY_PATH: &'static str = "/v1beta/models";
48
49    fn build<H>(
50        builder: &client::ClientBuilder<Self::Builder, GeminiApiKey, H>,
51    ) -> http_client::Result<Self> {
52        Ok(Self {
53            api_key: builder.get_api_key().0.clone(),
54        })
55    }
56
57    fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
58        match transport {
59            Transport::Sse => {
60                format!(
61                    "{}/{}?alt=sse&key={}",
62                    base_url,
63                    path.trim_start_matches('/'),
64                    self.api_key
65                )
66            }
67            _ => {
68                format!(
69                    "{}/{}?key={}",
70                    base_url,
71                    path.trim_start_matches('/'),
72                    self.api_key
73                )
74            }
75        }
76    }
77}
78
79impl<H> Capabilities<H> for GeminiExt {
80    type Completion = Capable<super::completion::CompletionModel>;
81    type Embeddings = Capable<super::embedding::EmbeddingModel>;
82    type Transcription = Capable<super::transcription::TranscriptionModel>;
83    type ModelListing = Nothing;
84
85    #[cfg(feature = "image")]
86    type ImageGeneration = Nothing;
87    #[cfg(feature = "audio")]
88    type AudioGeneration = Nothing;
89}
90
91impl ProviderBuilder for GeminiBuilder {
92    type Output = GeminiExt;
93    type ApiKey = GeminiApiKey;
94
95    const BASE_URL: &'static str = GEMINI_API_BASE_URL;
96}
97
98impl ProviderClient for Client {
99    type Input = GeminiApiKey;
100
101    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
102    /// Panics if the environment variable is not set.
103    fn from_env() -> Self {
104        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
105        Self::new(api_key).unwrap()
106    }
107
108    fn from_val(input: Self::Input) -> Self {
109        Self::new(input).unwrap()
110    }
111}
112
113#[derive(Debug, Deserialize)]
114pub struct ApiErrorResponse {
115    pub message: String,
116}
117
118#[derive(Debug, Deserialize)]
119#[serde(untagged)]
120pub enum ApiResponse<T> {
121    Ok(T),
122    Err(ApiErrorResponse),
123}
124
125// ================================================================
126// Tests
127// ================================================================
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    #[test]
133    fn test_client_initialization() {
134        let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
135        let _client_from_builder: Client = Client::builder()
136            .api_key("dummy-key")
137            .build()
138            .expect("Client::builder() failed");
139    }
140}