langdetect_rs/
detector.rs

1use rand::{SeedableRng, Rng};
2use rand::rngs::StdRng;
3use rand_distr::{Normal, Distribution};
4
5use crate::language::Language;
6use crate::utils::ngram::NGram;
7use std::collections::HashMap;
8
9/// Errors that can occur during language detection.
10#[derive(Debug, Clone)]
11pub enum DetectorError {
12    /// No detectable features found in the input text.
13    NoFeatures,
14}
15
16impl std::fmt::Display for DetectorError {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self {
19            DetectorError::NoFeatures => write!(f, "No features found in the input text"),
20        }
21    }
22}
23
24/// Core language detection engine.
25///
26/// The Detector performs the actual language identification using n-gram analysis
27/// and Bayesian probability estimation. It uses an iterative expectation-maximization
28/// algorithm to determine the most likely language for a given text.
29///
30/// # Algorithm Overview
31///
32/// 1. Extract n-grams (1-3 characters) from the input text
33/// 2. Look up probabilities for each n-gram across all languages
34/// 3. Use iterative EM algorithm to estimate language probabilities
35/// 4. Return the language with highest probability
36///
37/// # Examples
38///
39/// ```rust
40/// use langdetect_rs::detector_factory::DetectorFactory;
41///
42/// let factory = DetectorFactory::default().build();
43/// let mut detector = factory.create(None);
44/// detector.append("Hello world!");
45/// let language = detector.detect().unwrap();
46/// ```
47pub struct Detector {
48    /// Word-to-language probability mapping.
49    pub word_lang_prob_map: HashMap<String, Vec<f64>>,
50    /// List of language identifiers.
51    pub langlist: Vec<String>,
52    /// Optional seed for reproducible randomization.
53    pub seed: Option<u64>,
54    /// Accumulated text for analysis.
55    pub text: String,
56    /// Current language probability estimates.
57    pub langprob: Option<Vec<f64>>,
58    /// Alpha smoothing parameter for probability estimation.
59    pub alpha: f64,
60    /// Number of trials for the EM algorithm.
61    pub n_trial: usize,
62    /// Maximum text length to process.
63    pub max_text_length: usize,
64    /// Prior probabilities for languages (optional).
65    pub prior_map: Option<Vec<f64>>,
66    /// Whether to enable verbose logging.
67    pub verbose: bool,
68}
69
70impl Detector {
71    /// Default alpha smoothing parameter.
72    pub const ALPHA_DEFAULT: f64 = 0.5;
73    /// Width of alpha variation during randomization.
74    pub const ALPHA_WIDTH: f64 = 0.05;
75    /// Maximum iterations for the EM algorithm.
76    pub const ITERATION_LIMIT: usize = 1000;
77    /// Minimum probability threshold for reporting languages.
78    pub const PROB_THRESHOLD: f64 = 0.1;
79    /// Convergence threshold for the EM algorithm.
80    pub const CONV_THRESHOLD: f64 = 0.99999;
81    /// Base frequency for probability calculations.
82    pub const BASE_FREQ: f64 = 10000.0;
83    /// Language identifier for unknown/undetected languages.
84    pub const UNKNOWN_LANG: &'static str = "unknown";
85
86    /// Creates a new Detector with the given language profiles.
87    ///
88    /// # Arguments
89    /// * `word_lang_prob_map` - Pre-computed word-to-language probability mapping.
90    /// * `langlist` - List of language identifiers.
91    /// * `seed` - Optional seed for reproducible randomization.
92    pub fn new(word_lang_prob_map: HashMap<String, Vec<f64>>, langlist: Vec<String>, seed: Option<u64>) -> Self {
93        Detector {
94            word_lang_prob_map,
95            langlist,
96            seed,
97            text: String::new(),
98            langprob: None,
99            alpha: Self::ALPHA_DEFAULT,
100            n_trial: 7,
101            max_text_length: 10000,
102            prior_map: None,
103            verbose: false,
104        }
105    }
106
107    /// Appends text to the detector for analysis.
108    ///
109    /// The text is preprocessed to remove URLs, emails, and normalize whitespace.
110    /// Vietnamese text is also normalized for better detection.
111    ///
112    /// # Arguments
113    /// * `text` - The text to append for language detection.
114    ///
115    /// # Examples
116    ///
117    /// ```rust
118    /// use langdetect_rs::detector_factory::DetectorFactory;
119    ///
120    /// let factory = DetectorFactory::default().build();
121    /// let mut detector = factory.create(None);
122    /// detector.append("Hello world!");
123    /// ```
124    pub fn append(&mut self, text: &str) {
125        // Remove URLs and emails (simple regex)
126        let url_re = regex::Regex::new(r"https?://[-_.?&~;+=/#0-9A-Za-z]{1,2076}").unwrap();
127        let mail_re = regex::Regex::new(r"[-_.0-9A-Za-z]{1,64}@[-_0-9A-Za-z]{1,255}[-_.0-9A-Za-z]{1,255}").unwrap();
128        let mut text = url_re.replace_all(text, " ").to_string();
129        text = mail_re.replace_all(&text, " ").to_string();
130        text = NGram::normalize_vi(&text);
131        let mut pre = ' ';
132        for ch in text.chars().take(self.max_text_length) {
133            if ch != ' ' || pre != ' ' {
134                self.text.push(ch);
135            }
136            pre = ch;
137        }
138    }
139
140    /// Cleans the text by removing Latin characters if they are outnumbered by non-Latin characters.
141    ///
142    /// This helps improve detection accuracy for texts that mix scripts.
143    fn cleaning_text(&mut self) {
144        let mut latin_count = 0;
145        let mut non_latin_count = 0;
146        for ch in self.text.chars() {
147            if ('A'..='z').contains(&ch) {
148                latin_count += 1;
149            } else if ch >= '\u{0300}' {
150                if let Some(block) = crate::utils::unicode_block::unicode_block(ch) {
151                    if block != crate::utils::unicode_block::UNICODE_LATIN_EXTENDED_ADDITIONAL {
152                        non_latin_count += 1;
153                    }
154                }
155            }
156        }
157        if latin_count * 2 < non_latin_count {
158            let mut text_without_latin = String::new();
159            for ch in self.text.chars() {
160                if ch < 'A' || ch > 'z' {
161                    text_without_latin.push(ch);
162                }
163            }
164            self.text = text_without_latin;
165        }
166    }
167
168    /// Performs language detection on the accumulated text.
169    ///
170    /// # Returns
171    /// The detected language code, or "unknown" if detection fails.
172    ///
173    /// # Errors
174    /// Returns `DetectorError::NoFeatures` if no detectable n-grams are found.
175    ///
176    /// # Examples
177    ///
178    /// ```rust
179    /// use langdetect_rs::detector_factory::DetectorFactory;
180    ///
181    /// let factory = DetectorFactory::default().build();
182    /// let mut detector = factory.create(None);
183    /// detector.append("Bonjour le monde!");
184    /// let language = detector.detect().unwrap();
185    /// assert_eq!(language, "fr");
186    /// ```
187    pub fn detect(&mut self) -> Result<String, DetectorError> {
188        let probabilities = self.get_probabilities()?;
189        if !probabilities.is_empty() {
190            Ok(probabilities[0].lang.clone().unwrap_or_else(|| Self::UNKNOWN_LANG.to_string()))
191        } else {
192            Ok(Self::UNKNOWN_LANG.to_string())
193        }
194    }
195
196    /// Gets detailed language probabilities for the accumulated text.
197    ///
198    /// Returns all languages with probability above the threshold, sorted by probability descending.
199    ///
200    /// # Returns
201    /// A vector of `Language` structs with language codes and probabilities.
202    ///
203    /// # Errors
204    /// Returns `DetectorError::NoFeatures` if no detectable n-grams are found.
205    ///
206    /// # Examples
207    ///
208    /// ```rust
209    /// use langdetect_rs::detector_factory::DetectorFactory;
210    ///
211    /// let factory = DetectorFactory::default().build();
212    /// let mut detector = factory.create(None);
213    /// detector.append("Hello world!");
214    /// let probabilities = detector.get_probabilities().unwrap();
215    /// for lang in probabilities {
216    ///     println!("{}: {:.3}", lang.lang.unwrap_or_default(), lang.prob);
217    /// }
218    /// ```
219    pub fn get_probabilities(&mut self) -> Result<Vec<Language>, DetectorError> {
220        if self.langprob.is_none() {
221            self.detect_block()?;
222        }
223        Ok(self.sort_probability(self.langprob.as_ref().unwrap()))
224    }
225
226    /// Runs the core detection algorithm on the accumulated text.
227    ///
228    /// This method implements the expectation-maximization algorithm for language detection.
229    ///
230    /// # Returns
231    /// Ok(()) on successful detection, or an error if no features are found.
232    fn detect_block(&mut self) -> Result<(), DetectorError> {
233        self.cleaning_text();
234        let ngrams = self.extract_ngrams();
235        if ngrams.is_empty() {
236            return Err(DetectorError::NoFeatures);
237        }
238        self.langprob = Some(vec![0.0; self.langlist.len()]);
239        let mut rng = if let Some(seed) = self.seed {
240            StdRng::seed_from_u64(seed)
241        } else {
242            let mut thread_rng = rand::rng();
243            StdRng::from_rng(&mut thread_rng)
244        };
245        for _t in 0..self.n_trial {
246            let mut prob = self.init_probability();
247            let normal = Normal::new(0.0, 1.0).unwrap();
248            let alpha = self.alpha + normal.sample(&mut rng) * Self::ALPHA_WIDTH;
249            let mut i = 0;
250            loop {
251                let word = ngrams[rng.random_range(0..ngrams.len())].clone();
252                self.update_lang_prob(&mut prob, &word, alpha);
253                if i % 5 == 0 {
254                    if self.normalize_prob(&mut prob) > Self::CONV_THRESHOLD || i >= Self::ITERATION_LIMIT {
255                        break;
256                    }
257                }
258                i += 1;
259            }
260            for j in 0..self.langprob.as_ref().unwrap().len() {
261                self.langprob.as_mut().unwrap()[j] += prob[j] / self.n_trial as f64;
262            }
263        }
264        Ok(())
265    }
266
267    /// Initializes probability estimates for the EM algorithm.
268    ///
269    /// Uses prior probabilities if available, otherwise uniform distribution.
270    fn init_probability(&self) -> Vec<f64> {
271        if let Some(ref prior) = self.prior_map {
272            prior.clone()
273        } else {
274            vec![1.0 / self.langlist.len() as f64; self.langlist.len()]
275        }
276    }
277
278    /// Extracts n-grams from the text for language detection.
279    ///
280    /// Only includes n-grams that exist in the language profiles.
281    fn extract_ngrams(&self) -> Vec<String> {
282        let range = 1..=NGram::N_GRAM;
283        let mut result = Vec::new();
284        let mut ngram = NGram::new();
285        for ch in self.text.chars() {
286            ngram.add_char(ch);
287            if ngram.capitalword {
288                continue;
289            }
290            for n in range.clone() {
291                if ngram.grams.len() < n {
292                    break;
293                }
294                let w: String = ngram.grams.chars().rev().take(n).collect::<Vec<_>>().into_iter().rev().collect();
295                if !w.is_empty() && w != " " && self.word_lang_prob_map.contains_key(&w) {
296                    result.push(w);
297                }
298            }
299        }
300        result
301    }
302
303    /// Updates language probabilities based on an n-gram observation.
304    ///
305    /// # Arguments
306    /// * `prob` - Current probability estimates (modified in-place).
307    /// * `word` - The n-gram to use for updating.
308    /// * `alpha` - Smoothing parameter.
309    ///
310    /// # Returns
311    /// true if the n-gram was found in profiles, false otherwise.
312    fn update_lang_prob(&self, prob: &mut [f64], word: &str, alpha: f64) -> bool {
313        if !self.word_lang_prob_map.contains_key(word) {
314            return false;
315        }
316        let lang_prob_map = &self.word_lang_prob_map[word];
317        let weight = alpha / Self::BASE_FREQ;
318        for i in 0..prob.len() {
319            prob[i] *= weight + lang_prob_map[i];
320        }
321        true
322    }
323
324    /// Normalizes probability estimates and returns the maximum probability.
325    ///
326    /// # Arguments
327    /// * `prob` - Probability vector to normalize (modified in-place).
328    ///
329    /// # Returns
330    /// The maximum probability value after normalization.
331    fn normalize_prob(&self, prob: &mut [f64]) -> f64 {
332        let sump: f64 = prob.iter().sum();
333        let mut maxp = 0.0;
334        for p in prob.iter_mut() {
335            *p /= sump;
336            if maxp < *p {
337                maxp = *p;
338            }
339        }
340        maxp
341    }
342
343    /// Converts probability estimates to a sorted list of Language structs.
344    ///
345    /// Only includes languages with probability above the threshold.
346    ///
347    /// # Arguments
348    /// * `prob` - Raw probability estimates.
349    ///
350    /// # Returns
351    /// Sorted vector of Language structs.
352    fn sort_probability(&self, prob: &[f64]) -> Vec<Language> {
353        let mut result: Vec<Language> = self.langlist.iter().zip(prob.iter())
354            .filter(|(_, p)| **p > Self::PROB_THRESHOLD)
355            .map(|(lang, &p)| Language::new(Some(lang.clone()), p)).collect();
356        result.sort_by(|a, b| b.partial_cmp(a).unwrap());
357        result
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use crate::detector_factory::DetectorFactory;
364    use crate::utils::lang_profile::LangProfile;
365
366    fn setup_factory() -> DetectorFactory {
367        let mut factory = DetectorFactory::new().build();
368
369        let mut profile_en = LangProfile::new().with_name("en").build();
370        for w in ["a", "a", "a", "b", "b", "c", "c", "d", "e"].iter() {
371            profile_en.add(w);
372        }
373        let result = factory.add_profile(profile_en, 0, 3);
374        assert!(result.is_ok(), "Unexpected error in add_profile: {:?}", result);
375        result.unwrap();
376
377        let mut profile_fr = LangProfile::new().with_name("fr").build();
378        for w in ["a", "b", "b", "c", "c", "c", "d", "d", "d"].iter() {
379            profile_fr.add(w);
380        }
381        let result = factory.add_profile(profile_fr, 1, 3);
382        assert!(result.is_ok(), "Unexpected error in add_profile: {:?}", result);
383        result.unwrap();
384
385        let mut profile_ja = LangProfile::new().with_name("ja").build();
386        for w in ["\u{3042}", "\u{3042}", "\u{3042}", "\u{3044}", "\u{3046}", "\u{3048}", "\u{3048}"].iter() {
387            profile_ja.add(w);
388        }
389        let result = factory.add_profile(profile_ja, 2, 3);
390        assert!(result.is_ok(), "Unexpected error in add_profile: {:?}", result);
391        result.unwrap();
392
393        factory
394    }
395
396    #[test]
397    fn test_detector1() {
398        let factory = setup_factory();
399        let mut detect = factory.create(None);
400        detect.append("a");
401        let result = detect.detect();
402        assert!(result.is_ok(), "Unexpected error: {:?}", result);
403        let lang = result.unwrap();
404        assert_eq!(lang, "en");
405    }
406
407    #[test]
408    fn test_detector2() {
409        let factory = setup_factory();
410        let mut detect = factory.create(None);
411        detect.append("b d");
412        let result = detect.detect();
413        assert!(result.is_ok(), "Unexpected error: {:?}", result);
414        let lang = result.unwrap();
415        assert_eq!(lang, "fr");
416    }
417
418    #[test]
419    fn test_detector3() {
420        let factory = setup_factory();
421        let mut detect = factory.create(None);
422        detect.append("d e");
423        let result = detect.detect();
424        assert!(result.is_ok(), "Unexpected error: {:?}", result);
425        let lang = result.unwrap();
426        assert_eq!(lang, "en");
427    }
428
429    #[test]
430    fn test_detector4() {
431        let factory = setup_factory();
432        let mut detect = factory.create(None);
433        detect.append("\u{3042}\u{3042}\u{3042}\u{3042}a");
434        let result = detect.detect();
435        assert!(result.is_ok(), "Unexpected error: {:?}", result);
436        let lang = result.unwrap();
437        assert_eq!(lang, "ja");
438    }
439
440    #[test]
441    fn test_lang_list() {
442        let factory = setup_factory();
443        let langlist = factory.get_lang_list();
444        assert_eq!(langlist.len(), 3);
445        assert_eq!(langlist[0], "en");
446        assert_eq!(langlist[1], "fr");
447        assert_eq!(langlist[2], "ja");
448    }
449
450    #[test]
451    fn test_factory_from_json_string() {
452        let mut factory = DetectorFactory::new().build();
453        factory.clear();
454        let json_lang1 = "{\"freq\":{\"A\":3,\"B\":6,\"C\":3,\"AB\":2,\"BC\":1,\"ABC\":2,\"BBC\":1,\"CBA\":1},\"n_words\":[12,3,4],\"name\":\"lang1\"}";
455        let json_lang2 = "{\"freq\":{\"A\":6,\"B\":3,\"C\":3,\"AA\":3,\"AB\":2,\"ABC\":1,\"ABA\":1,\"CAA\":1},\"n_words\":[12,5,3],\"name\":\"lang2\"}";
456        let profiles = vec![json_lang1, json_lang2];
457        let profiles_ref: Vec<&str> = profiles.iter().map(|s| *s).collect();
458        factory.load_json_profile(&profiles_ref).unwrap();
459        let langlist = factory.get_lang_list();
460        assert_eq!(langlist.len(), 2);
461        assert_eq!(langlist[0], "lang1");
462        assert_eq!(langlist[1], "lang2");
463    }
464}