1use crate::rag::providers::{
8 custom, hash, huggingface, ollama, onnx, openai, EmbeddingProvider as ProviderTrait,
9};
10use crate::rag::{EmbeddingConfig, EmbeddingProvider};
11use anyhow::{anyhow, Result};
12
13pub struct EmbeddingModel {
17 provider: Box<dyn ProviderTrait + Send + Sync>,
18 config: EmbeddingConfig,
19}
20
21impl EmbeddingModel {
22 pub async fn new() -> Result<Self> {
24 Self::new_with_config(EmbeddingConfig::default()).await
25 }
26
27 pub async fn new_auto_select() -> Result<Self> {
29 let best_config = Self::auto_select_best_provider().await?;
30 Self::new_with_config(best_config).await
31 }
32
33 pub async fn new_with_config(config: EmbeddingConfig) -> Result<Self> {
35 log::info!(
36 "Initializing embedding model with provider: {:?}",
37 config.provider
38 );
39
40 let provider: Box<dyn ProviderTrait + Send + Sync> = match &config.provider {
41 EmbeddingProvider::Hash => {
42 log::info!("Using hash-based embeddings (default provider)");
43 Box::new(hash::HashProvider::new(384)) }
45 EmbeddingProvider::Onnx(model_name) => {
46 log::info!("Loading ONNX model: {}", model_name);
47 let onnx_provider = onnx::OnnxProvider::new(model_name).await?;
48 Box::new(onnx_provider)
49 }
50 EmbeddingProvider::Ollama(model_name) => {
51 log::info!("Connecting to Ollama model: {}", model_name);
52 let ollama_provider =
53 ollama::OllamaProvider::new(model_name.clone(), config.endpoint.clone());
54 ollama_provider.health_check().await?;
56 Box::new(ollama_provider)
57 }
58 EmbeddingProvider::OpenAI(model_name) => {
59 log::info!("Connecting to OpenAI model: {}", model_name);
60 let api_key = config.api_key.as_ref().ok_or_else(|| {
61 anyhow!("OpenAI API key required. Use 'manx config --embedding-api-key <key>'")
62 })?;
63 let openai_provider =
64 openai::OpenAiProvider::new(api_key.clone(), model_name.clone());
65 Box::new(openai_provider)
66 }
67 EmbeddingProvider::HuggingFace(model_name) => {
68 log::info!("Connecting to HuggingFace model: {}", model_name);
69 let api_key = config.api_key.as_ref().ok_or_else(|| {
70 anyhow!(
71 "HuggingFace API key required. Use 'manx config --embedding-api-key <key>'"
72 )
73 })?;
74 let hf_provider =
75 huggingface::HuggingFaceProvider::new(api_key.clone(), model_name.clone());
76 Box::new(hf_provider)
77 }
78 EmbeddingProvider::Custom(endpoint) => {
79 log::info!("Connecting to custom endpoint: {}", endpoint);
80 let custom_provider =
81 custom::CustomProvider::new(endpoint.clone(), config.api_key.clone());
82 Box::new(custom_provider)
83 }
84 };
85
86 Ok(Self { provider, config })
87 }
88
89 pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
91 if text.trim().is_empty() {
92 return Err(anyhow!("Cannot embed empty text"));
93 }
94
95 self.provider.embed_text(text).await
96 }
97
98 pub async fn get_dimension(&self) -> Result<usize> {
100 self.provider.get_dimension().await
101 }
102
103 pub async fn health_check(&self) -> Result<()> {
105 self.provider.health_check().await
106 }
107
108 pub fn get_provider_info(&self) -> crate::rag::providers::ProviderInfo {
110 self.provider.get_info()
111 }
112
113 pub fn get_config(&self) -> &EmbeddingConfig {
115 &self.config
116 }
117
118 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
120 if a.len() != b.len() {
121 return 0.0;
122 }
123
124 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
125 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
126 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
127
128 if norm_a == 0.0 || norm_b == 0.0 {
129 0.0
130 } else {
131 dot_product / (norm_a * norm_b)
132 }
133 }
134
135 pub async fn auto_select_best_provider() -> Result<EmbeddingConfig> {
138 log::info!("Auto-selecting best available embedding provider from installed models...");
139
140 if let Ok(available_models) = Self::get_available_onnx_models().await {
142 if !available_models.is_empty() {
143 let selected_model = &available_models[0];
145 log::info!("Auto-selected installed ONNX model: {}", selected_model);
146
147 if let Ok(test_config) = Self::create_config_for_model(selected_model).await {
149 return Ok(test_config);
150 }
151 }
152 }
153
154 log::info!("No ONNX models found, using hash-based embeddings");
156 Ok(EmbeddingConfig::default())
157 }
158
159 async fn get_available_onnx_models() -> Result<Vec<String>> {
161 let potential_models = [
164 "sentence-transformers/all-MiniLM-L6-v2",
165 "sentence-transformers/all-mpnet-base-v2",
166 "BAAI/bge-base-en-v1.5",
167 "BAAI/bge-small-en-v1.5",
168 "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
169 ];
170
171 let mut available = Vec::new();
172 for model in &potential_models {
173 if Self::is_onnx_model_available(model).await {
174 available.push(model.to_string());
175 }
176 }
177
178 Ok(available)
179 }
180
181 async fn create_config_for_model(model_name: &str) -> Result<EmbeddingConfig> {
183 match onnx::OnnxProvider::new(model_name).await {
185 Ok(provider) => {
186 let dimension = provider.get_dimension().await.unwrap_or(384);
187 Ok(EmbeddingConfig {
188 provider: EmbeddingProvider::Onnx(model_name.to_string()),
189 dimension,
190 ..EmbeddingConfig::default()
191 })
192 }
193 Err(e) => Err(anyhow!(
194 "Failed to create config for model {}: {}",
195 model_name,
196 e
197 )),
198 }
199 }
200
201 async fn is_onnx_model_available(model_name: &str) -> bool {
203 match onnx::OnnxProvider::new(model_name).await {
205 Ok(_) => {
206 log::debug!("ONNX model '{}' is available", model_name);
207 true
208 }
209 Err(e) => {
210 log::debug!("ONNX model '{}' not available: {}", model_name, e);
211 false
212 }
213 }
214 }
215}
216
217pub mod preprocessing {
219 pub fn clean_text(text: &str) -> String {
221 if is_code_content(text) {
223 clean_code_text(text)
224 } else {
225 clean_regular_text(text)
226 }
227 }
228
229 fn clean_regular_text(text: &str) -> String {
231 let cleaned = text
233 .lines()
234 .map(|line| line.trim())
235 .filter(|line| !line.is_empty())
236 .collect::<Vec<_>>()
237 .join(" ")
238 .split_whitespace()
239 .collect::<Vec<_>>()
240 .join(" ");
241
242 const MAX_LENGTH: usize = 2048;
244 if cleaned.len() > MAX_LENGTH {
245 format!("{}...", &cleaned[..MAX_LENGTH])
246 } else {
247 cleaned
248 }
249 }
250
251 fn clean_code_text(text: &str) -> String {
253 let mut cleaned = String::new();
254 let mut in_comment_block = false;
255
256 for line in text.lines() {
257 let trimmed = line.trim();
258
259 if trimmed.is_empty() && !cleaned.is_empty() {
261 continue;
262 }
263
264 if trimmed.starts_with("/*") {
266 in_comment_block = true;
267 }
268 if in_comment_block {
269 if trimmed.ends_with("*/") {
270 in_comment_block = false;
271 }
272 cleaned.push_str("// ");
273 cleaned.push_str(trimmed);
274 cleaned.push('\n');
275 continue;
276 }
277
278 if is_important_code_line(trimmed) {
280 let indent_level = line.len() - line.trim_start().len();
282 let normalized_indent = " ".repeat((indent_level / 2).min(4));
283 cleaned.push_str(&normalized_indent);
284 cleaned.push_str(trimmed);
285 cleaned.push('\n');
286 }
287 }
288
289 const MAX_CODE_LENGTH: usize = 3000;
291 if cleaned.len() > MAX_CODE_LENGTH {
292 format!("{}...", &cleaned[..MAX_CODE_LENGTH])
293 } else {
294 cleaned
295 }
296 }
297
298 fn is_code_content(text: &str) -> bool {
300 let code_indicators = [
301 "function",
302 "const",
303 "let",
304 "var",
305 "def",
306 "class",
307 "import",
308 "export",
309 "public",
310 "private",
311 "protected",
312 "return",
313 "if (",
314 "for (",
315 "while (",
316 "=>",
317 "->",
318 "::",
319 "<?php",
320 "#!/",
321 "package",
322 "namespace",
323 "struct",
324 ];
325
326 let text_lower = text.to_lowercase();
327 let indicator_count = code_indicators
328 .iter()
329 .filter(|&&ind| text_lower.contains(ind))
330 .count();
331
332 indicator_count >= 3
334 }
335
336 fn is_important_code_line(line: &str) -> bool {
338 if line.starts_with("//") && !line.starts_with("///") && !line.starts_with("//!") {
340 return false;
341 }
342
343 let important_patterns = [
345 "import ",
346 "from ",
347 "require",
348 "include",
349 "function ",
350 "def ",
351 "fn ",
352 "func ",
353 "class ",
354 "struct ",
355 "interface ",
356 "enum ",
357 "public ",
358 "private ",
359 "protected ",
360 "export ",
361 "module ",
362 "namespace ",
363 ];
364
365 for pattern in &important_patterns {
366 if line.contains(pattern) {
367 return true;
368 }
369 }
370
371 !line
373 .chars()
374 .all(|c| c == '{' || c == '}' || c == '(' || c == ')' || c == ';' || c.is_whitespace())
375 }
376
377 pub fn chunk_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
379 if is_code_content(text) {
381 chunk_code_text(text, chunk_size, overlap)
382 } else {
383 chunk_regular_text(text, chunk_size, overlap)
384 }
385 }
386
387 fn chunk_regular_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
389 let words: Vec<&str> = text.split_whitespace().collect();
390 let mut chunks = Vec::new();
391
392 if words.len() <= chunk_size {
393 chunks.push(text.to_string());
394 return chunks;
395 }
396
397 let mut start = 0;
398 while start < words.len() {
399 let end = std::cmp::min(start + chunk_size, words.len());
400 let chunk = words[start..end].join(" ");
401 chunks.push(chunk);
402
403 if end == words.len() {
404 break;
405 }
406
407 start = end - overlap;
408 }
409
410 chunks
411 }
412
413 fn chunk_code_text(text: &str, chunk_size: usize, _overlap: usize) -> Vec<String> {
415 let mut chunks = Vec::new();
416 let mut current_chunk = String::new();
417 let mut current_size = 0;
418 let mut brace_depth = 0;
419 let mut in_function = false;
420
421 for line in text.lines() {
422 let trimmed = line.trim();
423
424 if trimmed.contains("function ")
426 || trimmed.contains("def ")
427 || trimmed.contains("class ")
428 || trimmed.contains("fn ")
429 {
430 in_function = true;
431
432 if current_size > chunk_size / 2 && brace_depth == 0 && !current_chunk.is_empty() {
434 chunks.push(current_chunk.clone());
435 current_chunk.clear();
436 current_size = 0;
437 }
438 }
439
440 brace_depth += trimmed.chars().filter(|&c| c == '{').count() as i32;
442 brace_depth -= trimmed.chars().filter(|&c| c == '}').count() as i32;
443 brace_depth = brace_depth.max(0);
444
445 current_chunk.push_str(line);
447 current_chunk.push('\n');
448 current_size += line.split_whitespace().count();
449
450 if current_size >= chunk_size && brace_depth == 0 && !in_function {
452 chunks.push(current_chunk.clone());
453 current_chunk.clear();
454 current_size = 0;
455 }
456
457 if in_function && brace_depth == 0 && trimmed.ends_with('}') {
459 in_function = false;
460 }
461 }
462
463 if !current_chunk.trim().is_empty() {
465 chunks.push(current_chunk);
466 }
467
468 if chunks.is_empty() {
470 return chunk_regular_text(text, chunk_size, chunk_size / 10);
471 }
472
473 chunks
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[tokio::test]
482 async fn test_embedding_model() {
483 let model = EmbeddingModel::new().await.unwrap();
484
485 let text = "This is a test sentence for embedding.";
486 let embedding = model.embed_text(text).await.unwrap();
487
488 assert_eq!(embedding.len(), 384); assert!(embedding.iter().any(|&x| x != 0.0));
490 }
491
492 #[test]
493 fn test_cosine_similarity() {
494 let a = vec![1.0, 2.0, 3.0];
495 let b = vec![1.0, 2.0, 3.0];
496 let similarity = EmbeddingModel::cosine_similarity(&a, &b);
497 assert!((similarity - 1.0).abs() < 0.001);
498
499 let c = vec![-1.0, -2.0, -3.0];
500 let similarity2 = EmbeddingModel::cosine_similarity(&a, &c);
501 assert!((similarity2 + 1.0).abs() < 0.001);
502 }
503
504 #[test]
505 fn test_text_preprocessing() {
506 let text = " This is a test\n\n with multiple lines \n ";
507 let cleaned = preprocessing::clean_text(text);
508 assert_eq!(cleaned, "This is a test with multiple lines");
509 }
510
511 #[test]
512 fn test_text_chunking() {
513 let text = "one two three four five six seven eight nine ten";
514 let chunks = preprocessing::chunk_text(text, 3, 1);
515
516 assert_eq!(chunks.len(), 5);
517 assert_eq!(chunks[0], "one two three");
518 assert_eq!(chunks[1], "three four five");
519 assert_eq!(chunks[2], "five six seven");
520 assert_eq!(chunks[3], "seven eight nine");
521 assert_eq!(chunks[4], "nine ten");
522 }
523
524 #[tokio::test]
525 async fn test_similarity_detection() {
526 let model = EmbeddingModel::new().await.unwrap();
527
528 let text1 = "React hooks useState";
529 let text2 = "useState React hooks";
530 let text3 = "Python Django models";
531
532 let emb1 = model.embed_text(text1).await.unwrap();
533 let emb2 = model.embed_text(text2).await.unwrap();
534 let emb3 = model.embed_text(text3).await.unwrap();
535
536 let sim_12 = EmbeddingModel::cosine_similarity(&emb1, &emb2);
537 let sim_13 = EmbeddingModel::cosine_similarity(&emb1, &emb3);
538
539 assert!(sim_12 > sim_13);
541 }
542}