manx_cli/rag/
embeddings.rs

1//! Text embedding generation using configurable embedding providers
2//!
3//! This module provides flexible text embedding functionality for semantic similarity search.
4//! Supports multiple embedding providers: hash-based (default), local models, and API services.
5//! Users can configure their preferred embedding method via `manx config --embedding-provider`.
6
7use crate::rag::providers::{
8    custom, hash, huggingface, ollama, onnx, openai, EmbeddingProvider as ProviderTrait,
9};
10use crate::rag::{EmbeddingConfig, EmbeddingProvider};
11use anyhow::{anyhow, Result};
12
13/// Text embedding model wrapper with configurable providers
14/// Supports hash-based embeddings (default), local ONNX models, and API services.
15/// Users can configure their preferred embedding method via `manx config`.
16pub struct EmbeddingModel {
17    provider: Box<dyn ProviderTrait + Send + Sync>,
18    config: EmbeddingConfig,
19}
20
21impl EmbeddingModel {
22    /// Create a new embedding model with default hash-based provider
23    pub async fn new() -> Result<Self> {
24        Self::new_with_config(EmbeddingConfig::default()).await
25    }
26
27    /// Create embedding model with smart auto-selection of best available provider
28    pub async fn new_auto_select() -> Result<Self> {
29        let best_config = Self::auto_select_best_provider().await?;
30        Self::new_with_config(best_config).await
31    }
32
33    /// Create a new embedding model with custom configuration
34    pub async fn new_with_config(config: EmbeddingConfig) -> Result<Self> {
35        log::info!(
36            "Initializing embedding model with provider: {:?}",
37            config.provider
38        );
39
40        let provider: Box<dyn ProviderTrait + Send + Sync> = match &config.provider {
41            EmbeddingProvider::Hash => {
42                log::info!("Using hash-based embeddings (default provider)");
43                Box::new(hash::HashProvider::new(384)) // Hash provider always uses 384 dimensions
44            }
45            EmbeddingProvider::Onnx(model_name) => {
46                log::info!("Loading ONNX model: {}", model_name);
47                let onnx_provider = onnx::OnnxProvider::new(model_name).await?;
48                Box::new(onnx_provider)
49            }
50            EmbeddingProvider::Ollama(model_name) => {
51                log::info!("Connecting to Ollama model: {}", model_name);
52                let ollama_provider =
53                    ollama::OllamaProvider::new(model_name.clone(), config.endpoint.clone());
54                // Test connection
55                ollama_provider.health_check().await?;
56                Box::new(ollama_provider)
57            }
58            EmbeddingProvider::OpenAI(model_name) => {
59                log::info!("Connecting to OpenAI model: {}", model_name);
60                let api_key = config.api_key.as_ref().ok_or_else(|| {
61                    anyhow!("OpenAI API key required. Use 'manx config --embedding-api-key <key>'")
62                })?;
63                let openai_provider =
64                    openai::OpenAiProvider::new(api_key.clone(), model_name.clone());
65                Box::new(openai_provider)
66            }
67            EmbeddingProvider::HuggingFace(model_name) => {
68                log::info!("Connecting to HuggingFace model: {}", model_name);
69                let api_key = config.api_key.as_ref().ok_or_else(|| {
70                    anyhow!(
71                        "HuggingFace API key required. Use 'manx config --embedding-api-key <key>'"
72                    )
73                })?;
74                let hf_provider =
75                    huggingface::HuggingFaceProvider::new(api_key.clone(), model_name.clone());
76                Box::new(hf_provider)
77            }
78            EmbeddingProvider::Custom(endpoint) => {
79                log::info!("Connecting to custom endpoint: {}", endpoint);
80                let custom_provider =
81                    custom::CustomProvider::new(endpoint.clone(), config.api_key.clone());
82                Box::new(custom_provider)
83            }
84        };
85
86        Ok(Self { provider, config })
87    }
88
89    /// Generate embeddings for a single text using configured provider
90    pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
91        if text.trim().is_empty() {
92            return Err(anyhow!("Cannot embed empty text"));
93        }
94
95        self.provider.embed_text(text).await
96    }
97
98    /// Get the dimension of embeddings produced by this model
99    pub async fn get_dimension(&self) -> Result<usize> {
100        self.provider.get_dimension().await
101    }
102
103    /// Test if the embedding model is working correctly
104    pub async fn health_check(&self) -> Result<()> {
105        self.provider.health_check().await
106    }
107
108    /// Get information about the current provider
109    pub fn get_provider_info(&self) -> crate::rag::providers::ProviderInfo {
110        self.provider.get_info()
111    }
112
113    /// Get the current configuration
114    pub fn get_config(&self) -> &EmbeddingConfig {
115        &self.config
116    }
117
118    /// Calculate cosine similarity between two embeddings
119    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
120        if a.len() != b.len() {
121            return 0.0;
122        }
123
124        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
125        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
126        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
127
128        if norm_a == 0.0 || norm_b == 0.0 {
129            0.0
130        } else {
131            dot_product / (norm_a * norm_b)
132        }
133    }
134
135    /// Automatically select the best available embedding provider from installed models
136    /// Respects user's installed models and doesn't hardcode specific model names
137    pub async fn auto_select_best_provider() -> Result<EmbeddingConfig> {
138        log::info!("Auto-selecting best available embedding provider from installed models...");
139
140        // Try to find any available ONNX models by checking common model directories
141        if let Ok(available_models) = Self::get_available_onnx_models().await {
142            if !available_models.is_empty() {
143                // Select the first available model (user chose to install it)
144                let selected_model = &available_models[0];
145                log::info!("Auto-selected installed ONNX model: {}", selected_model);
146
147                // Try to determine dimension by testing the model
148                if let Ok(test_config) = Self::create_config_for_model(selected_model).await {
149                    return Ok(test_config);
150                }
151            }
152        }
153
154        // Fallback to hash-based embeddings if no ONNX models available
155        log::info!("No ONNX models found, using hash-based embeddings");
156        Ok(EmbeddingConfig::default())
157    }
158
159    /// Get list of available ONNX models (non-hardcoded discovery)
160    async fn get_available_onnx_models() -> Result<Vec<String>> {
161        // This would typically scan the model cache directory
162        // For now, we'll try a few common models that might be installed
163        let potential_models = [
164            "sentence-transformers/all-MiniLM-L6-v2",
165            "sentence-transformers/all-mpnet-base-v2",
166            "BAAI/bge-base-en-v1.5",
167            "BAAI/bge-small-en-v1.5",
168            "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
169        ];
170
171        let mut available = Vec::new();
172        for model in &potential_models {
173            if Self::is_onnx_model_available(model).await {
174                available.push(model.to_string());
175            }
176        }
177
178        Ok(available)
179    }
180
181    /// Create config for a specific model with proper dimension detection
182    async fn create_config_for_model(model_name: &str) -> Result<EmbeddingConfig> {
183        // Test the model to get its dimension
184        match onnx::OnnxProvider::new(model_name).await {
185            Ok(provider) => {
186                let dimension = provider.get_dimension().await.unwrap_or(384);
187                Ok(EmbeddingConfig {
188                    provider: EmbeddingProvider::Onnx(model_name.to_string()),
189                    dimension,
190                    ..EmbeddingConfig::default()
191                })
192            }
193            Err(e) => Err(anyhow!(
194                "Failed to create config for model {}: {}",
195                model_name,
196                e
197            )),
198        }
199    }
200
201    /// Check if an ONNX model is available locally
202    async fn is_onnx_model_available(model_name: &str) -> bool {
203        // Try to create the provider to test availability
204        match onnx::OnnxProvider::new(model_name).await {
205            Ok(_) => {
206                log::debug!("ONNX model '{}' is available", model_name);
207                true
208            }
209            Err(e) => {
210                log::debug!("ONNX model '{}' not available: {}", model_name, e);
211                false
212            }
213        }
214    }
215}
216
217/// Utility functions for text preprocessing
218pub mod preprocessing {
219    /// Clean and normalize text for embedding
220    pub fn clean_text(text: &str) -> String {
221        // Detect if this is code based on common patterns
222        if is_code_content(text) {
223            clean_code_text(text)
224        } else {
225            clean_regular_text(text)
226        }
227    }
228
229    /// Clean regular text (documents, markdown, etc.)
230    fn clean_regular_text(text: &str) -> String {
231        // Remove excessive whitespace
232        let cleaned = text
233            .lines()
234            .map(|line| line.trim())
235            .filter(|line| !line.is_empty())
236            .collect::<Vec<_>>()
237            .join(" ")
238            .split_whitespace()
239            .collect::<Vec<_>>()
240            .join(" ");
241
242        // Limit length to prevent very long embeddings (respect UTF-8 boundaries)
243        const MAX_CHARS: usize = 2048;
244        if cleaned.chars().count() > MAX_CHARS {
245            let truncated: String = cleaned.chars().take(MAX_CHARS).collect();
246            format!("{}...", truncated)
247        } else {
248            cleaned
249        }
250    }
251
252    /// Clean code text while preserving structure
253    fn clean_code_text(text: &str) -> String {
254        let mut cleaned = String::new();
255        let mut in_comment_block = false;
256
257        for line in text.lines() {
258            let trimmed = line.trim();
259
260            // Skip empty lines in code
261            if trimmed.is_empty() && !cleaned.is_empty() {
262                continue;
263            }
264
265            // Handle comment blocks
266            if trimmed.starts_with("/*") {
267                in_comment_block = true;
268            }
269            if in_comment_block {
270                if trimmed.ends_with("*/") {
271                    in_comment_block = false;
272                }
273                cleaned.push_str("// ");
274                cleaned.push_str(trimmed);
275                cleaned.push('\n');
276                continue;
277            }
278
279            // Preserve important code structure
280            if is_important_code_line(trimmed) {
281                // Keep indentation context (simplified)
282                let indent_level = line.len() - line.trim_start().len();
283                let normalized_indent = " ".repeat((indent_level / 2).min(4));
284                cleaned.push_str(&normalized_indent);
285                cleaned.push_str(trimmed);
286                cleaned.push('\n');
287            }
288        }
289
290        // Limit length (UTF-8 safe)
291        const MAX_CODE_CHARS: usize = 3000;
292        if cleaned.chars().count() > MAX_CODE_CHARS {
293            let truncated: String = cleaned.chars().take(MAX_CODE_CHARS).collect();
294            format!("{}...", truncated)
295        } else {
296            cleaned
297        }
298    }
299
300    /// Check if text appears to be code
301    fn is_code_content(text: &str) -> bool {
302        let code_indicators = [
303            "function",
304            "const",
305            "let",
306            "var",
307            "def",
308            "class",
309            "import",
310            "export",
311            "public",
312            "private",
313            "protected",
314            "return",
315            "if (",
316            "for (",
317            "while (",
318            "=>",
319            "->",
320            "::",
321            "<?php",
322            "#!/",
323            "package",
324            "namespace",
325            "struct",
326        ];
327
328        let text_lower = text.to_lowercase();
329        let indicator_count = code_indicators
330            .iter()
331            .filter(|&&ind| text_lower.contains(ind))
332            .count();
333
334        // If multiple code indicators found, likely code
335        indicator_count >= 3
336    }
337
338    /// Check if a line is important for code context
339    fn is_important_code_line(line: &str) -> bool {
340        // Skip pure comments unless they're doc comments
341        if line.starts_with("//") && !line.starts_with("///") && !line.starts_with("//!") {
342            return false;
343        }
344
345        // Keep imports, function definitions, class definitions, etc.
346        let important_patterns = [
347            "import ",
348            "from ",
349            "require",
350            "include",
351            "function ",
352            "def ",
353            "fn ",
354            "func ",
355            "class ",
356            "struct ",
357            "interface ",
358            "enum ",
359            "public ",
360            "private ",
361            "protected ",
362            "export ",
363            "module ",
364            "namespace ",
365        ];
366
367        for pattern in &important_patterns {
368            if line.contains(pattern) {
369                return true;
370            }
371        }
372
373        // Keep lines with actual code (not just brackets)
374        !line
375            .chars()
376            .all(|c| c == '{' || c == '}' || c == '(' || c == ')' || c == ';' || c.is_whitespace())
377    }
378
379    /// Split text into chunks suitable for embedding
380    pub fn chunk_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
381        // Use code-aware chunking if this appears to be code
382        if is_code_content(text) {
383            chunk_code_text(text, chunk_size, overlap)
384        } else {
385            chunk_regular_text(text, chunk_size, overlap)
386        }
387    }
388
389    /// Regular text chunking by words
390    fn chunk_regular_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
391        let words: Vec<&str> = text.split_whitespace().collect();
392        let mut chunks = Vec::new();
393
394        if words.len() <= chunk_size {
395            chunks.push(text.to_string());
396            return chunks;
397        }
398
399        let mut start = 0;
400        while start < words.len() {
401            let end = std::cmp::min(start + chunk_size, words.len());
402            let chunk = words[start..end].join(" ");
403            chunks.push(chunk);
404
405            if end == words.len() {
406                break;
407            }
408
409            start = end - overlap;
410        }
411
412        chunks
413    }
414
415    /// Code-aware chunking that respects function/class boundaries
416    fn chunk_code_text(text: &str, chunk_size: usize, _overlap: usize) -> Vec<String> {
417        let mut chunks = Vec::new();
418        let mut current_chunk = String::new();
419        let mut current_size = 0;
420        let mut brace_depth = 0;
421        let mut in_function = false;
422
423        for line in text.lines() {
424            let trimmed = line.trim();
425
426            // Detect function/class boundaries
427            if trimmed.contains("function ")
428                || trimmed.contains("def ")
429                || trimmed.contains("class ")
430                || trimmed.contains("fn ")
431            {
432                in_function = true;
433
434                // If current chunk is large enough, save it
435                if current_size > chunk_size / 2 && brace_depth == 0 && !current_chunk.is_empty() {
436                    chunks.push(current_chunk.clone());
437                    current_chunk.clear();
438                    current_size = 0;
439                }
440            }
441
442            // Track brace depth for better chunking
443            brace_depth += trimmed.chars().filter(|&c| c == '{').count() as i32;
444            brace_depth -= trimmed.chars().filter(|&c| c == '}').count() as i32;
445            brace_depth = brace_depth.max(0);
446
447            // Add line to current chunk
448            current_chunk.push_str(line);
449            current_chunk.push('\n');
450            current_size += line.split_whitespace().count();
451
452            // Create new chunk when we hit size limit and are at a good boundary
453            if current_size >= chunk_size && brace_depth == 0 && !in_function {
454                chunks.push(current_chunk.clone());
455                current_chunk.clear();
456                current_size = 0;
457            }
458
459            // Reset function flag when we exit a function
460            if in_function && brace_depth == 0 && trimmed.ends_with('}') {
461                in_function = false;
462            }
463        }
464
465        // Add remaining content
466        if !current_chunk.trim().is_empty() {
467            chunks.push(current_chunk);
468        }
469
470        // If no chunks were created, fall back to regular chunking
471        if chunks.is_empty() {
472            return chunk_regular_text(text, chunk_size, chunk_size / 10);
473        }
474
475        chunks
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[tokio::test]
484    async fn test_embedding_model() {
485        let model = EmbeddingModel::new().await.unwrap();
486
487        let text = "This is a test sentence for embedding.";
488        let embedding = model.embed_text(text).await.unwrap();
489
490        assert_eq!(embedding.len(), 384); // Hash provider default
491        assert!(embedding.iter().any(|&x| x != 0.0));
492    }
493
494    #[test]
495    fn test_cosine_similarity() {
496        let a = vec![1.0, 2.0, 3.0];
497        let b = vec![1.0, 2.0, 3.0];
498        let similarity = EmbeddingModel::cosine_similarity(&a, &b);
499        assert!((similarity - 1.0).abs() < 0.001);
500
501        let c = vec![-1.0, -2.0, -3.0];
502        let similarity2 = EmbeddingModel::cosine_similarity(&a, &c);
503        assert!((similarity2 + 1.0).abs() < 0.001);
504    }
505
506    #[test]
507    fn test_text_preprocessing() {
508        let text = "  This is   a test\n\n  with  multiple   lines  \n  ";
509        let cleaned = preprocessing::clean_text(text);
510        assert_eq!(cleaned, "This is a test with multiple lines");
511    }
512
513    #[test]
514    fn test_text_chunking() {
515        let text = "one two three four five six seven eight nine ten";
516        let chunks = preprocessing::chunk_text(text, 3, 1);
517
518        assert_eq!(chunks.len(), 5);
519        assert_eq!(chunks[0], "one two three");
520        assert_eq!(chunks[1], "three four five");
521        assert_eq!(chunks[2], "five six seven");
522        assert_eq!(chunks[3], "seven eight nine");
523        assert_eq!(chunks[4], "nine ten");
524    }
525
526    #[tokio::test]
527    async fn test_similarity_detection() {
528        let model = EmbeddingModel::new().await.unwrap();
529
530        let text1 = "React hooks useState";
531        let text2 = "useState React hooks";
532        let text3 = "Python Django models";
533
534        let emb1 = model.embed_text(text1).await.unwrap();
535        let emb2 = model.embed_text(text2).await.unwrap();
536        let emb3 = model.embed_text(text3).await.unwrap();
537
538        let sim_12 = EmbeddingModel::cosine_similarity(&emb1, &emb2);
539        let sim_13 = EmbeddingModel::cosine_similarity(&emb1, &emb3);
540
541        // Similar texts should have higher similarity
542        assert!(sim_12 > sim_13);
543    }
544}