1use crate::rag::providers::{
8 custom, hash, huggingface, ollama, onnx, openai, EmbeddingProvider as ProviderTrait,
9};
10use crate::rag::{EmbeddingConfig, EmbeddingProvider};
11use anyhow::{anyhow, Result};
12use lru::LruCache;
13use std::collections::hash_map::DefaultHasher;
14use std::hash::{Hash, Hasher};
15use std::num::NonZeroUsize;
16use std::sync::Mutex;
17
18pub struct EmbeddingModel {
22 provider: Box<dyn ProviderTrait + Send + Sync>,
23 config: EmbeddingConfig,
24 cache: Mutex<LruCache<u64, Vec<f32>>>,
26}
27
28impl EmbeddingModel {
29 pub async fn new() -> Result<Self> {
31 Self::new_with_config(EmbeddingConfig::default()).await
32 }
33
34 pub async fn new_auto_select() -> Result<Self> {
36 let best_config = Self::auto_select_best_provider().await?;
37 Self::new_with_config(best_config).await
38 }
39
40 pub async fn new_with_config(config: EmbeddingConfig) -> Result<Self> {
42 log::info!(
43 "Initializing embedding model with provider: {:?}",
44 config.provider
45 );
46
47 let provider: Box<dyn ProviderTrait + Send + Sync> = match &config.provider {
48 EmbeddingProvider::Hash => {
49 log::info!("Using hash-based embeddings (default provider)");
50 Box::new(hash::HashProvider::new(384)) }
52 EmbeddingProvider::Onnx(model_name) => {
53 log::info!("Loading ONNX model: {}", model_name);
54 let onnx_provider = onnx::OnnxProvider::new(model_name).await?;
55 Box::new(onnx_provider)
56 }
57 EmbeddingProvider::Ollama(model_name) => {
58 log::info!("Connecting to Ollama model: {}", model_name);
59 let ollama_provider =
60 ollama::OllamaProvider::new(model_name.clone(), config.endpoint.clone());
61 ollama_provider.health_check().await?;
63 Box::new(ollama_provider)
64 }
65 EmbeddingProvider::OpenAI(model_name) => {
66 log::info!("Connecting to OpenAI model: {}", model_name);
67 let api_key = config.api_key.as_ref().ok_or_else(|| {
68 anyhow!("OpenAI API key required. Use 'manx config --embedding-api-key <key>'")
69 })?;
70 let openai_provider =
71 openai::OpenAiProvider::new(api_key.clone(), model_name.clone());
72 Box::new(openai_provider)
73 }
74 EmbeddingProvider::HuggingFace(model_name) => {
75 log::info!("Connecting to HuggingFace model: {}", model_name);
76 let api_key = config.api_key.as_ref().ok_or_else(|| {
77 anyhow!(
78 "HuggingFace API key required. Use 'manx config --embedding-api-key <key>'"
79 )
80 })?;
81 let hf_provider =
82 huggingface::HuggingFaceProvider::new(api_key.clone(), model_name.clone());
83 Box::new(hf_provider)
84 }
85 EmbeddingProvider::Custom(endpoint) => {
86 log::info!("Connecting to custom endpoint: {}", endpoint);
87 let custom_provider =
88 custom::CustomProvider::new(endpoint.clone(), config.api_key.clone());
89 Box::new(custom_provider)
90 }
91 };
92
93 let cache_capacity = NonZeroUsize::new(1000).unwrap();
95 let cache = Mutex::new(LruCache::new(cache_capacity));
96
97 Ok(Self {
98 provider,
99 config,
100 cache,
101 })
102 }
103
104 pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
106 if text.trim().is_empty() {
107 return Err(anyhow!("Cannot embed empty text"));
108 }
109
110 let text_hash = Self::hash_text(text);
112
113 {
115 let mut cache = self.cache.lock().unwrap();
116 if let Some(cached_embedding) = cache.get(&text_hash) {
117 log::debug!("Cache hit for text embedding");
118 return Ok(cached_embedding.clone());
119 }
120 }
121
122 log::debug!("Cache miss for text embedding, generating...");
124 let embedding = Self::retry_with_backoff(
125 || async { self.provider.embed_text(text).await },
126 3, )
128 .await?;
129
130 {
132 let mut cache = self.cache.lock().unwrap();
133 cache.put(text_hash, embedding.clone());
134 }
135
136 Ok(embedding)
137 }
138
139 async fn retry_with_backoff<F, Fut, T>(mut operation: F, max_retries: u32) -> Result<T>
141 where
142 F: FnMut() -> Fut,
143 Fut: std::future::Future<Output = Result<T>>,
144 {
145 let mut retries = 0;
146 loop {
147 match operation().await {
148 Ok(result) => return Ok(result),
149 Err(e) => {
150 retries += 1;
151 if retries > max_retries {
152 log::error!("Operation failed after {} retries: {}", max_retries, e);
153 return Err(e);
154 }
155
156 let delay_ms = 100 * (2_u64.pow(retries - 1)); log::warn!(
158 "Operation failed (attempt {}/{}), retrying in {}ms: {}",
159 retries,
160 max_retries,
161 delay_ms,
162 e
163 );
164 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
165 }
166 }
167 }
168 }
169
170 fn hash_text(text: &str) -> u64 {
172 let mut hasher = DefaultHasher::new();
173 text.hash(&mut hasher);
174 hasher.finish()
175 }
176
177 pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
180 if texts.is_empty() {
181 return Ok(vec![]);
182 }
183
184 match &self.config.provider {
186 EmbeddingProvider::Onnx(_) => {
187 log::debug!(
190 "Using ONNX native batch processing for {} texts",
191 texts.len()
192 );
193
194 if let Some(onnx_provider) =
197 self.provider.as_any().downcast_ref::<onnx::OnnxProvider>()
198 {
199 return onnx_provider.embed_batch(texts).await;
200 }
201
202 log::warn!("Failed to downcast ONNX provider, using sequential processing");
204 self.embed_batch_sequential(texts).await
205 }
206 _ => {
207 self.embed_batch_sequential(texts).await
209 }
210 }
211 }
212
213 async fn embed_batch_sequential(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
215 let mut embeddings = Vec::with_capacity(texts.len());
216 let mut failed_count = 0;
217
218 for (i, text) in texts.iter().enumerate() {
219 match self.embed_text(text).await {
220 Ok(embedding) => embeddings.push(embedding),
221 Err(e) => {
222 log::warn!("Failed to embed text {} in batch: {}", i, e);
223 failed_count += 1;
224 continue;
225 }
226 }
227 }
228
229 if embeddings.is_empty() {
230 return Err(anyhow!(
231 "Batch embedding failed for all {} texts",
232 texts.len()
233 ));
234 }
235
236 if failed_count > 0 {
237 log::warn!(
238 "Batch embedding completed with {} failures out of {} texts",
239 failed_count,
240 texts.len()
241 );
242 }
243
244 Ok(embeddings)
245 }
246
247 pub async fn get_dimension(&self) -> Result<usize> {
249 self.provider.get_dimension().await
250 }
251
252 pub async fn health_check(&self) -> Result<()> {
254 self.provider.health_check().await
255 }
256
257 pub fn get_provider_info(&self) -> crate::rag::providers::ProviderInfo {
259 self.provider.get_info()
260 }
261
262 pub fn get_config(&self) -> &EmbeddingConfig {
264 &self.config
265 }
266
267 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
269 if a.len() != b.len() {
270 return 0.0;
271 }
272
273 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
274 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
275 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
276
277 if norm_a == 0.0 || norm_b == 0.0 {
278 0.0
279 } else {
280 dot_product / (norm_a * norm_b)
281 }
282 }
283
284 pub async fn auto_select_best_provider() -> Result<EmbeddingConfig> {
287 log::info!("Auto-selecting best available embedding provider from installed models...");
288
289 if let Ok(available_models) = Self::get_available_onnx_models().await {
291 if !available_models.is_empty() {
292 let selected_model = &available_models[0];
294 log::info!("Auto-selected installed ONNX model: {}", selected_model);
295
296 if let Ok(test_config) = Self::create_config_for_model(selected_model).await {
298 return Ok(test_config);
299 }
300 }
301 }
302
303 log::info!("No ONNX models found, using hash-based embeddings");
305 Ok(EmbeddingConfig::default())
306 }
307
308 async fn get_available_onnx_models() -> Result<Vec<String>> {
310 let potential_models = [
313 "sentence-transformers/all-MiniLM-L6-v2",
314 "sentence-transformers/all-mpnet-base-v2",
315 "BAAI/bge-base-en-v1.5",
316 "BAAI/bge-small-en-v1.5",
317 "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
318 ];
319
320 let mut available = Vec::new();
321 for model in &potential_models {
322 if Self::is_onnx_model_available(model).await {
323 available.push(model.to_string());
324 }
325 }
326
327 Ok(available)
328 }
329
330 async fn create_config_for_model(model_name: &str) -> Result<EmbeddingConfig> {
332 match onnx::OnnxProvider::new(model_name).await {
334 Ok(provider) => {
335 let dimension = provider.get_dimension().await.unwrap_or(384);
336 Ok(EmbeddingConfig {
337 provider: EmbeddingProvider::Onnx(model_name.to_string()),
338 dimension,
339 ..EmbeddingConfig::default()
340 })
341 }
342 Err(e) => Err(anyhow!(
343 "Failed to create config for model {}: {}",
344 model_name,
345 e
346 )),
347 }
348 }
349
350 async fn is_onnx_model_available(model_name: &str) -> bool {
352 match onnx::OnnxProvider::new(model_name).await {
354 Ok(_) => {
355 log::debug!("ONNX model '{}' is available", model_name);
356 true
357 }
358 Err(e) => {
359 log::debug!("ONNX model '{}' not available: {}", model_name, e);
360 false
361 }
362 }
363 }
364}
365
366pub mod preprocessing {
368 pub fn clean_text(text: &str) -> String {
370 if is_code_content(text) {
372 clean_code_text(text)
373 } else {
374 clean_regular_text(text)
375 }
376 }
377
378 fn clean_regular_text(text: &str) -> String {
380 let cleaned = text
382 .lines()
383 .map(|line| line.trim())
384 .filter(|line| !line.is_empty())
385 .collect::<Vec<_>>()
386 .join(" ")
387 .split_whitespace()
388 .collect::<Vec<_>>()
389 .join(" ");
390
391 const MAX_CHARS: usize = 2048;
393 if cleaned.chars().count() > MAX_CHARS {
394 let truncated: String = cleaned.chars().take(MAX_CHARS).collect();
395 format!("{}...", truncated)
396 } else {
397 cleaned
398 }
399 }
400
401 fn clean_code_text(text: &str) -> String {
403 let mut cleaned = String::new();
404 let mut in_comment_block = false;
405
406 for line in text.lines() {
407 let trimmed = line.trim();
408
409 if trimmed.is_empty() && !cleaned.is_empty() {
411 continue;
412 }
413
414 if trimmed.starts_with("/*") {
416 in_comment_block = true;
417 }
418 if in_comment_block {
419 if trimmed.ends_with("*/") {
420 in_comment_block = false;
421 }
422 cleaned.push_str("// ");
423 cleaned.push_str(trimmed);
424 cleaned.push('\n');
425 continue;
426 }
427
428 if is_important_code_line(trimmed) {
430 let indent_level = line.len() - line.trim_start().len();
432 let normalized_indent = " ".repeat((indent_level / 2).min(4));
433 cleaned.push_str(&normalized_indent);
434 cleaned.push_str(trimmed);
435 cleaned.push('\n');
436 }
437 }
438
439 const MAX_CODE_CHARS: usize = 3000;
441 if cleaned.chars().count() > MAX_CODE_CHARS {
442 let truncated: String = cleaned.chars().take(MAX_CODE_CHARS).collect();
443 format!("{}...", truncated)
444 } else {
445 cleaned
446 }
447 }
448
449 fn is_code_content(text: &str) -> bool {
451 let code_indicators = [
452 "function",
453 "const",
454 "let",
455 "var",
456 "def",
457 "class",
458 "import",
459 "export",
460 "public",
461 "private",
462 "protected",
463 "return",
464 "if (",
465 "for (",
466 "while (",
467 "=>",
468 "->",
469 "::",
470 "<?php",
471 "#!/",
472 "package",
473 "namespace",
474 "struct",
475 ];
476
477 let text_lower = text.to_lowercase();
478 let indicator_count = code_indicators
479 .iter()
480 .filter(|&&ind| text_lower.contains(ind))
481 .count();
482
483 indicator_count >= 3
485 }
486
487 fn is_important_code_line(line: &str) -> bool {
489 if line.starts_with("//") && !line.starts_with("///") && !line.starts_with("//!") {
491 return false;
492 }
493
494 let important_patterns = [
496 "import ",
497 "from ",
498 "require",
499 "include",
500 "function ",
501 "def ",
502 "fn ",
503 "func ",
504 "class ",
505 "struct ",
506 "interface ",
507 "enum ",
508 "public ",
509 "private ",
510 "protected ",
511 "export ",
512 "module ",
513 "namespace ",
514 ];
515
516 for pattern in &important_patterns {
517 if line.contains(pattern) {
518 return true;
519 }
520 }
521
522 !line
524 .chars()
525 .all(|c| c == '{' || c == '}' || c == '(' || c == ')' || c == ';' || c.is_whitespace())
526 }
527
528 pub fn chunk_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
530 if is_code_content(text) {
532 chunk_code_text(text, chunk_size, overlap)
533 } else {
534 chunk_regular_text(text, chunk_size, overlap)
535 }
536 }
537
538 fn chunk_regular_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
540 let words: Vec<&str> = text.split_whitespace().collect();
541 let mut chunks = Vec::new();
542
543 if words.len() <= chunk_size {
544 chunks.push(text.to_string());
545 return chunks;
546 }
547
548 let mut start = 0;
549 while start < words.len() {
550 let end = std::cmp::min(start + chunk_size, words.len());
551 let chunk = words[start..end].join(" ");
552 chunks.push(chunk);
553
554 if end == words.len() {
555 break;
556 }
557
558 start = end - overlap;
559 }
560
561 chunks
562 }
563
564 fn chunk_code_text(text: &str, chunk_size: usize, _overlap: usize) -> Vec<String> {
566 let mut chunks = Vec::new();
567 let mut current_chunk = String::new();
568 let mut current_size = 0;
569 let mut brace_depth = 0;
570 let mut in_function = false;
571
572 for line in text.lines() {
573 let trimmed = line.trim();
574
575 if trimmed.contains("function ")
577 || trimmed.contains("def ")
578 || trimmed.contains("class ")
579 || trimmed.contains("fn ")
580 {
581 in_function = true;
582
583 if current_size > chunk_size / 2 && brace_depth == 0 && !current_chunk.is_empty() {
585 chunks.push(current_chunk.clone());
586 current_chunk.clear();
587 current_size = 0;
588 }
589 }
590
591 brace_depth += trimmed.chars().filter(|&c| c == '{').count() as i32;
593 brace_depth -= trimmed.chars().filter(|&c| c == '}').count() as i32;
594 brace_depth = brace_depth.max(0);
595
596 current_chunk.push_str(line);
598 current_chunk.push('\n');
599 current_size += line.split_whitespace().count();
600
601 if current_size >= chunk_size && brace_depth == 0 && !in_function {
603 chunks.push(current_chunk.clone());
604 current_chunk.clear();
605 current_size = 0;
606 }
607
608 if in_function && brace_depth == 0 && trimmed.ends_with('}') {
610 in_function = false;
611 }
612 }
613
614 if !current_chunk.trim().is_empty() {
616 chunks.push(current_chunk);
617 }
618
619 if chunks.is_empty() {
621 return chunk_regular_text(text, chunk_size, chunk_size / 10);
622 }
623
624 chunks
625 }
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631
632 #[tokio::test]
633 async fn test_embedding_model() {
634 let model = EmbeddingModel::new().await.unwrap();
635
636 let text = "This is a test sentence for embedding.";
637 let embedding = model.embed_text(text).await.unwrap();
638
639 assert_eq!(embedding.len(), 384); assert!(embedding.iter().any(|&x| x != 0.0));
641 }
642
643 #[test]
644 fn test_cosine_similarity() {
645 let a = vec![1.0, 2.0, 3.0];
646 let b = vec![1.0, 2.0, 3.0];
647 let similarity = EmbeddingModel::cosine_similarity(&a, &b);
648 assert!((similarity - 1.0).abs() < 0.001);
649
650 let c = vec![-1.0, -2.0, -3.0];
651 let similarity2 = EmbeddingModel::cosine_similarity(&a, &c);
652 assert!((similarity2 + 1.0).abs() < 0.001);
653 }
654
655 #[test]
656 fn test_text_preprocessing() {
657 let text = " This is a test\n\n with multiple lines \n ";
658 let cleaned = preprocessing::clean_text(text);
659 assert_eq!(cleaned, "This is a test with multiple lines");
660 }
661
662 #[test]
663 fn test_text_chunking() {
664 let text = "one two three four five six seven eight nine ten";
665 let chunks = preprocessing::chunk_text(text, 3, 1);
666
667 assert_eq!(chunks.len(), 5);
668 assert_eq!(chunks[0], "one two three");
669 assert_eq!(chunks[1], "three four five");
670 assert_eq!(chunks[2], "five six seven");
671 assert_eq!(chunks[3], "seven eight nine");
672 assert_eq!(chunks[4], "nine ten");
673 }
674
675 #[tokio::test]
676 async fn test_similarity_detection() {
677 let model = EmbeddingModel::new().await.unwrap();
678
679 let text1 = "React hooks useState";
680 let text2 = "useState React hooks";
681 let text3 = "Python Django models";
682
683 let emb1 = model.embed_text(text1).await.unwrap();
684 let emb2 = model.embed_text(text2).await.unwrap();
685 let emb3 = model.embed_text(text3).await.unwrap();
686
687 let sim_12 = EmbeddingModel::cosine_similarity(&emb1, &emb2);
688 let sim_13 = EmbeddingModel::cosine_similarity(&emb1, &emb3);
689
690 assert!(sim_12 > sim_13);
692 }
693}