1use crate::rag::providers::{
8 custom, hash, huggingface, ollama, onnx, openai, EmbeddingProvider as ProviderTrait,
9};
10use crate::rag::{EmbeddingConfig, EmbeddingProvider};
11use anyhow::{anyhow, Result};
12
13pub struct EmbeddingModel {
17 provider: Box<dyn ProviderTrait + Send + Sync>,
18 config: EmbeddingConfig,
19}
20
21impl EmbeddingModel {
22 pub async fn new() -> Result<Self> {
24 Self::new_with_config(EmbeddingConfig::default()).await
25 }
26
27 pub async fn new_auto_select() -> Result<Self> {
29 let best_config = Self::auto_select_best_provider().await?;
30 Self::new_with_config(best_config).await
31 }
32
33 pub async fn new_with_config(config: EmbeddingConfig) -> Result<Self> {
35 log::info!(
36 "Initializing embedding model with provider: {:?}",
37 config.provider
38 );
39
40 let provider: Box<dyn ProviderTrait + Send + Sync> = match &config.provider {
41 EmbeddingProvider::Hash => {
42 log::info!("Using hash-based embeddings (default provider)");
43 Box::new(hash::HashProvider::new(384)) }
45 EmbeddingProvider::Onnx(model_name) => {
46 log::info!("Loading ONNX model: {}", model_name);
47 let onnx_provider = onnx::OnnxProvider::new(model_name).await?;
48 Box::new(onnx_provider)
49 }
50 EmbeddingProvider::Ollama(model_name) => {
51 log::info!("Connecting to Ollama model: {}", model_name);
52 let ollama_provider =
53 ollama::OllamaProvider::new(model_name.clone(), config.endpoint.clone());
54 ollama_provider.health_check().await?;
56 Box::new(ollama_provider)
57 }
58 EmbeddingProvider::OpenAI(model_name) => {
59 log::info!("Connecting to OpenAI model: {}", model_name);
60 let api_key = config.api_key.as_ref().ok_or_else(|| {
61 anyhow!("OpenAI API key required. Use 'manx config --embedding-api-key <key>'")
62 })?;
63 let openai_provider =
64 openai::OpenAiProvider::new(api_key.clone(), model_name.clone());
65 Box::new(openai_provider)
66 }
67 EmbeddingProvider::HuggingFace(model_name) => {
68 log::info!("Connecting to HuggingFace model: {}", model_name);
69 let api_key = config.api_key.as_ref().ok_or_else(|| {
70 anyhow!(
71 "HuggingFace API key required. Use 'manx config --embedding-api-key <key>'"
72 )
73 })?;
74 let hf_provider =
75 huggingface::HuggingFaceProvider::new(api_key.clone(), model_name.clone());
76 Box::new(hf_provider)
77 }
78 EmbeddingProvider::Custom(endpoint) => {
79 log::info!("Connecting to custom endpoint: {}", endpoint);
80 let custom_provider =
81 custom::CustomProvider::new(endpoint.clone(), config.api_key.clone());
82 Box::new(custom_provider)
83 }
84 };
85
86 Ok(Self { provider, config })
87 }
88
89 pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
91 if text.trim().is_empty() {
92 return Err(anyhow!("Cannot embed empty text"));
93 }
94
95 self.provider.embed_text(text).await
96 }
97
98 pub async fn get_dimension(&self) -> Result<usize> {
100 self.provider.get_dimension().await
101 }
102
103 pub async fn health_check(&self) -> Result<()> {
105 self.provider.health_check().await
106 }
107
108 pub fn get_provider_info(&self) -> crate::rag::providers::ProviderInfo {
110 self.provider.get_info()
111 }
112
113 pub fn get_config(&self) -> &EmbeddingConfig {
115 &self.config
116 }
117
118 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
120 if a.len() != b.len() {
121 return 0.0;
122 }
123
124 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
125 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
126 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
127
128 if norm_a == 0.0 || norm_b == 0.0 {
129 0.0
130 } else {
131 dot_product / (norm_a * norm_b)
132 }
133 }
134
135 pub async fn auto_select_best_provider() -> Result<EmbeddingConfig> {
138 log::info!("Auto-selecting best available embedding provider from installed models...");
139
140 if let Ok(available_models) = Self::get_available_onnx_models().await {
142 if !available_models.is_empty() {
143 let selected_model = &available_models[0];
145 log::info!("Auto-selected installed ONNX model: {}", selected_model);
146
147 if let Ok(test_config) = Self::create_config_for_model(selected_model).await {
149 return Ok(test_config);
150 }
151 }
152 }
153
154 log::info!("No ONNX models found, using hash-based embeddings");
156 Ok(EmbeddingConfig::default())
157 }
158
159 async fn get_available_onnx_models() -> Result<Vec<String>> {
161 let potential_models = [
164 "sentence-transformers/all-MiniLM-L6-v2",
165 "sentence-transformers/all-mpnet-base-v2",
166 "BAAI/bge-base-en-v1.5",
167 "BAAI/bge-small-en-v1.5",
168 "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
169 ];
170
171 let mut available = Vec::new();
172 for model in &potential_models {
173 if Self::is_onnx_model_available(model).await {
174 available.push(model.to_string());
175 }
176 }
177
178 Ok(available)
179 }
180
181 async fn create_config_for_model(model_name: &str) -> Result<EmbeddingConfig> {
183 match onnx::OnnxProvider::new(model_name).await {
185 Ok(provider) => {
186 let dimension = provider.get_dimension().await.unwrap_or(384);
187 Ok(EmbeddingConfig {
188 provider: EmbeddingProvider::Onnx(model_name.to_string()),
189 dimension,
190 ..EmbeddingConfig::default()
191 })
192 }
193 Err(e) => Err(anyhow!(
194 "Failed to create config for model {}: {}",
195 model_name,
196 e
197 )),
198 }
199 }
200
201 async fn is_onnx_model_available(model_name: &str) -> bool {
203 match onnx::OnnxProvider::new(model_name).await {
205 Ok(_) => {
206 log::debug!("ONNX model '{}' is available", model_name);
207 true
208 }
209 Err(e) => {
210 log::debug!("ONNX model '{}' not available: {}", model_name, e);
211 false
212 }
213 }
214 }
215}
216
217pub mod preprocessing {
219 pub fn clean_text(text: &str) -> String {
221 if is_code_content(text) {
223 clean_code_text(text)
224 } else {
225 clean_regular_text(text)
226 }
227 }
228
229 fn clean_regular_text(text: &str) -> String {
231 let cleaned = text
233 .lines()
234 .map(|line| line.trim())
235 .filter(|line| !line.is_empty())
236 .collect::<Vec<_>>()
237 .join(" ")
238 .split_whitespace()
239 .collect::<Vec<_>>()
240 .join(" ");
241
242 const MAX_CHARS: usize = 2048;
244 if cleaned.chars().count() > MAX_CHARS {
245 let truncated: String = cleaned.chars().take(MAX_CHARS).collect();
246 format!("{}...", truncated)
247 } else {
248 cleaned
249 }
250 }
251
252 fn clean_code_text(text: &str) -> String {
254 let mut cleaned = String::new();
255 let mut in_comment_block = false;
256
257 for line in text.lines() {
258 let trimmed = line.trim();
259
260 if trimmed.is_empty() && !cleaned.is_empty() {
262 continue;
263 }
264
265 if trimmed.starts_with("/*") {
267 in_comment_block = true;
268 }
269 if in_comment_block {
270 if trimmed.ends_with("*/") {
271 in_comment_block = false;
272 }
273 cleaned.push_str("// ");
274 cleaned.push_str(trimmed);
275 cleaned.push('\n');
276 continue;
277 }
278
279 if is_important_code_line(trimmed) {
281 let indent_level = line.len() - line.trim_start().len();
283 let normalized_indent = " ".repeat((indent_level / 2).min(4));
284 cleaned.push_str(&normalized_indent);
285 cleaned.push_str(trimmed);
286 cleaned.push('\n');
287 }
288 }
289
290 const MAX_CODE_CHARS: usize = 3000;
292 if cleaned.chars().count() > MAX_CODE_CHARS {
293 let truncated: String = cleaned.chars().take(MAX_CODE_CHARS).collect();
294 format!("{}...", truncated)
295 } else {
296 cleaned
297 }
298 }
299
300 fn is_code_content(text: &str) -> bool {
302 let code_indicators = [
303 "function",
304 "const",
305 "let",
306 "var",
307 "def",
308 "class",
309 "import",
310 "export",
311 "public",
312 "private",
313 "protected",
314 "return",
315 "if (",
316 "for (",
317 "while (",
318 "=>",
319 "->",
320 "::",
321 "<?php",
322 "#!/",
323 "package",
324 "namespace",
325 "struct",
326 ];
327
328 let text_lower = text.to_lowercase();
329 let indicator_count = code_indicators
330 .iter()
331 .filter(|&&ind| text_lower.contains(ind))
332 .count();
333
334 indicator_count >= 3
336 }
337
338 fn is_important_code_line(line: &str) -> bool {
340 if line.starts_with("//") && !line.starts_with("///") && !line.starts_with("//!") {
342 return false;
343 }
344
345 let important_patterns = [
347 "import ",
348 "from ",
349 "require",
350 "include",
351 "function ",
352 "def ",
353 "fn ",
354 "func ",
355 "class ",
356 "struct ",
357 "interface ",
358 "enum ",
359 "public ",
360 "private ",
361 "protected ",
362 "export ",
363 "module ",
364 "namespace ",
365 ];
366
367 for pattern in &important_patterns {
368 if line.contains(pattern) {
369 return true;
370 }
371 }
372
373 !line
375 .chars()
376 .all(|c| c == '{' || c == '}' || c == '(' || c == ')' || c == ';' || c.is_whitespace())
377 }
378
379 pub fn chunk_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
381 if is_code_content(text) {
383 chunk_code_text(text, chunk_size, overlap)
384 } else {
385 chunk_regular_text(text, chunk_size, overlap)
386 }
387 }
388
389 fn chunk_regular_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
391 let words: Vec<&str> = text.split_whitespace().collect();
392 let mut chunks = Vec::new();
393
394 if words.len() <= chunk_size {
395 chunks.push(text.to_string());
396 return chunks;
397 }
398
399 let mut start = 0;
400 while start < words.len() {
401 let end = std::cmp::min(start + chunk_size, words.len());
402 let chunk = words[start..end].join(" ");
403 chunks.push(chunk);
404
405 if end == words.len() {
406 break;
407 }
408
409 start = end - overlap;
410 }
411
412 chunks
413 }
414
415 fn chunk_code_text(text: &str, chunk_size: usize, _overlap: usize) -> Vec<String> {
417 let mut chunks = Vec::new();
418 let mut current_chunk = String::new();
419 let mut current_size = 0;
420 let mut brace_depth = 0;
421 let mut in_function = false;
422
423 for line in text.lines() {
424 let trimmed = line.trim();
425
426 if trimmed.contains("function ")
428 || trimmed.contains("def ")
429 || trimmed.contains("class ")
430 || trimmed.contains("fn ")
431 {
432 in_function = true;
433
434 if current_size > chunk_size / 2 && brace_depth == 0 && !current_chunk.is_empty() {
436 chunks.push(current_chunk.clone());
437 current_chunk.clear();
438 current_size = 0;
439 }
440 }
441
442 brace_depth += trimmed.chars().filter(|&c| c == '{').count() as i32;
444 brace_depth -= trimmed.chars().filter(|&c| c == '}').count() as i32;
445 brace_depth = brace_depth.max(0);
446
447 current_chunk.push_str(line);
449 current_chunk.push('\n');
450 current_size += line.split_whitespace().count();
451
452 if current_size >= chunk_size && brace_depth == 0 && !in_function {
454 chunks.push(current_chunk.clone());
455 current_chunk.clear();
456 current_size = 0;
457 }
458
459 if in_function && brace_depth == 0 && trimmed.ends_with('}') {
461 in_function = false;
462 }
463 }
464
465 if !current_chunk.trim().is_empty() {
467 chunks.push(current_chunk);
468 }
469
470 if chunks.is_empty() {
472 return chunk_regular_text(text, chunk_size, chunk_size / 10);
473 }
474
475 chunks
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[tokio::test]
484 async fn test_embedding_model() {
485 let model = EmbeddingModel::new().await.unwrap();
486
487 let text = "This is a test sentence for embedding.";
488 let embedding = model.embed_text(text).await.unwrap();
489
490 assert_eq!(embedding.len(), 384); assert!(embedding.iter().any(|&x| x != 0.0));
492 }
493
494 #[test]
495 fn test_cosine_similarity() {
496 let a = vec![1.0, 2.0, 3.0];
497 let b = vec![1.0, 2.0, 3.0];
498 let similarity = EmbeddingModel::cosine_similarity(&a, &b);
499 assert!((similarity - 1.0).abs() < 0.001);
500
501 let c = vec![-1.0, -2.0, -3.0];
502 let similarity2 = EmbeddingModel::cosine_similarity(&a, &c);
503 assert!((similarity2 + 1.0).abs() < 0.001);
504 }
505
506 #[test]
507 fn test_text_preprocessing() {
508 let text = " This is a test\n\n with multiple lines \n ";
509 let cleaned = preprocessing::clean_text(text);
510 assert_eq!(cleaned, "This is a test with multiple lines");
511 }
512
513 #[test]
514 fn test_text_chunking() {
515 let text = "one two three four five six seven eight nine ten";
516 let chunks = preprocessing::chunk_text(text, 3, 1);
517
518 assert_eq!(chunks.len(), 5);
519 assert_eq!(chunks[0], "one two three");
520 assert_eq!(chunks[1], "three four five");
521 assert_eq!(chunks[2], "five six seven");
522 assert_eq!(chunks[3], "seven eight nine");
523 assert_eq!(chunks[4], "nine ten");
524 }
525
526 #[tokio::test]
527 async fn test_similarity_detection() {
528 let model = EmbeddingModel::new().await.unwrap();
529
530 let text1 = "React hooks useState";
531 let text2 = "useState React hooks";
532 let text3 = "Python Django models";
533
534 let emb1 = model.embed_text(text1).await.unwrap();
535 let emb2 = model.embed_text(text2).await.unwrap();
536 let emb3 = model.embed_text(text3).await.unwrap();
537
538 let sim_12 = EmbeddingModel::cosine_similarity(&emb1, &emb2);
539 let sim_13 = EmbeddingModel::cosine_similarity(&emb1, &emb3);
540
541 assert!(sim_12 > sim_13);
543 }
544}