rig/providers/gemini/
client.rs1use crate::client::{
2 self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3 ProviderClient, Transport,
4};
5use crate::http_client;
6use serde::Deserialize;
7use std::fmt::Debug;
8
9const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
13
14#[derive(Debug, Default, Clone)]
16pub struct GeminiExt {
17 api_key: String,
18}
19
20#[derive(Debug, Default, Clone)]
22pub struct GeminiBuilder;
23
24#[derive(Debug, Default, Clone)]
26pub struct GeminiInteractionsExt {
27 api_key: String,
28}
29
30#[derive(Debug, Default, Clone)]
32pub struct GeminiInteractionsBuilder;
33
34pub struct GeminiApiKey(String);
36
37impl<S> From<S> for GeminiApiKey
38where
39 S: Into<String>,
40{
41 fn from(value: S) -> Self {
42 Self(value.into())
43 }
44}
45
46pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
48pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
50pub type InteractionsClient<H = reqwest::Client> = client::Client<GeminiInteractionsExt, H>;
52
53impl ApiKey for GeminiApiKey {}
54
55impl DebugExt for GeminiExt {
56 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
57 std::iter::once(("api_key", (&"******") as &dyn Debug))
58 }
59}
60
61impl DebugExt for GeminiInteractionsExt {
62 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
63 std::iter::once(("api_key", (&"******") as &dyn Debug))
64 }
65}
66
67impl Provider for GeminiExt {
68 type Builder = GeminiBuilder;
69
70 const VERIFY_PATH: &'static str = "/v1beta/models";
71
72 fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
73 match transport {
74 Transport::Sse => {
75 format!(
76 "{}/{}?alt=sse&key={}",
77 base_url,
78 path.trim_start_matches('/'),
79 self.api_key
80 )
81 }
82 _ => {
83 format!(
84 "{}/{}?key={}",
85 base_url,
86 path.trim_start_matches('/'),
87 self.api_key
88 )
89 }
90 }
91 }
92}
93
94impl Provider for GeminiInteractionsExt {
95 type Builder = GeminiInteractionsBuilder;
96
97 const VERIFY_PATH: &'static str = "/v1beta/models";
98
99 fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
100 let trimmed = path.trim_start_matches('/');
101 match transport {
102 Transport::Sse => {
103 if trimmed.contains('?') {
104 format!("{}/{}&alt=sse", base_url, trimmed)
105 } else {
106 format!("{}/{}?alt=sse", base_url, trimmed)
107 }
108 }
109 _ => format!("{}/{}", base_url, trimmed),
110 }
111 }
112
113 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
114 Ok(req.header("x-goog-api-key", self.api_key.clone()))
115 }
116}
117
118impl<H> Capabilities<H> for GeminiExt {
119 type Completion = Capable<super::completion::CompletionModel>;
120 type Embeddings = Capable<super::embedding::EmbeddingModel>;
121 type Transcription = Capable<super::transcription::TranscriptionModel>;
122 type ModelListing = Nothing;
123
124 #[cfg(feature = "image")]
125 type ImageGeneration = Nothing;
126 #[cfg(feature = "audio")]
127 type AudioGeneration = Nothing;
128}
129
130impl<H> Capabilities<H> for GeminiInteractionsExt {
131 type Completion = Capable<super::interactions_api::InteractionsCompletionModel<H>>;
132 type Embeddings = Capable<super::embedding::EmbeddingModel>;
133 type Transcription = Capable<super::transcription::TranscriptionModel>;
134 type ModelListing = Nothing;
135
136 #[cfg(feature = "image")]
137 type ImageGeneration = Nothing;
138 #[cfg(feature = "audio")]
139 type AudioGeneration = Nothing;
140}
141
142impl ProviderBuilder for GeminiBuilder {
143 type Extension<H>
144 = GeminiExt
145 where
146 H: http_client::HttpClientExt;
147 type ApiKey = GeminiApiKey;
148
149 const BASE_URL: &'static str = GEMINI_API_BASE_URL;
150
151 fn build<H>(
152 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
153 ) -> http_client::Result<Self::Extension<H>>
154 where
155 H: http_client::HttpClientExt,
156 {
157 Ok(GeminiExt {
158 api_key: builder.get_api_key().0.clone(),
159 })
160 }
161}
162
163impl ProviderBuilder for GeminiInteractionsBuilder {
164 type Extension<H>
165 = GeminiInteractionsExt
166 where
167 H: http_client::HttpClientExt;
168 type ApiKey = GeminiApiKey;
169
170 const BASE_URL: &'static str = GEMINI_API_BASE_URL;
171
172 fn build<H>(
173 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
174 ) -> http_client::Result<Self::Extension<H>>
175 where
176 H: http_client::HttpClientExt,
177 {
178 Ok(GeminiInteractionsExt {
179 api_key: builder.get_api_key().0.clone(),
180 })
181 }
182}
183
184impl ProviderClient for Client {
185 type Input = GeminiApiKey;
186
187 fn from_env() -> Self {
190 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
191 Self::new(api_key).unwrap()
192 }
193
194 fn from_val(input: Self::Input) -> Self {
195 Self::new(input).unwrap()
196 }
197}
198
199impl ProviderClient for InteractionsClient {
200 type Input = GeminiApiKey;
201
202 fn from_env() -> Self {
205 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
206 Self::new(api_key).unwrap()
207 }
208
209 fn from_val(input: Self::Input) -> Self {
210 Self::new(input).unwrap()
211 }
212}
213
214impl<H> Client<H> {
215 pub fn interactions_api(self) -> InteractionsClient<H> {
217 let api_key = self.ext().api_key.clone();
218 self.with_ext(GeminiInteractionsExt { api_key })
219 }
220}
221
222impl<H> InteractionsClient<H> {
223 pub fn generate_content_api(self) -> Client<H> {
225 let api_key = self.ext().api_key.clone();
226 self.with_ext(GeminiExt { api_key })
227 }
228}
229
230#[derive(Debug, Deserialize)]
232pub struct ApiErrorResponse {
233 pub message: String,
234}
235
236#[derive(Debug, Deserialize)]
238#[serde(untagged)]
239pub enum ApiResponse<T> {
240 Ok(T),
241 Err(ApiErrorResponse),
242}
243
244#[cfg(test)]
249mod tests {
250 use super::*;
251 #[test]
252 fn test_client_initialization() {
253 let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
254 let _client_from_builder: Client = Client::builder()
255 .api_key("dummy-key")
256 .build()
257 .expect("Client::builder() failed");
258 }
259}