rig/providers/gemini/
client.rs1use crate::client::{
2 self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3 ProviderClient, Transport,
4};
5use crate::http_client;
6use serde::Deserialize;
7use std::fmt::Debug;
8
9const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
13
14#[derive(Debug, Default, Clone)]
15pub struct GeminiExt {
16 api_key: String,
17}
18
19#[derive(Debug, Default, Clone)]
20pub struct GeminiBuilder;
21
22pub struct GeminiApiKey(String);
23
24impl<S> From<S> for GeminiApiKey
25where
26 S: Into<String>,
27{
28 fn from(value: S) -> Self {
29 Self(value.into())
30 }
31}
32
33pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
34pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
35
36impl ApiKey for GeminiApiKey {}
37
38impl DebugExt for GeminiExt {
39 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
40 std::iter::once(("api_key", (&"******") as &dyn Debug))
41 }
42}
43
44impl Provider for GeminiExt {
45 type Builder = GeminiBuilder;
46
47 const VERIFY_PATH: &'static str = "/v1beta/models";
48
49 fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
50 match transport {
51 Transport::Sse => {
52 format!(
53 "{}/{}?alt=sse&key={}",
54 base_url,
55 path.trim_start_matches('/'),
56 self.api_key
57 )
58 }
59 _ => {
60 format!(
61 "{}/{}?key={}",
62 base_url,
63 path.trim_start_matches('/'),
64 self.api_key
65 )
66 }
67 }
68 }
69}
70
71impl<H> Capabilities<H> for GeminiExt {
72 type Completion = Capable<super::completion::CompletionModel>;
73 type Embeddings = Capable<super::embedding::EmbeddingModel>;
74 type Transcription = Capable<super::transcription::TranscriptionModel>;
75 type ModelListing = Nothing;
76
77 #[cfg(feature = "image")]
78 type ImageGeneration = Nothing;
79 #[cfg(feature = "audio")]
80 type AudioGeneration = Nothing;
81}
82
83impl ProviderBuilder for GeminiBuilder {
84 type Extension<H>
85 = GeminiExt
86 where
87 H: http_client::HttpClientExt;
88 type ApiKey = GeminiApiKey;
89
90 const BASE_URL: &'static str = GEMINI_API_BASE_URL;
91
92 fn build<H>(
93 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
94 ) -> http_client::Result<Self::Extension<H>>
95 where
96 H: http_client::HttpClientExt,
97 {
98 Ok(GeminiExt {
99 api_key: builder.get_api_key().0.clone(),
100 })
101 }
102}
103
104impl ProviderClient for Client {
105 type Input = GeminiApiKey;
106
107 fn from_env() -> Self {
110 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
111 Self::new(api_key).unwrap()
112 }
113
114 fn from_val(input: Self::Input) -> Self {
115 Self::new(input).unwrap()
116 }
117}
118
119#[derive(Debug, Deserialize)]
120pub struct ApiErrorResponse {
121 pub message: String,
122}
123
124#[derive(Debug, Deserialize)]
125#[serde(untagged)]
126pub enum ApiResponse<T> {
127 Ok(T),
128 Err(ApiErrorResponse),
129}
130
131#[cfg(test)]
136mod tests {
137 use super::*;
138 #[test]
139 fn test_client_initialization() {
140 let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
141 let _client_from_builder: Client = Client::builder()
142 .api_key("dummy-key")
143 .build()
144 .expect("Client::builder() failed");
145 }
146}