contrag_core/embedders/
gemini.rs

1use serde::{Deserialize, Serialize};
2use crate::embedders::{Embedder, http_client::HttpClient};
3use crate::error::{ContragError, Result};
4use crate::types::ConnectionTestResult;
5
6/// Google Gemini embedder using HTTP outcalls
7pub struct GeminiEmbedder {
8    api_key: String,
9    model: String,
10    dimensions: usize,
11    api_endpoint: String,
12    http_client: HttpClient,
13}
14
15impl GeminiEmbedder {
16    /// Create a new Gemini embedder
17    pub fn new(api_key: String, model: String) -> Self {
18        let dimensions = match model.as_str() {
19            "embedding-001" => 768,
20            "text-embedding-004" => 768,
21            _ => 768, // default
22        };
23
24        Self {
25            api_key,
26            model,
27            dimensions,
28            api_endpoint: "https://generativelanguage.googleapis.com/v1beta/models".to_string(),
29            http_client: HttpClient::new(),
30        }
31    }
32
33    /// Create with custom API endpoint
34    pub fn with_endpoint(mut self, endpoint: String) -> Self {
35        self.api_endpoint = endpoint;
36        self
37    }
38
39    fn get_embed_url(&self) -> String {
40        format!(
41            "{}/{}:embedContent?key={}",
42            self.api_endpoint, self.model, self.api_key
43        )
44    }
45
46    fn get_batch_embed_url(&self) -> String {
47        format!(
48            "{}/{}:batchEmbedContents?key={}",
49            self.api_endpoint, self.model, self.api_key
50        )
51    }
52}
53
54#[async_trait::async_trait]
55impl Embedder for GeminiEmbedder {
56    fn name(&self) -> &str {
57        "gemini"
58    }
59
60    async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
61        if texts.is_empty() {
62            return Ok(vec![]);
63        }
64
65        // Use batch embed for multiple texts
66        if texts.len() > 1 {
67            return self.batch_embed(texts).await;
68        }
69
70        // Single text embedding
71        let request = GeminiEmbedRequest {
72            content: GeminiContent {
73                parts: vec![GeminiPart {
74                    text: texts[0].clone(),
75                }],
76            },
77        };
78
79        let body = serde_json::to_vec(&request)
80            .map_err(|e| ContragError::SerializationError(e.to_string()))?;
81
82        let headers = vec![("Content-Type".to_string(), "application/json".to_string())];
83
84        let response = self
85            .http_client
86            .post(self.get_embed_url(), headers, body)
87            .await?;
88
89        if response.status != 200 {
90            let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
91            return Err(ContragError::EmbedderError(format!(
92                "Gemini API returned status {}: {}",
93                response.status, error_text
94            )));
95        }
96
97        let embed_response: GeminiEmbedResponse = response.json()?;
98
99        Ok(vec![embed_response.embedding.values])
100    }
101
102    fn dimensions(&self) -> usize {
103        self.dimensions
104    }
105
106    async fn test_connection(&self) -> Result<ConnectionTestResult> {
107        let start = ic_cdk::api::time();
108
109        match self.embed(vec!["test connection".to_string()]).await {
110            Ok(_) => {
111                let latency = (ic_cdk::api::time() - start) / 1_000_000; // Convert to ms
112                Ok(ConnectionTestResult {
113                    plugin: self.name().to_string(),
114                    connected: true,
115                    latency: Some(latency),
116                    error: None,
117                    details: Some(format!(
118                        "model: {}, dimensions: {}",
119                        self.model, self.dimensions
120                    )),
121                })
122            }
123            Err(e) => Ok(ConnectionTestResult {
124                plugin: self.name().to_string(),
125                connected: false,
126                latency: None,
127                error: Some(e.to_string()),
128                details: None,
129            }),
130        }
131    }
132
133    async fn generate_with_prompt(
134        &self,
135        text: String,
136        system_prompt: String,
137    ) -> Result<String> {
138        let request = GeminiGenerateRequest {
139            contents: vec![GeminiContent {
140                parts: vec![GeminiPart {
141                    text: format!("{}\n\n{}", system_prompt, text),
142                }],
143            }],
144            generation_config: Some(GeminiGenerationConfig {
145                temperature: 0.7,
146                max_output_tokens: 1000,
147            }),
148        };
149
150        let body = serde_json::to_vec(&request)
151            .map_err(|e| ContragError::SerializationError(e.to_string()))?;
152
153        let headers = vec![("Content-Type".to_string(), "application/json".to_string())];
154
155        let url = format!(
156            "{}/gemini-pro:generateContent?key={}",
157            self.api_endpoint, self.api_key
158        );
159
160        let response = self.http_client.post(url, headers, body).await?;
161
162        if response.status != 200 {
163            let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
164            return Err(ContragError::EmbedderError(format!(
165                "Gemini API returned status {}: {}",
166                response.status, error_text
167            )));
168        }
169
170        let generate_response: GeminiGenerateResponse = response.json()?;
171
172        Ok(generate_response
173            .candidates
174            .get(0)
175            .and_then(|c| c.content.parts.get(0))
176            .map(|p| p.text.clone())
177            .unwrap_or_default())
178    }
179}
180
181impl GeminiEmbedder {
182    async fn batch_embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
183        let requests: Vec<GeminiEmbedRequest> = texts
184            .into_iter()
185            .map(|text| GeminiEmbedRequest {
186                content: GeminiContent {
187                    parts: vec![GeminiPart { text }],
188                },
189            })
190            .collect();
191
192        let batch_request = GeminiBatchEmbedRequest { requests };
193
194        let body = serde_json::to_vec(&batch_request)
195            .map_err(|e| ContragError::SerializationError(e.to_string()))?;
196
197        let headers = vec![("Content-Type".to_string(), "application/json".to_string())];
198
199        let response = self
200            .http_client
201            .post(self.get_batch_embed_url(), headers, body)
202            .await?;
203
204        if response.status != 200 {
205            let error_text = response.text().unwrap_or_else(|_| "Unknown error".to_string());
206            return Err(ContragError::EmbedderError(format!(
207                "Gemini API returned status {}: {}",
208                response.status, error_text
209            )));
210        }
211
212        let batch_response: GeminiBatchEmbedResponse = response.json()?;
213
214        Ok(batch_response
215            .embeddings
216            .into_iter()
217            .map(|e| e.values)
218            .collect())
219    }
220}
221
222// Request/Response types for Gemini API
223
224#[derive(Serialize)]
225struct GeminiEmbedRequest {
226    content: GeminiContent,
227}
228
229#[derive(Serialize)]
230struct GeminiBatchEmbedRequest {
231    requests: Vec<GeminiEmbedRequest>,
232}
233
234#[derive(Serialize, Deserialize)]
235struct GeminiContent {
236    parts: Vec<GeminiPart>,
237}
238
239#[derive(Serialize, Deserialize)]
240struct GeminiPart {
241    text: String,
242}
243
244#[derive(Deserialize)]
245struct GeminiEmbedResponse {
246    embedding: GeminiEmbedding,
247}
248
249#[derive(Deserialize)]
250struct GeminiBatchEmbedResponse {
251    embeddings: Vec<GeminiEmbedding>,
252}
253
254#[derive(Deserialize)]
255struct GeminiEmbedding {
256    values: Vec<f32>,
257}
258
259#[derive(Serialize)]
260struct GeminiGenerateRequest {
261    contents: Vec<GeminiContent>,
262    #[serde(skip_serializing_if = "Option::is_none")]
263    generation_config: Option<GeminiGenerationConfig>,
264}
265
266#[derive(Serialize)]
267struct GeminiGenerationConfig {
268    temperature: f32,
269    max_output_tokens: u32,
270}
271
272#[derive(Deserialize)]
273struct GeminiGenerateResponse {
274    candidates: Vec<GeminiCandidate>,
275}
276
277#[derive(Deserialize)]
278struct GeminiCandidate {
279    content: GeminiContent,
280}