codex_memory/
embedding.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use backoff::{future::retry, ExponentialBackoff};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7use tracing::{info, warn};
8
9/// Trait for embedding services
10#[async_trait]
11pub trait EmbeddingService: Send + Sync {
12    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>>;
13    async fn health_check(&self) -> Result<()>;
14}
15
16#[derive(Debug, Clone)]
17pub struct SimpleEmbedder {
18    client: Client,
19    api_key: String,
20    model: String,
21    base_url: String,
22    provider: EmbeddingProvider,
23    fallback_models: Vec<String>,
24}
25
26#[derive(Debug, Clone, PartialEq)]
27pub enum EmbeddingProvider {
28    OpenAI,
29    Ollama,
30    Mock, // For testing
31}
32
33// OpenAI API request/response structures
34#[derive(Debug, Serialize)]
35struct OpenAIEmbeddingRequest {
36    input: String,
37    model: String,
38}
39
40#[derive(Debug, Deserialize)]
41struct OpenAIEmbeddingResponse {
42    data: Vec<OpenAIEmbeddingData>,
43}
44
45#[derive(Debug, Deserialize)]
46struct OpenAIEmbeddingData {
47    embedding: Vec<f32>,
48}
49
50// Ollama API request/response structures
51#[derive(Debug, Serialize)]
52struct OllamaEmbeddingRequest {
53    model: String,
54    prompt: String,
55}
56
57#[derive(Debug, Deserialize)]
58struct OllamaEmbeddingResponse {
59    embedding: Vec<f32>,
60}
61
62#[derive(Debug, Deserialize)]
63struct OllamaModel {
64    name: String,
65    #[allow(dead_code)]
66    size: u64,
67    #[serde(default)]
68    #[allow(dead_code)]
69    family: String,
70}
71
72#[derive(Debug, Deserialize)]
73struct OllamaModelsResponse {
74    models: Vec<OllamaModel>,
75}
76
77impl SimpleEmbedder {
78    pub fn new(api_key: String) -> Self {
79        let client = Client::builder()
80            .timeout(Duration::from_secs(30))
81            .build()
82            .expect("Failed to create HTTP client");
83
84        Self {
85            client,
86            api_key,
87            model: "text-embedding-3-small".to_string(),
88            base_url: "https://api.openai.com".to_string(),
89            provider: EmbeddingProvider::OpenAI,
90            fallback_models: vec![
91                "text-embedding-3-large".to_string(),
92                "text-embedding-ada-002".to_string(),
93            ],
94        }
95    }
96
97    pub fn new_ollama(base_url: String, model: String) -> Self {
98        let client = Client::builder()
99            .timeout(Duration::from_secs(60)) // Ollama might be slower
100            .build()
101            .expect("Failed to create HTTP client");
102
103        Self {
104            client,
105            api_key: String::new(), // Ollama doesn't need an API key
106            model,
107            base_url,
108            provider: EmbeddingProvider::Ollama,
109            fallback_models: vec![
110                "nomic-embed-text".to_string(),
111                "mxbai-embed-large".to_string(),
112                "all-minilm".to_string(),
113                "all-mpnet-base-v2".to_string(),
114            ],
115        }
116    }
117
118    pub fn new_mock() -> Self {
119        let client = Client::builder()
120            .timeout(Duration::from_secs(1))
121            .build()
122            .expect("Failed to create HTTP client");
123
124        Self {
125            client,
126            api_key: String::new(),
127            model: "mock-model".to_string(),
128            base_url: "http://mock:11434".to_string(),
129            provider: EmbeddingProvider::Mock,
130            fallback_models: vec!["mock-model-2".to_string()],
131        }
132    }
133
134    pub fn with_model(mut self, model: String) -> Self {
135        self.model = model;
136        self
137    }
138
139    pub fn with_base_url(mut self, base_url: String) -> Self {
140        self.base_url = base_url;
141        self
142    }
143
144    /// Generate embedding for text with automatic retry
145    pub async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
146        info!("Generating embedding for text of length: {}", text.len());
147
148        let operation = || async {
149            match self.generate_embedding_internal(text).await {
150                Ok(embedding) => Ok(embedding),
151                Err(e) => {
152                    if e.to_string().contains("Rate limited") {
153                        Err(backoff::Error::transient(e))
154                    } else {
155                        Err(backoff::Error::permanent(e))
156                    }
157                }
158            }
159        };
160
161        let backoff = ExponentialBackoff {
162            max_elapsed_time: Some(Duration::from_secs(60)),
163            ..Default::default()
164        };
165
166        retry(backoff, operation).await
167    }
168
169    async fn generate_embedding_internal(&self, text: &str) -> Result<Vec<f32>> {
170        match self.provider {
171            EmbeddingProvider::OpenAI => self.generate_openai_embedding(text).await,
172            EmbeddingProvider::Ollama => self.generate_ollama_embedding(text).await,
173            EmbeddingProvider::Mock => self.generate_mock_embedding(text).await,
174        }
175    }
176
177    async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
178        let request = OpenAIEmbeddingRequest {
179            input: text.to_string(),
180            model: self.model.clone(),
181        };
182
183        let response = self
184            .client
185            .post(format!("{}/v1/embeddings", self.base_url))
186            .header("Authorization", format!("Bearer {}", self.api_key))
187            .header("Content-Type", "application/json")
188            .json(&request)
189            .send()
190            .await?;
191
192        if !response.status().is_success() {
193            let status = response.status();
194            let error_text = response
195                .text()
196                .await
197                .unwrap_or_else(|_| "Unknown error".to_string());
198
199            if status.as_u16() == 429 {
200                warn!("Rate limited by OpenAI API, will retry");
201                return Err(anyhow::anyhow!("Rate limited: {}", error_text));
202            }
203
204            return Err(anyhow::anyhow!(
205                "OpenAI API request failed with status {}: {}",
206                status,
207                error_text
208            ));
209        }
210
211        let embedding_response: OpenAIEmbeddingResponse = response.json().await?;
212
213        if let Some(embedding_data) = embedding_response.data.first() {
214            Ok(embedding_data.embedding.clone())
215        } else {
216            Err(anyhow::anyhow!("No embedding data in OpenAI response"))
217        }
218    }
219
220    async fn generate_ollama_embedding(&self, text: &str) -> Result<Vec<f32>> {
221        let request = OllamaEmbeddingRequest {
222            model: self.model.clone(),
223            prompt: text.to_string(),
224        };
225
226        let response = self
227            .client
228            .post(format!("{}/api/embeddings", self.base_url))
229            .header("Content-Type", "application/json")
230            .json(&request)
231            .send()
232            .await?;
233
234        if !response.status().is_success() {
235            let status = response.status();
236            let error_text = response
237                .text()
238                .await
239                .unwrap_or_else(|_| "Unknown error".to_string());
240
241            if status.as_u16() == 429 {
242                warn!("Rate limited by Ollama API, will retry");
243                return Err(anyhow::anyhow!("Rate limited: {}", error_text));
244            }
245
246            return Err(anyhow::anyhow!(
247                "Ollama API request failed with status {}: {}",
248                status,
249                error_text
250            ));
251        }
252
253        let embedding_response: OllamaEmbeddingResponse = response.json().await?;
254        Ok(embedding_response.embedding)
255    }
256
257    async fn generate_mock_embedding(&self, text: &str) -> Result<Vec<f32>> {
258        // Generate a deterministic mock embedding based on text content
259        // This is useful for testing without requiring real embedding services
260        use std::collections::hash_map::DefaultHasher;
261        use std::hash::{Hash, Hasher};
262
263        let mut hasher = DefaultHasher::new();
264        text.hash(&mut hasher);
265        let hash = hasher.finish();
266
267        // Generate a fixed-size embedding (768 dimensions for consistency)
268        let dimensions = self.embedding_dimension();
269        let mut embedding = Vec::with_capacity(dimensions);
270
271        // Use the hash to seed a simple PRNG for consistent results
272        let mut seed = hash;
273        for _ in 0..dimensions {
274            seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
275            let value = ((seed >> 16) % 1000) as f32 / 1000.0 - 0.5; // -0.5 to 0.5
276            embedding.push(value);
277        }
278
279        // Normalize the embedding to unit length
280        let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
281        if magnitude > 0.0 {
282            for val in &mut embedding {
283                *val /= magnitude;
284            }
285        }
286
287        Ok(embedding)
288    }
289
290    /// Generate embeddings for multiple texts in batch
291    pub async fn generate_embeddings_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
292        info!("Generating embeddings for {} texts", texts.len());
293
294        let mut embeddings = Vec::with_capacity(texts.len());
295
296        // Process in small batches to avoid rate limits
297        for chunk in texts.chunks(10) {
298            let mut chunk_embeddings = Vec::with_capacity(chunk.len());
299
300            for text in chunk {
301                match self.generate_embedding(text).await {
302                    Ok(embedding) => chunk_embeddings.push(embedding),
303                    Err(e) => {
304                        warn!("Failed to generate embedding for text: {}", e);
305                        return Err(e);
306                    }
307                }
308
309                // Small delay to be respectful to the API
310                tokio::time::sleep(Duration::from_millis(100)).await;
311            }
312
313            embeddings.extend(chunk_embeddings);
314        }
315
316        Ok(embeddings)
317    }
318
319    /// Get the dimension of embeddings for this model
320    pub fn embedding_dimension(&self) -> usize {
321        match self.provider {
322            EmbeddingProvider::OpenAI => match self.model.as_str() {
323                "text-embedding-3-small" => 1536,
324                "text-embedding-3-large" => 3072,
325                "text-embedding-ada-002" => 1536,
326                _ => 1536, // Default to small model dimensions
327            },
328            EmbeddingProvider::Ollama => {
329                // Ollama models vary in dimensions, but many common ones use these sizes
330                match self.model.as_str() {
331                    "gpt-oss:20b" => 4096, // Assuming this model has 4096 dimensions
332                    "nomic-embed-text" => 768,
333                    "mxbai-embed-large" => 1024,
334                    "all-minilm" => 384,
335                    _ => 768, // Default dimension for many embedding models
336                }
337            }
338            EmbeddingProvider::Mock => 768, // Consistent mock embedding dimension
339        }
340    }
341
342    /// Get the provider type
343    pub fn provider(&self) -> &EmbeddingProvider {
344        &self.provider
345    }
346
347    /// Auto-detect and configure the best available embedding model
348    pub async fn auto_configure(base_url: String) -> Result<Self> {
349        info!("🔍 Auto-detecting best available embedding model...");
350
351        let client = Client::builder()
352            .timeout(Duration::from_secs(30))
353            .build()
354            .context("Failed to create HTTP client")?;
355
356        // Try to get available models from Ollama
357        let available_models = Self::detect_ollama_models(&client, &base_url).await?;
358
359        if available_models.is_empty() {
360            return Err(anyhow::anyhow!(
361                "No embedding models found on Ollama server"
362            ));
363        }
364
365        // Select the best available model
366        let selected_model = Self::select_best_model(&available_models)?;
367
368        info!(
369            "✅ Selected model: {} ({}D)",
370            selected_model.name, selected_model.dimensions
371        );
372
373        let mut embedder = Self::new_ollama(base_url, selected_model.name.clone());
374        embedder.fallback_models = available_models
375            .into_iter()
376            .filter(|m| m.name != embedder.model)
377            .map(|m| m.name)
378            .collect();
379
380        Ok(embedder)
381    }
382
383    /// Generate embedding with automatic fallback to alternative models
384    pub async fn generate_embedding_with_fallback(&self, text: &str) -> Result<Vec<f32>> {
385        // Try the primary model first
386        match self.generate_embedding(text).await {
387            Ok(embedding) => return Ok(embedding),
388            Err(e) => {
389                warn!("Primary model '{}' failed: {}", self.model, e);
390            }
391        }
392
393        // Try fallback models
394        for fallback_model in &self.fallback_models {
395            info!("🔄 Trying fallback model: {}", fallback_model);
396
397            let mut fallback_embedder = self.clone();
398            fallback_embedder.model = fallback_model.clone();
399
400            match fallback_embedder.generate_embedding(text).await {
401                Ok(embedding) => {
402                    info!("✅ Fallback model '{}' succeeded", fallback_model);
403                    return Ok(embedding);
404                }
405                Err(e) => {
406                    warn!("Fallback model '{}' failed: {}", fallback_model, e);
407                    continue;
408                }
409            }
410        }
411
412        Err(anyhow::anyhow!(
413            "All embedding models failed, including fallbacks"
414        ))
415    }
416
417    /// Health check for the embedding service
418    pub async fn health_check(&self) -> Result<EmbeddingHealth> {
419        let start_time = std::time::Instant::now();
420
421        let test_result = self.generate_embedding("Health check test").await;
422        let response_time = start_time.elapsed();
423
424        let health = match test_result {
425            Ok(embedding) => EmbeddingHealth {
426                status: "healthy".to_string(),
427                model: self.model.clone(),
428                provider: format!("{:?}", self.provider),
429                response_time_ms: response_time.as_millis() as u64,
430                embedding_dimensions: embedding.len(),
431                error: None,
432            },
433            Err(e) => EmbeddingHealth {
434                status: "unhealthy".to_string(),
435                model: self.model.clone(),
436                provider: format!("{:?}", self.provider),
437                response_time_ms: response_time.as_millis() as u64,
438                embedding_dimensions: 0,
439                error: Some(e.to_string()),
440            },
441        };
442
443        Ok(health)
444    }
445
446    /// Detect available embedding models on Ollama
447    async fn detect_ollama_models(
448        client: &Client,
449        base_url: &str,
450    ) -> Result<Vec<EmbeddingModelInfo>> {
451        let response = client
452            .get(format!("{base_url}/api/tags"))
453            .send()
454            .await
455            .context("Failed to connect to Ollama API")?;
456
457        if !response.status().is_success() {
458            return Err(anyhow::anyhow!(
459                "Ollama API returned error: {}",
460                response.status()
461            ));
462        }
463
464        let models_response: OllamaModelsResponse = response
465            .json()
466            .await
467            .context("Failed to parse Ollama models response")?;
468
469        let mut embedding_models = Vec::new();
470
471        for model in models_response.models {
472            if let Some(model_info) = Self::classify_embedding_model(&model.name) {
473                embedding_models.push(model_info);
474            }
475        }
476
477        Ok(embedding_models)
478    }
479
480    /// Classify a model name as an embedding model
481    fn classify_embedding_model(model_name: &str) -> Option<EmbeddingModelInfo> {
482        let name_lower = model_name.to_lowercase();
483
484        // Define known embedding models with their properties
485        let known_models = [
486            (
487                "nomic-embed-text",
488                768,
489                "High-quality text embeddings",
490                true,
491            ),
492            (
493                "mxbai-embed-large",
494                1024,
495                "Large multilingual embeddings",
496                true,
497            ),
498            ("all-minilm", 384, "Compact sentence embeddings", false),
499            (
500                "all-mpnet-base-v2",
501                768,
502                "Sentence transformer embeddings",
503                false,
504            ),
505            ("bge-small-en", 384, "BGE small English embeddings", false),
506            ("bge-base-en", 768, "BGE base English embeddings", false),
507            ("bge-large-en", 1024, "BGE large English embeddings", false),
508            ("e5-small", 384, "E5 small embeddings", false),
509            ("e5-base", 768, "E5 base embeddings", false),
510            ("e5-large", 1024, "E5 large embeddings", false),
511        ];
512
513        for (pattern, dimensions, description, preferred) in known_models {
514            if name_lower.contains(pattern) || model_name.contains(pattern) {
515                return Some(EmbeddingModelInfo {
516                    name: model_name.to_string(),
517                    dimensions,
518                    description: description.to_string(),
519                    preferred,
520                });
521            }
522        }
523
524        // Check if it's likely an embedding model based on common patterns
525        if name_lower.contains("embed")
526            || name_lower.contains("sentence")
527            || name_lower.contains("vector")
528        {
529            return Some(EmbeddingModelInfo {
530                name: model_name.to_string(),
531                dimensions: 768, // Default assumption
532                description: "Detected embedding model".to_string(),
533                preferred: false,
534            });
535        }
536
537        None
538    }
539
540    /// Select the best model from available options
541    fn select_best_model(available_models: &[EmbeddingModelInfo]) -> Result<&EmbeddingModelInfo> {
542        // Prefer recommended models first
543        if let Some(preferred) = available_models.iter().find(|m| m.preferred) {
544            return Ok(preferred);
545        }
546
547        // Fall back to any available model
548        available_models
549            .first()
550            .ok_or_else(|| anyhow::anyhow!("No embedding models available"))
551    }
552}
553
554#[async_trait]
555impl EmbeddingService for SimpleEmbedder {
556    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
557        SimpleEmbedder::generate_embedding(self, text).await
558    }
559
560    async fn health_check(&self) -> Result<()> {
561        let health = SimpleEmbedder::health_check(self).await?;
562        if health.status == "healthy" {
563            Ok(())
564        } else {
565            Err(anyhow::anyhow!(
566                "Embedding service unhealthy: {:?}",
567                health.error
568            ))
569        }
570    }
571}
572
573/// Information about an embedding model
574#[derive(Debug, Clone)]
575pub struct EmbeddingModelInfo {
576    pub name: String,
577    pub dimensions: usize,
578    pub description: String,
579    pub preferred: bool,
580}
581
582/// Health status of the embedding service
583#[derive(Debug, Clone, Serialize, Deserialize)]
584pub struct EmbeddingHealth {
585    pub status: String,
586    pub model: String,
587    pub provider: String,
588    pub response_time_ms: u64,
589    pub embedding_dimensions: usize,
590    pub error: Option<String>,
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596
597    #[tokio::test]
598    #[ignore] // Requires actual OpenAI API key
599    async fn test_generate_openai_embedding() {
600        let api_key = match std::env::var("OPENAI_API_KEY") {
601            Ok(key) => key,
602            Err(_) => {
603                eprintln!("OPENAI_API_KEY not set, skipping test");
604                return;
605            }
606        };
607        let embedder = SimpleEmbedder::new(api_key);
608
609        let result = embedder.generate_embedding("Hello, world!").await;
610        assert!(result.is_ok());
611
612        let embedding = result.unwrap();
613        assert_eq!(embedding.len(), 1536);
614    }
615
616    #[tokio::test]
617    #[ignore] // Requires running Ollama instance
618    async fn test_generate_ollama_embedding() {
619        let embedder = SimpleEmbedder::new_ollama(
620            "http://192.168.1.110:11434".to_string(),
621            "nomic-embed-text".to_string(),
622        );
623
624        let result = embedder.generate_embedding("Hello, world!").await;
625        assert!(result.is_ok());
626
627        let embedding = result.unwrap();
628        assert_eq!(embedding.len(), 768);
629    }
630
631    #[test]
632    fn test_embedding_dimensions() {
633        let embedder = SimpleEmbedder::new("dummy_key".to_string());
634        assert_eq!(embedder.embedding_dimension(), 1536);
635
636        let embedder = embedder.with_model("text-embedding-3-large".to_string());
637        assert_eq!(embedder.embedding_dimension(), 3072);
638
639        let ollama_embedder = SimpleEmbedder::new_ollama(
640            "http://localhost:11434".to_string(),
641            "nomic-embed-text".to_string(),
642        );
643        assert_eq!(ollama_embedder.embedding_dimension(), 768);
644
645        let gpt_oss_embedder = SimpleEmbedder::new_ollama(
646            "http://localhost:11434".to_string(),
647            "gpt-oss:20b".to_string(),
648        );
649        assert_eq!(gpt_oss_embedder.embedding_dimension(), 4096);
650
651        let mock_embedder = SimpleEmbedder::new_mock();
652        assert_eq!(mock_embedder.embedding_dimension(), 768);
653    }
654
655    #[test]
656    fn test_provider_types() {
657        let openai_embedder = SimpleEmbedder::new("dummy_key".to_string());
658        assert_eq!(openai_embedder.provider(), &EmbeddingProvider::OpenAI);
659
660        let ollama_embedder = SimpleEmbedder::new_ollama(
661            "http://localhost:11434".to_string(),
662            "nomic-embed-text".to_string(),
663        );
664        assert_eq!(ollama_embedder.provider(), &EmbeddingProvider::Ollama);
665
666        let mock_embedder = SimpleEmbedder::new_mock();
667        assert_eq!(mock_embedder.provider(), &EmbeddingProvider::Mock);
668    }
669}