1use std::collections::HashMap;
7
8#[derive(Debug, Clone, PartialEq)]
10pub struct LangDetection {
11 pub code: String,
13 pub name: String,
15 pub confidence: f64,
17}
18
19pub fn detect_language(text: &str) -> Option<LangDetection> {
23 let clean: String = text
24 .chars()
25 .filter(|c| c.is_alphabetic() || c.is_whitespace())
26 .collect::<String>()
27 .to_lowercase();
28
29 if clean.len() < 20 {
30 return None;
31 }
32
33 let trigrams = extract_trigrams(&clean);
34 if trigrams.is_empty() {
35 return None;
36 }
37
38 let mut best = ("en", "English", 0.0_f64);
39
40 for &(code, name, profile) in PROFILES {
41 let score = cosine_similarity(&trigrams, profile);
42 if score > best.2 {
43 best = (code, name, score);
44 }
45 }
46
47 Some(LangDetection {
48 code: best.0.to_string(),
49 name: best.1.to_string(),
50 confidence: best.2,
51 })
52}
53
54fn extract_trigrams(text: &str) -> HashMap<&str, f64> {
56 let mut counts: HashMap<&str, usize> = HashMap::new();
57 let bytes = text.as_bytes();
58 if bytes.len() < 3 {
59 return HashMap::new();
60 }
61 let len = text.len();
63 for i in 0..len.saturating_sub(2) {
64 if text.is_char_boundary(i) && text.is_char_boundary(i + 3) {
65 let tri = &text[i..i + 3];
66 *counts.entry(tri).or_insert(0) += 1;
67 }
68 }
69 let total: f64 = counts.values().sum::<usize>() as f64;
70 if total == 0.0 {
71 return HashMap::new();
72 }
73 counts
74 .into_iter()
75 .map(|(k, v)| (k, v as f64 / total))
76 .collect()
77}
78
79fn cosine_similarity(trigrams: &HashMap<&str, f64>, profile: &[(&str, f64)]) -> f64 {
81 let mut dot = 0.0_f64;
82 let mut norm_a = 0.0_f64;
83 let mut norm_b = 0.0_f64;
84
85 let profile_map: HashMap<&str, f64> = profile.iter().copied().collect();
86
87 for (&tri, &freq) in trigrams {
88 norm_a += freq * freq;
89 if let Some(&pf) = profile_map.get(tri) {
90 dot += freq * pf;
91 }
92 }
93 for &(_, pf) in profile {
94 norm_b += pf * pf;
95 }
96
97 let denom = norm_a.sqrt() * norm_b.sqrt();
98 if denom < 1e-10 {
99 0.0
100 } else {
101 dot / denom
102 }
103}
104
105type LangProfile = (&'static str, &'static str, &'static [(&'static str, f64)]);
108static PROFILES: &[LangProfile] = &[
109 (
110 "en",
111 "English",
112 &[
113 ("the", 0.035),
114 ("he ", 0.025),
115 ("and", 0.020),
116 ("ing", 0.018),
117 ("tion", 0.015),
118 ("er ", 0.014),
119 ("ion", 0.013),
120 (" th", 0.025),
121 ("ed ", 0.012),
122 ("in ", 0.012),
123 ("to ", 0.011),
124 (" to", 0.011),
125 ("of ", 0.020),
126 (" of", 0.018),
127 ("ent", 0.010),
128 ("is ", 0.010),
129 (" is", 0.009),
130 ("hat", 0.009),
131 (" an", 0.012),
132 ("nd ", 0.010),
133 ],
134 ),
135 (
136 "fr",
137 "French",
138 &[
139 ("es ", 0.025),
140 ("de ", 0.022),
141 (" de", 0.022),
142 ("le ", 0.018),
143 ("ent", 0.017),
144 (" le", 0.016),
145 ("ion", 0.015),
146 ("les", 0.014),
147 ("la ", 0.013),
148 (" la", 0.013),
149 ("re ", 0.012),
150 ("tion", 0.011),
151 ("que", 0.013),
152 (" qu", 0.011),
153 ("ue ", 0.010),
154 ("et ", 0.010),
155 (" et", 0.009),
156 ("des", 0.012),
157 (" de", 0.022),
158 ("ont", 0.009),
159 ],
160 ),
161 (
162 "de",
163 "German",
164 &[
165 ("en ", 0.030),
166 ("er ", 0.025),
167 ("der", 0.018),
168 ("die", 0.017),
169 ("ein", 0.015),
170 ("sch", 0.014),
171 (" de", 0.016),
172 ("ich", 0.014),
173 ("und", 0.013),
174 (" un", 0.012),
175 ("nd ", 0.011),
176 ("den", 0.010),
177 ("che", 0.012),
178 (" di", 0.013),
179 ("ie ", 0.012),
180 ("ung", 0.010),
181 ("gen", 0.009),
182 ("ine", 0.009),
183 (" ei", 0.010),
184 ("das", 0.008),
185 ],
186 ),
187 (
188 "es",
189 "Spanish",
190 &[
191 ("de ", 0.025),
192 (" de", 0.023),
193 ("os ", 0.018),
194 ("la ", 0.016),
195 (" la", 0.015),
196 ("en ", 0.015),
197 ("el ", 0.014),
198 (" el", 0.013),
199 ("ión", 0.012),
200 ("es ", 0.020),
201 (" en", 0.012),
202 ("ent", 0.010),
203 ("que", 0.012),
204 (" qu", 0.010),
205 ("ue ", 0.009),
206 ("aci", 0.008),
207 ("ado", 0.008),
208 ("las", 0.010),
209 (" lo", 0.009),
210 ("los", 0.010),
211 ],
212 ),
213 (
214 "it",
215 "Italian",
216 &[
217 ("la ", 0.020),
218 (" la", 0.018),
219 (" di", 0.017),
220 ("di ", 0.016),
221 ("che", 0.015),
222 ("re ", 0.014),
223 ("ell", 0.013),
224 ("lla", 0.012),
225 ("to ", 0.011),
226 ("ne ", 0.011),
227 (" de", 0.012),
228 ("del", 0.011),
229 ("ent", 0.010),
230 ("ion", 0.010),
231 ("con", 0.009),
232 (" co", 0.009),
233 ("per", 0.009),
234 (" pe", 0.008),
235 ("ato", 0.008),
236 ("ment", 0.007),
237 ],
238 ),
239 (
240 "pt",
241 "Portuguese",
242 &[
243 ("de ", 0.025),
244 (" de", 0.023),
245 ("os ", 0.016),
246 (" qu", 0.012),
247 ("que", 0.012),
248 ("ão ", 0.014),
249 ("ção", 0.012),
250 (" do", 0.010),
251 ("do ", 0.010),
252 ("da ", 0.011),
253 (" da", 0.011),
254 ("ent", 0.010),
255 ("es ", 0.015),
256 (" co", 0.009),
257 ("com", 0.009),
258 ("nte", 0.008),
259 ("ment", 0.007),
260 ("para", 0.007),
261 (" pa", 0.007),
262 (" no", 0.008),
263 ],
264 ),
265];
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_detect_english() {
273 let text = "The quick brown fox jumps over the lazy dog and then runs away into the forest";
274 let result = detect_language(text).unwrap();
275 assert_eq!(result.code, "en");
276 assert!(result.confidence > 0.0);
277 }
278
279 #[test]
280 fn test_detect_french() {
281 let text = "Le petit prince est un livre que tout le monde devrait lire au moins une fois dans sa vie";
282 let result = detect_language(text).unwrap();
283 assert_eq!(result.code, "fr");
284 }
285
286 #[test]
287 fn test_detect_german() {
288 let text = "Die Bundesrepublik Deutschland ist ein demokratischer und sozialer Bundesstaat";
289 let result = detect_language(text).unwrap();
290 assert_eq!(result.code, "de");
291 }
292
293 #[test]
294 fn test_too_short() {
295 let result = detect_language("hi");
296 assert!(result.is_none());
297 }
298
299 #[test]
300 fn test_empty() {
301 let result = detect_language("");
302 assert!(result.is_none());
303 }
304}