Skip to main content

chat_gemini/
lib.rs

1mod api;
2pub mod client;
3mod tools;
4use std::env;
5use std::marker::PhantomData;
6
7use crate::api::types::request::{
8    EmbeddingsTask, GeminiEmbeddingsConfig, GeminiFunctionCallingConfig,
9};
10use chat_core::types::provider_meta::ProviderMeta;
11
12use crate::client::GeminiClient;
13use crate::tools::GeminiNativeTool;
14use crate::tools::code_execution::CodeExecutionTool;
15use crate::tools::google_maps::GoogleMapsTool;
16use crate::tools::google_search::GoogleSearchTool;
17
18pub struct WithoutModel;
19pub struct WithModel;
20
21pub struct BaseConfig;
22pub struct CompletionConfig;
23pub struct EmbeddingConfig;
24
25pub struct GeminiBuilder<M = WithoutModel, C = BaseConfig> {
26    model_name: Option<String>,
27    api_key: Option<String>,
28    native_tools: Vec<Box<dyn GeminiNativeTool>>,
29    function_config: Option<GeminiFunctionCallingConfig>,
30    embeddings_config: Option<GeminiEmbeddingsConfig>,
31    include_thoughts: bool,
32    meta: ProviderMeta,
33    _m: PhantomData<M>,
34    _c: PhantomData<C>,
35}
36
37impl Default for GeminiBuilder<WithoutModel, BaseConfig> {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl GeminiBuilder<WithoutModel, BaseConfig> {
44    pub fn new() -> Self {
45        Self {
46            model_name: None,
47            api_key: None,
48            native_tools: Vec::new(),
49            function_config: None,
50            embeddings_config: None,
51            include_thoughts: false,
52            meta: ProviderMeta::default(),
53            _m: PhantomData,
54            _c: PhantomData,
55        }
56    }
57}
58
59impl<M, C> GeminiBuilder<M, C> {
60    pub fn with_api_key(mut self, api_key: String) -> Self {
61        self.api_key = Some(api_key);
62        self
63    }
64
65    pub fn with_description(mut self, description: impl Into<String>) -> Self {
66        self.meta.description = Some(description.into());
67        self
68    }
69
70    pub fn with_metadata(
71        mut self,
72        key: impl Into<String>,
73        value: impl std::any::Any + Send + Sync + 'static,
74    ) -> Self {
75        self.meta.data.insert(key.into(), Box::new(value));
76        self
77    }
78}
79
80impl<C> GeminiBuilder<WithoutModel, C> {
81    pub fn with_model(self, model_name: String) -> GeminiBuilder<WithModel, C> {
82        GeminiBuilder {
83            model_name: Some(model_name),
84            api_key: self.api_key,
85            native_tools: self.native_tools,
86            function_config: self.function_config,
87            embeddings_config: self.embeddings_config,
88            include_thoughts: self.include_thoughts,
89            meta: self.meta,
90            _m: PhantomData,
91            _c: PhantomData,
92        }
93    }
94}
95
96impl<M> GeminiBuilder<M, BaseConfig> {
97    fn into_completion(self) -> GeminiBuilder<M, CompletionConfig> {
98        GeminiBuilder {
99            model_name: self.model_name,
100            api_key: self.api_key,
101            native_tools: self.native_tools,
102            function_config: self.function_config,
103            embeddings_config: self.embeddings_config,
104            include_thoughts: self.include_thoughts,
105            meta: self.meta,
106            _m: PhantomData,
107            _c: PhantomData,
108        }
109    }
110
111    pub fn with_code_execution(self) -> GeminiBuilder<M, CompletionConfig> {
112        self.into_completion().with_code_execution()
113    }
114
115    pub fn with_google_search(self) -> GeminiBuilder<M, CompletionConfig> {
116        self.into_completion().with_google_search()
117    }
118
119    pub fn with_thoughts(mut self, include: bool) -> Self {
120        self.include_thoughts = include;
121        self
122    }
123
124    pub fn with_google_maps(
125        self,
126        lat_lng: Option<(f32, f32)>,
127        widget: bool,
128    ) -> GeminiBuilder<M, CompletionConfig> {
129        self.into_completion().with_google_maps(lat_lng, widget)
130    }
131
132    pub fn with_function_calling_mode(
133        self,
134        mode: &str,
135        allowed: Option<Vec<String>>,
136    ) -> GeminiBuilder<M, CompletionConfig> {
137        self.into_completion()
138            .with_function_calling_mode(mode, allowed)
139    }
140}
141
142impl<M> GeminiBuilder<M, CompletionConfig> {
143    pub fn with_code_execution(mut self) -> Self {
144        self.native_tools.push(Box::new(CodeExecutionTool));
145        self
146    }
147
148    pub fn with_google_search(mut self) -> Self {
149        self.native_tools.push(Box::new(GoogleSearchTool {
150            dynamic_threshold: None,
151        }));
152        self
153    }
154
155    pub fn with_google_search_threshold(mut self, threshold: f32) -> Self {
156        self.native_tools.push(Box::new(GoogleSearchTool {
157            dynamic_threshold: Some(threshold),
158        }));
159        self
160    }
161
162    pub fn with_google_maps(mut self, lat_lng: Option<(f32, f32)>, widget: bool) -> Self {
163        self.native_tools.push(Box::new(GoogleMapsTool {
164            lat_lng,
165            enable_widget: widget,
166        }));
167        self
168    }
169
170    pub fn with_function_calling_mode(mut self, mode: &str, allowed: Option<Vec<String>>) -> Self {
171        self.function_config = Some(GeminiFunctionCallingConfig {
172            mode: mode.to_string(),
173            allowed_function_names: allowed,
174        });
175        self
176    }
177}
178
179impl<M> GeminiBuilder<M, BaseConfig> {
180    fn into_embedding(self) -> GeminiBuilder<M, EmbeddingConfig> {
181        GeminiBuilder {
182            model_name: self.model_name,
183            api_key: self.api_key,
184            native_tools: vec![],
185            function_config: None,
186            embeddings_config: self.embeddings_config,
187            include_thoughts: false,
188            meta: self.meta,
189            _m: PhantomData,
190            _c: PhantomData,
191        }
192    }
193
194    pub fn with_embeddings(self, dimensions: Option<usize>) -> GeminiBuilder<M, EmbeddingConfig> {
195        self.into_embedding().with_embeddings(dimensions)
196    }
197}
198
199impl<M> GeminiBuilder<M, EmbeddingConfig> {
200    pub fn with_embeddings(mut self, dimensions: Option<usize>) -> Self {
201        let mut config = self.embeddings_config.unwrap_or_default();
202        config.dimensions = dimensions;
203        self.embeddings_config = Some(config);
204        self
205    }
206
207    pub fn with_embeddings_task(mut self, task: EmbeddingsTask) -> Self {
208        let mut config = self.embeddings_config.unwrap_or_default();
209        config.task = task;
210        self.embeddings_config = Some(config);
211        self
212    }
213}
214
215impl<C> GeminiBuilder<WithModel, C> {
216    pub fn build(self) -> GeminiClient {
217        GeminiClient {
218            model_name: self.model_name.unwrap(),
219            api_key: self.api_key.unwrap_or_else(|| {
220                env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not found in environment")
221            }),
222            http_client: reqwest::Client::new(),
223            native_tools: self.native_tools,
224            function_config: self.function_config,
225            embeddings_config: self.embeddings_config,
226            include_thoughts: self.include_thoughts,
227            meta: self.meta,
228        }
229    }
230}