1use super::types::{DomainPreprocessingRules, TransformerConfig, TransformerType};
4use anyhow::Result;
5use regex::Regex;
6use std::collections::HashMap;
7
8#[derive(Debug, Clone)]
10pub struct TransformerPreprocessor {
11 config: TransformerConfig,
12 domain_rules: Option<DomainPreprocessingRules>,
13}
14
15impl TransformerPreprocessor {
16 pub fn new(config: TransformerConfig) -> Self {
17 let domain_rules = match config.transformer_type {
18 TransformerType::SciBERT => Some(DomainPreprocessingRules::scientific()),
19 TransformerType::BioBERT => Some(DomainPreprocessingRules::biomedical()),
20 TransformerType::LegalBERT => Some(DomainPreprocessingRules::legal()),
21 TransformerType::NewsBERT => Some(DomainPreprocessingRules::news()),
22 TransformerType::SocialMediaBERT => Some(DomainPreprocessingRules::social_media()),
23 _ => None,
24 };
25
26 Self {
27 config,
28 domain_rules,
29 }
30 }
31
32 pub fn preprocess_text(&self, text: &str) -> String {
34 let mut processed = text.to_string();
35
36 processed = self.clean_uri(&processed);
38 processed = self.normalize_whitespace(&processed);
39
40 processed = match self.config.transformer_type {
42 TransformerType::SciBERT => self.preprocess_scientific_text(&processed),
43 TransformerType::BioBERT => self.preprocess_biomedical_text(&processed),
44 TransformerType::CodeBERT => self.preprocess_code_text(&processed),
45 TransformerType::LegalBERT => self.preprocess_legal_text(&processed),
46 TransformerType::NewsBERT => self.preprocess_news_text(&processed),
47 TransformerType::SocialMediaBERT => self.preprocess_social_media_text(&processed),
48 _ => processed,
49 };
50
51 if let Some(ref rules) = self.domain_rules {
53 processed = self.apply_domain_rules(&processed, rules);
54 }
55
56 processed
57 }
58
59 fn clean_uri(&self, text: &str) -> String {
61 let mut result = text.to_string();
62
63 result = result.replace("http://", "");
65 result = result.replace("https://", "");
66 result = result.replace("ftp://", "");
67
68 result = result.replace('/', " ");
70 result = result.replace('#', " ");
71 result = result.replace('?', " ");
72 result = result.replace('&', " and ");
73 result = result.replace('=', " equals ");
74
75 result = result.replace('_', " ");
77
78 result
79 }
80
81 fn normalize_whitespace(&self, text: &str) -> String {
83 let re = Regex::new(r"\s+").unwrap();
85 re.replace_all(text, " ").trim().to_string()
86 }
87
88 fn apply_domain_rules(&self, text: &str, rules: &DomainPreprocessingRules) -> String {
90 let mut result = text.to_string();
91
92 for (abbrev, expansion) in &rules.abbreviation_expansions {
94 result = result.replace(abbrev, expansion);
95 }
96
97 for (pattern, replacement) in &rules.domain_specific_patterns {
99 if let Ok(re) = Regex::new(pattern) {
100 result = re.replace_all(&result, replacement).to_string();
101 }
102 }
103
104 result
105 }
106
107 pub fn preprocess_scientific_text(&self, text: &str) -> String {
109 let mut result = text.to_string();
110
111 let scientific_abbrevs = HashMap::from([
113 ("DNA", "deoxyribonucleic acid"),
114 ("RNA", "ribonucleic acid"),
115 ("ATP", "adenosine triphosphate"),
116 ("GDP", "guanosine diphosphate"),
117 ("GTP", "guanosine triphosphate"),
118 ("Co2", "carbon dioxide"),
119 ("H2O", "water"),
120 ("NaCl", "sodium chloride"),
121 ]);
122
123 for (abbrev, expansion) in scientific_abbrevs {
124 result = result.replace(abbrev, expansion);
125 }
126
127 result = result.replace("°C", " degrees celsius");
129 result = result.replace("mg/ml", " milligrams per milliliter");
130 result = result.replace("μg/ml", " micrograms per milliliter");
131 result = result.replace("mM", " millimolar");
132 result = result.replace("μM", " micromolar");
133
134 result = result.replace("->", " produces ");
136 result = result.replace("<->", " is in equilibrium with ");
137
138 result
139 }
140
141 pub fn preprocess_biomedical_text(&self, text: &str) -> String {
143 let mut result = text.to_string();
144
145 let biomedical_abbrevs = HashMap::from([
147 ("p53", "tumor protein p53"),
148 ("BRCA1", "breast cancer gene 1"),
149 ("BRCA2", "breast cancer gene 2"),
150 ("TNF-α", "tumor necrosis factor alpha"),
151 ("IL-1", "interleukin 1"),
152 ("IL-6", "interleukin 6"),
153 ("mRNA", "messenger ribonucleic acid"),
154 ("tRNA", "transfer ribonucleic acid"),
155 ("rRNA", "ribosomal ribonucleic acid"),
156 ("CNS", "central nervous system"),
157 ("PNS", "peripheral nervous system"),
158 ]);
159
160 for (abbrev, expansion) in biomedical_abbrevs {
161 result = result.replace(abbrev, expansion);
162 }
163
164 result = result.replace("bp", " base pairs");
166 result = result.replace("kDa", " kilodaltons");
167 result = result.replace("mg/kg", " milligrams per kilogram");
168
169 result
170 }
171
172 pub fn preprocess_code_text(&self, text: &str) -> String {
174 let mut result = text.to_string();
175
176 result = result.replace("impl", "implementation");
178 result = result.replace("func", "function");
179 result = result.replace("var", "variable");
180 result = result.replace("const", "constant");
181 result = result.replace("struct", "structure");
182 result = result.replace("enum", "enumeration");
183
184 result = result.replace("Vec<i32>", "vector of integers");
186 result = result.replace("HashMap", "hash map");
187 result = result.replace("String", "string");
188 result = result.replace("bool", "boolean");
189
190 result = self.expand_camel_case(&result);
192
193 result
194 }
195
196 pub fn preprocess_legal_text(&self, text: &str) -> String {
198 let mut result = text.to_string();
199
200 let legal_abbrevs = HashMap::from([
202 ("USC", "United States Code"),
203 ("CFR", "Code of Federal Regulations"),
204 ("plaintiff", "party bringing lawsuit"),
205 ("defendant", "party being sued"),
206 ("tort", "civil wrong"),
207 ("v.", "versus"),
208 ("vs.", "versus"),
209 ("et al.", "and others"),
210 ("cf.", "compare"),
211 ("ibid.", "in the same place"),
212 ("supra", "above mentioned"),
213 ]);
214
215 for (abbrev, expansion) in legal_abbrevs {
216 result = result.replace(abbrev, expansion);
217 }
218
219 let section_re = Regex::new(r"§(\d+)").unwrap();
221 result = section_re.replace_all(&result, "section $1").to_string();
222
223 result
224 }
225
226 pub fn preprocess_news_text(&self, text: &str) -> String {
228 let mut result = text.to_string();
229
230 let news_abbrevs = HashMap::from([
232 ("CEO", "chief executive officer"),
233 ("CFO", "chief financial officer"),
234 ("CTO", "chief technology officer"),
235 ("IPO", "initial public offering"),
236 ("SEC", "Securities and Exchange Commission"),
237 ("GDP", "gross domestic product"),
238 ("CPI", "consumer price index"),
239 ("NYSE", "New York Stock Exchange"),
240 (
241 "NASDAQ",
242 "National Association of Securities Dealers Automated Quotations",
243 ),
244 ]);
245
246 for (abbrev, expansion) in news_abbrevs {
247 result = result.replace(abbrev, expansion);
248 }
249
250 result = result.replace("Q1", "first quarter");
252 result = result.replace("Q2", "second quarter");
253 result = result.replace("Q3", "third quarter");
254 result = result.replace("Q4", "fourth quarter");
255
256 let percent_re = Regex::new(r"(\d+\.?\d*)%").unwrap();
258 result = percent_re.replace_all(&result, "$1 percent").to_string();
259
260 result
261 }
262
263 pub fn preprocess_social_media_text(&self, text: &str) -> String {
265 let mut result = text.to_string();
266
267 result = result.replace("lol", "laugh out loud");
269 result = result.replace("omg", "oh my god");
270 result = result.replace("btw", "by the way");
271 result = result.replace("fyi", "for your information");
272 result = result.replace("imo", "in my opinion");
273 result = result.replace("tbh", "to be honest");
274 result = result.replace("smh", "shaking my head");
275 result = result.replace("rn", "right now");
276 result = result.replace("irl", "in real life");
277
278 result = result.replace('#', "hashtag ");
280 result = result.replace('@', "mention ");
281
282 result = result.replace(":)", "happy");
284 result = result.replace(":(", "sad");
285 result = result.replace(":D", "very happy");
286 result = result.replace(";)", "winking");
287 result = result.replace(":P", "playful");
288 result = result.replace(":/", "confused");
289
290 result = result.replace("!!", "exclamation");
292 result = result.replace("???", "question");
293
294 result
295 }
296
297 pub fn expand_camel_case(&self, text: &str) -> String {
299 if text.is_empty() {
300 return String::new();
301 }
302
303 let mut result = String::new();
304 let chars: Vec<char> = text.chars().collect();
305
306 for (i, &ch) in chars.iter().enumerate() {
307 if i > 0 && ch.is_uppercase() {
309 result.push(' ');
310 }
311
312 result.push(ch.to_lowercase().next().unwrap_or(ch));
313 }
314
315 result
316 }
317
318 pub fn tokenize(&self, text: &str) -> Result<Vec<u32>> {
320 let tokens: Vec<u32> = text
323 .chars()
324 .take(self.config.max_sequence_length)
325 .map(|c| c as u32 % 30522) .collect();
327
328 Ok(tokens)
329 }
330
331 pub fn max_sequence_length(&self) -> usize {
333 self.config.max_sequence_length
334 }
335
336 pub fn needs_truncation(&self, text: &str) -> bool {
338 text.len() > self.config.max_sequence_length
339 }
340
341 pub fn truncate_text(&self, text: &str) -> String {
343 if text.len() <= self.config.max_sequence_length {
344 text.to_string()
345 } else {
346 text.chars().take(self.config.max_sequence_length).collect()
347 }
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use crate::models::transformer::types::TransformerType;
355
356 #[test]
357 fn test_scientific_preprocessing() {
358 let config = TransformerConfig {
359 transformer_type: TransformerType::SciBERT,
360 ..Default::default()
361 };
362 let preprocessor = TransformerPreprocessor::new(config);
363
364 let text = "DNA synthesis with ATP and Co2 at 25°C using 5mg/ml concentration";
365 let processed = preprocessor.preprocess_scientific_text(text);
366 assert!(processed.contains("deoxyribonucleic acid"));
367 assert!(processed.contains("adenosine triphosphate"));
368 assert!(processed.contains("carbon dioxide"));
369 assert!(processed.contains("degrees celsius"));
370 assert!(processed.contains("milligrams per milliliter"));
371 }
372
373 #[test]
374 fn test_biomedical_preprocessing() {
375 let config = TransformerConfig {
376 transformer_type: TransformerType::BioBERT,
377 ..Default::default()
378 };
379 let preprocessor = TransformerPreprocessor::new(config);
380
381 let text = "p53 and BRCA1 mutations affect TNF-α via mRNA expression in CNS";
382 let processed = preprocessor.preprocess_biomedical_text(text);
383 assert!(processed.contains("tumor protein p53"));
384 assert!(processed.contains("breast cancer gene 1"));
385 assert!(processed.contains("tumor necrosis factor"));
386 assert!(processed.contains("messenger ribonucleic acid"));
387 assert!(processed.contains("central nervous system"));
388 }
389
390 #[test]
391 fn test_code_preprocessing() {
392 let config = TransformerConfig {
393 transformer_type: TransformerType::CodeBERT,
394 ..Default::default()
395 };
396 let preprocessor = TransformerPreprocessor::new(config);
397
398 let text = "MyClass impl func calculateValue() returns Vec<i32>";
399 let processed = preprocessor.preprocess_code_text(text);
400 assert!(processed.contains("my class"));
401 assert!(processed.contains("implementation"));
402 assert!(processed.contains("function"));
403 assert!(processed.contains("calculate value"));
404 }
405
406 #[test]
407 fn test_camel_case_expansion() {
408 let config = TransformerConfig::default();
409 let preprocessor = TransformerPreprocessor::new(config);
410
411 assert_eq!(preprocessor.expand_camel_case("MyClass"), "my class");
412 assert_eq!(
413 preprocessor.expand_camel_case("calculateValue"),
414 "calculate value"
415 );
416 assert_eq!(
417 preprocessor.expand_camel_case("getUserNameFromAPI"),
418 "get user name from a p i"
419 );
420 assert_eq!(preprocessor.expand_camel_case(""), "");
421 }
422
423 #[test]
424 fn test_uri_cleaning() {
425 let config = TransformerConfig::default();
426 let preprocessor = TransformerPreprocessor::new(config);
427
428 let uri = "http://example.org/DNA_molecule#structure";
429 let cleaned = preprocessor.clean_uri(uri);
430 assert!(cleaned.contains("example"));
431 assert!(cleaned.contains("DNA"));
432 assert!(cleaned.contains("molecule"));
433 assert!(cleaned.contains("structure"));
434 assert!(!cleaned.contains("http://"));
435 assert!(!cleaned.contains("#"));
436 }
437}