1use rand::{SeedableRng, Rng};
2use rand::rngs::StdRng;
3use rand_distr::{Normal, Distribution};
4
5use crate::language::Language;
6use crate::utils::ngram::NGram;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub enum DetectorError {
12 NoFeatures,
14}
15
16impl std::fmt::Display for DetectorError {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 match self {
19 DetectorError::NoFeatures => write!(f, "No features found in the input text"),
20 }
21 }
22}
23
24pub struct Detector {
48 pub word_lang_prob_map: HashMap<String, Vec<f64>>,
50 pub langlist: Vec<String>,
52 pub seed: Option<u64>,
54 pub text: String,
56 pub langprob: Option<Vec<f64>>,
58 pub alpha: f64,
60 pub n_trial: usize,
62 pub max_text_length: usize,
64 pub prior_map: Option<Vec<f64>>,
66 pub verbose: bool,
68}
69
70impl Detector {
71 pub const ALPHA_DEFAULT: f64 = 0.5;
73 pub const ALPHA_WIDTH: f64 = 0.05;
75 pub const ITERATION_LIMIT: usize = 1000;
77 pub const PROB_THRESHOLD: f64 = 0.1;
79 pub const CONV_THRESHOLD: f64 = 0.99999;
81 pub const BASE_FREQ: f64 = 10000.0;
83 pub const UNKNOWN_LANG: &'static str = "unknown";
85
86 pub fn new(word_lang_prob_map: HashMap<String, Vec<f64>>, langlist: Vec<String>, seed: Option<u64>) -> Self {
93 Detector {
94 word_lang_prob_map,
95 langlist,
96 seed,
97 text: String::new(),
98 langprob: None,
99 alpha: Self::ALPHA_DEFAULT,
100 n_trial: 7,
101 max_text_length: 10000,
102 prior_map: None,
103 verbose: false,
104 }
105 }
106
107 pub fn append(&mut self, text: &str) {
125 let url_re = regex::Regex::new(r"https?://[-_.?&~;+=/#0-9A-Za-z]{1,2076}").unwrap();
127 let mail_re = regex::Regex::new(r"[-_.0-9A-Za-z]{1,64}@[-_0-9A-Za-z]{1,255}[-_.0-9A-Za-z]{1,255}").unwrap();
128 let mut text = url_re.replace_all(text, " ").to_string();
129 text = mail_re.replace_all(&text, " ").to_string();
130 text = NGram::normalize_vi(&text);
131 let mut pre = ' ';
132 for ch in text.chars().take(self.max_text_length) {
133 if ch != ' ' || pre != ' ' {
134 self.text.push(ch);
135 }
136 pre = ch;
137 }
138 }
139
140 fn cleaning_text(&mut self) {
144 let mut latin_count = 0;
145 let mut non_latin_count = 0;
146 for ch in self.text.chars() {
147 if ('A'..='z').contains(&ch) {
148 latin_count += 1;
149 } else if ch >= '\u{0300}' {
150 if let Some(block) = crate::utils::unicode_block::unicode_block(ch) {
151 if block != crate::utils::unicode_block::UNICODE_LATIN_EXTENDED_ADDITIONAL {
152 non_latin_count += 1;
153 }
154 }
155 }
156 }
157 if latin_count * 2 < non_latin_count {
158 let mut text_without_latin = String::new();
159 for ch in self.text.chars() {
160 if ch < 'A' || ch > 'z' {
161 text_without_latin.push(ch);
162 }
163 }
164 self.text = text_without_latin;
165 }
166 }
167
168 pub fn detect(&mut self) -> Result<String, DetectorError> {
188 let probabilities = self.get_probabilities()?;
189 if !probabilities.is_empty() {
190 Ok(probabilities[0].lang.clone().unwrap_or_else(|| Self::UNKNOWN_LANG.to_string()))
191 } else {
192 Ok(Self::UNKNOWN_LANG.to_string())
193 }
194 }
195
196 pub fn get_probabilities(&mut self) -> Result<Vec<Language>, DetectorError> {
220 if self.langprob.is_none() {
221 self.detect_block()?;
222 }
223 Ok(self.sort_probability(self.langprob.as_ref().unwrap()))
224 }
225
226 fn detect_block(&mut self) -> Result<(), DetectorError> {
233 self.cleaning_text();
234 let ngrams = self.extract_ngrams();
235 if ngrams.is_empty() {
236 return Err(DetectorError::NoFeatures);
237 }
238 self.langprob = Some(vec![0.0; self.langlist.len()]);
239 let mut rng = if let Some(seed) = self.seed {
240 StdRng::seed_from_u64(seed)
241 } else {
242 let mut thread_rng = rand::rng();
243 StdRng::from_rng(&mut thread_rng)
244 };
245 for _t in 0..self.n_trial {
246 let mut prob = self.init_probability();
247 let normal = Normal::new(0.0, 1.0).unwrap();
248 let alpha = self.alpha + normal.sample(&mut rng) * Self::ALPHA_WIDTH;
249 let mut i = 0;
250 loop {
251 let word = ngrams[rng.random_range(0..ngrams.len())].clone();
252 self.update_lang_prob(&mut prob, &word, alpha);
253 if i % 5 == 0 {
254 if self.normalize_prob(&mut prob) > Self::CONV_THRESHOLD || i >= Self::ITERATION_LIMIT {
255 break;
256 }
257 }
258 i += 1;
259 }
260 for j in 0..self.langprob.as_ref().unwrap().len() {
261 self.langprob.as_mut().unwrap()[j] += prob[j] / self.n_trial as f64;
262 }
263 }
264 Ok(())
265 }
266
267 fn init_probability(&self) -> Vec<f64> {
271 if let Some(ref prior) = self.prior_map {
272 prior.clone()
273 } else {
274 vec![1.0 / self.langlist.len() as f64; self.langlist.len()]
275 }
276 }
277
278 fn extract_ngrams(&self) -> Vec<String> {
282 let range = 1..=NGram::N_GRAM;
283 let mut result = Vec::new();
284 let mut ngram = NGram::new();
285 for ch in self.text.chars() {
286 ngram.add_char(ch);
287 if ngram.capitalword {
288 continue;
289 }
290 for n in range.clone() {
291 if ngram.grams.len() < n {
292 break;
293 }
294 let w: String = ngram.grams.chars().rev().take(n).collect::<Vec<_>>().into_iter().rev().collect();
295 if !w.is_empty() && w != " " && self.word_lang_prob_map.contains_key(&w) {
296 result.push(w);
297 }
298 }
299 }
300 result
301 }
302
303 fn update_lang_prob(&self, prob: &mut [f64], word: &str, alpha: f64) -> bool {
313 if !self.word_lang_prob_map.contains_key(word) {
314 return false;
315 }
316 let lang_prob_map = &self.word_lang_prob_map[word];
317 let weight = alpha / Self::BASE_FREQ;
318 for i in 0..prob.len() {
319 prob[i] *= weight + lang_prob_map[i];
320 }
321 true
322 }
323
324 fn normalize_prob(&self, prob: &mut [f64]) -> f64 {
332 let sump: f64 = prob.iter().sum();
333 let mut maxp = 0.0;
334 for p in prob.iter_mut() {
335 *p /= sump;
336 if maxp < *p {
337 maxp = *p;
338 }
339 }
340 maxp
341 }
342
343 fn sort_probability(&self, prob: &[f64]) -> Vec<Language> {
353 let mut result: Vec<Language> = self.langlist.iter().zip(prob.iter())
354 .filter(|(_, p)| **p > Self::PROB_THRESHOLD)
355 .map(|(lang, &p)| Language::new(Some(lang.clone()), p)).collect();
356 result.sort_by(|a, b| b.partial_cmp(a).unwrap());
357 result
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use crate::detector_factory::DetectorFactory;
364 use crate::utils::lang_profile::LangProfile;
365
366 fn setup_factory() -> DetectorFactory {
367 let mut factory = DetectorFactory::new().build();
368
369 let mut profile_en = LangProfile::new().with_name("en").build();
370 for w in ["a", "a", "a", "b", "b", "c", "c", "d", "e"].iter() {
371 profile_en.add(w);
372 }
373 let result = factory.add_profile(profile_en, 0, 3);
374 assert!(result.is_ok(), "Unexpected error in add_profile: {:?}", result);
375 result.unwrap();
376
377 let mut profile_fr = LangProfile::new().with_name("fr").build();
378 for w in ["a", "b", "b", "c", "c", "c", "d", "d", "d"].iter() {
379 profile_fr.add(w);
380 }
381 let result = factory.add_profile(profile_fr, 1, 3);
382 assert!(result.is_ok(), "Unexpected error in add_profile: {:?}", result);
383 result.unwrap();
384
385 let mut profile_ja = LangProfile::new().with_name("ja").build();
386 for w in ["\u{3042}", "\u{3042}", "\u{3042}", "\u{3044}", "\u{3046}", "\u{3048}", "\u{3048}"].iter() {
387 profile_ja.add(w);
388 }
389 let result = factory.add_profile(profile_ja, 2, 3);
390 assert!(result.is_ok(), "Unexpected error in add_profile: {:?}", result);
391 result.unwrap();
392
393 factory
394 }
395
396 #[test]
397 fn test_detector1() {
398 let factory = setup_factory();
399 let mut detect = factory.create(None);
400 detect.append("a");
401 let result = detect.detect();
402 assert!(result.is_ok(), "Unexpected error: {:?}", result);
403 let lang = result.unwrap();
404 assert_eq!(lang, "en");
405 }
406
407 #[test]
408 fn test_detector2() {
409 let factory = setup_factory();
410 let mut detect = factory.create(None);
411 detect.append("b d");
412 let result = detect.detect();
413 assert!(result.is_ok(), "Unexpected error: {:?}", result);
414 let lang = result.unwrap();
415 assert_eq!(lang, "fr");
416 }
417
418 #[test]
419 fn test_detector3() {
420 let factory = setup_factory();
421 let mut detect = factory.create(None);
422 detect.append("d e");
423 let result = detect.detect();
424 assert!(result.is_ok(), "Unexpected error: {:?}", result);
425 let lang = result.unwrap();
426 assert_eq!(lang, "en");
427 }
428
429 #[test]
430 fn test_detector4() {
431 let factory = setup_factory();
432 let mut detect = factory.create(None);
433 detect.append("\u{3042}\u{3042}\u{3042}\u{3042}a");
434 let result = detect.detect();
435 assert!(result.is_ok(), "Unexpected error: {:?}", result);
436 let lang = result.unwrap();
437 assert_eq!(lang, "ja");
438 }
439
440 #[test]
441 fn test_lang_list() {
442 let factory = setup_factory();
443 let langlist = factory.get_lang_list();
444 assert_eq!(langlist.len(), 3);
445 assert_eq!(langlist[0], "en");
446 assert_eq!(langlist[1], "fr");
447 assert_eq!(langlist[2], "ja");
448 }
449
450 #[test]
451 fn test_factory_from_json_string() {
452 let mut factory = DetectorFactory::new().build();
453 factory.clear();
454 let json_lang1 = "{\"freq\":{\"A\":3,\"B\":6,\"C\":3,\"AB\":2,\"BC\":1,\"ABC\":2,\"BBC\":1,\"CBA\":1},\"n_words\":[12,3,4],\"name\":\"lang1\"}";
455 let json_lang2 = "{\"freq\":{\"A\":6,\"B\":3,\"C\":3,\"AA\":3,\"AB\":2,\"ABC\":1,\"ABA\":1,\"CAA\":1},\"n_words\":[12,5,3],\"name\":\"lang2\"}";
456 let profiles = vec![json_lang1, json_lang2];
457 let profiles_ref: Vec<&str> = profiles.iter().map(|s| *s).collect();
458 factory.load_json_profile(&profiles_ref).unwrap();
459 let langlist = factory.get_lang_list();
460 assert_eq!(langlist.len(), 2);
461 assert_eq!(langlist[0], "lang1");
462 assert_eq!(langlist[1], "lang2");
463 }
464}