1use crate::{embeddings::EmbeddableContent, Vector, VectorData};
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct PreprocessingPipeline {
9 pub tokenizer: TokenizerConfig,
11 pub normalization: NormalizationConfig,
13 pub stop_words: HashSet<String>,
15 pub entity_recognition: Option<EntityRecognitionConfig>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct TokenizerConfig {
21 pub lowercase: bool,
22 pub remove_punctuation: bool,
23 pub min_token_length: usize,
24 pub max_token_length: usize,
25 pub split_camel_case: bool,
26}
27
28impl Default for TokenizerConfig {
29 fn default() -> Self {
30 Self {
31 lowercase: true,
32 remove_punctuation: true,
33 min_token_length: 2,
34 max_token_length: 50,
35 split_camel_case: true,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct NormalizationConfig {
42 pub unicode_normalization: bool,
43 pub accent_removal: bool,
44 pub stemming: bool,
45 pub lemmatization: bool,
46}
47
48impl Default for NormalizationConfig {
49 fn default() -> Self {
50 Self {
51 unicode_normalization: true,
52 accent_removal: true,
53 stemming: false,
54 lemmatization: false,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct EntityRecognitionConfig {
61 pub recognize_uris: bool,
62 pub recognize_dates: bool,
63 pub recognize_numbers: bool,
64 pub entity_linking: bool,
65}
66
67impl Default for PreprocessingPipeline {
68 fn default() -> Self {
69 let mut stop_words = HashSet::new();
71 for word in &[
72 "the", "is", "at", "which", "on", "a", "an", "and", "or", "but", "in", "with", "to",
73 "for", "of", "as", "by", "that", "this", "it", "from", "be", "are", "was", "were",
74 "been",
75 ] {
76 stop_words.insert(word.to_string());
77 }
78
79 Self {
80 tokenizer: TokenizerConfig::default(),
81 normalization: NormalizationConfig::default(),
82 stop_words,
83 entity_recognition: None,
84 }
85 }
86}
87
88impl PreprocessingPipeline {
89 pub fn process(&self, text: &str) -> Vec<String> {
91 let mut tokens = self.tokenize(text);
92
93 if self.normalization.unicode_normalization {
94 tokens = self.normalize_unicode(tokens);
95 }
96
97 if self.normalization.accent_removal {
98 tokens = self.remove_accents(tokens);
99 }
100
101 tokens = self.filter_tokens(tokens);
102
103 if self.normalization.stemming {
104 tokens = self.stem_tokens(tokens);
105 }
106
107 tokens
108 }
109
110 fn tokenize(&self, text: &str) -> Vec<String> {
111 let mut processed = text.to_string();
112
113 if self.tokenizer.remove_punctuation {
114 processed = processed
115 .chars()
116 .map(|c| {
117 if c.is_alphanumeric() || c.is_whitespace() {
118 c
119 } else {
120 ' '
121 }
122 })
123 .collect();
124 }
125
126 let mut tokens: Vec<String> = processed
128 .split_whitespace()
129 .map(|s| s.to_string())
130 .collect();
131
132 if self.tokenizer.split_camel_case {
134 tokens = tokens
135 .into_iter()
136 .flat_map(|token| self.split_camel_case(&token))
137 .collect();
138 }
139
140 if self.tokenizer.lowercase {
142 tokens = tokens.into_iter().map(|s| s.to_lowercase()).collect();
143 }
144
145 tokens
146 }
147
148 fn split_camel_case(&self, word: &str) -> Vec<String> {
149 let mut result = Vec::new();
150 let mut current = String::new();
151
152 for (i, ch) in word.chars().enumerate() {
153 if i > 0 && ch.is_uppercase() && !current.is_empty() {
154 result.push(current.clone());
155 current.clear();
156 }
157 current.push(ch);
158 }
159
160 if !current.is_empty() {
161 result.push(current);
162 }
163
164 if result.is_empty() {
165 vec![word.to_string()]
166 } else {
167 result
168 }
169 }
170
171 fn normalize_unicode(&self, tokens: Vec<String>) -> Vec<String> {
172 tokens
174 }
175
176 fn remove_accents(&self, tokens: Vec<String>) -> Vec<String> {
177 tokens
178 .into_iter()
179 .map(|token| {
180 token
181 .chars()
182 .map(|c| match c {
183 'à' | 'á' | 'â' | 'ã' | 'ä' | 'å' => 'a',
184 'è' | 'é' | 'ê' | 'ë' => 'e',
185 'ì' | 'í' | 'î' | 'ï' => 'i',
186 'ò' | 'ó' | 'ô' | 'õ' | 'ö' => 'o',
187 'ù' | 'ú' | 'û' | 'ü' => 'u',
188 'ñ' => 'n',
189 'ç' => 'c',
190 _ => c,
191 })
192 .collect()
193 })
194 .collect()
195 }
196
197 fn filter_tokens(&self, tokens: Vec<String>) -> Vec<String> {
198 tokens
199 .into_iter()
200 .filter(|token| {
201 token.len() >= self.tokenizer.min_token_length
202 && token.len() <= self.tokenizer.max_token_length
203 && !self.stop_words.contains(token)
204 })
205 .collect()
206 }
207
208 fn stem_tokens(&self, tokens: Vec<String>) -> Vec<String> {
209 tokens
211 .into_iter()
212 .map(|token| self.porter_stem(&token))
213 .collect()
214 }
215
216 fn porter_stem(&self, word: &str) -> String {
218 let word = word.to_lowercase();
219 if word.len() <= 2 {
220 return word;
221 }
222
223 let mut stem = word.clone();
224
225 stem = self.stem_step_1a(stem);
227
228 stem = self.stem_step_1b(stem);
230
231 stem = self.stem_step_2(stem);
233
234 stem = self.stem_step_3(stem);
236
237 stem = self.stem_step_4(stem);
239
240 stem = self.stem_step_5(stem);
242
243 stem
244 }
245
246 fn stem_step_1a(&self, mut word: String) -> String {
247 if word.ends_with("sses") {
248 word.truncate(word.len() - 2); } else if word.ends_with("ies") {
250 word.truncate(word.len() - 2); } else if word.ends_with("ss") {
252 } else if word.ends_with("s") && word.len() > 1 {
254 word.truncate(word.len() - 1); }
256 word
257 }
258
259 fn stem_step_1b(&self, mut word: String) -> String {
260 if word.ends_with("eed") {
261 if self.measure(&word[..word.len() - 3]) > 0 {
262 word.truncate(word.len() - 1); }
264 } else if word.ends_with("ed") && self.contains_vowel(&word[..word.len() - 2]) {
265 word.truncate(word.len() - 2);
266 word = self.post_process_1b(word);
267 } else if word.ends_with("ing") && self.contains_vowel(&word[..word.len() - 3]) {
268 word.truncate(word.len() - 3);
269 word = self.post_process_1b(word);
270 }
271 word
272 }
273
274 fn stem_step_2(&self, mut word: String) -> String {
275 let suffixes = [
276 ("ational", "ate"),
277 ("tional", "tion"),
278 ("enci", "ence"),
279 ("anci", "ance"),
280 ("izer", "ize"),
281 ("abli", "able"),
282 ("alli", "al"),
283 ("entli", "ent"),
284 ("eli", "e"),
285 ("ousli", "ous"),
286 ("ization", "ize"),
287 ("ation", "ate"),
288 ("ator", "ate"),
289 ("alism", "al"),
290 ("iveness", "ive"),
291 ("fulness", "ful"),
292 ("ousness", "ous"),
293 ("aliti", "al"),
294 ("iviti", "ive"),
295 ("biliti", "ble"),
296 ];
297
298 for (suffix, replacement) in &suffixes {
299 if word.ends_with(suffix) {
300 let stem = &word[..word.len() - suffix.len()];
301 if self.measure(stem) > 0 {
302 word = format!("{stem}{replacement}");
303 }
304 break;
305 }
306 }
307 word
308 }
309
310 fn stem_step_3(&self, mut word: String) -> String {
311 let suffixes = [
312 ("icate", "ic"),
313 ("ative", ""),
314 ("alize", "al"),
315 ("iciti", "ic"),
316 ("ical", "ic"),
317 ("ful", ""),
318 ("ness", ""),
319 ];
320
321 for (suffix, replacement) in &suffixes {
322 if word.ends_with(suffix) {
323 let stem = &word[..word.len() - suffix.len()];
324 if self.measure(stem) > 0 {
325 word = format!("{stem}{replacement}");
326 }
327 break;
328 }
329 }
330 word
331 }
332
333 fn stem_step_4(&self, mut word: String) -> String {
334 let suffixes = [
335 "al", "ance", "ence", "er", "ic", "able", "ible", "ant", "ement", "ment", "ent", "ion",
336 "ou", "ism", "ate", "iti", "ous", "ive", "ize",
337 ];
338
339 for suffix in &suffixes {
340 if word.ends_with(suffix) {
341 let stem = &word[..word.len() - suffix.len()];
342 if self.measure(stem) > 1
343 && (*suffix != "ion" || (stem.ends_with("s") || stem.ends_with("t")))
344 {
345 word = stem.to_string();
346 }
347 break;
348 }
349 }
350 word
351 }
352
353 fn stem_step_5(&self, mut word: String) -> String {
354 if word.ends_with("e") {
355 let stem = &word[..word.len() - 1];
356 let m = self.measure(stem);
357 if m > 1 || (m == 1 && !self.cvc(stem)) {
358 word.truncate(word.len() - 1);
359 }
360 }
361
362 if word.ends_with("ll") && self.measure(&word) > 1 {
363 word.truncate(word.len() - 1);
364 }
365
366 word
367 }
368
369 fn post_process_1b(&self, mut word: String) -> String {
370 if word.ends_with("at") || word.ends_with("bl") || word.ends_with("iz") {
371 word.push('e');
372 } else if self.double_consonant(&word)
373 && !word.ends_with("l")
374 && !word.ends_with("s")
375 && !word.ends_with("z")
376 {
377 word.truncate(word.len() - 1);
378 } else if self.measure(&word) == 1 && self.cvc(&word) {
379 word.push('e');
380 }
381 word
382 }
383
384 fn measure(&self, word: &str) -> usize {
385 let chars: Vec<char> = word.chars().collect();
386 let mut m = 0;
387 let mut prev_was_vowel = false;
388
389 for (i, &ch) in chars.iter().enumerate() {
390 let is_vowel = self.is_vowel(ch, i, &chars);
391 if !is_vowel && prev_was_vowel {
392 m += 1;
393 }
394 prev_was_vowel = is_vowel;
395 }
396 m
397 }
398
399 fn contains_vowel(&self, word: &str) -> bool {
400 let chars: Vec<char> = word.chars().collect();
401 chars
402 .iter()
403 .enumerate()
404 .any(|(i, &ch)| self.is_vowel(ch, i, &chars))
405 }
406
407 #[allow(clippy::only_used_in_recursion)]
408 fn is_vowel(&self, ch: char, pos: usize, chars: &[char]) -> bool {
409 match ch {
410 'a' | 'e' | 'i' | 'o' | 'u' => true,
411 'y' => pos > 0 && !self.is_vowel(chars[pos - 1], pos - 1, chars),
412 _ => false,
413 }
414 }
415
416 fn cvc(&self, word: &str) -> bool {
417 let chars: Vec<char> = word.chars().collect();
418 if chars.len() < 3 {
419 return false;
420 }
421
422 let len = chars.len();
423 !self.is_vowel(chars[len - 3], len - 3, &chars)
424 && self.is_vowel(chars[len - 2], len - 2, &chars)
425 && !self.is_vowel(chars[len - 1], len - 1, &chars)
426 && chars[len - 1] != 'w'
427 && chars[len - 1] != 'x'
428 && chars[len - 1] != 'y'
429 }
430
431 fn double_consonant(&self, word: &str) -> bool {
432 let chars: Vec<char> = word.chars().collect();
433 if chars.len() < 2 {
434 return false;
435 }
436
437 let len = chars.len();
438 chars[len - 1] == chars[len - 2] && !self.is_vowel(chars[len - 1], len - 1, &chars)
439 }
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct PostprocessingPipeline {
445 pub dimensionality_reduction: Option<DimensionalityReduction>,
446 pub normalization: VectorNormalization,
447 pub outlier_detection: Option<OutlierDetection>,
448 pub quality_scoring: bool,
449}
450
451#[derive(Debug, Clone, Serialize, Deserialize)]
452pub enum DimensionalityReduction {
453 PCA { target_dims: usize },
454 RandomProjection { target_dims: usize },
455 AutoEncoder { target_dims: usize },
456}
457
458#[derive(Debug, Clone, Serialize, Deserialize)]
459pub enum VectorNormalization {
460 None,
461 L2,
462 L1,
463 MinMax,
464 ZScore,
465}
466
467#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct OutlierDetection {
469 pub method: OutlierMethod,
470 pub threshold: f32,
471}
472
473#[derive(Debug, Clone, Serialize, Deserialize)]
474pub enum OutlierMethod {
475 ZScore,
476 IsolationForest,
477 LocalOutlierFactor,
478}
479
480impl Default for PostprocessingPipeline {
481 fn default() -> Self {
482 Self {
483 dimensionality_reduction: None,
484 normalization: VectorNormalization::L2,
485 outlier_detection: None,
486 quality_scoring: true,
487 }
488 }
489}
490
491impl PostprocessingPipeline {
492 pub fn process(&self, vector: &mut Vector) -> Result<f32> {
494 if let Some(ref dr) = self.dimensionality_reduction {
496 self.apply_dimensionality_reduction(vector, dr)?;
497 }
498
499 self.apply_normalization(vector)?;
501
502 let quality_score = if self.quality_scoring {
504 self.calculate_quality_score(vector)
505 } else {
506 1.0
507 };
508
509 if let Some(ref od) = self.outlier_detection {
511 if self.is_outlier(vector, od) {
512 return Ok(quality_score * 0.5); }
514 }
515
516 Ok(quality_score)
517 }
518
519 fn apply_dimensionality_reduction(
520 &self,
521 vector: &mut Vector,
522 method: &DimensionalityReduction,
523 ) -> Result<()> {
524 match method {
525 DimensionalityReduction::PCA { target_dims } => {
526 let values = vector.as_f32();
528 if values.len() <= *target_dims {
529 return Ok(());
530 }
531
532 let reduced: Vec<f32> = values.into_iter().take(*target_dims).collect();
534 vector.values = VectorData::F32(reduced);
535 vector.dimensions = *target_dims;
536 }
537 DimensionalityReduction::RandomProjection { target_dims } => {
538 let values = vector.as_f32();
540 if values.len() <= *target_dims {
541 return Ok(());
542 }
543
544 use scirs2_core::random::Random;
546 let mut rng = Random::seed(42);
547 let mut projected = vec![0.0; *target_dims];
548
549 for projected_val in projected.iter_mut().take(*target_dims) {
550 for &val in values.iter() {
551 let random_weight: f32 = rng.gen_range(-1.0..1.0);
552 *projected_val += val * random_weight;
553 }
554 *projected_val /= (values.len() as f32).sqrt();
555 }
556
557 vector.values = VectorData::F32(projected);
558 vector.dimensions = *target_dims;
559 }
560 DimensionalityReduction::AutoEncoder { .. } => {
561 }
563 }
564 Ok(())
565 }
566
567 fn apply_normalization(&self, vector: &mut Vector) -> Result<()> {
568 match self.normalization {
569 VectorNormalization::None => Ok(()),
570 VectorNormalization::L2 => {
571 vector.normalize();
572 Ok(())
573 }
574 VectorNormalization::L1 => {
575 let values = vector.as_f32();
576 let l1_norm: f32 = values.iter().map(|x| x.abs()).sum();
577 if l1_norm > 0.0 {
578 let normalized: Vec<f32> = values.into_iter().map(|x| x / l1_norm).collect();
579 vector.values = VectorData::F32(normalized);
580 }
581 Ok(())
582 }
583 VectorNormalization::MinMax => {
584 let values = vector.as_f32();
585 let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
586 let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
587 let range = max - min;
588
589 if range > 0.0 {
590 let normalized: Vec<f32> =
591 values.into_iter().map(|x| (x - min) / range).collect();
592 vector.values = VectorData::F32(normalized);
593 }
594 Ok(())
595 }
596 VectorNormalization::ZScore => {
597 let values = vector.as_f32();
598 let n = values.len() as f32;
599 let mean: f32 = values.iter().sum::<f32>() / n;
600 let variance: f32 = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
601 let std_dev = variance.sqrt();
602
603 if std_dev > 0.0 {
604 let normalized: Vec<f32> =
605 values.into_iter().map(|x| (x - mean) / std_dev).collect();
606 vector.values = VectorData::F32(normalized);
607 }
608 Ok(())
609 }
610 }
611 }
612
613 fn calculate_quality_score(&self, vector: &Vector) -> f32 {
614 let values = vector.as_f32();
615
616 let mut score = 1.0;
618
619 if values.iter().any(|x| !x.is_finite()) {
621 return 0.0;
622 }
623
624 let zero_count = values.iter().filter(|&&x| x.abs() < f32::EPSILON).count();
626 let sparsity = zero_count as f32 / values.len() as f32;
627 if sparsity > 0.9 {
628 score *= 0.5;
629 }
630
631 let mean = values.iter().sum::<f32>() / values.len() as f32;
633 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
634
635 if variance < 0.01 {
636 score *= 0.7;
637 }
638
639 let magnitude = vector.magnitude();
641 if magnitude < 0.1 {
642 score *= 0.8;
643 }
644
645 score
646 }
647
648 fn is_outlier(&self, vector: &Vector, config: &OutlierDetection) -> bool {
649 match config.method {
650 OutlierMethod::ZScore => {
651 let values = vector.as_f32();
652 let mean = values.iter().sum::<f32>() / values.len() as f32;
653 let variance =
654 values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
655 let std_dev = variance.sqrt();
656
657 values
659 .iter()
660 .any(|&x| ((x - mean) / std_dev).abs() > config.threshold)
661 }
662 _ => false, }
664 }
665}
666
667#[derive(Debug, Clone, Default)]
669pub struct EmbeddingPipeline {
670 pub preprocessing: PreprocessingPipeline,
671 pub postprocessing: PostprocessingPipeline,
672}
673
674impl EmbeddingPipeline {
675 pub fn process_content(&self, content: &EmbeddableContent) -> Result<(Vec<String>, f32)> {
677 let text = content.to_text();
679
680 let tokens = self.preprocessing.process(&text);
682
683 Ok((tokens, 1.0))
686 }
687
688 pub fn process_vector(&self, vector: &mut Vector) -> Result<f32> {
690 self.postprocessing.process(vector)
691 }
692}
693
694#[cfg(test)]
695mod tests {
696 use super::*;
697
698 #[test]
699 fn test_preprocessing_pipeline() {
700 let pipeline = PreprocessingPipeline::default();
701
702 let text = "The quick brown fox jumps over the lazy dog!";
703 let tokens = pipeline.process(text);
704
705 assert!(!tokens.contains(&"the".to_string()));
707 assert!(tokens.contains(&"quick".to_string()));
708 assert!(tokens.contains(&"brown".to_string()));
709 }
710
711 #[test]
712 fn test_camel_case_splitting() {
713 let mut pipeline = PreprocessingPipeline::default();
714 pipeline.tokenizer.split_camel_case = true;
715
716 let text = "CamelCaseWord HTTPSConnection";
717 let tokens = pipeline.process(text);
718
719 assert!(tokens.contains(&"camel".to_string()));
720 assert!(tokens.contains(&"case".to_string()));
721 assert!(tokens.contains(&"word".to_string()));
722 }
723
724 #[test]
725 fn test_postprocessing_normalization() {
726 let pipeline = PostprocessingPipeline {
727 normalization: VectorNormalization::L2,
728 ..Default::default()
729 };
730
731 let mut vector = Vector::new(vec![3.0, 4.0, 0.0]);
732 let quality = pipeline.process(&mut vector).unwrap();
733
734 let magnitude = vector.magnitude();
736 assert!((magnitude - 1.0).abs() < 1e-6);
737 assert!(quality > 0.0);
738 }
739
740 #[test]
741 fn test_quality_scoring() {
742 let pipeline = PostprocessingPipeline::default();
743
744 let mut good_vector = Vector::new(vec![0.5, 0.3, -0.2, 0.8]);
746 let good_quality = pipeline.process(&mut good_vector).unwrap();
747 assert!(good_quality > 0.9);
748
749 let poor_vector = Vector::new(vec![0.0, 0.0, 0.0, 0.0]);
751 let poor_quality = pipeline.calculate_quality_score(&poor_vector);
752 assert!(poor_quality < 0.5);
753 }
754}