Skip to main content

manx_cli/rag/
embeddings.rs

1//! Text embedding generation using configurable embedding providers
2//!
3//! This module provides flexible text embedding functionality for semantic similarity search.
4//! Supports multiple embedding providers: hash-based (default), local models, and API services.
5//! Users can configure their preferred embedding method via `manx config --embedding-provider`.
6
7use crate::rag::providers::{
8    custom, hash, huggingface, ollama, onnx, openai, EmbeddingProvider as ProviderTrait,
9};
10use crate::rag::{EmbeddingConfig, EmbeddingProvider};
11use anyhow::{anyhow, Result};
12use lru::LruCache;
13use std::collections::hash_map::DefaultHasher;
14use std::hash::{Hash, Hasher};
15use std::num::NonZeroUsize;
16use std::sync::Mutex;
17
18/// Text embedding model wrapper with configurable providers and LRU cache
19/// Supports hash-based embeddings (default), local ONNX models, and API services.
20/// Users can configure their preferred embedding method via `manx config`.
21pub struct EmbeddingModel {
22    provider: Box<dyn ProviderTrait + Send + Sync>,
23    config: EmbeddingConfig,
24    /// LRU cache for recent embeddings (text hash -> embedding vector)
25    cache: Mutex<LruCache<u64, Vec<f32>>>,
26}
27
28impl EmbeddingModel {
29    /// Create a new embedding model with default hash-based provider
30    pub async fn new() -> Result<Self> {
31        Self::new_with_config(EmbeddingConfig::default()).await
32    }
33
34    /// Create embedding model with smart auto-selection of best available provider
35    pub async fn new_auto_select() -> Result<Self> {
36        let best_config = Self::auto_select_best_provider().await?;
37        Self::new_with_config(best_config).await
38    }
39
40    /// Create a new embedding model with custom configuration
41    pub async fn new_with_config(config: EmbeddingConfig) -> Result<Self> {
42        log::info!(
43            "Initializing embedding model with provider: {:?}",
44            config.provider
45        );
46
47        let provider: Box<dyn ProviderTrait + Send + Sync> = match &config.provider {
48            EmbeddingProvider::Hash => {
49                log::info!("Using hash-based embeddings (default provider)");
50                Box::new(hash::HashProvider::new(384)) // Hash provider always uses 384 dimensions
51            }
52            EmbeddingProvider::Onnx(model_name) => {
53                log::info!("Loading ONNX model: {}", model_name);
54                let onnx_provider = onnx::OnnxProvider::new(model_name).await?;
55                Box::new(onnx_provider)
56            }
57            EmbeddingProvider::Ollama(model_name) => {
58                log::info!("Connecting to Ollama model: {}", model_name);
59                let ollama_provider =
60                    ollama::OllamaProvider::new(model_name.clone(), config.endpoint.clone());
61                // Test connection
62                ollama_provider.health_check().await?;
63                Box::new(ollama_provider)
64            }
65            EmbeddingProvider::OpenAI(model_name) => {
66                log::info!("Connecting to OpenAI model: {}", model_name);
67                let api_key = config.api_key.as_ref().ok_or_else(|| {
68                    anyhow!("OpenAI API key required. Use 'manx config --embedding-api-key <key>'")
69                })?;
70                let openai_provider =
71                    openai::OpenAiProvider::new(api_key.clone(), model_name.clone());
72                Box::new(openai_provider)
73            }
74            EmbeddingProvider::HuggingFace(model_name) => {
75                log::info!("Connecting to HuggingFace model: {}", model_name);
76                let api_key = config.api_key.as_ref().ok_or_else(|| {
77                    anyhow!(
78                        "HuggingFace API key required. Use 'manx config --embedding-api-key <key>'"
79                    )
80                })?;
81                let hf_provider =
82                    huggingface::HuggingFaceProvider::new(api_key.clone(), model_name.clone());
83                Box::new(hf_provider)
84            }
85            EmbeddingProvider::Custom(endpoint) => {
86                log::info!("Connecting to custom endpoint: {}", endpoint);
87                let custom_provider =
88                    custom::CustomProvider::new(endpoint.clone(), config.api_key.clone());
89                Box::new(custom_provider)
90            }
91        };
92
93        // Initialize LRU cache with capacity for 1000 embeddings (configurable)
94        let cache_capacity = NonZeroUsize::new(1000).unwrap();
95        let cache = Mutex::new(LruCache::new(cache_capacity));
96
97        Ok(Self {
98            provider,
99            config,
100            cache,
101        })
102    }
103
104    /// Generate embeddings for a single text using configured provider with caching
105    pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
106        if text.trim().is_empty() {
107            return Err(anyhow!("Cannot embed empty text"));
108        }
109
110        // Compute hash of the text for cache key
111        let text_hash = Self::hash_text(text);
112
113        // Check cache first
114        {
115            let mut cache = self.cache.lock().unwrap();
116            if let Some(cached_embedding) = cache.get(&text_hash) {
117                log::debug!("Cache hit for text embedding");
118                return Ok(cached_embedding.clone());
119            }
120        }
121
122        // Cache miss - generate embedding with retry logic
123        log::debug!("Cache miss for text embedding, generating...");
124        let embedding = Self::retry_with_backoff(
125            || async { self.provider.embed_text(text).await },
126            3, // max retries
127        )
128        .await?;
129
130        // Store in cache
131        {
132            let mut cache = self.cache.lock().unwrap();
133            cache.put(text_hash, embedding.clone());
134        }
135
136        Ok(embedding)
137    }
138
139    /// Retry async operation with exponential backoff
140    async fn retry_with_backoff<F, Fut, T>(mut operation: F, max_retries: u32) -> Result<T>
141    where
142        F: FnMut() -> Fut,
143        Fut: std::future::Future<Output = Result<T>>,
144    {
145        let mut retries = 0;
146        loop {
147            match operation().await {
148                Ok(result) => return Ok(result),
149                Err(e) => {
150                    retries += 1;
151                    if retries > max_retries {
152                        log::error!("Operation failed after {} retries: {}", max_retries, e);
153                        return Err(e);
154                    }
155
156                    let delay_ms = 100 * (2_u64.pow(retries - 1)); // 100ms, 200ms, 400ms
157                    log::warn!(
158                        "Operation failed (attempt {}/{}), retrying in {}ms: {}",
159                        retries,
160                        max_retries,
161                        delay_ms,
162                        e
163                    );
164                    tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
165                }
166            }
167        }
168    }
169
170    /// Hash text for cache key
171    fn hash_text(text: &str) -> u64 {
172        let mut hasher = DefaultHasher::new();
173        text.hash(&mut hasher);
174        hasher.finish()
175    }
176
177    /// Generate embeddings for multiple texts (batch processing)
178    /// More efficient than calling embed_text repeatedly - uses native batch for ONNX providers
179    pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
180        if texts.is_empty() {
181            return Ok(vec![]);
182        }
183
184        // Check which provider we're using and delegate appropriately
185        match &self.config.provider {
186            EmbeddingProvider::Onnx(_) => {
187                // Use ONNX's native batch processing - bypasses cache for batch operations
188                // This is more efficient as it processes all texts in one go
189                log::debug!(
190                    "Using ONNX native batch processing for {} texts",
191                    texts.len()
192                );
193
194                // Get the ONNX provider and call its batch method directly
195                // Use the as_any method to downcast the trait object
196                if let Some(onnx_provider) =
197                    self.provider.as_any().downcast_ref::<onnx::OnnxProvider>()
198                {
199                    return onnx_provider.embed_batch(texts).await;
200                }
201
202                // Fallback if downcast fails (shouldn't happen)
203                log::warn!("Failed to downcast ONNX provider, using sequential processing");
204                self.embed_batch_sequential(texts).await
205            }
206            _ => {
207                // For other providers, use sequential processing with cache
208                self.embed_batch_sequential(texts).await
209            }
210        }
211    }
212
213    /// Sequential batch processing with caching (fallback for non-ONNX providers)
214    async fn embed_batch_sequential(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
215        let mut embeddings = Vec::with_capacity(texts.len());
216        let mut failed_count = 0;
217
218        for (i, text) in texts.iter().enumerate() {
219            match self.embed_text(text).await {
220                Ok(embedding) => embeddings.push(embedding),
221                Err(e) => {
222                    log::warn!("Failed to embed text {} in batch: {}", i, e);
223                    failed_count += 1;
224                    continue;
225                }
226            }
227        }
228
229        if embeddings.is_empty() {
230            return Err(anyhow!(
231                "Batch embedding failed for all {} texts",
232                texts.len()
233            ));
234        }
235
236        if failed_count > 0 {
237            log::warn!(
238                "Batch embedding completed with {} failures out of {} texts",
239                failed_count,
240                texts.len()
241            );
242        }
243
244        Ok(embeddings)
245    }
246
247    /// Get the dimension of embeddings produced by this model
248    pub async fn get_dimension(&self) -> Result<usize> {
249        self.provider.get_dimension().await
250    }
251
252    /// Test if the embedding model is working correctly
253    pub async fn health_check(&self) -> Result<()> {
254        self.provider.health_check().await
255    }
256
257    /// Get information about the current provider
258    pub fn get_provider_info(&self) -> crate::rag::providers::ProviderInfo {
259        self.provider.get_info()
260    }
261
262    /// Get the current configuration
263    pub fn get_config(&self) -> &EmbeddingConfig {
264        &self.config
265    }
266
267    /// Calculate cosine similarity between two embeddings
268    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
269        if a.len() != b.len() {
270            return 0.0;
271        }
272
273        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
274        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
275        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
276
277        if norm_a == 0.0 || norm_b == 0.0 {
278            0.0
279        } else {
280            dot_product / (norm_a * norm_b)
281        }
282    }
283
284    /// Automatically select the best available embedding provider from installed models
285    /// Respects user's installed models and doesn't hardcode specific model names
286    pub async fn auto_select_best_provider() -> Result<EmbeddingConfig> {
287        log::info!("Auto-selecting best available embedding provider from installed models...");
288
289        // Try to find any available ONNX models by checking common model directories
290        if let Ok(available_models) = Self::get_available_onnx_models().await {
291            if !available_models.is_empty() {
292                // Select the first available model (user chose to install it)
293                let selected_model = &available_models[0];
294                log::info!("Auto-selected installed ONNX model: {}", selected_model);
295
296                // Try to determine dimension by testing the model
297                if let Ok(test_config) = Self::create_config_for_model(selected_model).await {
298                    return Ok(test_config);
299                }
300            }
301        }
302
303        // Fallback to hash-based embeddings if no ONNX models available
304        log::info!("No ONNX models found, using hash-based embeddings");
305        Ok(EmbeddingConfig::default())
306    }
307
308    /// Get list of available ONNX models (non-hardcoded discovery)
309    async fn get_available_onnx_models() -> Result<Vec<String>> {
310        // This would typically scan the model cache directory
311        // For now, we'll try a few common models that might be installed
312        let potential_models = [
313            "sentence-transformers/all-MiniLM-L6-v2",
314            "sentence-transformers/all-mpnet-base-v2",
315            "BAAI/bge-base-en-v1.5",
316            "BAAI/bge-small-en-v1.5",
317            "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
318        ];
319
320        let mut available = Vec::new();
321        for model in &potential_models {
322            if Self::is_onnx_model_available(model).await {
323                available.push(model.to_string());
324            }
325        }
326
327        Ok(available)
328    }
329
330    /// Create config for a specific model with proper dimension detection
331    async fn create_config_for_model(model_name: &str) -> Result<EmbeddingConfig> {
332        // Test the model to get its dimension
333        match onnx::OnnxProvider::new(model_name).await {
334            Ok(provider) => {
335                let dimension = provider.get_dimension().await.unwrap_or(384);
336                Ok(EmbeddingConfig {
337                    provider: EmbeddingProvider::Onnx(model_name.to_string()),
338                    dimension,
339                    ..EmbeddingConfig::default()
340                })
341            }
342            Err(e) => Err(anyhow!(
343                "Failed to create config for model {}: {}",
344                model_name,
345                e
346            )),
347        }
348    }
349
350    /// Check if an ONNX model is available locally
351    async fn is_onnx_model_available(model_name: &str) -> bool {
352        // Try to create the provider to test availability
353        match onnx::OnnxProvider::new(model_name).await {
354            Ok(_) => {
355                log::debug!("ONNX model '{}' is available", model_name);
356                true
357            }
358            Err(e) => {
359                log::debug!("ONNX model '{}' not available: {}", model_name, e);
360                false
361            }
362        }
363    }
364}
365
366/// Utility functions for text preprocessing
367pub mod preprocessing {
368    /// Clean and normalize text for embedding
369    pub fn clean_text(text: &str) -> String {
370        // Detect if this is code based on common patterns
371        if is_code_content(text) {
372            clean_code_text(text)
373        } else {
374            clean_regular_text(text)
375        }
376    }
377
378    /// Clean regular text (documents, markdown, etc.)
379    fn clean_regular_text(text: &str) -> String {
380        // Remove excessive whitespace
381        let cleaned = text
382            .lines()
383            .map(|line| line.trim())
384            .filter(|line| !line.is_empty())
385            .collect::<Vec<_>>()
386            .join(" ")
387            .split_whitespace()
388            .collect::<Vec<_>>()
389            .join(" ");
390
391        // Limit length to prevent very long embeddings (respect UTF-8 boundaries)
392        const MAX_CHARS: usize = 2048;
393        if cleaned.chars().count() > MAX_CHARS {
394            let truncated: String = cleaned.chars().take(MAX_CHARS).collect();
395            format!("{}...", truncated)
396        } else {
397            cleaned
398        }
399    }
400
401    /// Clean code text while preserving structure
402    fn clean_code_text(text: &str) -> String {
403        let mut cleaned = String::new();
404        let mut in_comment_block = false;
405
406        for line in text.lines() {
407            let trimmed = line.trim();
408
409            // Skip empty lines in code
410            if trimmed.is_empty() && !cleaned.is_empty() {
411                continue;
412            }
413
414            // Handle comment blocks
415            if trimmed.starts_with("/*") {
416                in_comment_block = true;
417            }
418            if in_comment_block {
419                if trimmed.ends_with("*/") {
420                    in_comment_block = false;
421                }
422                cleaned.push_str("// ");
423                cleaned.push_str(trimmed);
424                cleaned.push('\n');
425                continue;
426            }
427
428            // Preserve important code structure
429            if is_important_code_line(trimmed) {
430                // Keep indentation context (simplified)
431                let indent_level = line.len() - line.trim_start().len();
432                let normalized_indent = " ".repeat((indent_level / 2).min(4));
433                cleaned.push_str(&normalized_indent);
434                cleaned.push_str(trimmed);
435                cleaned.push('\n');
436            }
437        }
438
439        // Limit length (UTF-8 safe)
440        const MAX_CODE_CHARS: usize = 3000;
441        if cleaned.chars().count() > MAX_CODE_CHARS {
442            let truncated: String = cleaned.chars().take(MAX_CODE_CHARS).collect();
443            format!("{}...", truncated)
444        } else {
445            cleaned
446        }
447    }
448
449    /// Check if text appears to be code
450    fn is_code_content(text: &str) -> bool {
451        let code_indicators = [
452            "function",
453            "const",
454            "let",
455            "var",
456            "def",
457            "class",
458            "import",
459            "export",
460            "public",
461            "private",
462            "protected",
463            "return",
464            "if (",
465            "for (",
466            "while (",
467            "=>",
468            "->",
469            "::",
470            "<?php",
471            "#!/",
472            "package",
473            "namespace",
474            "struct",
475        ];
476
477        let text_lower = text.to_lowercase();
478        let indicator_count = code_indicators
479            .iter()
480            .filter(|&&ind| text_lower.contains(ind))
481            .count();
482
483        // If multiple code indicators found, likely code
484        indicator_count >= 3
485    }
486
487    /// Check if a line is important for code context
488    fn is_important_code_line(line: &str) -> bool {
489        // Skip pure comments unless they're doc comments
490        if line.starts_with("//") && !line.starts_with("///") && !line.starts_with("//!") {
491            return false;
492        }
493
494        // Keep imports, function definitions, class definitions, etc.
495        let important_patterns = [
496            "import ",
497            "from ",
498            "require",
499            "include",
500            "function ",
501            "def ",
502            "fn ",
503            "func ",
504            "class ",
505            "struct ",
506            "interface ",
507            "enum ",
508            "public ",
509            "private ",
510            "protected ",
511            "export ",
512            "module ",
513            "namespace ",
514        ];
515
516        for pattern in &important_patterns {
517            if line.contains(pattern) {
518                return true;
519            }
520        }
521
522        // Keep lines with actual code (not just brackets)
523        !line
524            .chars()
525            .all(|c| c == '{' || c == '}' || c == '(' || c == ')' || c == ';' || c.is_whitespace())
526    }
527
528    /// Split text into chunks suitable for embedding
529    pub fn chunk_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
530        // Use code-aware chunking if this appears to be code
531        if is_code_content(text) {
532            chunk_code_text(text, chunk_size, overlap)
533        } else {
534            chunk_regular_text(text, chunk_size, overlap)
535        }
536    }
537
538    /// Regular text chunking by words
539    fn chunk_regular_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
540        let words: Vec<&str> = text.split_whitespace().collect();
541        let mut chunks = Vec::new();
542
543        if words.len() <= chunk_size {
544            chunks.push(text.to_string());
545            return chunks;
546        }
547
548        let mut start = 0;
549        while start < words.len() {
550            let end = std::cmp::min(start + chunk_size, words.len());
551            let chunk = words[start..end].join(" ");
552            chunks.push(chunk);
553
554            if end == words.len() {
555                break;
556            }
557
558            start = end - overlap;
559        }
560
561        chunks
562    }
563
564    /// Code-aware chunking that respects function/class boundaries
565    fn chunk_code_text(text: &str, chunk_size: usize, _overlap: usize) -> Vec<String> {
566        let mut chunks = Vec::new();
567        let mut current_chunk = String::new();
568        let mut current_size = 0;
569        let mut brace_depth = 0;
570        let mut in_function = false;
571
572        for line in text.lines() {
573            let trimmed = line.trim();
574
575            // Detect function/class boundaries
576            if trimmed.contains("function ")
577                || trimmed.contains("def ")
578                || trimmed.contains("class ")
579                || trimmed.contains("fn ")
580            {
581                in_function = true;
582
583                // If current chunk is large enough, save it
584                if current_size > chunk_size / 2 && brace_depth == 0 && !current_chunk.is_empty() {
585                    chunks.push(current_chunk.clone());
586                    current_chunk.clear();
587                    current_size = 0;
588                }
589            }
590
591            // Track brace depth for better chunking
592            brace_depth += trimmed.chars().filter(|&c| c == '{').count() as i32;
593            brace_depth -= trimmed.chars().filter(|&c| c == '}').count() as i32;
594            brace_depth = brace_depth.max(0);
595
596            // Add line to current chunk
597            current_chunk.push_str(line);
598            current_chunk.push('\n');
599            current_size += line.split_whitespace().count();
600
601            // Create new chunk when we hit size limit and are at a good boundary
602            if current_size >= chunk_size && brace_depth == 0 && !in_function {
603                chunks.push(current_chunk.clone());
604                current_chunk.clear();
605                current_size = 0;
606            }
607
608            // Reset function flag when we exit a function
609            if in_function && brace_depth == 0 && trimmed.ends_with('}') {
610                in_function = false;
611            }
612        }
613
614        // Add remaining content
615        if !current_chunk.trim().is_empty() {
616            chunks.push(current_chunk);
617        }
618
619        // If no chunks were created, fall back to regular chunking
620        if chunks.is_empty() {
621            return chunk_regular_text(text, chunk_size, chunk_size / 10);
622        }
623
624        chunks
625    }
626}
627
628#[cfg(test)]
629mod tests {
630    use super::*;
631
632    #[tokio::test]
633    async fn test_embedding_model() {
634        let model = EmbeddingModel::new().await.unwrap();
635
636        let text = "This is a test sentence for embedding.";
637        let embedding = model.embed_text(text).await.unwrap();
638
639        assert_eq!(embedding.len(), 384); // Hash provider default
640        assert!(embedding.iter().any(|&x| x != 0.0));
641    }
642
643    #[test]
644    fn test_cosine_similarity() {
645        let a = vec![1.0, 2.0, 3.0];
646        let b = vec![1.0, 2.0, 3.0];
647        let similarity = EmbeddingModel::cosine_similarity(&a, &b);
648        assert!((similarity - 1.0).abs() < 0.001);
649
650        let c = vec![-1.0, -2.0, -3.0];
651        let similarity2 = EmbeddingModel::cosine_similarity(&a, &c);
652        assert!((similarity2 + 1.0).abs() < 0.001);
653    }
654
655    #[test]
656    fn test_text_preprocessing() {
657        let text = "  This is   a test\n\n  with  multiple   lines  \n  ";
658        let cleaned = preprocessing::clean_text(text);
659        assert_eq!(cleaned, "This is a test with multiple lines");
660    }
661
662    #[test]
663    fn test_text_chunking() {
664        let text = "one two three four five six seven eight nine ten";
665        let chunks = preprocessing::chunk_text(text, 3, 1);
666
667        assert_eq!(chunks.len(), 5);
668        assert_eq!(chunks[0], "one two three");
669        assert_eq!(chunks[1], "three four five");
670        assert_eq!(chunks[2], "five six seven");
671        assert_eq!(chunks[3], "seven eight nine");
672        assert_eq!(chunks[4], "nine ten");
673    }
674
675    #[tokio::test]
676    async fn test_similarity_detection() {
677        let model = EmbeddingModel::new().await.unwrap();
678
679        let text1 = "React hooks useState";
680        let text2 = "useState React hooks";
681        let text3 = "Python Django models";
682
683        let emb1 = model.embed_text(text1).await.unwrap();
684        let emb2 = model.embed_text(text2).await.unwrap();
685        let emb3 = model.embed_text(text3).await.unwrap();
686
687        let sim_12 = EmbeddingModel::cosine_similarity(&emb1, &emb2);
688        let sim_13 = EmbeddingModel::cosine_similarity(&emb1, &emb3);
689
690        // Similar texts should have higher similarity
691        assert!(sim_12 > sim_13);
692    }
693}