rig/providers/gemini/
client.rs1#[cfg(feature = "image")]
2use crate::client::Nothing;
3use crate::client::{
4 self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
5 Transport,
6};
7use crate::http_client;
8use serde::Deserialize;
9use std::fmt::Debug;
10
11const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
15
16#[derive(Debug, Default, Clone)]
17pub struct GeminiExt {
18 api_key: String,
19}
20
21#[derive(Debug, Default, Clone)]
22pub struct GeminiBuilder;
23
24pub struct GeminiApiKey(String);
25
26impl<S> From<S> for GeminiApiKey
27where
28 S: Into<String>,
29{
30 fn from(value: S) -> Self {
31 Self(value.into())
32 }
33}
34
35pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
36pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
37
38impl ApiKey for GeminiApiKey {}
39
40impl DebugExt for GeminiExt {
41 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
42 std::iter::once(("api_key", (&"******") as &dyn Debug))
43 }
44}
45
46impl Provider for GeminiExt {
47 type Builder = GeminiBuilder;
48
49 const VERIFY_PATH: &'static str = "/v1beta/models";
50
51 fn build<H>(
52 builder: &client::ClientBuilder<Self::Builder, GeminiApiKey, H>,
53 ) -> http_client::Result<Self> {
54 Ok(Self {
55 api_key: builder.get_api_key().0.clone(),
56 })
57 }
58
59 fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
60 match transport {
61 Transport::Sse => {
62 format!(
63 "{}/{}?alt=sse&key={}",
64 base_url,
65 path.trim_start_matches('/'),
66 self.api_key
67 )
68 }
69 _ => {
70 format!(
71 "{}/{}?key={}",
72 base_url,
73 path.trim_start_matches('/'),
74 self.api_key
75 )
76 }
77 }
78 }
79}
80
81impl<H> Capabilities<H> for GeminiExt {
82 type Completion = Capable<super::completion::CompletionModel>;
83 type Embeddings = Capable<super::embedding::EmbeddingModel>;
84 type Transcription = Capable<super::transcription::TranscriptionModel>;
85
86 #[cfg(feature = "image")]
87 type ImageGeneration = Nothing;
88 #[cfg(feature = "audio")]
89 type AudioGeneration = Nothing;
90}
91
92impl ProviderBuilder for GeminiBuilder {
93 type Output = GeminiExt;
94 type ApiKey = GeminiApiKey;
95
96 const BASE_URL: &'static str = GEMINI_API_BASE_URL;
97}
98
99impl ProviderClient for Client {
100 type Input = GeminiApiKey;
101
102 fn from_env() -> Self {
105 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
106 Self::new(api_key).unwrap()
107 }
108
109 fn from_val(input: Self::Input) -> Self {
110 Self::new(input).unwrap()
111 }
112}
113
114#[derive(Debug, Deserialize)]
115pub struct ApiErrorResponse {
116 pub message: String,
117}
118
119#[derive(Debug, Deserialize)]
120#[serde(untagged)]
121pub enum ApiResponse<T> {
122 Ok(T),
123 Err(ApiErrorResponse),
124}