codex_memory/
embedding.rs

1use anyhow::{Context, Result};
2use backoff::{future::retry, ExponentialBackoff};
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6use tracing::{info, warn};
7
8#[derive(Debug, Clone)]
9pub struct SimpleEmbedder {
10    client: Client,
11    api_key: String,
12    model: String,
13    base_url: String,
14    provider: EmbeddingProvider,
15    fallback_models: Vec<String>,
16}
17
18#[derive(Debug, Clone, PartialEq)]
19pub enum EmbeddingProvider {
20    OpenAI,
21    Ollama,
22    Mock, // For testing
23}
24
25// OpenAI API request/response structures
26#[derive(Debug, Serialize)]
27struct OpenAIEmbeddingRequest {
28    input: String,
29    model: String,
30}
31
32#[derive(Debug, Deserialize)]
33struct OpenAIEmbeddingResponse {
34    data: Vec<OpenAIEmbeddingData>,
35}
36
37#[derive(Debug, Deserialize)]
38struct OpenAIEmbeddingData {
39    embedding: Vec<f32>,
40}
41
42// Ollama API request/response structures
43#[derive(Debug, Serialize)]
44struct OllamaEmbeddingRequest {
45    model: String,
46    prompt: String,
47}
48
49#[derive(Debug, Deserialize)]
50struct OllamaEmbeddingResponse {
51    embedding: Vec<f32>,
52}
53
54#[derive(Debug, Deserialize)]
55struct OllamaModel {
56    name: String,
57    #[allow(dead_code)]
58    size: u64,
59    #[serde(default)]
60    #[allow(dead_code)]
61    family: String,
62}
63
64#[derive(Debug, Deserialize)]
65struct OllamaModelsResponse {
66    models: Vec<OllamaModel>,
67}
68
69impl SimpleEmbedder {
70    pub fn new(api_key: String) -> Self {
71        let client = Client::builder()
72            .timeout(Duration::from_secs(30))
73            .build()
74            .expect("Failed to create HTTP client");
75
76        Self {
77            client,
78            api_key,
79            model: "text-embedding-3-small".to_string(),
80            base_url: "https://api.openai.com".to_string(),
81            provider: EmbeddingProvider::OpenAI,
82            fallback_models: vec![
83                "text-embedding-3-large".to_string(),
84                "text-embedding-ada-002".to_string(),
85            ],
86        }
87    }
88
89    pub fn new_ollama(base_url: String, model: String) -> Self {
90        let client = Client::builder()
91            .timeout(Duration::from_secs(60)) // Ollama might be slower
92            .build()
93            .expect("Failed to create HTTP client");
94
95        Self {
96            client,
97            api_key: String::new(), // Ollama doesn't need an API key
98            model,
99            base_url,
100            provider: EmbeddingProvider::Ollama,
101            fallback_models: vec![
102                "nomic-embed-text".to_string(),
103                "mxbai-embed-large".to_string(),
104                "all-minilm".to_string(),
105                "all-mpnet-base-v2".to_string(),
106            ],
107        }
108    }
109
110    pub fn new_mock() -> Self {
111        let client = Client::builder()
112            .timeout(Duration::from_secs(1))
113            .build()
114            .expect("Failed to create HTTP client");
115
116        Self {
117            client,
118            api_key: String::new(),
119            model: "mock-model".to_string(),
120            base_url: "http://mock:11434".to_string(),
121            provider: EmbeddingProvider::Mock,
122            fallback_models: vec!["mock-model-2".to_string()],
123        }
124    }
125
126    pub fn with_model(mut self, model: String) -> Self {
127        self.model = model;
128        self
129    }
130
131    pub fn with_base_url(mut self, base_url: String) -> Self {
132        self.base_url = base_url;
133        self
134    }
135
136    /// Generate embedding for text with automatic retry
137    pub async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
138        info!("Generating embedding for text of length: {}", text.len());
139
140        let operation = || async {
141            match self.generate_embedding_internal(text).await {
142                Ok(embedding) => Ok(embedding),
143                Err(e) => {
144                    if e.to_string().contains("Rate limited") {
145                        Err(backoff::Error::transient(e))
146                    } else {
147                        Err(backoff::Error::permanent(e))
148                    }
149                }
150            }
151        };
152
153        let backoff = ExponentialBackoff {
154            max_elapsed_time: Some(Duration::from_secs(60)),
155            ..Default::default()
156        };
157
158        retry(backoff, operation).await
159    }
160
161    async fn generate_embedding_internal(&self, text: &str) -> Result<Vec<f32>> {
162        match self.provider {
163            EmbeddingProvider::OpenAI => self.generate_openai_embedding(text).await,
164            EmbeddingProvider::Ollama => self.generate_ollama_embedding(text).await,
165            EmbeddingProvider::Mock => self.generate_mock_embedding(text).await,
166        }
167    }
168
169    async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
170        let request = OpenAIEmbeddingRequest {
171            input: text.to_string(),
172            model: self.model.clone(),
173        };
174
175        let response = self
176            .client
177            .post(&format!("{}/v1/embeddings", self.base_url))
178            .header("Authorization", format!("Bearer {}", self.api_key))
179            .header("Content-Type", "application/json")
180            .json(&request)
181            .send()
182            .await?;
183
184        if !response.status().is_success() {
185            let status = response.status();
186            let error_text = response
187                .text()
188                .await
189                .unwrap_or_else(|_| "Unknown error".to_string());
190
191            if status.as_u16() == 429 {
192                warn!("Rate limited by OpenAI API, will retry");
193                return Err(anyhow::anyhow!("Rate limited: {}", error_text));
194            }
195
196            return Err(anyhow::anyhow!(
197                "OpenAI API request failed with status {}: {}",
198                status,
199                error_text
200            ));
201        }
202
203        let embedding_response: OpenAIEmbeddingResponse = response.json().await?;
204
205        if let Some(embedding_data) = embedding_response.data.first() {
206            Ok(embedding_data.embedding.clone())
207        } else {
208            Err(anyhow::anyhow!("No embedding data in OpenAI response"))
209        }
210    }
211
212    async fn generate_ollama_embedding(&self, text: &str) -> Result<Vec<f32>> {
213        let request = OllamaEmbeddingRequest {
214            model: self.model.clone(),
215            prompt: text.to_string(),
216        };
217
218        let response = self
219            .client
220            .post(&format!("{}/api/embeddings", self.base_url))
221            .header("Content-Type", "application/json")
222            .json(&request)
223            .send()
224            .await?;
225
226        if !response.status().is_success() {
227            let status = response.status();
228            let error_text = response
229                .text()
230                .await
231                .unwrap_or_else(|_| "Unknown error".to_string());
232
233            if status.as_u16() == 429 {
234                warn!("Rate limited by Ollama API, will retry");
235                return Err(anyhow::anyhow!("Rate limited: {}", error_text));
236            }
237
238            return Err(anyhow::anyhow!(
239                "Ollama API request failed with status {}: {}",
240                status,
241                error_text
242            ));
243        }
244
245        let embedding_response: OllamaEmbeddingResponse = response.json().await?;
246        Ok(embedding_response.embedding)
247    }
248
249    async fn generate_mock_embedding(&self, text: &str) -> Result<Vec<f32>> {
250        // Generate a deterministic mock embedding based on text content
251        // This is useful for testing without requiring real embedding services
252        use std::collections::hash_map::DefaultHasher;
253        use std::hash::{Hash, Hasher};
254
255        let mut hasher = DefaultHasher::new();
256        text.hash(&mut hasher);
257        let hash = hasher.finish();
258
259        // Generate a fixed-size embedding (768 dimensions for consistency)
260        let dimensions = self.embedding_dimension();
261        let mut embedding = Vec::with_capacity(dimensions);
262
263        // Use the hash to seed a simple PRNG for consistent results
264        let mut seed = hash;
265        for _ in 0..dimensions {
266            seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
267            let value = ((seed >> 16) % 1000) as f32 / 1000.0 - 0.5; // -0.5 to 0.5
268            embedding.push(value);
269        }
270
271        // Normalize the embedding to unit length
272        let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
273        if magnitude > 0.0 {
274            for val in &mut embedding {
275                *val /= magnitude;
276            }
277        }
278
279        Ok(embedding)
280    }
281
282    /// Generate embeddings for multiple texts in batch
283    pub async fn generate_embeddings_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
284        info!("Generating embeddings for {} texts", texts.len());
285
286        let mut embeddings = Vec::with_capacity(texts.len());
287
288        // Process in small batches to avoid rate limits
289        for chunk in texts.chunks(10) {
290            let mut chunk_embeddings = Vec::with_capacity(chunk.len());
291
292            for text in chunk {
293                match self.generate_embedding(text).await {
294                    Ok(embedding) => chunk_embeddings.push(embedding),
295                    Err(e) => {
296                        warn!("Failed to generate embedding for text: {}", e);
297                        return Err(e);
298                    }
299                }
300
301                // Small delay to be respectful to the API
302                tokio::time::sleep(Duration::from_millis(100)).await;
303            }
304
305            embeddings.extend(chunk_embeddings);
306        }
307
308        Ok(embeddings)
309    }
310
311    /// Get the dimension of embeddings for this model
312    pub fn embedding_dimension(&self) -> usize {
313        match self.provider {
314            EmbeddingProvider::OpenAI => match self.model.as_str() {
315                "text-embedding-3-small" => 1536,
316                "text-embedding-3-large" => 3072,
317                "text-embedding-ada-002" => 1536,
318                _ => 1536, // Default to small model dimensions
319            },
320            EmbeddingProvider::Ollama => {
321                // Ollama models vary in dimensions, but many common ones use these sizes
322                match self.model.as_str() {
323                    "gpt-oss:20b" => 4096, // Assuming this model has 4096 dimensions
324                    "nomic-embed-text" => 768,
325                    "mxbai-embed-large" => 1024,
326                    "all-minilm" => 384,
327                    _ => 768, // Default dimension for many embedding models
328                }
329            }
330            EmbeddingProvider::Mock => 768, // Consistent mock embedding dimension
331        }
332    }
333
334    /// Get the provider type
335    pub fn provider(&self) -> &EmbeddingProvider {
336        &self.provider
337    }
338
339    /// Auto-detect and configure the best available embedding model
340    pub async fn auto_configure(base_url: String) -> Result<Self> {
341        info!("🔍 Auto-detecting best available embedding model...");
342
343        let client = Client::builder()
344            .timeout(Duration::from_secs(30))
345            .build()
346            .context("Failed to create HTTP client")?;
347
348        // Try to get available models from Ollama
349        let available_models = Self::detect_ollama_models(&client, &base_url).await?;
350
351        if available_models.is_empty() {
352            return Err(anyhow::anyhow!(
353                "No embedding models found on Ollama server"
354            ));
355        }
356
357        // Select the best available model
358        let selected_model = Self::select_best_model(&available_models)?;
359
360        info!(
361            "✅ Selected model: {} ({}D)",
362            selected_model.name, selected_model.dimensions
363        );
364
365        let mut embedder = Self::new_ollama(base_url, selected_model.name.clone());
366        embedder.fallback_models = available_models
367            .into_iter()
368            .filter(|m| m.name != embedder.model)
369            .map(|m| m.name)
370            .collect();
371
372        Ok(embedder)
373    }
374
375    /// Generate embedding with automatic fallback to alternative models
376    pub async fn generate_embedding_with_fallback(&self, text: &str) -> Result<Vec<f32>> {
377        // Try the primary model first
378        match self.generate_embedding(text).await {
379            Ok(embedding) => return Ok(embedding),
380            Err(e) => {
381                warn!("Primary model '{}' failed: {}", self.model, e);
382            }
383        }
384
385        // Try fallback models
386        for fallback_model in &self.fallback_models {
387            info!("🔄 Trying fallback model: {}", fallback_model);
388
389            let mut fallback_embedder = self.clone();
390            fallback_embedder.model = fallback_model.clone();
391
392            match fallback_embedder.generate_embedding(text).await {
393                Ok(embedding) => {
394                    info!("✅ Fallback model '{}' succeeded", fallback_model);
395                    return Ok(embedding);
396                }
397                Err(e) => {
398                    warn!("Fallback model '{}' failed: {}", fallback_model, e);
399                    continue;
400                }
401            }
402        }
403
404        Err(anyhow::anyhow!(
405            "All embedding models failed, including fallbacks"
406        ))
407    }
408
409    /// Health check for the embedding service
410    pub async fn health_check(&self) -> Result<EmbeddingHealth> {
411        let start_time = std::time::Instant::now();
412
413        let test_result = self.generate_embedding("Health check test").await;
414        let response_time = start_time.elapsed();
415
416        let health = match test_result {
417            Ok(embedding) => EmbeddingHealth {
418                status: "healthy".to_string(),
419                model: self.model.clone(),
420                provider: format!("{:?}", self.provider),
421                response_time_ms: response_time.as_millis() as u64,
422                embedding_dimensions: embedding.len(),
423                error: None,
424            },
425            Err(e) => EmbeddingHealth {
426                status: "unhealthy".to_string(),
427                model: self.model.clone(),
428                provider: format!("{:?}", self.provider),
429                response_time_ms: response_time.as_millis() as u64,
430                embedding_dimensions: 0,
431                error: Some(e.to_string()),
432            },
433        };
434
435        Ok(health)
436    }
437
438    /// Detect available embedding models on Ollama
439    async fn detect_ollama_models(
440        client: &Client,
441        base_url: &str,
442    ) -> Result<Vec<EmbeddingModelInfo>> {
443        let response = client
444            .get(&format!("{}/api/tags", base_url))
445            .send()
446            .await
447            .context("Failed to connect to Ollama API")?;
448
449        if !response.status().is_success() {
450            return Err(anyhow::anyhow!(
451                "Ollama API returned error: {}",
452                response.status()
453            ));
454        }
455
456        let models_response: OllamaModelsResponse = response
457            .json()
458            .await
459            .context("Failed to parse Ollama models response")?;
460
461        let mut embedding_models = Vec::new();
462
463        for model in models_response.models {
464            if let Some(model_info) = Self::classify_embedding_model(&model.name) {
465                embedding_models.push(model_info);
466            }
467        }
468
469        Ok(embedding_models)
470    }
471
472    /// Classify a model name as an embedding model
473    fn classify_embedding_model(model_name: &str) -> Option<EmbeddingModelInfo> {
474        let name_lower = model_name.to_lowercase();
475
476        // Define known embedding models with their properties
477        let known_models = [
478            (
479                "nomic-embed-text",
480                768,
481                "High-quality text embeddings",
482                true,
483            ),
484            (
485                "mxbai-embed-large",
486                1024,
487                "Large multilingual embeddings",
488                true,
489            ),
490            ("all-minilm", 384, "Compact sentence embeddings", false),
491            (
492                "all-mpnet-base-v2",
493                768,
494                "Sentence transformer embeddings",
495                false,
496            ),
497            ("bge-small-en", 384, "BGE small English embeddings", false),
498            ("bge-base-en", 768, "BGE base English embeddings", false),
499            ("bge-large-en", 1024, "BGE large English embeddings", false),
500            ("e5-small", 384, "E5 small embeddings", false),
501            ("e5-base", 768, "E5 base embeddings", false),
502            ("e5-large", 1024, "E5 large embeddings", false),
503        ];
504
505        for (pattern, dimensions, description, preferred) in known_models {
506            if name_lower.contains(pattern) || model_name.contains(pattern) {
507                return Some(EmbeddingModelInfo {
508                    name: model_name.to_string(),
509                    dimensions,
510                    description: description.to_string(),
511                    preferred,
512                });
513            }
514        }
515
516        // Check if it's likely an embedding model based on common patterns
517        if name_lower.contains("embed")
518            || name_lower.contains("sentence")
519            || name_lower.contains("vector")
520        {
521            return Some(EmbeddingModelInfo {
522                name: model_name.to_string(),
523                dimensions: 768, // Default assumption
524                description: "Detected embedding model".to_string(),
525                preferred: false,
526            });
527        }
528
529        None
530    }
531
532    /// Select the best model from available options
533    fn select_best_model(available_models: &[EmbeddingModelInfo]) -> Result<&EmbeddingModelInfo> {
534        // Prefer recommended models first
535        if let Some(preferred) = available_models.iter().find(|m| m.preferred) {
536            return Ok(preferred);
537        }
538
539        // Fall back to any available model
540        available_models
541            .first()
542            .ok_or_else(|| anyhow::anyhow!("No embedding models available"))
543    }
544}
545
546/// Information about an embedding model
547#[derive(Debug, Clone)]
548pub struct EmbeddingModelInfo {
549    pub name: String,
550    pub dimensions: usize,
551    pub description: String,
552    pub preferred: bool,
553}
554
555/// Health status of the embedding service
556#[derive(Debug, Clone, Serialize, Deserialize)]
557pub struct EmbeddingHealth {
558    pub status: String,
559    pub model: String,
560    pub provider: String,
561    pub response_time_ms: u64,
562    pub embedding_dimensions: usize,
563    pub error: Option<String>,
564}
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569
570    #[tokio::test]
571    #[ignore] // Requires actual OpenAI API key
572    async fn test_generate_openai_embedding() {
573        let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
574        let embedder = SimpleEmbedder::new(api_key);
575
576        let result = embedder.generate_embedding("Hello, world!").await;
577        assert!(result.is_ok());
578
579        let embedding = result.unwrap();
580        assert_eq!(embedding.len(), 1536);
581    }
582
583    #[tokio::test]
584    #[ignore] // Requires running Ollama instance
585    async fn test_generate_ollama_embedding() {
586        let embedder = SimpleEmbedder::new_ollama(
587            "http://192.168.1.110:11434".to_string(),
588            "nomic-embed-text".to_string(),
589        );
590
591        let result = embedder.generate_embedding("Hello, world!").await;
592        assert!(result.is_ok());
593
594        let embedding = result.unwrap();
595        assert_eq!(embedding.len(), 768);
596    }
597
598    #[test]
599    fn test_embedding_dimensions() {
600        let embedder = SimpleEmbedder::new("dummy_key".to_string());
601        assert_eq!(embedder.embedding_dimension(), 1536);
602
603        let embedder = embedder.with_model("text-embedding-3-large".to_string());
604        assert_eq!(embedder.embedding_dimension(), 3072);
605
606        let ollama_embedder = SimpleEmbedder::new_ollama(
607            "http://localhost:11434".to_string(),
608            "nomic-embed-text".to_string(),
609        );
610        assert_eq!(ollama_embedder.embedding_dimension(), 768);
611
612        let gpt_oss_embedder = SimpleEmbedder::new_ollama(
613            "http://localhost:11434".to_string(),
614            "gpt-oss:20b".to_string(),
615        );
616        assert_eq!(gpt_oss_embedder.embedding_dimension(), 4096);
617
618        let mock_embedder = SimpleEmbedder::new_mock();
619        assert_eq!(mock_embedder.embedding_dimension(), 768);
620    }
621
622    #[test]
623    fn test_provider_types() {
624        let openai_embedder = SimpleEmbedder::new("dummy_key".to_string());
625        assert_eq!(openai_embedder.provider(), &EmbeddingProvider::OpenAI);
626
627        let ollama_embedder = SimpleEmbedder::new_ollama(
628            "http://localhost:11434".to_string(),
629            "nomic-embed-text".to_string(),
630        );
631        assert_eq!(ollama_embedder.provider(), &EmbeddingProvider::Ollama);
632
633        let mock_embedder = SimpleEmbedder::new_mock();
634        assert_eq!(mock_embedder.provider(), &EmbeddingProvider::Mock);
635    }
636}