Skip to main content

chat_gemini/
lib.rs

1mod api;
2pub mod client;
3mod tools;
4use std::env;
5use std::marker::PhantomData;
6
7use crate::api::types::request::{
8    EmbeddingsTask, GeminiEmbeddingsConfig, GeminiFunctionCallingConfig,
9};
10use chat_core::transport::Transport;
11use chat_core::types::provider_meta::ProviderMeta;
12
13pub use crate::client::GeminiClient;
14use crate::tools::GeminiNativeTool;
15use crate::tools::code_execution::CodeExecutionTool;
16use crate::tools::google_maps::GoogleMapsTool;
17use crate::tools::google_search::GoogleSearchTool;
18
19pub use chat_core::transport::ReqwestTransport;
20
21pub struct WithoutModel;
22pub struct WithModel;
23
24pub struct BaseConfig;
25pub struct CompletionConfig;
26pub struct EmbeddingConfig;
27
28pub struct GeminiBuilder<M = WithoutModel, C = BaseConfig, T: Transport = ReqwestTransport> {
29    model_name: Option<String>,
30    api_key: Option<String>,
31    native_tools: Vec<Box<dyn GeminiNativeTool>>,
32    function_config: Option<GeminiFunctionCallingConfig>,
33    embeddings_config: Option<GeminiEmbeddingsConfig>,
34    include_thoughts: bool,
35    transport: Option<T>,
36    meta: ProviderMeta,
37    _m: PhantomData<M>,
38    _c: PhantomData<C>,
39}
40
41impl Default for GeminiBuilder<WithoutModel, BaseConfig, ReqwestTransport> {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl GeminiBuilder<WithoutModel, BaseConfig, ReqwestTransport> {
48    pub fn new() -> Self {
49        Self {
50            model_name: None,
51            api_key: None,
52            native_tools: Vec::new(),
53            function_config: None,
54            embeddings_config: None,
55            include_thoughts: false,
56            transport: Some(ReqwestTransport::default()),
57            meta: ProviderMeta::default(),
58            _m: PhantomData,
59            _c: PhantomData,
60        }
61    }
62}
63
64impl<M, C, T: Transport> GeminiBuilder<M, C, T> {
65    pub fn with_api_key(mut self, api_key: String) -> Self {
66        self.api_key = Some(api_key);
67        self
68    }
69
70    pub fn with_description(mut self, description: impl Into<String>) -> Self {
71        self.meta.description = Some(description.into());
72        self
73    }
74
75    pub fn with_metadata(
76        mut self,
77        key: impl Into<String>,
78        value: impl std::any::Any + Send + Sync + 'static,
79    ) -> Self {
80        self.meta.data.insert(key.into(), Box::new(value));
81        self
82    }
83
84    /// Supply a custom transport, replacing the default.
85    pub fn with_transport<T2: Transport>(self, transport: T2) -> GeminiBuilder<M, C, T2> {
86        GeminiBuilder {
87            model_name: self.model_name,
88            api_key: self.api_key,
89            native_tools: self.native_tools,
90            function_config: self.function_config,
91            embeddings_config: self.embeddings_config,
92            include_thoughts: self.include_thoughts,
93            transport: Some(transport),
94            meta: self.meta,
95            _m: PhantomData,
96            _c: PhantomData,
97        }
98    }
99}
100
101impl<C, T: Transport> GeminiBuilder<WithoutModel, C, T> {
102    pub fn with_model(self, model_name: String) -> GeminiBuilder<WithModel, C, T> {
103        GeminiBuilder {
104            model_name: Some(model_name),
105            api_key: self.api_key,
106            native_tools: self.native_tools,
107            function_config: self.function_config,
108            embeddings_config: self.embeddings_config,
109            include_thoughts: self.include_thoughts,
110            transport: self.transport,
111            meta: self.meta,
112            _m: PhantomData,
113            _c: PhantomData,
114        }
115    }
116}
117
118impl<M, T: Transport> GeminiBuilder<M, BaseConfig, T> {
119    fn into_completion(self) -> GeminiBuilder<M, CompletionConfig, T> {
120        GeminiBuilder {
121            model_name: self.model_name,
122            api_key: self.api_key,
123            native_tools: self.native_tools,
124            function_config: self.function_config,
125            embeddings_config: self.embeddings_config,
126            include_thoughts: self.include_thoughts,
127            transport: self.transport,
128            meta: self.meta,
129            _m: PhantomData,
130            _c: PhantomData,
131        }
132    }
133
134    pub fn with_code_execution(self) -> GeminiBuilder<M, CompletionConfig, T> {
135        self.into_completion().with_code_execution()
136    }
137
138    pub fn with_google_search(self) -> GeminiBuilder<M, CompletionConfig, T> {
139        self.into_completion().with_google_search()
140    }
141
142    pub fn with_thoughts(mut self, include: bool) -> Self {
143        self.include_thoughts = include;
144        self
145    }
146
147    pub fn with_google_maps(
148        self,
149        lat_lng: Option<(f32, f32)>,
150        widget: bool,
151    ) -> GeminiBuilder<M, CompletionConfig, T> {
152        self.into_completion().with_google_maps(lat_lng, widget)
153    }
154
155    pub fn with_function_calling_mode(
156        self,
157        mode: &str,
158        allowed: Option<Vec<String>>,
159    ) -> GeminiBuilder<M, CompletionConfig, T> {
160        self.into_completion()
161            .with_function_calling_mode(mode, allowed)
162    }
163}
164
165impl<M, T: Transport> GeminiBuilder<M, CompletionConfig, T> {
166    pub fn with_code_execution(mut self) -> Self {
167        self.native_tools.push(Box::new(CodeExecutionTool));
168        self
169    }
170
171    pub fn with_google_search(mut self) -> Self {
172        self.native_tools.push(Box::new(GoogleSearchTool {
173            dynamic_threshold: None,
174        }));
175        self
176    }
177
178    pub fn with_google_search_threshold(mut self, threshold: f32) -> Self {
179        self.native_tools.push(Box::new(GoogleSearchTool {
180            dynamic_threshold: Some(threshold),
181        }));
182        self
183    }
184
185    pub fn with_google_maps(mut self, lat_lng: Option<(f32, f32)>, widget: bool) -> Self {
186        self.native_tools.push(Box::new(GoogleMapsTool {
187            lat_lng,
188            enable_widget: widget,
189        }));
190        self
191    }
192
193    pub fn with_function_calling_mode(mut self, mode: &str, allowed: Option<Vec<String>>) -> Self {
194        self.function_config = Some(GeminiFunctionCallingConfig {
195            mode: mode.to_string(),
196            allowed_function_names: allowed,
197        });
198        self
199    }
200}
201
202impl<M, T: Transport> GeminiBuilder<M, BaseConfig, T> {
203    fn into_embedding(self) -> GeminiBuilder<M, EmbeddingConfig, T> {
204        GeminiBuilder {
205            model_name: self.model_name,
206            api_key: self.api_key,
207            native_tools: vec![],
208            function_config: None,
209            embeddings_config: self.embeddings_config,
210            include_thoughts: false,
211            transport: self.transport,
212            meta: self.meta,
213            _m: PhantomData,
214            _c: PhantomData,
215        }
216    }
217
218    pub fn with_embeddings(self, dimensions: Option<usize>) -> GeminiBuilder<M, EmbeddingConfig, T> {
219        self.into_embedding().with_embeddings(dimensions)
220    }
221}
222
223impl<M, T: Transport> GeminiBuilder<M, EmbeddingConfig, T> {
224    pub fn with_embeddings(mut self, dimensions: Option<usize>) -> Self {
225        let mut config = self.embeddings_config.unwrap_or_default();
226        config.dimensions = dimensions;
227        self.embeddings_config = Some(config);
228        self
229    }
230
231    pub fn with_embeddings_task(mut self, task: EmbeddingsTask) -> Self {
232        let mut config = self.embeddings_config.unwrap_or_default();
233        config.task = task;
234        self.embeddings_config = Some(config);
235        self
236    }
237}
238
239impl<C, T: Transport> GeminiBuilder<WithModel, C, T> {
240    pub fn build(self) -> GeminiClient<T> {
241        GeminiClient {
242            model_name: self.model_name.unwrap(),
243            api_key: self.api_key.unwrap_or_else(|| {
244                env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not found in environment")
245            }),
246            scheme: "https".to_string(),
247            host: "generativelanguage.googleapis.com".to_string(),
248            base_path: "/v1beta".to_string(),
249            transport: self.transport.expect(
250                "No transport provided. Call .with_transport() or use the default GeminiBuilder (which provides ReqwestTransport).",
251            ),
252            native_tools: self.native_tools,
253            function_config: self.function_config,
254            embeddings_config: self.embeddings_config,
255            include_thoughts: self.include_thoughts,
256            meta: self.meta,
257        }
258    }
259}