oxirs_embed/models/transformer/
preprocessing.rs

1//! Text preprocessing for transformer models
2
3use super::types::{DomainPreprocessingRules, TransformerConfig, TransformerType};
4use anyhow::Result;
5use regex::Regex;
6use std::collections::HashMap;
7
8/// Text preprocessor for transformer models
9#[derive(Debug, Clone)]
10pub struct TransformerPreprocessor {
11    config: TransformerConfig,
12    domain_rules: Option<DomainPreprocessingRules>,
13}
14
15impl TransformerPreprocessor {
16    pub fn new(config: TransformerConfig) -> Self {
17        let domain_rules = match config.transformer_type {
18            TransformerType::SciBERT => Some(DomainPreprocessingRules::scientific()),
19            TransformerType::BioBERT => Some(DomainPreprocessingRules::biomedical()),
20            TransformerType::LegalBERT => Some(DomainPreprocessingRules::legal()),
21            TransformerType::NewsBERT => Some(DomainPreprocessingRules::news()),
22            TransformerType::SocialMediaBERT => Some(DomainPreprocessingRules::social_media()),
23            _ => None,
24        };
25
26        Self {
27            config,
28            domain_rules,
29        }
30    }
31
32    /// Main preprocessing function that routes to appropriate domain-specific methods
33    pub fn preprocess_text(&self, text: &str) -> String {
34        let mut processed = text.to_string();
35
36        // Common preprocessing steps
37        processed = self.clean_uri(&processed);
38        processed = self.normalize_whitespace(&processed);
39
40        // Apply domain-specific preprocessing
41        processed = match self.config.transformer_type {
42            TransformerType::SciBERT => self.preprocess_scientific_text(&processed),
43            TransformerType::BioBERT => self.preprocess_biomedical_text(&processed),
44            TransformerType::CodeBERT => self.preprocess_code_text(&processed),
45            TransformerType::LegalBERT => self.preprocess_legal_text(&processed),
46            TransformerType::NewsBERT => self.preprocess_news_text(&processed),
47            TransformerType::SocialMediaBERT => self.preprocess_social_media_text(&processed),
48            _ => processed,
49        };
50
51        // Apply general domain rules if available
52        if let Some(ref rules) = self.domain_rules {
53            processed = self.apply_domain_rules(&processed, rules);
54        }
55
56        processed
57    }
58
59    /// Clean URI components for better semantic representation
60    fn clean_uri(&self, text: &str) -> String {
61        let mut result = text.to_string();
62
63        // Remove protocol prefixes
64        result = result.replace("http://", "");
65        result = result.replace("https://", "");
66        result = result.replace("ftp://", "");
67
68        // Replace common URI separators with spaces
69        result = result.replace('/', " ");
70        result = result.replace('#', " ");
71        result = result.replace('?', " ");
72        result = result.replace('&', " and ");
73        result = result.replace('=', " equals ");
74
75        // Handle underscores in URIs (common in ontologies)
76        result = result.replace('_', " ");
77
78        result
79    }
80
81    /// Normalize whitespace
82    fn normalize_whitespace(&self, text: &str) -> String {
83        // Replace multiple whitespace with single space
84        let re = Regex::new(r"\s+").unwrap();
85        re.replace_all(text, " ").trim().to_string()
86    }
87
88    /// Apply domain-specific preprocessing rules
89    fn apply_domain_rules(&self, text: &str, rules: &DomainPreprocessingRules) -> String {
90        let mut result = text.to_string();
91
92        // Apply abbreviation expansions
93        for (abbrev, expansion) in &rules.abbreviation_expansions {
94            result = result.replace(abbrev, expansion);
95        }
96
97        // Apply pattern replacements
98        for (pattern, replacement) in &rules.domain_specific_patterns {
99            if let Ok(re) = Regex::new(pattern) {
100                result = re.replace_all(&result, replacement).to_string();
101            }
102        }
103
104        result
105    }
106
107    /// Preprocessing for scientific text (SciBERT)
108    pub fn preprocess_scientific_text(&self, text: &str) -> String {
109        let mut result = text.to_string();
110
111        // Scientific abbreviations
112        let scientific_abbrevs = HashMap::from([
113            ("DNA", "deoxyribonucleic acid"),
114            ("RNA", "ribonucleic acid"),
115            ("ATP", "adenosine triphosphate"),
116            ("GDP", "guanosine diphosphate"),
117            ("GTP", "guanosine triphosphate"),
118            ("Co2", "carbon dioxide"),
119            ("H2O", "water"),
120            ("NaCl", "sodium chloride"),
121        ]);
122
123        for (abbrev, expansion) in scientific_abbrevs {
124            result = result.replace(abbrev, expansion);
125        }
126
127        // Handle scientific notation and units
128        result = result.replace("°C", " degrees celsius");
129        result = result.replace("mg/ml", " milligrams per milliliter");
130        result = result.replace("μg/ml", " micrograms per milliliter");
131        result = result.replace("mM", " millimolar");
132        result = result.replace("μM", " micromolar");
133
134        // Handle chemical formulas and reactions
135        result = result.replace("->", " produces ");
136        result = result.replace("<->", " is in equilibrium with ");
137
138        result
139    }
140
141    /// Preprocessing for biomedical text (BioBERT)
142    pub fn preprocess_biomedical_text(&self, text: &str) -> String {
143        let mut result = text.to_string();
144
145        // Biomedical gene and protein abbreviations
146        let biomedical_abbrevs = HashMap::from([
147            ("p53", "tumor protein p53"),
148            ("BRCA1", "breast cancer gene 1"),
149            ("BRCA2", "breast cancer gene 2"),
150            ("TNF-α", "tumor necrosis factor alpha"),
151            ("IL-1", "interleukin 1"),
152            ("IL-6", "interleukin 6"),
153            ("mRNA", "messenger ribonucleic acid"),
154            ("tRNA", "transfer ribonucleic acid"),
155            ("rRNA", "ribosomal ribonucleic acid"),
156            ("CNS", "central nervous system"),
157            ("PNS", "peripheral nervous system"),
158        ]);
159
160        for (abbrev, expansion) in biomedical_abbrevs {
161            result = result.replace(abbrev, expansion);
162        }
163
164        // Handle medical terminology
165        result = result.replace("bp", " base pairs");
166        result = result.replace("kDa", " kilodaltons");
167        result = result.replace("mg/kg", " milligrams per kilogram");
168
169        result
170    }
171
172    /// Preprocessing for code text (CodeBERT)
173    pub fn preprocess_code_text(&self, text: &str) -> String {
174        let mut result = text.to_string();
175
176        // Programming language keywords and common terms
177        result = result.replace("impl", "implementation");
178        result = result.replace("func", "function");
179        result = result.replace("var", "variable");
180        result = result.replace("const", "constant");
181        result = result.replace("struct", "structure");
182        result = result.replace("enum", "enumeration");
183
184        // Common type names
185        result = result.replace("Vec<i32>", "vector of integers");
186        result = result.replace("HashMap", "hash map");
187        result = result.replace("String", "string");
188        result = result.replace("bool", "boolean");
189
190        // Expand camelCase and PascalCase
191        result = self.expand_camel_case(&result);
192
193        result
194    }
195
196    /// Preprocessing for legal text (LegalBERT)
197    pub fn preprocess_legal_text(&self, text: &str) -> String {
198        let mut result = text.to_string();
199
200        // Legal abbreviations
201        let legal_abbrevs = HashMap::from([
202            ("USC", "United States Code"),
203            ("CFR", "Code of Federal Regulations"),
204            ("plaintiff", "party bringing lawsuit"),
205            ("defendant", "party being sued"),
206            ("tort", "civil wrong"),
207            ("v.", "versus"),
208            ("vs.", "versus"),
209            ("et al.", "and others"),
210            ("cf.", "compare"),
211            ("ibid.", "in the same place"),
212            ("supra", "above mentioned"),
213        ]);
214
215        for (abbrev, expansion) in legal_abbrevs {
216            result = result.replace(abbrev, expansion);
217        }
218
219        // Handle legal citations
220        let section_re = Regex::new(r"§(\d+)").unwrap();
221        result = section_re.replace_all(&result, "section $1").to_string();
222
223        result
224    }
225
226    /// Preprocessing for news text (NewsBERT)
227    pub fn preprocess_news_text(&self, text: &str) -> String {
228        let mut result = text.to_string();
229
230        // Business and economics abbreviations
231        let news_abbrevs = HashMap::from([
232            ("CEO", "chief executive officer"),
233            ("CFO", "chief financial officer"),
234            ("CTO", "chief technology officer"),
235            ("IPO", "initial public offering"),
236            ("SEC", "Securities and Exchange Commission"),
237            ("GDP", "gross domestic product"),
238            ("CPI", "consumer price index"),
239            ("NYSE", "New York Stock Exchange"),
240            (
241                "NASDAQ",
242                "National Association of Securities Dealers Automated Quotations",
243            ),
244        ]);
245
246        for (abbrev, expansion) in news_abbrevs {
247            result = result.replace(abbrev, expansion);
248        }
249
250        // Handle financial terms
251        result = result.replace("Q1", "first quarter");
252        result = result.replace("Q2", "second quarter");
253        result = result.replace("Q3", "third quarter");
254        result = result.replace("Q4", "fourth quarter");
255
256        // Handle percentages
257        let percent_re = Regex::new(r"(\d+\.?\d*)%").unwrap();
258        result = percent_re.replace_all(&result, "$1 percent").to_string();
259
260        result
261    }
262
263    /// Preprocessing for social media text (SocialMediaBERT)
264    pub fn preprocess_social_media_text(&self, text: &str) -> String {
265        let mut result = text.to_string();
266
267        // Handle social media abbreviations
268        result = result.replace("lol", "laugh out loud");
269        result = result.replace("omg", "oh my god");
270        result = result.replace("btw", "by the way");
271        result = result.replace("fyi", "for your information");
272        result = result.replace("imo", "in my opinion");
273        result = result.replace("tbh", "to be honest");
274        result = result.replace("smh", "shaking my head");
275        result = result.replace("rn", "right now");
276        result = result.replace("irl", "in real life");
277
278        // Handle hashtags and mentions
279        result = result.replace('#', "hashtag ");
280        result = result.replace('@', "mention ");
281
282        // Handle emoticons (basic)
283        result = result.replace(":)", "happy");
284        result = result.replace(":(", "sad");
285        result = result.replace(":D", "very happy");
286        result = result.replace(";)", "winking");
287        result = result.replace(":P", "playful");
288        result = result.replace(":/", "confused");
289
290        // Handle emphasis
291        result = result.replace("!!", "exclamation");
292        result = result.replace("???", "question");
293
294        result
295    }
296
297    /// Expand camelCase to separate words
298    pub fn expand_camel_case(&self, text: &str) -> String {
299        if text.is_empty() {
300            return String::new();
301        }
302
303        let mut result = String::new();
304        let chars: Vec<char> = text.chars().collect();
305
306        for (i, &ch) in chars.iter().enumerate() {
307            // Add space before every uppercase letter (except the first character)
308            if i > 0 && ch.is_uppercase() {
309                result.push(' ');
310            }
311
312            result.push(ch.to_lowercase().next().unwrap_or(ch));
313        }
314
315        result
316    }
317
318    /// Simple tokenization for demonstration
319    pub fn tokenize(&self, text: &str) -> Result<Vec<u32>> {
320        // In real implementation, this would use a proper tokenizer
321        // For now, convert each character to a token
322        let tokens: Vec<u32> = text
323            .chars()
324            .take(self.config.max_sequence_length)
325            .map(|c| c as u32 % 30522) // Map to vocab range
326            .collect();
327
328        Ok(tokens)
329    }
330
331    /// Get maximum sequence length
332    pub fn max_sequence_length(&self) -> usize {
333        self.config.max_sequence_length
334    }
335
336    /// Check if text should be truncated
337    pub fn needs_truncation(&self, text: &str) -> bool {
338        text.len() > self.config.max_sequence_length
339    }
340
341    /// Truncate text to maximum sequence length
342    pub fn truncate_text(&self, text: &str) -> String {
343        if text.len() <= self.config.max_sequence_length {
344            text.to_string()
345        } else {
346            text.chars().take(self.config.max_sequence_length).collect()
347        }
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use crate::models::transformer::types::TransformerType;
355
356    #[test]
357    fn test_scientific_preprocessing() {
358        let config = TransformerConfig {
359            transformer_type: TransformerType::SciBERT,
360            ..Default::default()
361        };
362        let preprocessor = TransformerPreprocessor::new(config);
363
364        let text = "DNA synthesis with ATP and Co2 at 25°C using 5mg/ml concentration";
365        let processed = preprocessor.preprocess_scientific_text(text);
366        assert!(processed.contains("deoxyribonucleic acid"));
367        assert!(processed.contains("adenosine triphosphate"));
368        assert!(processed.contains("carbon dioxide"));
369        assert!(processed.contains("degrees celsius"));
370        assert!(processed.contains("milligrams per milliliter"));
371    }
372
373    #[test]
374    fn test_biomedical_preprocessing() {
375        let config = TransformerConfig {
376            transformer_type: TransformerType::BioBERT,
377            ..Default::default()
378        };
379        let preprocessor = TransformerPreprocessor::new(config);
380
381        let text = "p53 and BRCA1 mutations affect TNF-α via mRNA expression in CNS";
382        let processed = preprocessor.preprocess_biomedical_text(text);
383        assert!(processed.contains("tumor protein p53"));
384        assert!(processed.contains("breast cancer gene 1"));
385        assert!(processed.contains("tumor necrosis factor"));
386        assert!(processed.contains("messenger ribonucleic acid"));
387        assert!(processed.contains("central nervous system"));
388    }
389
390    #[test]
391    fn test_code_preprocessing() {
392        let config = TransformerConfig {
393            transformer_type: TransformerType::CodeBERT,
394            ..Default::default()
395        };
396        let preprocessor = TransformerPreprocessor::new(config);
397
398        let text = "MyClass impl func calculateValue() returns Vec<i32>";
399        let processed = preprocessor.preprocess_code_text(text);
400        assert!(processed.contains("my class"));
401        assert!(processed.contains("implementation"));
402        assert!(processed.contains("function"));
403        assert!(processed.contains("calculate value"));
404    }
405
406    #[test]
407    fn test_camel_case_expansion() {
408        let config = TransformerConfig::default();
409        let preprocessor = TransformerPreprocessor::new(config);
410
411        assert_eq!(preprocessor.expand_camel_case("MyClass"), "my class");
412        assert_eq!(
413            preprocessor.expand_camel_case("calculateValue"),
414            "calculate value"
415        );
416        assert_eq!(
417            preprocessor.expand_camel_case("getUserNameFromAPI"),
418            "get user name from a p i"
419        );
420        assert_eq!(preprocessor.expand_camel_case(""), "");
421    }
422
423    #[test]
424    fn test_uri_cleaning() {
425        let config = TransformerConfig::default();
426        let preprocessor = TransformerPreprocessor::new(config);
427
428        let uri = "http://example.org/DNA_molecule#structure";
429        let cleaned = preprocessor.clean_uri(uri);
430        assert!(cleaned.contains("example"));
431        assert!(cleaned.contains("DNA"));
432        assert!(cleaned.contains("molecule"));
433        assert!(cleaned.contains("structure"));
434        assert!(!cleaned.contains("http://"));
435        assert!(!cleaned.contains("#"));
436    }
437}