rig/providers/gemini/
client.rs1use crate::client::{
2 self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3 ProviderClient, Transport,
4};
5use crate::http_client;
6use serde::Deserialize;
7use std::fmt::Debug;
8
9const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
13
14#[derive(Debug, Default, Clone)]
15pub struct GeminiExt {
16 api_key: String,
17}
18
19#[derive(Debug, Default, Clone)]
20pub struct GeminiBuilder;
21
22pub struct GeminiApiKey(String);
23
24impl<S> From<S> for GeminiApiKey
25where
26 S: Into<String>,
27{
28 fn from(value: S) -> Self {
29 Self(value.into())
30 }
31}
32
33pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
34pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
35
36impl ApiKey for GeminiApiKey {}
37
38impl DebugExt for GeminiExt {
39 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
40 std::iter::once(("api_key", (&"******") as &dyn Debug))
41 }
42}
43
44impl Provider for GeminiExt {
45 type Builder = GeminiBuilder;
46
47 const VERIFY_PATH: &'static str = "/v1beta/models";
48
49 fn build<H>(
50 builder: &client::ClientBuilder<Self::Builder, GeminiApiKey, H>,
51 ) -> http_client::Result<Self> {
52 Ok(Self {
53 api_key: builder.get_api_key().0.clone(),
54 })
55 }
56
57 fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
58 match transport {
59 Transport::Sse => {
60 format!(
61 "{}/{}?alt=sse&key={}",
62 base_url,
63 path.trim_start_matches('/'),
64 self.api_key
65 )
66 }
67 _ => {
68 format!(
69 "{}/{}?key={}",
70 base_url,
71 path.trim_start_matches('/'),
72 self.api_key
73 )
74 }
75 }
76 }
77}
78
79impl<H> Capabilities<H> for GeminiExt {
80 type Completion = Capable<super::completion::CompletionModel>;
81 type Embeddings = Capable<super::embedding::EmbeddingModel>;
82 type Transcription = Capable<super::transcription::TranscriptionModel>;
83 type ModelListing = Nothing;
84
85 #[cfg(feature = "image")]
86 type ImageGeneration = Nothing;
87 #[cfg(feature = "audio")]
88 type AudioGeneration = Nothing;
89}
90
91impl ProviderBuilder for GeminiBuilder {
92 type Output = GeminiExt;
93 type ApiKey = GeminiApiKey;
94
95 const BASE_URL: &'static str = GEMINI_API_BASE_URL;
96}
97
98impl ProviderClient for Client {
99 type Input = GeminiApiKey;
100
101 fn from_env() -> Self {
104 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
105 Self::new(api_key).unwrap()
106 }
107
108 fn from_val(input: Self::Input) -> Self {
109 Self::new(input).unwrap()
110 }
111}
112
113#[derive(Debug, Deserialize)]
114pub struct ApiErrorResponse {
115 pub message: String,
116}
117
118#[derive(Debug, Deserialize)]
119#[serde(untagged)]
120pub enum ApiResponse<T> {
121 Ok(T),
122 Err(ApiErrorResponse),
123}
124
125#[cfg(test)]
130mod tests {
131 use super::*;
132 #[test]
133 fn test_client_initialization() {
134 let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
135 let _client_from_builder: Client = Client::builder()
136 .api_key("dummy-key")
137 .build()
138 .expect("Client::builder() failed");
139 }
140}