manx_cli/rag/
embeddings.rs

1//! Text embedding generation using configurable embedding providers
2//!
3//! This module provides flexible text embedding functionality for semantic similarity search.
4//! Supports multiple embedding providers: hash-based (default), local models, and API services.
5//! Users can configure their preferred embedding method via `manx config --embedding-provider`.
6
7use crate::rag::providers::{
8    custom, hash, huggingface, ollama, onnx, openai, EmbeddingProvider as ProviderTrait,
9};
10use crate::rag::{EmbeddingConfig, EmbeddingProvider};
11use anyhow::{anyhow, Result};
12
13/// Text embedding model wrapper with configurable providers
14/// Supports hash-based embeddings (default), local ONNX models, and API services.
15/// Users can configure their preferred embedding method via `manx config`.
16pub struct EmbeddingModel {
17    provider: Box<dyn ProviderTrait + Send + Sync>,
18    config: EmbeddingConfig,
19}
20
21impl EmbeddingModel {
22    /// Create a new embedding model with default hash-based provider
23    pub async fn new() -> Result<Self> {
24        Self::new_with_config(EmbeddingConfig::default()).await
25    }
26
27    /// Create a new embedding model with custom configuration
28    pub async fn new_with_config(config: EmbeddingConfig) -> Result<Self> {
29        log::info!(
30            "Initializing embedding model with provider: {:?}",
31            config.provider
32        );
33
34        let provider: Box<dyn ProviderTrait + Send + Sync> = match &config.provider {
35            EmbeddingProvider::Hash => {
36                log::info!("Using hash-based embeddings (default provider)");
37                Box::new(hash::HashProvider::new(384)) // Hash provider always uses 384 dimensions
38            }
39            EmbeddingProvider::Onnx(model_name) => {
40                log::info!("Loading ONNX model: {}", model_name);
41                let onnx_provider = onnx::OnnxProvider::new(model_name).await?;
42                Box::new(onnx_provider)
43            }
44            EmbeddingProvider::Ollama(model_name) => {
45                log::info!("Connecting to Ollama model: {}", model_name);
46                let ollama_provider =
47                    ollama::OllamaProvider::new(model_name.clone(), config.endpoint.clone());
48                // Test connection
49                ollama_provider.health_check().await?;
50                Box::new(ollama_provider)
51            }
52            EmbeddingProvider::OpenAI(model_name) => {
53                log::info!("Connecting to OpenAI model: {}", model_name);
54                let api_key = config.api_key.as_ref().ok_or_else(|| {
55                    anyhow!("OpenAI API key required. Use 'manx config --embedding-api-key <key>'")
56                })?;
57                let openai_provider =
58                    openai::OpenAiProvider::new(api_key.clone(), model_name.clone());
59                Box::new(openai_provider)
60            }
61            EmbeddingProvider::HuggingFace(model_name) => {
62                log::info!("Connecting to HuggingFace model: {}", model_name);
63                let api_key = config.api_key.as_ref().ok_or_else(|| {
64                    anyhow!(
65                        "HuggingFace API key required. Use 'manx config --embedding-api-key <key>'"
66                    )
67                })?;
68                let hf_provider =
69                    huggingface::HuggingFaceProvider::new(api_key.clone(), model_name.clone());
70                Box::new(hf_provider)
71            }
72            EmbeddingProvider::Custom(endpoint) => {
73                log::info!("Connecting to custom endpoint: {}", endpoint);
74                let custom_provider =
75                    custom::CustomProvider::new(endpoint.clone(), config.api_key.clone());
76                Box::new(custom_provider)
77            }
78        };
79
80        Ok(Self { provider, config })
81    }
82
83    /// Generate embeddings for a single text using configured provider
84    pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
85        if text.trim().is_empty() {
86            return Err(anyhow!("Cannot embed empty text"));
87        }
88
89        self.provider.embed_text(text).await
90    }
91
92    /// Get the dimension of embeddings produced by this model
93    pub async fn get_dimension(&self) -> Result<usize> {
94        self.provider.get_dimension().await
95    }
96
97    /// Test if the embedding model is working correctly
98    pub async fn health_check(&self) -> Result<()> {
99        self.provider.health_check().await
100    }
101
102    /// Get information about the current provider
103    pub fn get_provider_info(&self) -> crate::rag::providers::ProviderInfo {
104        self.provider.get_info()
105    }
106
107    /// Get the current configuration
108    pub fn get_config(&self) -> &EmbeddingConfig {
109        &self.config
110    }
111
112    /// Calculate cosine similarity between two embeddings
113    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
114        if a.len() != b.len() {
115            return 0.0;
116        }
117
118        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
119        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
120        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
121
122        if norm_a == 0.0 || norm_b == 0.0 {
123            0.0
124        } else {
125            dot_product / (norm_a * norm_b)
126        }
127    }
128}
129
130/// Utility functions for text preprocessing
131pub mod preprocessing {
132    /// Clean and normalize text for embedding
133    pub fn clean_text(text: &str) -> String {
134        // Remove excessive whitespace
135        let cleaned = text
136            .lines()
137            .map(|line| line.trim())
138            .filter(|line| !line.is_empty())
139            .collect::<Vec<_>>()
140            .join(" ")
141            .split_whitespace()
142            .collect::<Vec<_>>()
143            .join(" ");
144
145        // Limit length to prevent very long embeddings
146        const MAX_LENGTH: usize = 2048;
147        if cleaned.len() > MAX_LENGTH {
148            format!("{}...", &cleaned[..MAX_LENGTH])
149        } else {
150            cleaned
151        }
152    }
153
154    /// Split text into chunks suitable for embedding
155    pub fn chunk_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
156        let words: Vec<&str> = text.split_whitespace().collect();
157        let mut chunks = Vec::new();
158
159        if words.len() <= chunk_size {
160            chunks.push(text.to_string());
161            return chunks;
162        }
163
164        let mut start = 0;
165        while start < words.len() {
166            let end = std::cmp::min(start + chunk_size, words.len());
167            let chunk = words[start..end].join(" ");
168            chunks.push(chunk);
169
170            if end == words.len() {
171                break;
172            }
173
174            start = end - overlap;
175        }
176
177        chunks
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[tokio::test]
186    async fn test_embedding_model() {
187        let model = EmbeddingModel::new().await.unwrap();
188
189        let text = "This is a test sentence for embedding.";
190        let embedding = model.embed_text(text).await.unwrap();
191
192        assert_eq!(embedding.len(), 384); // Hash provider default
193        assert!(embedding.iter().any(|&x| x != 0.0));
194    }
195
196    #[test]
197    fn test_cosine_similarity() {
198        let a = vec![1.0, 2.0, 3.0];
199        let b = vec![1.0, 2.0, 3.0];
200        let similarity = EmbeddingModel::cosine_similarity(&a, &b);
201        assert!((similarity - 1.0).abs() < 0.001);
202
203        let c = vec![-1.0, -2.0, -3.0];
204        let similarity2 = EmbeddingModel::cosine_similarity(&a, &c);
205        assert!((similarity2 + 1.0).abs() < 0.001);
206    }
207
208    #[test]
209    fn test_text_preprocessing() {
210        let text = "  This is   a test\n\n  with  multiple   lines  \n  ";
211        let cleaned = preprocessing::clean_text(text);
212        assert_eq!(cleaned, "This is a test with multiple lines");
213    }
214
215    #[test]
216    fn test_text_chunking() {
217        let text = "one two three four five six seven eight nine ten";
218        let chunks = preprocessing::chunk_text(text, 3, 1);
219
220        assert_eq!(chunks.len(), 5);
221        assert_eq!(chunks[0], "one two three");
222        assert_eq!(chunks[1], "three four five");
223        assert_eq!(chunks[2], "five six seven");
224        assert_eq!(chunks[3], "seven eight nine");
225        assert_eq!(chunks[4], "nine ten");
226    }
227
228    #[tokio::test]
229    async fn test_similarity_detection() {
230        let model = EmbeddingModel::new().await.unwrap();
231
232        let text1 = "React hooks useState";
233        let text2 = "useState React hooks";
234        let text3 = "Python Django models";
235
236        let emb1 = model.embed_text(text1).await.unwrap();
237        let emb2 = model.embed_text(text2).await.unwrap();
238        let emb3 = model.embed_text(text3).await.unwrap();
239
240        let sim_12 = EmbeddingModel::cosine_similarity(&emb1, &emb2);
241        let sim_13 = EmbeddingModel::cosine_similarity(&emb1, &emb3);
242
243        // Similar texts should have higher similarity
244        assert!(sim_12 > sim_13);
245    }
246}