oxirs_vec/
embedding_pipeline.rs

1use crate::{embeddings::EmbeddableContent, Vector, VectorData};
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5
6/// Text preprocessing pipeline
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct PreprocessingPipeline {
9    /// Tokenization settings
10    pub tokenizer: TokenizerConfig,
11    /// Normalization settings
12    pub normalization: NormalizationConfig,
13    /// Stop words to remove
14    pub stop_words: HashSet<String>,
15    /// Entity recognition settings
16    pub entity_recognition: Option<EntityRecognitionConfig>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct TokenizerConfig {
21    pub lowercase: bool,
22    pub remove_punctuation: bool,
23    pub min_token_length: usize,
24    pub max_token_length: usize,
25    pub split_camel_case: bool,
26}
27
28impl Default for TokenizerConfig {
29    fn default() -> Self {
30        Self {
31            lowercase: true,
32            remove_punctuation: true,
33            min_token_length: 2,
34            max_token_length: 50,
35            split_camel_case: true,
36        }
37    }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct NormalizationConfig {
42    pub unicode_normalization: bool,
43    pub accent_removal: bool,
44    pub stemming: bool,
45    pub lemmatization: bool,
46}
47
48impl Default for NormalizationConfig {
49    fn default() -> Self {
50        Self {
51            unicode_normalization: true,
52            accent_removal: true,
53            stemming: false,
54            lemmatization: false,
55        }
56    }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct EntityRecognitionConfig {
61    pub recognize_uris: bool,
62    pub recognize_dates: bool,
63    pub recognize_numbers: bool,
64    pub entity_linking: bool,
65}
66
67impl Default for PreprocessingPipeline {
68    fn default() -> Self {
69        // Common English stop words
70        let mut stop_words = HashSet::new();
71        for word in &[
72            "the", "is", "at", "which", "on", "a", "an", "and", "or", "but", "in", "with", "to",
73            "for", "of", "as", "by", "that", "this", "it", "from", "be", "are", "was", "were",
74            "been",
75        ] {
76            stop_words.insert(word.to_string());
77        }
78
79        Self {
80            tokenizer: TokenizerConfig::default(),
81            normalization: NormalizationConfig::default(),
82            stop_words,
83            entity_recognition: None,
84        }
85    }
86}
87
88impl PreprocessingPipeline {
89    /// Process text through the preprocessing pipeline
90    pub fn process(&self, text: &str) -> Vec<String> {
91        let mut tokens = self.tokenize(text);
92
93        if self.normalization.unicode_normalization {
94            tokens = self.normalize_unicode(tokens);
95        }
96
97        if self.normalization.accent_removal {
98            tokens = self.remove_accents(tokens);
99        }
100
101        tokens = self.filter_tokens(tokens);
102
103        if self.normalization.stemming {
104            tokens = self.stem_tokens(tokens);
105        }
106
107        tokens
108    }
109
110    fn tokenize(&self, text: &str) -> Vec<String> {
111        let mut processed = text.to_string();
112
113        if self.tokenizer.remove_punctuation {
114            processed = processed
115                .chars()
116                .map(|c| {
117                    if c.is_alphanumeric() || c.is_whitespace() {
118                        c
119                    } else {
120                        ' '
121                    }
122                })
123                .collect();
124        }
125
126        // Split on whitespace and filter
127        let mut tokens: Vec<String> = processed
128            .split_whitespace()
129            .map(|s| s.to_string())
130            .collect();
131
132        // Split camelCase if enabled (must happen before lowercasing)
133        if self.tokenizer.split_camel_case {
134            tokens = tokens
135                .into_iter()
136                .flat_map(|token| self.split_camel_case(&token))
137                .collect();
138        }
139
140        // Lowercase after camel case splitting
141        if self.tokenizer.lowercase {
142            tokens = tokens.into_iter().map(|s| s.to_lowercase()).collect();
143        }
144
145        tokens
146    }
147
148    fn split_camel_case(&self, word: &str) -> Vec<String> {
149        let mut result = Vec::new();
150        let mut current = String::new();
151
152        for (i, ch) in word.chars().enumerate() {
153            if i > 0 && ch.is_uppercase() && !current.is_empty() {
154                result.push(current.clone());
155                current.clear();
156            }
157            current.push(ch);
158        }
159
160        if !current.is_empty() {
161            result.push(current);
162        }
163
164        if result.is_empty() {
165            vec![word.to_string()]
166        } else {
167            result
168        }
169    }
170
171    fn normalize_unicode(&self, tokens: Vec<String>) -> Vec<String> {
172        // Simple unicode normalization - in production, use unicode-normalization crate
173        tokens
174    }
175
176    fn remove_accents(&self, tokens: Vec<String>) -> Vec<String> {
177        tokens
178            .into_iter()
179            .map(|token| {
180                token
181                    .chars()
182                    .map(|c| match c {
183                        'à' | 'á' | 'â' | 'ã' | 'ä' | 'å' => 'a',
184                        'è' | 'é' | 'ê' | 'ë' => 'e',
185                        'ì' | 'í' | 'î' | 'ï' => 'i',
186                        'ò' | 'ó' | 'ô' | 'õ' | 'ö' => 'o',
187                        'ù' | 'ú' | 'û' | 'ü' => 'u',
188                        'ñ' => 'n',
189                        'ç' => 'c',
190                        _ => c,
191                    })
192                    .collect()
193            })
194            .collect()
195    }
196
197    fn filter_tokens(&self, tokens: Vec<String>) -> Vec<String> {
198        tokens
199            .into_iter()
200            .filter(|token| {
201                token.len() >= self.tokenizer.min_token_length
202                    && token.len() <= self.tokenizer.max_token_length
203                    && !self.stop_words.contains(token)
204            })
205            .collect()
206    }
207
208    fn stem_tokens(&self, tokens: Vec<String>) -> Vec<String> {
209        // Production-ready Porter stemmer implementation
210        tokens
211            .into_iter()
212            .map(|token| self.porter_stem(&token))
213            .collect()
214    }
215
216    /// Porter stemmer algorithm implementation
217    fn porter_stem(&self, word: &str) -> String {
218        let word = word.to_lowercase();
219        if word.len() <= 2 {
220            return word;
221        }
222
223        let mut stem = word.clone();
224
225        // Step 1a: plurals and past participles
226        stem = self.stem_step_1a(stem);
227
228        // Step 1b: past tense and gerunds
229        stem = self.stem_step_1b(stem);
230
231        // Step 2: derivational suffixes
232        stem = self.stem_step_2(stem);
233
234        // Step 3: more derivational suffixes
235        stem = self.stem_step_3(stem);
236
237        // Step 4: remove derivational suffixes
238        stem = self.stem_step_4(stem);
239
240        // Step 5: remove final e and double l
241        stem = self.stem_step_5(stem);
242
243        stem
244    }
245
246    fn stem_step_1a(&self, mut word: String) -> String {
247        if word.ends_with("sses") {
248            word.truncate(word.len() - 2); // sses -> ss
249        } else if word.ends_with("ies") {
250            word.truncate(word.len() - 2); // ies -> i
251        } else if word.ends_with("ss") {
252            // ss -> ss (no change)
253        } else if word.ends_with("s") && word.len() > 1 {
254            word.truncate(word.len() - 1); // s -> (empty)
255        }
256        word
257    }
258
259    fn stem_step_1b(&self, mut word: String) -> String {
260        if word.ends_with("eed") {
261            if self.measure(&word[..word.len() - 3]) > 0 {
262                word.truncate(word.len() - 1); // eed -> ee
263            }
264        } else if word.ends_with("ed") && self.contains_vowel(&word[..word.len() - 2]) {
265            word.truncate(word.len() - 2);
266            word = self.post_process_1b(word);
267        } else if word.ends_with("ing") && self.contains_vowel(&word[..word.len() - 3]) {
268            word.truncate(word.len() - 3);
269            word = self.post_process_1b(word);
270        }
271        word
272    }
273
274    fn stem_step_2(&self, mut word: String) -> String {
275        let suffixes = [
276            ("ational", "ate"),
277            ("tional", "tion"),
278            ("enci", "ence"),
279            ("anci", "ance"),
280            ("izer", "ize"),
281            ("abli", "able"),
282            ("alli", "al"),
283            ("entli", "ent"),
284            ("eli", "e"),
285            ("ousli", "ous"),
286            ("ization", "ize"),
287            ("ation", "ate"),
288            ("ator", "ate"),
289            ("alism", "al"),
290            ("iveness", "ive"),
291            ("fulness", "ful"),
292            ("ousness", "ous"),
293            ("aliti", "al"),
294            ("iviti", "ive"),
295            ("biliti", "ble"),
296        ];
297
298        for (suffix, replacement) in &suffixes {
299            if word.ends_with(suffix) {
300                let stem = &word[..word.len() - suffix.len()];
301                if self.measure(stem) > 0 {
302                    word = format!("{stem}{replacement}");
303                }
304                break;
305            }
306        }
307        word
308    }
309
310    fn stem_step_3(&self, mut word: String) -> String {
311        let suffixes = [
312            ("icate", "ic"),
313            ("ative", ""),
314            ("alize", "al"),
315            ("iciti", "ic"),
316            ("ical", "ic"),
317            ("ful", ""),
318            ("ness", ""),
319        ];
320
321        for (suffix, replacement) in &suffixes {
322            if word.ends_with(suffix) {
323                let stem = &word[..word.len() - suffix.len()];
324                if self.measure(stem) > 0 {
325                    word = format!("{stem}{replacement}");
326                }
327                break;
328            }
329        }
330        word
331    }
332
333    fn stem_step_4(&self, mut word: String) -> String {
334        let suffixes = [
335            "al", "ance", "ence", "er", "ic", "able", "ible", "ant", "ement", "ment", "ent", "ion",
336            "ou", "ism", "ate", "iti", "ous", "ive", "ize",
337        ];
338
339        for suffix in &suffixes {
340            if word.ends_with(suffix) {
341                let stem = &word[..word.len() - suffix.len()];
342                if self.measure(stem) > 1
343                    && (*suffix != "ion" || (stem.ends_with("s") || stem.ends_with("t")))
344                {
345                    word = stem.to_string();
346                }
347                break;
348            }
349        }
350        word
351    }
352
353    fn stem_step_5(&self, mut word: String) -> String {
354        if word.ends_with("e") {
355            let stem = &word[..word.len() - 1];
356            let m = self.measure(stem);
357            if m > 1 || (m == 1 && !self.cvc(stem)) {
358                word.truncate(word.len() - 1);
359            }
360        }
361
362        if word.ends_with("ll") && self.measure(&word) > 1 {
363            word.truncate(word.len() - 1);
364        }
365
366        word
367    }
368
369    fn post_process_1b(&self, mut word: String) -> String {
370        if word.ends_with("at") || word.ends_with("bl") || word.ends_with("iz") {
371            word.push('e');
372        } else if self.double_consonant(&word)
373            && !word.ends_with("l")
374            && !word.ends_with("s")
375            && !word.ends_with("z")
376        {
377            word.truncate(word.len() - 1);
378        } else if self.measure(&word) == 1 && self.cvc(&word) {
379            word.push('e');
380        }
381        word
382    }
383
384    fn measure(&self, word: &str) -> usize {
385        let chars: Vec<char> = word.chars().collect();
386        let mut m = 0;
387        let mut prev_was_vowel = false;
388
389        for (i, &ch) in chars.iter().enumerate() {
390            let is_vowel = self.is_vowel(ch, i, &chars);
391            if !is_vowel && prev_was_vowel {
392                m += 1;
393            }
394            prev_was_vowel = is_vowel;
395        }
396        m
397    }
398
399    fn contains_vowel(&self, word: &str) -> bool {
400        let chars: Vec<char> = word.chars().collect();
401        chars
402            .iter()
403            .enumerate()
404            .any(|(i, &ch)| self.is_vowel(ch, i, &chars))
405    }
406
407    #[allow(clippy::only_used_in_recursion)]
408    fn is_vowel(&self, ch: char, pos: usize, chars: &[char]) -> bool {
409        match ch {
410            'a' | 'e' | 'i' | 'o' | 'u' => true,
411            'y' => pos > 0 && !self.is_vowel(chars[pos - 1], pos - 1, chars),
412            _ => false,
413        }
414    }
415
416    fn cvc(&self, word: &str) -> bool {
417        let chars: Vec<char> = word.chars().collect();
418        if chars.len() < 3 {
419            return false;
420        }
421
422        let len = chars.len();
423        !self.is_vowel(chars[len - 3], len - 3, &chars)
424            && self.is_vowel(chars[len - 2], len - 2, &chars)
425            && !self.is_vowel(chars[len - 1], len - 1, &chars)
426            && chars[len - 1] != 'w'
427            && chars[len - 1] != 'x'
428            && chars[len - 1] != 'y'
429    }
430
431    fn double_consonant(&self, word: &str) -> bool {
432        let chars: Vec<char> = word.chars().collect();
433        if chars.len() < 2 {
434            return false;
435        }
436
437        let len = chars.len();
438        chars[len - 1] == chars[len - 2] && !self.is_vowel(chars[len - 1], len - 1, &chars)
439    }
440}
441
442/// Vector postprocessing pipeline
443#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct PostprocessingPipeline {
445    pub dimensionality_reduction: Option<DimensionalityReduction>,
446    pub normalization: VectorNormalization,
447    pub outlier_detection: Option<OutlierDetection>,
448    pub quality_scoring: bool,
449}
450
451#[derive(Debug, Clone, Serialize, Deserialize)]
452pub enum DimensionalityReduction {
453    PCA { target_dims: usize },
454    RandomProjection { target_dims: usize },
455    AutoEncoder { target_dims: usize },
456}
457
458#[derive(Debug, Clone, Serialize, Deserialize)]
459pub enum VectorNormalization {
460    None,
461    L2,
462    L1,
463    MinMax,
464    ZScore,
465}
466
467#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct OutlierDetection {
469    pub method: OutlierMethod,
470    pub threshold: f32,
471}
472
473#[derive(Debug, Clone, Serialize, Deserialize)]
474pub enum OutlierMethod {
475    ZScore,
476    IsolationForest,
477    LocalOutlierFactor,
478}
479
480impl Default for PostprocessingPipeline {
481    fn default() -> Self {
482        Self {
483            dimensionality_reduction: None,
484            normalization: VectorNormalization::L2,
485            outlier_detection: None,
486            quality_scoring: true,
487        }
488    }
489}
490
491impl PostprocessingPipeline {
492    /// Process vector through the postprocessing pipeline
493    pub fn process(&self, vector: &mut Vector) -> Result<f32> {
494        // Apply dimensionality reduction if configured
495        if let Some(ref dr) = self.dimensionality_reduction {
496            self.apply_dimensionality_reduction(vector, dr)?;
497        }
498
499        // Apply normalization
500        self.apply_normalization(vector)?;
501
502        // Calculate quality score
503        let quality_score = if self.quality_scoring {
504            self.calculate_quality_score(vector)
505        } else {
506            1.0
507        };
508
509        // Check for outliers
510        if let Some(ref od) = self.outlier_detection {
511            if self.is_outlier(vector, od) {
512                return Ok(quality_score * 0.5); // Reduce quality score for outliers
513            }
514        }
515
516        Ok(quality_score)
517    }
518
519    fn apply_dimensionality_reduction(
520        &self,
521        vector: &mut Vector,
522        method: &DimensionalityReduction,
523    ) -> Result<()> {
524        match method {
525            DimensionalityReduction::PCA { target_dims } => {
526                // Simplified PCA - in production, use proper implementation
527                let values = vector.as_f32();
528                if values.len() <= *target_dims {
529                    return Ok(());
530                }
531
532                // Take first target_dims dimensions (simplified)
533                let reduced: Vec<f32> = values.into_iter().take(*target_dims).collect();
534                vector.values = VectorData::F32(reduced);
535                vector.dimensions = *target_dims;
536            }
537            DimensionalityReduction::RandomProjection { target_dims } => {
538                // Random projection implementation
539                let values = vector.as_f32();
540                if values.len() <= *target_dims {
541                    return Ok(());
542                }
543
544                // Generate random projection matrix (simplified)
545                use scirs2_core::random::Random;
546                let mut rng = Random::seed(42);
547                let mut projected = vec![0.0; *target_dims];
548
549                for projected_val in projected.iter_mut().take(*target_dims) {
550                    for &val in values.iter() {
551                        let random_weight: f32 = rng.gen_range(-1.0..1.0);
552                        *projected_val += val * random_weight;
553                    }
554                    *projected_val /= (values.len() as f32).sqrt();
555                }
556
557                vector.values = VectorData::F32(projected);
558                vector.dimensions = *target_dims;
559            }
560            DimensionalityReduction::AutoEncoder { .. } => {
561                // AutoEncoder would require neural network - placeholder
562            }
563        }
564        Ok(())
565    }
566
567    fn apply_normalization(&self, vector: &mut Vector) -> Result<()> {
568        match self.normalization {
569            VectorNormalization::None => Ok(()),
570            VectorNormalization::L2 => {
571                vector.normalize();
572                Ok(())
573            }
574            VectorNormalization::L1 => {
575                let values = vector.as_f32();
576                let l1_norm: f32 = values.iter().map(|x| x.abs()).sum();
577                if l1_norm > 0.0 {
578                    let normalized: Vec<f32> = values.into_iter().map(|x| x / l1_norm).collect();
579                    vector.values = VectorData::F32(normalized);
580                }
581                Ok(())
582            }
583            VectorNormalization::MinMax => {
584                let values = vector.as_f32();
585                let min = values.iter().cloned().fold(f32::INFINITY, f32::min);
586                let max = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
587                let range = max - min;
588
589                if range > 0.0 {
590                    let normalized: Vec<f32> =
591                        values.into_iter().map(|x| (x - min) / range).collect();
592                    vector.values = VectorData::F32(normalized);
593                }
594                Ok(())
595            }
596            VectorNormalization::ZScore => {
597                let values = vector.as_f32();
598                let n = values.len() as f32;
599                let mean: f32 = values.iter().sum::<f32>() / n;
600                let variance: f32 = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
601                let std_dev = variance.sqrt();
602
603                if std_dev > 0.0 {
604                    let normalized: Vec<f32> =
605                        values.into_iter().map(|x| (x - mean) / std_dev).collect();
606                    vector.values = VectorData::F32(normalized);
607                }
608                Ok(())
609            }
610        }
611    }
612
613    fn calculate_quality_score(&self, vector: &Vector) -> f32 {
614        let values = vector.as_f32();
615
616        // Quality based on several factors
617        let mut score = 1.0;
618
619        // Check for NaN or infinite values
620        if values.iter().any(|x| !x.is_finite()) {
621            return 0.0;
622        }
623
624        // Check sparsity (too many zeros might indicate poor quality)
625        let zero_count = values.iter().filter(|&&x| x.abs() < f32::EPSILON).count();
626        let sparsity = zero_count as f32 / values.len() as f32;
627        if sparsity > 0.9 {
628            score *= 0.5;
629        }
630
631        // Check variance (too low variance might indicate poor quality)
632        let mean = values.iter().sum::<f32>() / values.len() as f32;
633        let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
634
635        if variance < 0.01 {
636            score *= 0.7;
637        }
638
639        // Check magnitude (vectors that are too small might be problematic)
640        let magnitude = vector.magnitude();
641        if magnitude < 0.1 {
642            score *= 0.8;
643        }
644
645        score
646    }
647
648    fn is_outlier(&self, vector: &Vector, config: &OutlierDetection) -> bool {
649        match config.method {
650            OutlierMethod::ZScore => {
651                let values = vector.as_f32();
652                let mean = values.iter().sum::<f32>() / values.len() as f32;
653                let variance =
654                    values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
655                let std_dev = variance.sqrt();
656
657                // Check if any dimension is beyond threshold standard deviations
658                values
659                    .iter()
660                    .any(|&x| ((x - mean) / std_dev).abs() > config.threshold)
661            }
662            _ => false, // Other methods would require more complex implementations
663        }
664    }
665}
666
667/// Complete embedding pipeline combining preprocessing and postprocessing
668#[derive(Debug, Clone, Default)]
669pub struct EmbeddingPipeline {
670    pub preprocessing: PreprocessingPipeline,
671    pub postprocessing: PostprocessingPipeline,
672}
673
674impl EmbeddingPipeline {
675    /// Process content through the complete pipeline
676    pub fn process_content(&self, content: &EmbeddableContent) -> Result<(Vec<String>, f32)> {
677        // Extract text from content
678        let text = content.to_text();
679
680        // Apply preprocessing
681        let tokens = self.preprocessing.process(&text);
682
683        // Return tokens and a placeholder quality score
684        // In a real implementation, this would generate embeddings and apply postprocessing
685        Ok((tokens, 1.0))
686    }
687
688    /// Process a vector through postprocessing
689    pub fn process_vector(&self, vector: &mut Vector) -> Result<f32> {
690        self.postprocessing.process(vector)
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697
698    #[test]
699    fn test_preprocessing_pipeline() {
700        let pipeline = PreprocessingPipeline::default();
701
702        let text = "The quick brown fox jumps over the lazy dog!";
703        let tokens = pipeline.process(text);
704
705        // Should remove stop words and punctuation
706        assert!(!tokens.contains(&"the".to_string()));
707        assert!(tokens.contains(&"quick".to_string()));
708        assert!(tokens.contains(&"brown".to_string()));
709    }
710
711    #[test]
712    fn test_camel_case_splitting() {
713        let mut pipeline = PreprocessingPipeline::default();
714        pipeline.tokenizer.split_camel_case = true;
715
716        let text = "CamelCaseWord HTTPSConnection";
717        let tokens = pipeline.process(text);
718
719        assert!(tokens.contains(&"camel".to_string()));
720        assert!(tokens.contains(&"case".to_string()));
721        assert!(tokens.contains(&"word".to_string()));
722    }
723
724    #[test]
725    fn test_postprocessing_normalization() {
726        let pipeline = PostprocessingPipeline {
727            normalization: VectorNormalization::L2,
728            ..Default::default()
729        };
730
731        let mut vector = Vector::new(vec![3.0, 4.0, 0.0]);
732        let quality = pipeline.process(&mut vector).unwrap();
733
734        // Check L2 normalization
735        let magnitude = vector.magnitude();
736        assert!((magnitude - 1.0).abs() < 1e-6);
737        assert!(quality > 0.0);
738    }
739
740    #[test]
741    fn test_quality_scoring() {
742        let pipeline = PostprocessingPipeline::default();
743
744        // Good quality vector
745        let mut good_vector = Vector::new(vec![0.5, 0.3, -0.2, 0.8]);
746        let good_quality = pipeline.process(&mut good_vector).unwrap();
747        assert!(good_quality > 0.9);
748
749        // Poor quality vector (all zeros)
750        let poor_vector = Vector::new(vec![0.0, 0.0, 0.0, 0.0]);
751        let poor_quality = pipeline.calculate_quality_score(&poor_vector);
752        assert!(poor_quality < 0.5);
753    }
754}