manx_cli/rag/
embeddings.rs1use crate::rag::providers::{
8 custom, hash, huggingface, ollama, onnx, openai, EmbeddingProvider as ProviderTrait,
9};
10use crate::rag::{EmbeddingConfig, EmbeddingProvider};
11use anyhow::{anyhow, Result};
12
13pub struct EmbeddingModel {
17 provider: Box<dyn ProviderTrait + Send + Sync>,
18 config: EmbeddingConfig,
19}
20
21impl EmbeddingModel {
22 pub async fn new() -> Result<Self> {
24 Self::new_with_config(EmbeddingConfig::default()).await
25 }
26
27 pub async fn new_with_config(config: EmbeddingConfig) -> Result<Self> {
29 log::info!(
30 "Initializing embedding model with provider: {:?}",
31 config.provider
32 );
33
34 let provider: Box<dyn ProviderTrait + Send + Sync> = match &config.provider {
35 EmbeddingProvider::Hash => {
36 log::info!("Using hash-based embeddings (default provider)");
37 Box::new(hash::HashProvider::new(384)) }
39 EmbeddingProvider::Onnx(model_name) => {
40 log::info!("Loading ONNX model: {}", model_name);
41 let onnx_provider = onnx::OnnxProvider::new(model_name).await?;
42 Box::new(onnx_provider)
43 }
44 EmbeddingProvider::Ollama(model_name) => {
45 log::info!("Connecting to Ollama model: {}", model_name);
46 let ollama_provider =
47 ollama::OllamaProvider::new(model_name.clone(), config.endpoint.clone());
48 ollama_provider.health_check().await?;
50 Box::new(ollama_provider)
51 }
52 EmbeddingProvider::OpenAI(model_name) => {
53 log::info!("Connecting to OpenAI model: {}", model_name);
54 let api_key = config.api_key.as_ref().ok_or_else(|| {
55 anyhow!("OpenAI API key required. Use 'manx config --embedding-api-key <key>'")
56 })?;
57 let openai_provider =
58 openai::OpenAiProvider::new(api_key.clone(), model_name.clone());
59 Box::new(openai_provider)
60 }
61 EmbeddingProvider::HuggingFace(model_name) => {
62 log::info!("Connecting to HuggingFace model: {}", model_name);
63 let api_key = config.api_key.as_ref().ok_or_else(|| {
64 anyhow!(
65 "HuggingFace API key required. Use 'manx config --embedding-api-key <key>'"
66 )
67 })?;
68 let hf_provider =
69 huggingface::HuggingFaceProvider::new(api_key.clone(), model_name.clone());
70 Box::new(hf_provider)
71 }
72 EmbeddingProvider::Custom(endpoint) => {
73 log::info!("Connecting to custom endpoint: {}", endpoint);
74 let custom_provider =
75 custom::CustomProvider::new(endpoint.clone(), config.api_key.clone());
76 Box::new(custom_provider)
77 }
78 };
79
80 Ok(Self { provider, config })
81 }
82
83 pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
85 if text.trim().is_empty() {
86 return Err(anyhow!("Cannot embed empty text"));
87 }
88
89 self.provider.embed_text(text).await
90 }
91
92 pub async fn get_dimension(&self) -> Result<usize> {
94 self.provider.get_dimension().await
95 }
96
97 pub async fn health_check(&self) -> Result<()> {
99 self.provider.health_check().await
100 }
101
102 pub fn get_provider_info(&self) -> crate::rag::providers::ProviderInfo {
104 self.provider.get_info()
105 }
106
107 pub fn get_config(&self) -> &EmbeddingConfig {
109 &self.config
110 }
111
112 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
114 if a.len() != b.len() {
115 return 0.0;
116 }
117
118 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
119 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
120 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
121
122 if norm_a == 0.0 || norm_b == 0.0 {
123 0.0
124 } else {
125 dot_product / (norm_a * norm_b)
126 }
127 }
128}
129
130pub mod preprocessing {
132 pub fn clean_text(text: &str) -> String {
134 let cleaned = text
136 .lines()
137 .map(|line| line.trim())
138 .filter(|line| !line.is_empty())
139 .collect::<Vec<_>>()
140 .join(" ")
141 .split_whitespace()
142 .collect::<Vec<_>>()
143 .join(" ");
144
145 const MAX_LENGTH: usize = 2048;
147 if cleaned.len() > MAX_LENGTH {
148 format!("{}...", &cleaned[..MAX_LENGTH])
149 } else {
150 cleaned
151 }
152 }
153
154 pub fn chunk_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
156 let words: Vec<&str> = text.split_whitespace().collect();
157 let mut chunks = Vec::new();
158
159 if words.len() <= chunk_size {
160 chunks.push(text.to_string());
161 return chunks;
162 }
163
164 let mut start = 0;
165 while start < words.len() {
166 let end = std::cmp::min(start + chunk_size, words.len());
167 let chunk = words[start..end].join(" ");
168 chunks.push(chunk);
169
170 if end == words.len() {
171 break;
172 }
173
174 start = end - overlap;
175 }
176
177 chunks
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[tokio::test]
186 async fn test_embedding_model() {
187 let model = EmbeddingModel::new().await.unwrap();
188
189 let text = "This is a test sentence for embedding.";
190 let embedding = model.embed_text(text).await.unwrap();
191
192 assert_eq!(embedding.len(), 384); assert!(embedding.iter().any(|&x| x != 0.0));
194 }
195
196 #[test]
197 fn test_cosine_similarity() {
198 let a = vec![1.0, 2.0, 3.0];
199 let b = vec![1.0, 2.0, 3.0];
200 let similarity = EmbeddingModel::cosine_similarity(&a, &b);
201 assert!((similarity - 1.0).abs() < 0.001);
202
203 let c = vec![-1.0, -2.0, -3.0];
204 let similarity2 = EmbeddingModel::cosine_similarity(&a, &c);
205 assert!((similarity2 + 1.0).abs() < 0.001);
206 }
207
208 #[test]
209 fn test_text_preprocessing() {
210 let text = " This is a test\n\n with multiple lines \n ";
211 let cleaned = preprocessing::clean_text(text);
212 assert_eq!(cleaned, "This is a test with multiple lines");
213 }
214
215 #[test]
216 fn test_text_chunking() {
217 let text = "one two three four five six seven eight nine ten";
218 let chunks = preprocessing::chunk_text(text, 3, 1);
219
220 assert_eq!(chunks.len(), 5);
221 assert_eq!(chunks[0], "one two three");
222 assert_eq!(chunks[1], "three four five");
223 assert_eq!(chunks[2], "five six seven");
224 assert_eq!(chunks[3], "seven eight nine");
225 assert_eq!(chunks[4], "nine ten");
226 }
227
228 #[tokio::test]
229 async fn test_similarity_detection() {
230 let model = EmbeddingModel::new().await.unwrap();
231
232 let text1 = "React hooks useState";
233 let text2 = "useState React hooks";
234 let text3 = "Python Django models";
235
236 let emb1 = model.embed_text(text1).await.unwrap();
237 let emb2 = model.embed_text(text2).await.unwrap();
238 let emb3 = model.embed_text(text3).await.unwrap();
239
240 let sim_12 = EmbeddingModel::cosine_similarity(&emb1, &emb2);
241 let sim_13 = EmbeddingModel::cosine_similarity(&emb1, &emb3);
242
243 assert!(sim_12 > sim_13);
245 }
246}