1use serde::{Deserialize, Serialize};
2use crate::embedders::{Embedder, http_client::HttpClient};
3use crate::error::{ContragError, Result};
4use crate::types::ConnectionTestResult;
5
6pub struct GeminiEmbedder {
8 api_key: String,
9 model: String,
10 dimensions: usize,
11 api_endpoint: String,
12 http_client: HttpClient,
13}
14
15impl GeminiEmbedder {
16 pub fn new(api_key: String, model: String) -> Self {
18 let dimensions = match model.as_str() {
19 "embedding-001" => 768,
20 "text-embedding-004" => 768,
21 _ => 768, };
23
24 Self {
25 api_key,
26 model,
27 dimensions,
28 api_endpoint: "https://generativelanguage.googleapis.com/v1beta/models".to_string(),
29 http_client: HttpClient::new(),
30 }
31 }
32
33 pub fn with_endpoint(mut self, endpoint: String) -> Self {
35 self.api_endpoint = endpoint;
36 self
37 }
38
39 fn get_embed_url(&self) -> String {
40 format!(
41 "{}/{}:embedContent?key={}",
42 self.api_endpoint, self.model, self.api_key
43 )
44 }
45
46 fn get_batch_embed_url(&self) -> String {
47 format!(
48 "{}/{}:batchEmbedContents?key={}",
49 self.api_endpoint, self.model, self.api_key
50 )
51 }
52}
53
54#[async_trait::async_trait]
55impl Embedder for GeminiEmbedder {
56 fn name(&self) -> &str {
57 "gemini"
58 }
59
60 async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
61 if texts.is_empty() {
62 return Ok(vec![]);
63 }
64
65 if texts.len() > 1 {
67 return self.batch_embed(texts).await;
68 }
69
70 let request = GeminiEmbedRequest {
72 content: GeminiContent {
73 parts: vec![GeminiPart {
74 text: texts[0].clone(),
75 }],
76 },
77 };
78
79 let body = serde_json::to_vec(&request)
80 .map_err(|e| ContragError::SerializationError(e.to_string()))?;
81
82 let headers = vec![("Content-Type".to_string(), "application/json".to_string())];
83
84 let response = self
85 .http_client
86 .post(self.get_embed_url(), headers, body)
87 .await?;
88
89 if response.status != 200 {
90 let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
91 return Err(ContragError::EmbedderError(format!(
92 "Gemini API returned status {}: {}",
93 response.status, error_text
94 )));
95 }
96
97 let embed_response: GeminiEmbedResponse = response.json()?;
98
99 Ok(vec![embed_response.embedding.values])
100 }
101
102 fn dimensions(&self) -> usize {
103 self.dimensions
104 }
105
106 async fn test_connection(&self) -> Result<ConnectionTestResult> {
107 let start = ic_cdk::api::time();
108
109 match self.embed(vec!["test connection".to_string()]).await {
110 Ok(_) => {
111 let latency = (ic_cdk::api::time() - start) / 1_000_000; Ok(ConnectionTestResult {
113 plugin: self.name().to_string(),
114 connected: true,
115 latency: Some(latency),
116 error: None,
117 details: Some(format!(
118 "model: {}, dimensions: {}",
119 self.model, self.dimensions
120 )),
121 })
122 }
123 Err(e) => Ok(ConnectionTestResult {
124 plugin: self.name().to_string(),
125 connected: false,
126 latency: None,
127 error: Some(e.to_string()),
128 details: None,
129 }),
130 }
131 }
132
133 async fn generate_with_prompt(
134 &self,
135 text: String,
136 system_prompt: String,
137 ) -> Result<String> {
138 let request = GeminiGenerateRequest {
139 contents: vec![GeminiContent {
140 parts: vec![GeminiPart {
141 text: format!("{}\n\n{}", system_prompt, text),
142 }],
143 }],
144 generation_config: Some(GeminiGenerationConfig {
145 temperature: 0.7,
146 max_output_tokens: 1000,
147 }),
148 };
149
150 let body = serde_json::to_vec(&request)
151 .map_err(|e| ContragError::SerializationError(e.to_string()))?;
152
153 let headers = vec![("Content-Type".to_string(), "application/json".to_string())];
154
155 let url = format!(
156 "{}/gemini-pro:generateContent?key={}",
157 self.api_endpoint, self.api_key
158 );
159
160 let response = self.http_client.post(url, headers, body).await?;
161
162 if response.status != 200 {
163 let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
164 return Err(ContragError::EmbedderError(format!(
165 "Gemini API returned status {}: {}",
166 response.status, error_text
167 )));
168 }
169
170 let generate_response: GeminiGenerateResponse = response.json()?;
171
172 Ok(generate_response
173 .candidates
174 .get(0)
175 .and_then(|c| c.content.parts.get(0))
176 .map(|p| p.text.clone())
177 .unwrap_or_default())
178 }
179}
180
181impl GeminiEmbedder {
182 async fn batch_embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
183 let requests: Vec<GeminiEmbedRequest> = texts
184 .into_iter()
185 .map(|text| GeminiEmbedRequest {
186 content: GeminiContent {
187 parts: vec![GeminiPart { text }],
188 },
189 })
190 .collect();
191
192 let batch_request = GeminiBatchEmbedRequest { requests };
193
194 let body = serde_json::to_vec(&batch_request)
195 .map_err(|e| ContragError::SerializationError(e.to_string()))?;
196
197 let headers = vec![("Content-Type".to_string(), "application/json".to_string())];
198
199 let response = self
200 .http_client
201 .post(self.get_batch_embed_url(), headers, body)
202 .await?;
203
204 if response.status != 200 {
205 let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
206 return Err(ContragError::EmbedderError(format!(
207 "Gemini API returned status {}: {}",
208 response.status, error_text
209 )));
210 }
211
212 let batch_response: GeminiBatchEmbedResponse = response.json()?;
213
214 Ok(batch_response
215 .embeddings
216 .into_iter()
217 .map(|e| e.values)
218 .collect())
219 }
220}
221
222#[derive(Serialize)]
225struct GeminiEmbedRequest {
226 content: GeminiContent,
227}
228
229#[derive(Serialize)]
230struct GeminiBatchEmbedRequest {
231 requests: Vec<GeminiEmbedRequest>,
232}
233
234#[derive(Serialize, Deserialize)]
235struct GeminiContent {
236 parts: Vec<GeminiPart>,
237}
238
239#[derive(Serialize, Deserialize)]
240struct GeminiPart {
241 text: String,
242}
243
244#[derive(Deserialize)]
245struct GeminiEmbedResponse {
246 embedding: GeminiEmbedding,
247}
248
249#[derive(Deserialize)]
250struct GeminiBatchEmbedResponse {
251 embeddings: Vec<GeminiEmbedding>,
252}
253
254#[derive(Deserialize)]
255struct GeminiEmbedding {
256 values: Vec<f32>,
257}
258
259#[derive(Serialize)]
260struct GeminiGenerateRequest {
261 contents: Vec<GeminiContent>,
262 #[serde(skip_serializing_if = "Option::is_none")]
263 generation_config: Option<GeminiGenerationConfig>,
264}
265
266#[derive(Serialize)]
267struct GeminiGenerationConfig {
268 temperature: f32,
269 max_output_tokens: u32,
270}
271
272#[derive(Deserialize)]
273struct GeminiGenerateResponse {
274 candidates: Vec<GeminiCandidate>,
275}
276
277#[derive(Deserialize)]
278struct GeminiCandidate {
279 content: GeminiContent,
280}