bep/providers/gemini/
client.rs

1use crate::{
2    agent::AgentBuilder,
3    embeddings::{self},
4    extractor::ExtractorBuilder,
5    Embed,
6};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9
10use super::{completion::CompletionModel, embedding::EmbeddingModel};
11
12// ================================================================
13// Google Gemini Client
14// ================================================================
15const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
16
17#[derive(Clone)]
18pub struct Client {
19    base_url: String,
20    api_key: String,
21    http_client: reqwest::Client,
22}
23
24impl Client {
25    pub fn new(api_key: &str) -> Self {
26        Self::from_url(api_key, GEMINI_API_BASE_URL)
27    }
28    fn from_url(api_key: &str, base_url: &str) -> Self {
29        Self {
30            base_url: base_url.to_string(),
31            api_key: api_key.to_string(),
32            http_client: reqwest::Client::builder()
33                .default_headers({
34                    let mut headers = reqwest::header::HeaderMap::new();
35                    headers.insert(
36                        reqwest::header::CONTENT_TYPE,
37                        "application/json".parse().unwrap(),
38                    );
39                    headers
40                })
41                .build()
42                .expect("Gemini reqwest client should build"),
43        }
44    }
45
46    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
47    /// Panics if the environment variable is not set.
48    pub fn from_env() -> Self {
49        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
50        Self::new(&api_key)
51    }
52
53    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
54        let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
55
56        tracing::debug!("POST {}", url);
57        self.http_client.post(url)
58    }
59
60    /// Create an embedding model with the given name.
61    /// Note: default embedding dimension of 0 will be used if model is not known.
62    /// If this is the case, it's better to use function `embedding_model_with_ndims`
63    ///
64    /// # Example
65    /// ```
66    /// use bep::providers::gemini::{Client, self};
67    ///
68    /// // Initialize the Google Gemini client
69    /// let gemini = Client::new("your-google-gemini-api-key");
70    ///
71    /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
72    /// ```
73    pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
74        EmbeddingModel::new(self.clone(), model, None)
75    }
76
77    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
78    ///
79    /// # Example
80    /// ```
81    /// use bep::providers::gemini::{Client, self};
82    ///
83    /// // Initialize the Google Gemini client
84    /// let gemini = Client::new("your-google-gemini-api-key");
85    ///
86    /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-bep", 1024);
87    /// ```
88    pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
89        EmbeddingModel::new(self.clone(), model, Some(ndims))
90    }
91
92    /// Create an embedding builder with the given embedding model.
93    ///
94    /// # Example
95    /// ```
96    /// use bep::providers::gemini::{Client, self};
97    ///
98    /// // Initialize the Google Gemini client
99    /// let gemini = Client::new("your-google-gemini-api-key");
100    ///
101    /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
102    ///     .simple_document("doc0", "Hello, world!")
103    ///     .simple_document("doc1", "Goodbye, world!")
104    ///     .build()
105    ///     .await
106    ///     .expect("Failed to embed documents");
107    /// ```
108    pub fn embeddings<D: Embed>(
109        &self,
110        model: &str,
111    ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
112        embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
113    }
114
115    /// Create a completion model with the given name.
116    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
117    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
118    pub fn completion_model(&self, model: &str) -> CompletionModel {
119        CompletionModel::new(self.clone(), model)
120    }
121
122    /// Create an agent builder with the given completion model.
123    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
124    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
125    /// # Example
126    /// ```
127    /// use bep::providers::gemini::{Client, self};
128    ///
129    /// // Initialize the Google Gemini client
130    /// let gemini = Client::new("your-google-gemini-api-key");
131    ///
132    /// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO)
133    ///    .preamble("You are comedian AI with a mission to make people laugh.")
134    ///    .temperature(0.0)
135    ///    .build();
136    /// ```
137    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
138        AgentBuilder::new(self.completion_model(model))
139    }
140
141    /// Create an extractor builder with the given completion model.
142    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
143        &self,
144        model: &str,
145    ) -> ExtractorBuilder<T, CompletionModel> {
146        ExtractorBuilder::new(self.completion_model(model))
147    }
148}
149
150#[derive(Debug, Deserialize)]
151pub struct ApiErrorResponse {
152    pub message: String,
153}
154
155#[derive(Debug, Deserialize)]
156#[serde(untagged)]
157pub enum ApiResponse<T> {
158    Ok(T),
159    Err(ApiErrorResponse),
160}