Skip to main content

anno/
lang.rs

1//! Language detection and classification utilities.
2
3/// Supported languages for text analysis.
4///
5/// Variants are intentionally ordered for indexed access in `detect_language`.
6/// The `repr(u8)` is required for safe conversion from index to enum variant.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8#[repr(u8)]
9pub enum Language {
10    /// English language
11    English,
12    /// German language
13    German,
14    /// French language
15    French,
16    /// Spanish language
17    Spanish,
18    /// Italian language
19    Italian,
20    /// Portuguese language
21    Portuguese,
22    /// Russian language
23    Russian,
24    /// Chinese language (Simplified/Traditional)
25    Chinese,
26    /// Japanese language
27    Japanese,
28    /// Korean language
29    Korean,
30    /// Arabic language
31    Arabic,
32    /// Hebrew language
33    Hebrew,
34    /// Other/unknown language
35    Other,
36}
37
38impl Language {
39    /// Returns true if this is a CJK (Chinese, Japanese, Korean) language.
40    #[must_use]
41    pub fn is_cjk(&self) -> bool {
42        matches!(
43            self,
44            Language::Chinese | Language::Japanese | Language::Korean
45        )
46    }
47
48    /// Returns true if this is a right-to-left language (Arabic, Hebrew).
49    #[must_use]
50    pub fn is_rtl(&self) -> bool {
51        matches!(self, Language::Arabic | Language::Hebrew)
52    }
53
54    /// Get ISO 639-1 language code (2-letter).
55    #[must_use]
56    pub fn iso_code(&self) -> &'static str {
57        match self {
58            Language::English => "en",
59            Language::German => "de",
60            Language::French => "fr",
61            Language::Spanish => "es",
62            Language::Italian => "it",
63            Language::Portuguese => "pt",
64            Language::Russian => "ru",
65            Language::Chinese => "zh",
66            Language::Japanese => "ja",
67            Language::Korean => "ko",
68            Language::Arabic => "ar",
69            Language::Hebrew => "he",
70            Language::Other => "xx",
71        }
72    }
73}
74
75/// Simple heuristic language detection based on Unicode scripts.
76///
77/// Returns the most likely language based on character counts.
78pub fn detect_language(text: &str) -> Language {
79    let mut counts = [0usize; 13];
80    let mut total = 0;
81
82    for c in text.chars() {
83        match c {
84            // CJK Unified Ideographs
85            '\u{4e00}'..='\u{9fff}' => {
86                total += 1;
87                counts[Language::Chinese as usize] += 1;
88            }
89            // Hiragana/Katakana
90            '\u{3040}'..='\u{30ff}' => {
91                total += 1;
92                counts[Language::Japanese as usize] += 1;
93            }
94            // Hangul
95            '\u{ac00}'..='\u{d7af}' => {
96                total += 1;
97                counts[Language::Korean as usize] += 1;
98            }
99            // Arabic
100            '\u{0600}'..='\u{06ff}' => {
101                total += 1;
102                counts[Language::Arabic as usize] += 1;
103            }
104            // Hebrew
105            '\u{0590}'..='\u{05ff}' => {
106                total += 1;
107                counts[Language::Hebrew as usize] += 1;
108            }
109            // Cyrillic
110            '\u{0400}'..='\u{04ff}' => {
111                total += 1;
112                counts[Language::Russian as usize] += 1;
113            }
114            // Latin - distinguishing languages is hard without dictionary,
115            // but we can check for specific chars
116            'a'..='z' | 'A'..='Z' => {
117                total += 1;
118                counts[Language::English as usize] += 1; // Generic Latin
119            }
120            // German specific (ß, ä, ö, ü)
121            'ß' | 'ä' | 'ö' | 'ü' | 'Ä' | 'Ö' | 'Ü' => {
122                total += 1;
123                counts[Language::German as usize] += 10
124            }
125            // French (à, â, ç, é, è, ê, ë, î, ï, ô, û, ù)
126            'à' | 'â' | 'ç' | 'é' | 'è' | 'ê' | 'ë' | 'î' | 'ï' | 'ô' | 'û' | 'ù' => {
127                total += 1;
128                counts[Language::French as usize] += 5
129            }
130            // Spanish (ñ, ¿, ¡, á, é, í, ó, ú)
131            'ñ' | '¿' | '¡' | 'á' | 'í' | 'ó' | 'ú' => {
132                total += 1;
133                counts[Language::Spanish as usize] += 5
134            }
135            _ => {}
136        }
137    }
138
139    if total == 0 {
140        return Language::English; // Default
141    }
142
143    // Find max
144    let mut max_idx = 0;
145    let mut max_val = 0;
146    for (i, &val) in counts.iter().enumerate() {
147        if val > max_val {
148            max_val = val;
149            max_idx = i;
150        }
151    }
152
153    // If we detected CJK chars but classified as Chinese, check if Japanese specific chars exist
154    if max_idx == Language::Chinese as usize && counts[Language::Japanese as usize] > 0 {
155        return Language::Japanese; // Japanese uses Kanji (Chinese chars) too
156    }
157
158    // Convert index to Language variant safely
159    // Using explicit match instead of transmute for compile-time safety
160    match max_idx {
161        0 => Language::English,
162        1 => Language::German,
163        2 => Language::French,
164        3 => Language::Spanish,
165        4 => Language::Italian,
166        5 => Language::Portuguese,
167        6 => Language::Russian,
168        7 => Language::Chinese,
169        8 => Language::Japanese,
170        9 => Language::Korean,
171        10 => Language::Arabic,
172        11 => Language::Hebrew,
173        _ => Language::Other,
174    }
175}
176
177/// Detect code-switching (mixed languages) in text.
178///
179/// Returns a vector of language segments with their positions.
180/// Useful for processing multilingual text where languages switch mid-sentence.
181///
182/// # Example
183///
184/// ```rust
185/// use anno::lang::{detect_code_switching, Language};
186///
187/// let segments = detect_code_switching("Dr. 田中 presented at MIT's conference.");
188/// // Returns: [(Language::English, 0, 4), (Language::Japanese, 5, 7), (Language::English, 8, 40)]
189/// ```
190#[must_use]
191pub fn detect_code_switching(text: &str) -> Vec<(Language, usize, usize)> {
192    if text.is_empty() {
193        return vec![];
194    }
195
196    let mut segments = Vec::new();
197    let chars: Vec<char> = text.chars().collect();
198    let mut current_lang = detect_language(text);
199    let mut segment_start = 0;
200
201    // Use a sliding window to detect language changes
202    const WINDOW_SIZE: usize = 10; // Characters per window
203    let mut i = 0;
204
205    while i < chars.len() {
206        // Check language in current window
207        let window_end = (i + WINDOW_SIZE).min(chars.len());
208        let window_text: String = chars[i..window_end].iter().collect();
209        let window_lang = detect_language(&window_text);
210
211        // If language changed significantly, start new segment
212        if window_lang != current_lang && window_lang != Language::Other {
213            // Save previous segment
214            if i > segment_start {
215                segments.push((current_lang, segment_start, i));
216            }
217            segment_start = i;
218            current_lang = window_lang;
219        }
220
221        i += WINDOW_SIZE / 2; // Overlap windows for smoother detection
222    }
223
224    // Add final segment
225    if segment_start < chars.len() {
226        segments.push((current_lang, segment_start, chars.len()));
227    }
228
229    // Merge adjacent segments of the same language
230    let mut merged = Vec::new();
231    for (lang, start, end) in segments {
232        if let Some((last_lang, _last_start, last_end)) = merged.last_mut() {
233            if *last_lang == lang && *last_end == start {
234                *last_end = end;
235            } else {
236                merged.push((lang, start, end));
237            }
238        } else {
239            merged.push((lang, start, end));
240        }
241    }
242
243    merged
244}
245
246/// Language clustering for cross-lingual transfer learning.
247///
248/// Groups languages by similarity for better multilingual NER performance.
249/// Based on research showing that semantic clustering outperforms linguistic family grouping.
250///
251/// Returns language clusters where languages in the same cluster benefit from shared training.
252#[must_use]
253pub fn language_clusters() -> Vec<Vec<Language>> {
254    // Research-based clusters (semantic similarity, not linguistic families)
255    vec![
256        // Cluster 1: Germanic + Romance (high-resource, similar syntax)
257        vec![
258            Language::English,
259            Language::German,
260            Language::French,
261            Language::Spanish,
262            Language::Italian,
263            Language::Portuguese,
264        ],
265        // Cluster 2: Slavic
266        vec![Language::Russian],
267        // Cluster 3: CJK (character-based, similar entity patterns)
268        vec![Language::Chinese, Language::Japanese, Language::Korean],
269        // Cluster 4: Semitic (RTL, similar morphology)
270        vec![Language::Arabic, Language::Hebrew],
271    ]
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_detect_english() {
280        assert_eq!(detect_language("Hello, world!"), Language::English);
281        assert_eq!(detect_language("The quick brown fox"), Language::English);
282    }
283
284    #[test]
285    fn test_detect_german() {
286        // Need enough German-specific characters to outweigh generic Latin
287        assert_eq!(
288            detect_language("Größe Müller Öffentlichkeit Übung"),
289            Language::German
290        );
291        assert_eq!(detect_language("ß ä ö ü ß Ä Ö Ü"), Language::German);
292    }
293
294    #[test]
295    fn test_detect_french() {
296        assert_eq!(detect_language("Café à Paris"), Language::French);
297        assert_eq!(detect_language("être où ça"), Language::French);
298    }
299
300    #[test]
301    fn test_detect_spanish() {
302        assert_eq!(detect_language("¿Cómo estás? Mañana"), Language::Spanish);
303    }
304
305    #[test]
306    fn test_detect_chinese() {
307        assert_eq!(detect_language("北京欢迎您"), Language::Chinese);
308        assert_eq!(detect_language("习近平"), Language::Chinese);
309    }
310
311    #[test]
312    fn test_detect_japanese() {
313        // Hiragana/Katakana triggers Japanese detection
314        assert_eq!(detect_language("こんにちは"), Language::Japanese);
315        assert_eq!(detect_language("東京タワー"), Language::Japanese);
316    }
317
318    #[test]
319    fn test_detect_korean() {
320        assert_eq!(detect_language("안녕하세요"), Language::Korean);
321        assert_eq!(detect_language("서울"), Language::Korean);
322    }
323
324    #[test]
325    fn test_detect_arabic() {
326        assert_eq!(detect_language("مرحبا"), Language::Arabic);
327        assert_eq!(detect_language("القاهرة"), Language::Arabic);
328    }
329
330    #[test]
331    fn test_detect_hebrew() {
332        assert_eq!(detect_language("שלום"), Language::Hebrew);
333        assert_eq!(detect_language("ירושלים"), Language::Hebrew);
334    }
335
336    #[test]
337    fn test_detect_russian() {
338        assert_eq!(detect_language("Привет, мир!"), Language::Russian);
339        assert_eq!(detect_language("Москва"), Language::Russian);
340    }
341
342    #[test]
343    fn test_empty_text_defaults_to_english() {
344        assert_eq!(detect_language(""), Language::English);
345        assert_eq!(detect_language("123 !@# "), Language::English);
346    }
347
348    #[test]
349    fn test_is_cjk() {
350        assert!(Language::Chinese.is_cjk());
351        assert!(Language::Japanese.is_cjk());
352        assert!(Language::Korean.is_cjk());
353        assert!(!Language::English.is_cjk());
354        assert!(!Language::Arabic.is_cjk());
355    }
356
357    #[test]
358    fn test_is_rtl() {
359        assert!(Language::Arabic.is_rtl());
360        assert!(Language::Hebrew.is_rtl());
361        assert!(!Language::English.is_rtl());
362        assert!(!Language::Chinese.is_rtl());
363    }
364
365    #[test]
366    fn test_language_repr_matches_index() {
367        // Verify the repr(u8) matches our index expectations
368        assert_eq!(Language::English as u8, 0);
369        assert_eq!(Language::German as u8, 1);
370        assert_eq!(Language::French as u8, 2);
371        assert_eq!(Language::Spanish as u8, 3);
372        assert_eq!(Language::Italian as u8, 4);
373        assert_eq!(Language::Portuguese as u8, 5);
374        assert_eq!(Language::Russian as u8, 6);
375        assert_eq!(Language::Chinese as u8, 7);
376        assert_eq!(Language::Japanese as u8, 8);
377        assert_eq!(Language::Korean as u8, 9);
378        assert_eq!(Language::Arabic as u8, 10);
379        assert_eq!(Language::Hebrew as u8, 11);
380        assert_eq!(Language::Other as u8, 12);
381    }
382
383    #[test]
384    fn test_detect_code_switching() {
385        // Mixed English-Japanese (CJK characters should be detected)
386        let segments = detect_code_switching("Dr. 田中 presented at MIT.");
387        // Should detect at least one segment (may merge if window is too large)
388        assert!(!segments.is_empty());
389
390        // Mixed English-Chinese
391        let segments = detect_code_switching("北京 (Beijing) is the capital.");
392        assert!(!segments.is_empty());
393
394        // Single language
395        let segments = detect_code_switching("Hello world");
396        assert_eq!(segments.len(), 1);
397
398        // Verify segments have valid ranges
399        for (_lang, start, end) in segments {
400            assert!(start < end);
401        }
402    }
403
404    #[test]
405    fn test_language_iso_code() {
406        assert_eq!(Language::English.iso_code(), "en");
407        assert_eq!(Language::Spanish.iso_code(), "es");
408        assert_eq!(Language::Chinese.iso_code(), "zh");
409        assert_eq!(Language::Arabic.iso_code(), "ar");
410    }
411
412    #[test]
413    fn test_language_clusters() {
414        let clusters = language_clusters();
415        assert!(!clusters.is_empty());
416
417        // Check that major languages are in clusters
418        let all_langs: Vec<Language> = clusters.iter().flat_map(|c| c.iter().copied()).collect();
419        assert!(all_langs.contains(&Language::English));
420        assert!(all_langs.contains(&Language::Chinese));
421    }
422}