Skip to main content

chat_gemini/
lib.rs

1mod api;
2pub mod client;
3mod tools;
4use std::env;
5use std::marker::PhantomData;
6
7use crate::api::types::request::{
8    EmbeddingsTask, GeminiEmbeddingsConfig, GeminiFunctionCallingConfig,
9};
10use chat_core::transport::Transport;
11use chat_core::types::provider_meta::ProviderMeta;
12
13pub use crate::client::GeminiClient;
14use crate::tools::GeminiNativeTool;
15use crate::tools::code_execution::CodeExecutionTool;
16use crate::tools::google_maps::GoogleMapsTool;
17use crate::tools::google_search::GoogleSearchTool;
18
19pub use chat_core::transport::ReqwestTransport;
20
21pub struct WithoutModel;
22pub struct WithModel;
23
24pub struct BaseConfig;
25pub struct CompletionConfig;
26pub struct EmbeddingConfig;
27
28pub struct GeminiBuilder<M = WithoutModel, C = BaseConfig, T: Transport = ReqwestTransport> {
29    model_name: Option<String>,
30    api_key: Option<String>,
31    native_tools: Vec<Box<dyn GeminiNativeTool>>,
32    function_config: Option<GeminiFunctionCallingConfig>,
33    embeddings_config: Option<GeminiEmbeddingsConfig>,
34    include_thoughts: bool,
35    response_modalities: Option<Vec<String>>,
36    transport: Option<T>,
37    meta: ProviderMeta,
38    _m: PhantomData<M>,
39    _c: PhantomData<C>,
40}
41
42impl Default for GeminiBuilder<WithoutModel, BaseConfig, ReqwestTransport> {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl GeminiBuilder<WithoutModel, BaseConfig, ReqwestTransport> {
49    pub fn new() -> Self {
50        Self {
51            model_name: None,
52            api_key: None,
53            native_tools: Vec::new(),
54            function_config: None,
55            embeddings_config: None,
56            include_thoughts: false,
57            response_modalities: None,
58            transport: Some(ReqwestTransport::default()),
59            meta: ProviderMeta::default(),
60            _m: PhantomData,
61            _c: PhantomData,
62        }
63    }
64}
65
66impl<M, C, T: Transport> GeminiBuilder<M, C, T> {
67    pub fn with_api_key(mut self, api_key: String) -> Self {
68        self.api_key = Some(api_key);
69        self
70    }
71
72    pub fn with_description(mut self, description: impl Into<String>) -> Self {
73        self.meta.description = Some(description.into());
74        self
75    }
76
77    pub fn with_metadata(
78        mut self,
79        key: impl Into<String>,
80        value: impl std::any::Any + Send + Sync + 'static,
81    ) -> Self {
82        self.meta.data.insert(key.into(), Box::new(value));
83        self
84    }
85
86    /// Supply a custom transport, replacing the default.
87    pub fn with_transport<T2: Transport>(self, transport: T2) -> GeminiBuilder<M, C, T2> {
88        GeminiBuilder {
89            model_name: self.model_name,
90            api_key: self.api_key,
91            native_tools: self.native_tools,
92            function_config: self.function_config,
93            embeddings_config: self.embeddings_config,
94            include_thoughts: self.include_thoughts,
95            response_modalities: self.response_modalities,
96            transport: Some(transport),
97            meta: self.meta,
98            _m: PhantomData,
99            _c: PhantomData,
100        }
101    }
102}
103
104impl<C, T: Transport> GeminiBuilder<WithoutModel, C, T> {
105    pub fn with_model(self, model_name: impl Into<String>) -> GeminiBuilder<WithModel, C, T> {
106        GeminiBuilder {
107            model_name: Some(model_name.into()),
108            api_key: self.api_key,
109            native_tools: self.native_tools,
110            function_config: self.function_config,
111            embeddings_config: self.embeddings_config,
112            include_thoughts: self.include_thoughts,
113            response_modalities: self.response_modalities,
114            transport: self.transport,
115            meta: self.meta,
116            _m: PhantomData,
117            _c: PhantomData,
118        }
119    }
120}
121
122impl<M, T: Transport> GeminiBuilder<M, BaseConfig, T> {
123    fn into_completion(self) -> GeminiBuilder<M, CompletionConfig, T> {
124        GeminiBuilder {
125            model_name: self.model_name,
126            api_key: self.api_key,
127            native_tools: self.native_tools,
128            function_config: self.function_config,
129            embeddings_config: self.embeddings_config,
130            include_thoughts: self.include_thoughts,
131            response_modalities: self.response_modalities,
132            transport: self.transport,
133            meta: self.meta,
134            _m: PhantomData,
135            _c: PhantomData,
136        }
137    }
138
139    pub fn with_code_execution(self) -> GeminiBuilder<M, CompletionConfig, T> {
140        self.into_completion().with_code_execution()
141    }
142
143    pub fn with_google_search(self) -> GeminiBuilder<M, CompletionConfig, T> {
144        self.into_completion().with_google_search()
145    }
146
147    pub fn with_thoughts(mut self, include: bool) -> Self {
148        self.include_thoughts = include;
149        self
150    }
151
152    pub fn with_image_output(mut self) -> Self {
153        self.response_modalities = Some(vec!["TEXT".to_string(), "IMAGE".to_string()]);
154        self
155    }
156
157    pub fn with_google_maps(
158        self,
159        lat_lng: Option<(f32, f32)>,
160        widget: bool,
161    ) -> GeminiBuilder<M, CompletionConfig, T> {
162        self.into_completion().with_google_maps(lat_lng, widget)
163    }
164
165    pub fn with_function_calling_mode(
166        self,
167        mode: &str,
168        allowed: Option<Vec<String>>,
169    ) -> GeminiBuilder<M, CompletionConfig, T> {
170        self.into_completion()
171            .with_function_calling_mode(mode, allowed)
172    }
173}
174
175impl<M, T: Transport> GeminiBuilder<M, CompletionConfig, T> {
176    pub fn with_code_execution(mut self) -> Self {
177        self.native_tools.push(Box::new(CodeExecutionTool));
178        self
179    }
180
181    pub fn with_google_search(mut self) -> Self {
182        self.native_tools.push(Box::new(GoogleSearchTool {
183            dynamic_threshold: None,
184        }));
185        self
186    }
187
188    pub fn with_google_search_threshold(mut self, threshold: f32) -> Self {
189        self.native_tools.push(Box::new(GoogleSearchTool {
190            dynamic_threshold: Some(threshold),
191        }));
192        self
193    }
194
195    pub fn with_google_maps(mut self, lat_lng: Option<(f32, f32)>, widget: bool) -> Self {
196        self.native_tools.push(Box::new(GoogleMapsTool {
197            lat_lng,
198            enable_widget: widget,
199        }));
200        self
201    }
202
203    pub fn with_function_calling_mode(mut self, mode: &str, allowed: Option<Vec<String>>) -> Self {
204        self.function_config = Some(GeminiFunctionCallingConfig {
205            mode: mode.to_string(),
206            allowed_function_names: allowed,
207        });
208        self
209    }
210}
211
212impl<M, T: Transport> GeminiBuilder<M, BaseConfig, T> {
213    fn into_embedding(self) -> GeminiBuilder<M, EmbeddingConfig, T> {
214        GeminiBuilder {
215            model_name: self.model_name,
216            api_key: self.api_key,
217            native_tools: vec![],
218            function_config: None,
219            embeddings_config: self.embeddings_config,
220            include_thoughts: false,
221            response_modalities: None,
222            transport: self.transport,
223            meta: self.meta,
224            _m: PhantomData,
225            _c: PhantomData,
226        }
227    }
228
229    pub fn with_embeddings(
230        self,
231        dimensions: Option<usize>,
232    ) -> GeminiBuilder<M, EmbeddingConfig, T> {
233        self.into_embedding().with_embeddings(dimensions)
234    }
235}
236
237impl<M, T: Transport> GeminiBuilder<M, EmbeddingConfig, T> {
238    pub fn with_embeddings(mut self, dimensions: Option<usize>) -> Self {
239        let mut config = self.embeddings_config.unwrap_or_default();
240        config.dimensions = dimensions;
241        self.embeddings_config = Some(config);
242        self
243    }
244
245    pub fn with_embeddings_task(mut self, task: EmbeddingsTask) -> Self {
246        let mut config = self.embeddings_config.unwrap_or_default();
247        config.task = task;
248        self.embeddings_config = Some(config);
249        self
250    }
251}
252
253impl<C, T: Transport> GeminiBuilder<WithModel, C, T> {
254    pub fn build(self) -> GeminiClient<T> {
255        GeminiClient {
256            model_name: self.model_name.unwrap(),
257            api_key: self.api_key.unwrap_or_else(|| {
258                env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not found in environment")
259            }),
260            scheme: "https".to_string(),
261            host: "generativelanguage.googleapis.com".to_string(),
262            base_path: "/v1beta".to_string(),
263            transport: self.transport.expect(
264                "No transport provided. Call .with_transport() or use the default GeminiBuilder (which provides ReqwestTransport).",
265            ),
266            native_tools: self.native_tools,
267            function_config: self.function_config,
268            embeddings_config: self.embeddings_config,
269            include_thoughts: self.include_thoughts,
270            response_modalities: self.response_modalities,
271            meta: self.meta,
272        }
273    }
274}