1#![warn(missing_docs)]
6#![deny(unsafe_code)]
7
8use icu_locale::{LanguageIdentifier, LocaleExpander};
9use serde::Deserialize;
10use std::collections::{HashMap, HashSet};
11
12trait Rule<T> {
13 fn matches(self, tag: T, vars: &Variables) -> bool;
14}
15
16#[derive(Debug, PartialEq)]
17enum SubTagRule {
18 Str(String),
19 Var(String),
20 VarExclude(String),
21 All,
22}
23
24impl From<&'_ str> for SubTagRule {
25 fn from(s: &'_ str) -> Self {
26 if s == "*" {
27 Self::All
28 } else if let Some(name) = s.strip_prefix("$!") {
29 Self::VarExclude(name.to_string())
30 } else if let Some(name) = s.strip_prefix('$') {
31 Self::Var(name.to_string())
32 } else {
33 Self::Str(s.to_string())
34 }
35 }
36}
37
38impl Rule<&'_ str> for &'_ SubTagRule {
39 fn matches(self, tag: &str, vars: &Variables) -> bool {
40 match self {
41 SubTagRule::Str(s) => s == tag,
42 SubTagRule::Var(key) => vars[key].contains(tag),
43 SubTagRule::VarExclude(key) => !vars[key].contains(tag),
44 SubTagRule::All => true,
45 }
46 }
47}
48
49impl Rule<Option<&'_ str>> for Option<&'_ SubTagRule> {
50 fn matches(self, tag: Option<&str>, vars: &Variables) -> bool {
51 match (self, tag) {
52 (None, None) | (Some(SubTagRule::All), _) => true,
53 (Some(s), Some(tag)) => s.matches(tag, vars),
54 _ => false,
55 }
56 }
57}
58
59#[derive(Debug, PartialEq, Deserialize)]
60#[serde(from = "String")]
61struct LanguageIdentifierRule {
62 pub language: SubTagRule,
63 pub script: Option<SubTagRule>,
64 pub region: Option<SubTagRule>,
65}
66
67impl From<&'_ str> for LanguageIdentifierRule {
68 fn from(s: &'_ str) -> Self {
69 let mut parts = s.split('_');
70 let language = parts.next().unwrap().into();
71 let script = parts.next().map(|s| s.into());
72 let region = parts.next().map(|s| s.into());
73 Self {
74 language,
75 script,
76 region,
77 }
78 }
79}
80
81impl From<String> for LanguageIdentifierRule {
82 fn from(s: String) -> Self {
83 s.as_str().into()
84 }
85}
86
87impl Rule<&'_ LanguageIdentifier> for &'_ LanguageIdentifierRule {
88 fn matches(self, lang: &LanguageIdentifier, vars: &Variables) -> bool {
89 self.language.matches(lang.language.as_str(), vars)
90 && self
91 .script
92 .as_ref()
93 .matches(lang.script.as_ref().map(|s| s.as_str()), vars)
94 && self
95 .region
96 .as_ref()
97 .matches(lang.region.as_ref().map(|s| s.as_str()), vars)
98 }
99}
100
101#[derive(Debug, Deserialize, PartialEq)]
102struct ParadigmLocales {
103 #[serde(rename = "@locales")]
104 pub locales: String,
105}
106
107#[derive(Debug, Deserialize, PartialEq)]
108struct MatchVariable {
109 #[serde(rename = "@id")]
110 pub id: String,
111 #[serde(rename = "@value")]
112 pub value: String,
113}
114
115#[derive(Debug, Deserialize, PartialEq)]
116struct LanguageMatch {
117 #[serde(rename = "@desired")]
118 pub desired: LanguageIdentifierRule,
119 #[serde(rename = "@supported")]
120 pub supported: LanguageIdentifierRule,
121 #[serde(rename = "@distance")]
122 pub distance: u16,
123 #[serde(default, rename = "@oneway")]
124 pub oneway: bool,
125}
126
127#[derive(Debug, Deserialize, PartialEq)]
128#[serde(rename_all = "camelCase")]
129struct LanguageMatches {
130 pub paradigm_locales: ParadigmLocales,
131 pub match_variable: Vec<MatchVariable>,
132 pub language_match: Vec<LanguageMatch>,
133}
134
135#[derive(Debug, Deserialize, PartialEq)]
136#[serde(rename_all = "camelCase")]
137struct LanguageMatching {
138 pub language_matches: LanguageMatches,
139}
140
141#[derive(Debug, Deserialize, PartialEq)]
142#[serde(rename_all = "camelCase")]
143struct SupplementalData {
144 pub language_matching: LanguageMatching,
145}
146
147const LANGUAGE_INFO: &str = include_str!(concat!(
148 env!("CARGO_MANIFEST_DIR"),
149 "/data/languageInfo.xml"
150));
151
152pub struct LanguageMatcher {
188 paradigm: HashSet<LanguageIdentifier>,
189 vars: Variables,
190 rules: Vec<LanguageMatch>,
191 expander: LocaleExpander,
192}
193
194type Variables = HashMap<String, HashSet<String>>;
195
196impl From<SupplementalData> for LanguageMatcher {
197 fn from(data: SupplementalData) -> Self {
198 let expander = LocaleExpander::new_extended();
199
200 let matches = data.language_matching.language_matches;
201
202 let paradigm = matches
203 .paradigm_locales
204 .locales
205 .split(' ')
206 .map(|s| {
207 let mut lang = s.parse().unwrap();
208 expander.maximize(&mut lang);
209 lang
210 })
211 .collect::<HashSet<_>>();
212 let vars = matches
213 .match_variable
214 .into_iter()
215 .map(|MatchVariable { id, value }| {
216 debug_assert!(id.starts_with('$'));
217 (
219 id[1..].to_string(),
220 value.split('+').map(|s| s.to_string()).collect(),
221 )
222 })
223 .collect::<HashMap<_, _>>();
224 Self {
225 paradigm,
226 vars,
227 rules: matches.language_match,
228 expander,
229 }
230 }
231}
232
233impl LanguageMatcher {
234 pub fn new() -> Self {
236 let data: SupplementalData = quick_xml::de::from_str(LANGUAGE_INFO).unwrap();
237 data.into()
238 }
239
240 pub fn matches<'a>(
246 &self,
247 mut desired: LanguageIdentifier,
248 supported: impl IntoIterator<Item = &'a LanguageIdentifier>,
249 ) -> Option<(&'a LanguageIdentifier, u16)> {
250 self.expander.maximize(&mut desired);
251 supported
252 .into_iter()
253 .map(|s| {
254 let mut max_s = s.clone();
255 self.expander.maximize(&mut max_s);
256 (s, self.distance_impl(desired.clone(), max_s))
257 })
258 .min_by_key(|(_, dis)| *dis)
259 .filter(|(_, dis)| *dis < 1000)
260 }
261
262 pub fn distance(
268 &self,
269 mut desired: LanguageIdentifier,
270 mut supported: LanguageIdentifier,
271 ) -> u16 {
272 self.expander.maximize(&mut desired);
273 self.expander.maximize(&mut supported);
274 self.distance_impl(desired, supported)
275 }
276
277 fn distance_impl(
278 &self,
279 mut desired: LanguageIdentifier,
280 mut supported: LanguageIdentifier,
281 ) -> u16 {
282 debug_assert!(desired.region.is_some());
283 debug_assert!(desired.script.is_some());
284 debug_assert!(supported.region.is_some());
285 debug_assert!(supported.script.is_some());
286
287 let mut distance = 0;
288
289 if desired.region != supported.region {
290 distance += self.distance_match(&desired, &supported);
291 }
292 desired.region = None;
293 supported.region = None;
294
295 if desired.script != supported.script {
296 distance += self.distance_match(&desired, &supported);
297 }
298 desired.script = None;
299 supported.script = None;
300
301 if desired.language != supported.language {
302 distance += self.distance_match(&desired, &supported);
303 }
304
305 distance
306 }
307
308 fn distance_match(&self, desired: &LanguageIdentifier, supported: &LanguageIdentifier) -> u16 {
309 for rule in &self.rules {
310 let mut matches = rule.desired.matches(desired, &self.vars)
311 && rule.supported.matches(supported, &self.vars);
312 if !rule.oneway && !matches {
313 matches = rule.supported.matches(desired, &self.vars)
314 && rule.desired.matches(supported, &self.vars);
315 }
316 if matches {
317 let mut distance = rule.distance * 10;
318 if self.is_paradigm(desired) ^ self.is_paradigm(supported) {
319 distance -= 1
320 }
321 return distance;
322 }
323 }
324 unreachable!()
325 }
326
327 fn is_paradigm(&self, lang: &LanguageIdentifier) -> bool {
328 self.paradigm.contains(lang)
329 }
330}
331
332impl Default for LanguageMatcher {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338#[cfg(test)]
339mod test {
340 use crate::LanguageMatcher;
341 use icu_locale::langid;
342
343 #[test]
344 fn distance() {
345 let matcher = LanguageMatcher::new();
346
347 assert_eq!(matcher.distance(langid!("zh-CN"), langid!("zh-Hans")), 0);
348 assert_eq!(matcher.distance(langid!("zh-TW"), langid!("zh-Hant")), 0);
349 assert_eq!(matcher.distance(langid!("zh-HK"), langid!("zh-MO")), 40);
350 assert_eq!(matcher.distance(langid!("zh-HK"), langid!("zh-Hant")), 50);
351 }
352
353 #[test]
354 fn matcher() {
355 let matcher = LanguageMatcher::new();
356
357 let accepts = [
358 langid!("en"),
359 langid!("ja"),
360 langid!("zh-Hans"),
361 langid!("zh-Hant"),
362 ];
363 assert_eq!(
364 matcher.matches(langid!("zh-CN"), &accepts),
365 Some((&langid!("zh-Hans"), 0))
366 );
367 assert_eq!(
368 matcher.matches(langid!("zh-TW"), &accepts),
369 Some((&langid!("zh-Hant"), 0))
370 );
371 }
372}