rig/providers/gemini/
client.rs

1#[cfg(feature = "image")]
2use crate::client::Nothing;
3use crate::client::{
4    self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
5    Transport,
6};
7use crate::http_client;
8use serde::Deserialize;
9use std::fmt::Debug;
10
11// ================================================================
12// Google Gemini Client
13// ================================================================
14const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
15
16#[derive(Debug, Default, Clone)]
17pub struct GeminiExt {
18    api_key: String,
19}
20
21#[derive(Debug, Default, Clone)]
22pub struct GeminiBuilder;
23
24pub struct GeminiApiKey(String);
25
26impl<S> From<S> for GeminiApiKey
27where
28    S: Into<String>,
29{
30    fn from(value: S) -> Self {
31        Self(value.into())
32    }
33}
34
35pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
36pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
37
38impl ApiKey for GeminiApiKey {}
39
40impl DebugExt for GeminiExt {
41    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
42        std::iter::once(("api_key", (&"******") as &dyn Debug))
43    }
44}
45
46impl Provider for GeminiExt {
47    type Builder = GeminiBuilder;
48
49    const VERIFY_PATH: &'static str = "/v1beta/models";
50
51    fn build<H>(
52        builder: &client::ClientBuilder<Self::Builder, GeminiApiKey, H>,
53    ) -> http_client::Result<Self> {
54        Ok(Self {
55            api_key: builder.get_api_key().0.clone(),
56        })
57    }
58
59    fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
60        match transport {
61            Transport::Sse => {
62                format!(
63                    "{}/{}?alt=sse&key={}",
64                    base_url,
65                    path.trim_start_matches('/'),
66                    self.api_key
67                )
68            }
69            _ => {
70                format!(
71                    "{}/{}?key={}",
72                    base_url,
73                    path.trim_start_matches('/'),
74                    self.api_key
75                )
76            }
77        }
78    }
79}
80
81impl<H> Capabilities<H> for GeminiExt {
82    type Completion = Capable<super::completion::CompletionModel>;
83    type Embeddings = Capable<super::embedding::EmbeddingModel>;
84    type Transcription = Capable<super::transcription::TranscriptionModel>;
85
86    #[cfg(feature = "image")]
87    type ImageGeneration = Nothing;
88    #[cfg(feature = "audio")]
89    type AudioGeneration = Nothing;
90}
91
92impl ProviderBuilder for GeminiBuilder {
93    type Output = GeminiExt;
94    type ApiKey = GeminiApiKey;
95
96    const BASE_URL: &'static str = GEMINI_API_BASE_URL;
97}
98
99impl ProviderClient for Client {
100    type Input = GeminiApiKey;
101
102    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
103    /// Panics if the environment variable is not set.
104    fn from_env() -> Self {
105        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
106        Self::new(api_key).unwrap()
107    }
108
109    fn from_val(input: Self::Input) -> Self {
110        Self::new(input).unwrap()
111    }
112}
113
114#[derive(Debug, Deserialize)]
115pub struct ApiErrorResponse {
116    pub message: String,
117}
118
119#[derive(Debug, Deserialize)]
120#[serde(untagged)]
121pub enum ApiResponse<T> {
122    Ok(T),
123    Err(ApiErrorResponse),
124}