codex_memory/
embedding.rs

1use anyhow::{Context, Result};
2use backoff::{future::retry, ExponentialBackoff};
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6use tracing::{info, warn};
7
8#[derive(Debug, Clone)]
9pub struct SimpleEmbedder {
10    client: Client,
11    api_key: String,
12    model: String,
13    base_url: String,
14    provider: EmbeddingProvider,
15    fallback_models: Vec<String>,
16}
17
18#[derive(Debug, Clone, PartialEq)]
19pub enum EmbeddingProvider {
20    OpenAI,
21    Ollama,
22    Mock, // For testing
23}
24
25// OpenAI API request/response structures
26#[derive(Debug, Serialize)]
27struct OpenAIEmbeddingRequest {
28    input: String,
29    model: String,
30}
31
32#[derive(Debug, Deserialize)]
33struct OpenAIEmbeddingResponse {
34    data: Vec<OpenAIEmbeddingData>,
35}
36
37#[derive(Debug, Deserialize)]
38struct OpenAIEmbeddingData {
39    embedding: Vec<f32>,
40}
41
42// Ollama API request/response structures
43#[derive(Debug, Serialize)]
44struct OllamaEmbeddingRequest {
45    model: String,
46    prompt: String,
47}
48
49#[derive(Debug, Deserialize)]
50struct OllamaEmbeddingResponse {
51    embedding: Vec<f32>,
52}
53
54#[derive(Debug, Deserialize)]
55struct OllamaModel {
56    name: String,
57    #[allow(dead_code)]
58    size: u64,
59    #[serde(default)]
60    #[allow(dead_code)]
61    family: String,
62}
63
64#[derive(Debug, Deserialize)]
65struct OllamaModelsResponse {
66    models: Vec<OllamaModel>,
67}
68
69impl SimpleEmbedder {
70    pub fn new(api_key: String) -> Self {
71        let client = Client::builder()
72            .timeout(Duration::from_secs(30))
73            .build()
74            .expect("Failed to create HTTP client");
75
76        Self {
77            client,
78            api_key,
79            model: "text-embedding-3-small".to_string(),
80            base_url: "https://api.openai.com".to_string(),
81            provider: EmbeddingProvider::OpenAI,
82            fallback_models: vec![
83                "text-embedding-3-large".to_string(),
84                "text-embedding-ada-002".to_string(),
85            ],
86        }
87    }
88
89    pub fn new_ollama(base_url: String, model: String) -> Self {
90        let client = Client::builder()
91            .timeout(Duration::from_secs(60)) // Ollama might be slower
92            .build()
93            .expect("Failed to create HTTP client");
94
95        Self {
96            client,
97            api_key: String::new(), // Ollama doesn't need an API key
98            model,
99            base_url,
100            provider: EmbeddingProvider::Ollama,
101            fallback_models: vec![
102                "nomic-embed-text".to_string(),
103                "mxbai-embed-large".to_string(),
104                "all-minilm".to_string(),
105                "all-mpnet-base-v2".to_string(),
106            ],
107        }
108    }
109
110    pub fn new_mock() -> Self {
111        let client = Client::builder()
112            .timeout(Duration::from_secs(1))
113            .build()
114            .expect("Failed to create HTTP client");
115
116        Self {
117            client,
118            api_key: String::new(),
119            model: "mock-model".to_string(),
120            base_url: "http://mock:11434".to_string(),
121            provider: EmbeddingProvider::Mock,
122            fallback_models: vec!["mock-model-2".to_string()],
123        }
124    }
125
126    pub fn with_model(mut self, model: String) -> Self {
127        self.model = model;
128        self
129    }
130
131    pub fn with_base_url(mut self, base_url: String) -> Self {
132        self.base_url = base_url;
133        self
134    }
135
136    /// Generate embedding for text with automatic retry
137    pub async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
138        info!("Generating embedding for text of length: {}", text.len());
139
140        let operation = || async {
141            match self.generate_embedding_internal(text).await {
142                Ok(embedding) => Ok(embedding),
143                Err(e) => {
144                    if e.to_string().contains("Rate limited") {
145                        Err(backoff::Error::transient(e))
146                    } else {
147                        Err(backoff::Error::permanent(e))
148                    }
149                }
150            }
151        };
152
153        let backoff = ExponentialBackoff {
154            max_elapsed_time: Some(Duration::from_secs(60)),
155            ..Default::default()
156        };
157
158        retry(backoff, operation).await
159    }
160
161    async fn generate_embedding_internal(&self, text: &str) -> Result<Vec<f32>> {
162        match self.provider {
163            EmbeddingProvider::OpenAI => self.generate_openai_embedding(text).await,
164            EmbeddingProvider::Ollama => self.generate_ollama_embedding(text).await,
165            EmbeddingProvider::Mock => self.generate_mock_embedding(text).await,
166        }
167    }
168
169    async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
170        let request = OpenAIEmbeddingRequest {
171            input: text.to_string(),
172            model: self.model.clone(),
173        };
174
175        let response = self
176            .client
177            .post(&format!("{}/v1/embeddings", self.base_url))
178            .header("Authorization", format!("Bearer {}", self.api_key))
179            .header("Content-Type", "application/json")
180            .json(&request)
181            .send()
182            .await?;
183
184        if !response.status().is_success() {
185            let status = response.status();
186            let error_text = response
187                .text()
188                .await
189                .unwrap_or_else(|_| "Unknown error".to_string());
190
191            if status.as_u16() == 429 {
192                warn!("Rate limited by OpenAI API, will retry");
193                return Err(anyhow::anyhow!("Rate limited: {}", error_text));
194            }
195
196            return Err(anyhow::anyhow!(
197                "OpenAI API request failed with status {}: {}",
198                status,
199                error_text
200            ));
201        }
202
203        let embedding_response: OpenAIEmbeddingResponse = response.json().await?;
204
205        if let Some(embedding_data) = embedding_response.data.first() {
206            Ok(embedding_data.embedding.clone())
207        } else {
208            Err(anyhow::anyhow!("No embedding data in OpenAI response"))
209        }
210    }
211
212    async fn generate_ollama_embedding(&self, text: &str) -> Result<Vec<f32>> {
213        let request = OllamaEmbeddingRequest {
214            model: self.model.clone(),
215            prompt: text.to_string(),
216        };
217
218        let response = self
219            .client
220            .post(&format!("{}/api/embeddings", self.base_url))
221            .header("Content-Type", "application/json")
222            .json(&request)
223            .send()
224            .await?;
225
226        if !response.status().is_success() {
227            let status = response.status();
228            let error_text = response
229                .text()
230                .await
231                .unwrap_or_else(|_| "Unknown error".to_string());
232
233            if status.as_u16() == 429 {
234                warn!("Rate limited by Ollama API, will retry");
235                return Err(anyhow::anyhow!("Rate limited: {}", error_text));
236            }
237
238            return Err(anyhow::anyhow!(
239                "Ollama API request failed with status {}: {}",
240                status,
241                error_text
242            ));
243        }
244
245        let embedding_response: OllamaEmbeddingResponse = response.json().await?;
246        Ok(embedding_response.embedding)
247    }
248
249    async fn generate_mock_embedding(&self, text: &str) -> Result<Vec<f32>> {
250        // Generate a deterministic mock embedding based on text content
251        // This is useful for testing without requiring real embedding services
252        use std::collections::hash_map::DefaultHasher;
253        use std::hash::{Hash, Hasher};
254
255        let mut hasher = DefaultHasher::new();
256        text.hash(&mut hasher);
257        let hash = hasher.finish();
258
259        // Generate a fixed-size embedding (768 dimensions for consistency)
260        let dimensions = self.embedding_dimension();
261        let mut embedding = Vec::with_capacity(dimensions);
262
263        // Use the hash to seed a simple PRNG for consistent results
264        let mut seed = hash;
265        for _ in 0..dimensions {
266            seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
267            let value = ((seed >> 16) % 1000) as f32 / 1000.0 - 0.5; // -0.5 to 0.5
268            embedding.push(value);
269        }
270
271        // Normalize the embedding to unit length
272        let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
273        if magnitude > 0.0 {
274            for val in &mut embedding {
275                *val /= magnitude;
276            }
277        }
278
279        Ok(embedding)
280    }
281
282    /// Generate embeddings for multiple texts in batch
283    pub async fn generate_embeddings_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
284        info!("Generating embeddings for {} texts", texts.len());
285
286        let mut embeddings = Vec::with_capacity(texts.len());
287
288        // Process in small batches to avoid rate limits
289        for chunk in texts.chunks(10) {
290            let mut chunk_embeddings = Vec::with_capacity(chunk.len());
291
292            for text in chunk {
293                match self.generate_embedding(text).await {
294                    Ok(embedding) => chunk_embeddings.push(embedding),
295                    Err(e) => {
296                        warn!("Failed to generate embedding for text: {}", e);
297                        return Err(e);
298                    }
299                }
300
301                // Small delay to be respectful to the API
302                tokio::time::sleep(Duration::from_millis(100)).await;
303            }
304
305            embeddings.extend(chunk_embeddings);
306        }
307
308        Ok(embeddings)
309    }
310
311    /// Get the dimension of embeddings for this model
312    pub fn embedding_dimension(&self) -> usize {
313        match self.provider {
314            EmbeddingProvider::OpenAI => match self.model.as_str() {
315                "text-embedding-3-small" => 1536,
316                "text-embedding-3-large" => 3072,
317                "text-embedding-ada-002" => 1536,
318                _ => 1536, // Default to small model dimensions
319            },
320            EmbeddingProvider::Ollama => {
321                // Ollama models vary in dimensions, but many common ones use these sizes
322                match self.model.as_str() {
323                    "gpt-oss:20b" => 4096,  // Assuming this model has 4096 dimensions
324                    "nomic-embed-text" => 768,
325                    "mxbai-embed-large" => 1024,
326                    "all-minilm" => 384,
327                    _ => 768, // Default dimension for many embedding models
328                }
329            }
330            EmbeddingProvider::Mock => 768, // Consistent mock embedding dimension
331        }
332    }
333
334    /// Get the provider type
335    pub fn provider(&self) -> &EmbeddingProvider {
336        &self.provider
337    }
338
339    /// Auto-detect and configure the best available embedding model
340    pub async fn auto_configure(base_url: String) -> Result<Self> {
341        info!("🔍 Auto-detecting best available embedding model...");
342
343        let client = Client::builder()
344            .timeout(Duration::from_secs(30))
345            .build()
346            .context("Failed to create HTTP client")?;
347
348        // Try to get available models from Ollama
349        let available_models = Self::detect_ollama_models(&client, &base_url).await?;
350        
351        if available_models.is_empty() {
352            return Err(anyhow::anyhow!("No embedding models found on Ollama server"));
353        }
354
355        // Select the best available model
356        let selected_model = Self::select_best_model(&available_models)?;
357        
358        info!("✅ Selected model: {} ({}D)", selected_model.name, selected_model.dimensions);
359        
360        let mut embedder = Self::new_ollama(base_url, selected_model.name.clone());
361        embedder.fallback_models = available_models.into_iter()
362            .filter(|m| m.name != embedder.model)
363            .map(|m| m.name)
364            .collect();
365
366        Ok(embedder)
367    }
368
369    /// Generate embedding with automatic fallback to alternative models
370    pub async fn generate_embedding_with_fallback(&self, text: &str) -> Result<Vec<f32>> {
371        // Try the primary model first
372        match self.generate_embedding(text).await {
373            Ok(embedding) => return Ok(embedding),
374            Err(e) => {
375                warn!("Primary model '{}' failed: {}", self.model, e);
376            }
377        }
378
379        // Try fallback models
380        for fallback_model in &self.fallback_models {
381            info!("🔄 Trying fallback model: {}", fallback_model);
382            
383            let mut fallback_embedder = self.clone();
384            fallback_embedder.model = fallback_model.clone();
385            
386            match fallback_embedder.generate_embedding(text).await {
387                Ok(embedding) => {
388                    info!("✅ Fallback model '{}' succeeded", fallback_model);
389                    return Ok(embedding);
390                }
391                Err(e) => {
392                    warn!("Fallback model '{}' failed: {}", fallback_model, e);
393                    continue;
394                }
395            }
396        }
397
398        Err(anyhow::anyhow!("All embedding models failed, including fallbacks"))
399    }
400
401    /// Health check for the embedding service
402    pub async fn health_check(&self) -> Result<EmbeddingHealth> {
403        let start_time = std::time::Instant::now();
404        
405        let test_result = self.generate_embedding("Health check test").await;
406        let response_time = start_time.elapsed();
407
408        let health = match test_result {
409            Ok(embedding) => EmbeddingHealth {
410                status: "healthy".to_string(),
411                model: self.model.clone(),
412                provider: format!("{:?}", self.provider),
413                response_time_ms: response_time.as_millis() as u64,
414                embedding_dimensions: embedding.len(),
415                error: None,
416            },
417            Err(e) => EmbeddingHealth {
418                status: "unhealthy".to_string(),
419                model: self.model.clone(),
420                provider: format!("{:?}", self.provider),
421                response_time_ms: response_time.as_millis() as u64,
422                embedding_dimensions: 0,
423                error: Some(e.to_string()),
424            },
425        };
426
427        Ok(health)
428    }
429
430    /// Detect available embedding models on Ollama
431    async fn detect_ollama_models(client: &Client, base_url: &str) -> Result<Vec<EmbeddingModelInfo>> {
432        let response = client
433            .get(&format!("{}/api/tags", base_url))
434            .send()
435            .await
436            .context("Failed to connect to Ollama API")?;
437
438        if !response.status().is_success() {
439            return Err(anyhow::anyhow!("Ollama API returned error: {}", response.status()));
440        }
441
442        let models_response: OllamaModelsResponse = response.json().await
443            .context("Failed to parse Ollama models response")?;
444
445        let mut embedding_models = Vec::new();
446        
447        for model in models_response.models {
448            if let Some(model_info) = Self::classify_embedding_model(&model.name) {
449                embedding_models.push(model_info);
450            }
451        }
452
453        Ok(embedding_models)
454    }
455
456    /// Classify a model name as an embedding model
457    fn classify_embedding_model(model_name: &str) -> Option<EmbeddingModelInfo> {
458        let name_lower = model_name.to_lowercase();
459        
460        // Define known embedding models with their properties
461        let known_models = [
462            ("nomic-embed-text", 768, "High-quality text embeddings", true),
463            ("mxbai-embed-large", 1024, "Large multilingual embeddings", true),
464            ("all-minilm", 384, "Compact sentence embeddings", false),
465            ("all-mpnet-base-v2", 768, "Sentence transformer embeddings", false),
466            ("bge-small-en", 384, "BGE small English embeddings", false),
467            ("bge-base-en", 768, "BGE base English embeddings", false),
468            ("bge-large-en", 1024, "BGE large English embeddings", false),
469            ("e5-small", 384, "E5 small embeddings", false),
470            ("e5-base", 768, "E5 base embeddings", false),
471            ("e5-large", 1024, "E5 large embeddings", false),
472        ];
473
474        for (pattern, dimensions, description, preferred) in known_models {
475            if name_lower.contains(pattern) || model_name.contains(pattern) {
476                return Some(EmbeddingModelInfo {
477                    name: model_name.to_string(),
478                    dimensions,
479                    description: description.to_string(),
480                    preferred,
481                });
482            }
483        }
484
485        // Check if it's likely an embedding model based on common patterns
486        if name_lower.contains("embed") || 
487           name_lower.contains("sentence") || 
488           name_lower.contains("vector") {
489            return Some(EmbeddingModelInfo {
490                name: model_name.to_string(),
491                dimensions: 768, // Default assumption
492                description: "Detected embedding model".to_string(),
493                preferred: false,
494            });
495        }
496
497        None
498    }
499
500    /// Select the best model from available options
501    fn select_best_model(available_models: &[EmbeddingModelInfo]) -> Result<&EmbeddingModelInfo> {
502        // Prefer recommended models first
503        if let Some(preferred) = available_models.iter().find(|m| m.preferred) {
504            return Ok(preferred);
505        }
506
507        // Fall back to any available model
508        available_models.first()
509            .ok_or_else(|| anyhow::anyhow!("No embedding models available"))
510    }
511}
512
513/// Information about an embedding model
514#[derive(Debug, Clone)]
515pub struct EmbeddingModelInfo {
516    pub name: String,
517    pub dimensions: usize,
518    pub description: String,
519    pub preferred: bool,
520}
521
522/// Health status of the embedding service
523#[derive(Debug, Clone, Serialize, Deserialize)]
524pub struct EmbeddingHealth {
525    pub status: String,
526    pub model: String,
527    pub provider: String,
528    pub response_time_ms: u64,
529    pub embedding_dimensions: usize,
530    pub error: Option<String>,
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536
537    #[tokio::test]
538    #[ignore] // Requires actual OpenAI API key
539    async fn test_generate_openai_embedding() {
540        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
541        let embedder = SimpleEmbedder::new(api_key);
542
543        let result = embedder.generate_embedding("Hello, world!").await;
544        assert!(result.is_ok());
545
546        let embedding = result.unwrap();
547        assert_eq!(embedding.len(), 1536);
548    }
549
550    #[tokio::test]
551    #[ignore] // Requires running Ollama instance
552    async fn test_generate_ollama_embedding() {
553        let embedder = SimpleEmbedder::new_ollama(
554            "http://192.168.1.110:11434".to_string(),
555            "nomic-embed-text".to_string(),
556        );
557
558        let result = embedder.generate_embedding("Hello, world!").await;
559        assert!(result.is_ok());
560
561        let embedding = result.unwrap();
562        assert_eq!(embedding.len(), 768);
563    }
564
565    #[test]
566    fn test_embedding_dimensions() {
567        let embedder = SimpleEmbedder::new("dummy_key".to_string());
568        assert_eq!(embedder.embedding_dimension(), 1536);
569
570        let embedder = embedder.with_model("text-embedding-3-large".to_string());
571        assert_eq!(embedder.embedding_dimension(), 3072);
572
573        let ollama_embedder = SimpleEmbedder::new_ollama(
574            "http://localhost:11434".to_string(),
575            "nomic-embed-text".to_string(),
576        );
577        assert_eq!(ollama_embedder.embedding_dimension(), 768);
578
579        let gpt_oss_embedder = SimpleEmbedder::new_ollama(
580            "http://localhost:11434".to_string(),
581            "gpt-oss:20b".to_string(),
582        );
583        assert_eq!(gpt_oss_embedder.embedding_dimension(), 4096);
584
585        let mock_embedder = SimpleEmbedder::new_mock();
586        assert_eq!(mock_embedder.embedding_dimension(), 768);
587    }
588
589    #[test]
590    fn test_provider_types() {
591        let openai_embedder = SimpleEmbedder::new("dummy_key".to_string());
592        assert_eq!(openai_embedder.provider(), &EmbeddingProvider::OpenAI);
593
594        let ollama_embedder = SimpleEmbedder::new_ollama(
595            "http://localhost:11434".to_string(),
596            "nomic-embed-text".to_string(),
597        );
598        assert_eq!(ollama_embedder.provider(), &EmbeddingProvider::Ollama);
599
600        let mock_embedder = SimpleEmbedder::new_mock();
601        assert_eq!(mock_embedder.provider(), &EmbeddingProvider::Mock);
602    }
603}