manx_cli/rag/
embeddings.rs

1//! Text embedding generation using configurable embedding providers
2//!
3//! This module provides flexible text embedding functionality for semantic similarity search.
4//! Supports multiple embedding providers: hash-based (default), local models, and API services.
5//! Users can configure their preferred embedding method via `manx config --embedding-provider`.
6
7use crate::rag::providers::{
8    custom, hash, huggingface, ollama, onnx, openai, EmbeddingProvider as ProviderTrait,
9};
10use crate::rag::{EmbeddingConfig, EmbeddingProvider};
11use anyhow::{anyhow, Result};
12
13/// Text embedding model wrapper with configurable providers
14/// Supports hash-based embeddings (default), local ONNX models, and API services.
15/// Users can configure their preferred embedding method via `manx config`.
16pub struct EmbeddingModel {
17    provider: Box<dyn ProviderTrait + Send + Sync>,
18    config: EmbeddingConfig,
19}
20
21impl EmbeddingModel {
22    /// Create a new embedding model with default hash-based provider
23    pub async fn new() -> Result<Self> {
24        Self::new_with_config(EmbeddingConfig::default()).await
25    }
26
27    /// Create embedding model with smart auto-selection of best available provider
28    pub async fn new_auto_select() -> Result<Self> {
29        let best_config = Self::auto_select_best_provider().await?;
30        Self::new_with_config(best_config).await
31    }
32
33    /// Create a new embedding model with custom configuration
34    pub async fn new_with_config(config: EmbeddingConfig) -> Result<Self> {
35        log::info!(
36            "Initializing embedding model with provider: {:?}",
37            config.provider
38        );
39
40        let provider: Box<dyn ProviderTrait + Send + Sync> = match &config.provider {
41            EmbeddingProvider::Hash => {
42                log::info!("Using hash-based embeddings (default provider)");
43                Box::new(hash::HashProvider::new(384)) // Hash provider always uses 384 dimensions
44            }
45            EmbeddingProvider::Onnx(model_name) => {
46                log::info!("Loading ONNX model: {}", model_name);
47                let onnx_provider = onnx::OnnxProvider::new(model_name).await?;
48                Box::new(onnx_provider)
49            }
50            EmbeddingProvider::Ollama(model_name) => {
51                log::info!("Connecting to Ollama model: {}", model_name);
52                let ollama_provider =
53                    ollama::OllamaProvider::new(model_name.clone(), config.endpoint.clone());
54                // Test connection
55                ollama_provider.health_check().await?;
56                Box::new(ollama_provider)
57            }
58            EmbeddingProvider::OpenAI(model_name) => {
59                log::info!("Connecting to OpenAI model: {}", model_name);
60                let api_key = config.api_key.as_ref().ok_or_else(|| {
61                    anyhow!("OpenAI API key required. Use 'manx config --embedding-api-key <key>'")
62                })?;
63                let openai_provider =
64                    openai::OpenAiProvider::new(api_key.clone(), model_name.clone());
65                Box::new(openai_provider)
66            }
67            EmbeddingProvider::HuggingFace(model_name) => {
68                log::info!("Connecting to HuggingFace model: {}", model_name);
69                let api_key = config.api_key.as_ref().ok_or_else(|| {
70                    anyhow!(
71                        "HuggingFace API key required. Use 'manx config --embedding-api-key <key>'"
72                    )
73                })?;
74                let hf_provider =
75                    huggingface::HuggingFaceProvider::new(api_key.clone(), model_name.clone());
76                Box::new(hf_provider)
77            }
78            EmbeddingProvider::Custom(endpoint) => {
79                log::info!("Connecting to custom endpoint: {}", endpoint);
80                let custom_provider =
81                    custom::CustomProvider::new(endpoint.clone(), config.api_key.clone());
82                Box::new(custom_provider)
83            }
84        };
85
86        Ok(Self { provider, config })
87    }
88
89    /// Generate embeddings for a single text using configured provider
90    pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
91        if text.trim().is_empty() {
92            return Err(anyhow!("Cannot embed empty text"));
93        }
94
95        self.provider.embed_text(text).await
96    }
97
98    /// Get the dimension of embeddings produced by this model
99    pub async fn get_dimension(&self) -> Result<usize> {
100        self.provider.get_dimension().await
101    }
102
103    /// Test if the embedding model is working correctly
104    pub async fn health_check(&self) -> Result<()> {
105        self.provider.health_check().await
106    }
107
108    /// Get information about the current provider
109    pub fn get_provider_info(&self) -> crate::rag::providers::ProviderInfo {
110        self.provider.get_info()
111    }
112
113    /// Get the current configuration
114    pub fn get_config(&self) -> &EmbeddingConfig {
115        &self.config
116    }
117
118    /// Calculate cosine similarity between two embeddings
119    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
120        if a.len() != b.len() {
121            return 0.0;
122        }
123
124        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
125        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
126        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
127
128        if norm_a == 0.0 || norm_b == 0.0 {
129            0.0
130        } else {
131            dot_product / (norm_a * norm_b)
132        }
133    }
134
135    /// Automatically select the best available embedding provider from installed models
136    /// Respects user's installed models and doesn't hardcode specific model names
137    pub async fn auto_select_best_provider() -> Result<EmbeddingConfig> {
138        log::info!("Auto-selecting best available embedding provider from installed models...");
139
140        // Try to find any available ONNX models by checking common model directories
141        if let Ok(available_models) = Self::get_available_onnx_models().await {
142            if !available_models.is_empty() {
143                // Select the first available model (user chose to install it)
144                let selected_model = &available_models[0];
145                log::info!("Auto-selected installed ONNX model: {}", selected_model);
146
147                // Try to determine dimension by testing the model
148                if let Ok(test_config) = Self::create_config_for_model(selected_model).await {
149                    return Ok(test_config);
150                }
151            }
152        }
153
154        // Fallback to hash-based embeddings if no ONNX models available
155        log::info!("No ONNX models found, using hash-based embeddings");
156        Ok(EmbeddingConfig::default())
157    }
158
159    /// Get list of available ONNX models (non-hardcoded discovery)
160    async fn get_available_onnx_models() -> Result<Vec<String>> {
161        // This would typically scan the model cache directory
162        // For now, we'll try a few common models that might be installed
163        let potential_models = [
164            "sentence-transformers/all-MiniLM-L6-v2",
165            "sentence-transformers/all-mpnet-base-v2",
166            "BAAI/bge-base-en-v1.5",
167            "BAAI/bge-small-en-v1.5",
168            "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
169        ];
170
171        let mut available = Vec::new();
172        for model in &potential_models {
173            if Self::is_onnx_model_available(model).await {
174                available.push(model.to_string());
175            }
176        }
177
178        Ok(available)
179    }
180
181    /// Create config for a specific model with proper dimension detection
182    async fn create_config_for_model(model_name: &str) -> Result<EmbeddingConfig> {
183        // Test the model to get its dimension
184        match onnx::OnnxProvider::new(model_name).await {
185            Ok(provider) => {
186                let dimension = provider.get_dimension().await.unwrap_or(384);
187                Ok(EmbeddingConfig {
188                    provider: EmbeddingProvider::Onnx(model_name.to_string()),
189                    dimension,
190                    ..EmbeddingConfig::default()
191                })
192            }
193            Err(e) => Err(anyhow!(
194                "Failed to create config for model {}: {}",
195                model_name,
196                e
197            )),
198        }
199    }
200
201    /// Check if an ONNX model is available locally
202    async fn is_onnx_model_available(model_name: &str) -> bool {
203        // Try to create the provider to test availability
204        match onnx::OnnxProvider::new(model_name).await {
205            Ok(_) => {
206                log::debug!("ONNX model '{}' is available", model_name);
207                true
208            }
209            Err(e) => {
210                log::debug!("ONNX model '{}' not available: {}", model_name, e);
211                false
212            }
213        }
214    }
215}
216
217/// Utility functions for text preprocessing
218pub mod preprocessing {
219    /// Clean and normalize text for embedding
220    pub fn clean_text(text: &str) -> String {
221        // Detect if this is code based on common patterns
222        if is_code_content(text) {
223            clean_code_text(text)
224        } else {
225            clean_regular_text(text)
226        }
227    }
228
229    /// Clean regular text (documents, markdown, etc.)
230    fn clean_regular_text(text: &str) -> String {
231        // Remove excessive whitespace
232        let cleaned = text
233            .lines()
234            .map(|line| line.trim())
235            .filter(|line| !line.is_empty())
236            .collect::<Vec<_>>()
237            .join(" ")
238            .split_whitespace()
239            .collect::<Vec<_>>()
240            .join(" ");
241
242        // Limit length to prevent very long embeddings
243        const MAX_LENGTH: usize = 2048;
244        if cleaned.len() > MAX_LENGTH {
245            format!("{}...", &cleaned[..MAX_LENGTH])
246        } else {
247            cleaned
248        }
249    }
250
251    /// Clean code text while preserving structure
252    fn clean_code_text(text: &str) -> String {
253        let mut cleaned = String::new();
254        let mut in_comment_block = false;
255
256        for line in text.lines() {
257            let trimmed = line.trim();
258
259            // Skip empty lines in code
260            if trimmed.is_empty() && !cleaned.is_empty() {
261                continue;
262            }
263
264            // Handle comment blocks
265            if trimmed.starts_with("/*") {
266                in_comment_block = true;
267            }
268            if in_comment_block {
269                if trimmed.ends_with("*/") {
270                    in_comment_block = false;
271                }
272                cleaned.push_str("// ");
273                cleaned.push_str(trimmed);
274                cleaned.push('\n');
275                continue;
276            }
277
278            // Preserve important code structure
279            if is_important_code_line(trimmed) {
280                // Keep indentation context (simplified)
281                let indent_level = line.len() - line.trim_start().len();
282                let normalized_indent = " ".repeat((indent_level / 2).min(4));
283                cleaned.push_str(&normalized_indent);
284                cleaned.push_str(trimmed);
285                cleaned.push('\n');
286            }
287        }
288
289        // Limit length
290        const MAX_CODE_LENGTH: usize = 3000;
291        if cleaned.len() > MAX_CODE_LENGTH {
292            format!("{}...", &cleaned[..MAX_CODE_LENGTH])
293        } else {
294            cleaned
295        }
296    }
297
298    /// Check if text appears to be code
299    fn is_code_content(text: &str) -> bool {
300        let code_indicators = [
301            "function",
302            "const",
303            "let",
304            "var",
305            "def",
306            "class",
307            "import",
308            "export",
309            "public",
310            "private",
311            "protected",
312            "return",
313            "if (",
314            "for (",
315            "while (",
316            "=>",
317            "->",
318            "::",
319            "<?php",
320            "#!/",
321            "package",
322            "namespace",
323            "struct",
324        ];
325
326        let text_lower = text.to_lowercase();
327        let indicator_count = code_indicators
328            .iter()
329            .filter(|&&ind| text_lower.contains(ind))
330            .count();
331
332        // If multiple code indicators found, likely code
333        indicator_count >= 3
334    }
335
336    /// Check if a line is important for code context
337    fn is_important_code_line(line: &str) -> bool {
338        // Skip pure comments unless they're doc comments
339        if line.starts_with("//") && !line.starts_with("///") && !line.starts_with("//!") {
340            return false;
341        }
342
343        // Keep imports, function definitions, class definitions, etc.
344        let important_patterns = [
345            "import ",
346            "from ",
347            "require",
348            "include",
349            "function ",
350            "def ",
351            "fn ",
352            "func ",
353            "class ",
354            "struct ",
355            "interface ",
356            "enum ",
357            "public ",
358            "private ",
359            "protected ",
360            "export ",
361            "module ",
362            "namespace ",
363        ];
364
365        for pattern in &important_patterns {
366            if line.contains(pattern) {
367                return true;
368            }
369        }
370
371        // Keep lines with actual code (not just brackets)
372        !line
373            .chars()
374            .all(|c| c == '{' || c == '}' || c == '(' || c == ')' || c == ';' || c.is_whitespace())
375    }
376
377    /// Split text into chunks suitable for embedding
378    pub fn chunk_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
379        // Use code-aware chunking if this appears to be code
380        if is_code_content(text) {
381            chunk_code_text(text, chunk_size, overlap)
382        } else {
383            chunk_regular_text(text, chunk_size, overlap)
384        }
385    }
386
387    /// Regular text chunking by words
388    fn chunk_regular_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
389        let words: Vec<&str> = text.split_whitespace().collect();
390        let mut chunks = Vec::new();
391
392        if words.len() <= chunk_size {
393            chunks.push(text.to_string());
394            return chunks;
395        }
396
397        let mut start = 0;
398        while start < words.len() {
399            let end = std::cmp::min(start + chunk_size, words.len());
400            let chunk = words[start..end].join(" ");
401            chunks.push(chunk);
402
403            if end == words.len() {
404                break;
405            }
406
407            start = end - overlap;
408        }
409
410        chunks
411    }
412
413    /// Code-aware chunking that respects function/class boundaries
414    fn chunk_code_text(text: &str, chunk_size: usize, _overlap: usize) -> Vec<String> {
415        let mut chunks = Vec::new();
416        let mut current_chunk = String::new();
417        let mut current_size = 0;
418        let mut brace_depth = 0;
419        let mut in_function = false;
420
421        for line in text.lines() {
422            let trimmed = line.trim();
423
424            // Detect function/class boundaries
425            if trimmed.contains("function ")
426                || trimmed.contains("def ")
427                || trimmed.contains("class ")
428                || trimmed.contains("fn ")
429            {
430                in_function = true;
431
432                // If current chunk is large enough, save it
433                if current_size > chunk_size / 2 && brace_depth == 0 && !current_chunk.is_empty() {
434                    chunks.push(current_chunk.clone());
435                    current_chunk.clear();
436                    current_size = 0;
437                }
438            }
439
440            // Track brace depth for better chunking
441            brace_depth += trimmed.chars().filter(|&c| c == '{').count() as i32;
442            brace_depth -= trimmed.chars().filter(|&c| c == '}').count() as i32;
443            brace_depth = brace_depth.max(0);
444
445            // Add line to current chunk
446            current_chunk.push_str(line);
447            current_chunk.push('\n');
448            current_size += line.split_whitespace().count();
449
450            // Create new chunk when we hit size limit and are at a good boundary
451            if current_size >= chunk_size && brace_depth == 0 && !in_function {
452                chunks.push(current_chunk.clone());
453                current_chunk.clear();
454                current_size = 0;
455            }
456
457            // Reset function flag when we exit a function
458            if in_function && brace_depth == 0 && trimmed.ends_with('}') {
459                in_function = false;
460            }
461        }
462
463        // Add remaining content
464        if !current_chunk.trim().is_empty() {
465            chunks.push(current_chunk);
466        }
467
468        // If no chunks were created, fall back to regular chunking
469        if chunks.is_empty() {
470            return chunk_regular_text(text, chunk_size, chunk_size / 10);
471        }
472
473        chunks
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480
481    #[tokio::test]
482    async fn test_embedding_model() {
483        let model = EmbeddingModel::new().await.unwrap();
484
485        let text = "This is a test sentence for embedding.";
486        let embedding = model.embed_text(text).await.unwrap();
487
488        assert_eq!(embedding.len(), 384); // Hash provider default
489        assert!(embedding.iter().any(|&x| x != 0.0));
490    }
491
492    #[test]
493    fn test_cosine_similarity() {
494        let a = vec![1.0, 2.0, 3.0];
495        let b = vec![1.0, 2.0, 3.0];
496        let similarity = EmbeddingModel::cosine_similarity(&a, &b);
497        assert!((similarity - 1.0).abs() < 0.001);
498
499        let c = vec![-1.0, -2.0, -3.0];
500        let similarity2 = EmbeddingModel::cosine_similarity(&a, &c);
501        assert!((similarity2 + 1.0).abs() < 0.001);
502    }
503
504    #[test]
505    fn test_text_preprocessing() {
506        let text = "  This is   a test\n\n  with  multiple   lines  \n  ";
507        let cleaned = preprocessing::clean_text(text);
508        assert_eq!(cleaned, "This is a test with multiple lines");
509    }
510
511    #[test]
512    fn test_text_chunking() {
513        let text = "one two three four five six seven eight nine ten";
514        let chunks = preprocessing::chunk_text(text, 3, 1);
515
516        assert_eq!(chunks.len(), 5);
517        assert_eq!(chunks[0], "one two three");
518        assert_eq!(chunks[1], "three four five");
519        assert_eq!(chunks[2], "five six seven");
520        assert_eq!(chunks[3], "seven eight nine");
521        assert_eq!(chunks[4], "nine ten");
522    }
523
524    #[tokio::test]
525    async fn test_similarity_detection() {
526        let model = EmbeddingModel::new().await.unwrap();
527
528        let text1 = "React hooks useState";
529        let text2 = "useState React hooks";
530        let text3 = "Python Django models";
531
532        let emb1 = model.embed_text(text1).await.unwrap();
533        let emb2 = model.embed_text(text2).await.unwrap();
534        let emb3 = model.embed_text(text3).await.unwrap();
535
536        let sim_12 = EmbeddingModel::cosine_similarity(&emb1, &emb2);
537        let sim_13 = EmbeddingModel::cosine_similarity(&emb1, &emb3);
538
539        // Similar texts should have higher similarity
540        assert!(sim_12 > sim_13);
541    }
542}