use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display};
use std::hash::Hash;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct LanguageId(pub String);
impl LanguageId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn code(&self) -> &str {
&self.0
}
}
impl Display for LanguageId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl<S: Into<String>> From<S> for LanguageId {
fn from(s: S) -> Self {
Self::new(s)
}
}
impl LanguageId {
pub fn english() -> Self {
Self::new("en")
}
pub fn spanish() -> Self {
Self::new("es")
}
pub fn french() -> Self {
Self::new("fr")
}
pub fn german() -> Self {
Self::new("de")
}
pub fn mandarin() -> Self {
Self::new("zh")
}
pub fn hindi() -> Self {
Self::new("hi")
}
pub fn arabic() -> Self {
Self::new("ar")
}
pub fn japanese() -> Self {
Self::new("ja")
}
}
#[derive(Debug, Clone)]
pub struct LanguageConfig {
pub id: LanguageId,
pub prior: f64,
pub vocabulary: HashSet<String>,
pub word_probs: HashMap<String, f64>,
pub unknown_word_prob: f64,
pub rtl: bool,
pub script: Script,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
pub enum Script {
#[default]
Latin,
Cyrillic,
Arabic,
Devanagari,
Han,
Japanese,
Hangul,
Greek,
Hebrew,
Thai,
Unknown,
}
impl LanguageConfig {
pub fn new(id: impl Into<LanguageId>) -> Self {
Self {
id: id.into(),
prior: 1.0,
vocabulary: HashSet::new(),
word_probs: HashMap::new(),
unknown_word_prob: -10.0, rtl: false,
script: Script::Latin,
}
}
pub fn with_prior(mut self, prior: f64) -> Self {
self.prior = prior;
self
}
pub fn with_vocabulary(mut self, words: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.vocabulary = words.into_iter().map(|w| w.into()).collect();
self
}
pub fn add_words(mut self, words: impl IntoIterator<Item = impl Into<String>>) -> Self {
for word in words {
self.vocabulary.insert(word.into());
}
self
}
pub fn with_word_probs(mut self, probs: HashMap<String, f64>) -> Self {
self.word_probs = probs;
self
}
pub fn add_word_prob(mut self, word: impl Into<String>, log_prob: f64) -> Self {
self.word_probs.insert(word.into(), log_prob);
self
}
pub fn rtl(mut self) -> Self {
self.rtl = true;
self
}
pub fn with_script(mut self, script: Script) -> Self {
self.script = script;
self
}
pub fn with_unknown_prob(mut self, log_prob: f64) -> Self {
self.unknown_word_prob = log_prob;
self
}
pub fn contains_word(&self, word: &str) -> bool {
self.vocabulary.contains(word) || self.vocabulary.contains(&word.to_lowercase())
}
pub fn word_log_prob(&self, word: &str) -> f64 {
if let Some(&prob) = self.word_probs.get(word) {
return prob;
}
if let Some(&prob) = self.word_probs.get(&word.to_lowercase()) {
return prob;
}
if self.contains_word(word) {
let vocab_size = self.vocabulary.len().max(1) as f64;
return -(vocab_size.ln());
}
self.unknown_word_prob
}
pub fn id(&self) -> &LanguageId {
&self.id
}
}
#[derive(Debug, Clone)]
pub struct WordProbability {
pub word: String,
pub log_prob: f64,
}
impl WordProbability {
pub fn new(word: impl Into<String>, log_prob: f64) -> Self {
Self {
word: word.into(),
log_prob,
}
}
}
pub trait LanguageModel: Send + Sync + Debug {
fn language(&self) -> &LanguageId;
fn word_log_prob(&self, word: &str) -> f64;
fn context_log_prob(&self, word: &str, _context: &[&str]) -> f64 {
self.word_log_prob(word)
}
fn vocabulary_size(&self) -> usize;
fn in_vocabulary(&self, word: &str) -> bool;
}
#[derive(Debug, Clone)]
pub struct SimpleLanguageModel {
language: LanguageId,
word_probs: HashMap<String, f64>,
unknown_prob: f64,
}
impl SimpleLanguageModel {
pub fn new(language: impl Into<LanguageId>) -> Self {
Self {
language: language.into(),
word_probs: HashMap::new(),
unknown_prob: -10.0,
}
}
pub fn with_probs(mut self, probs: HashMap<String, f64>) -> Self {
self.word_probs = probs;
self
}
pub fn add_prob(&mut self, word: impl Into<String>, log_prob: f64) {
self.word_probs.insert(word.into(), log_prob);
}
pub fn with_unknown_prob(mut self, log_prob: f64) -> Self {
self.unknown_prob = log_prob;
self
}
pub fn from_counts(language: impl Into<LanguageId>, counts: &HashMap<String, usize>) -> Self {
let total: usize = counts.values().sum();
let total_f64 = total as f64;
let word_probs: HashMap<String, f64> = counts
.iter()
.map(|(word, count)| {
let prob = (*count as f64) / total_f64;
(word.clone(), prob.ln())
})
.collect();
Self {
language: language.into(),
word_probs,
unknown_prob: (1.0 / (total_f64 * 10.0)).ln(), }
}
}
impl LanguageModel for SimpleLanguageModel {
fn language(&self) -> &LanguageId {
&self.language
}
fn word_log_prob(&self, word: &str) -> f64 {
self.word_probs
.get(word)
.or_else(|| self.word_probs.get(&word.to_lowercase()))
.copied()
.unwrap_or(self.unknown_prob)
}
fn vocabulary_size(&self) -> usize {
self.word_probs.len()
}
fn in_vocabulary(&self, word: &str) -> bool {
self.word_probs.contains_key(word) || self.word_probs.contains_key(&word.to_lowercase())
}
}
#[derive(Debug, Clone)]
pub struct DetectionResult {
pub language: LanguageId,
pub confidence: f64,
pub alternatives: Vec<(LanguageId, f64)>,
}
impl DetectionResult {
pub fn new(language: LanguageId, confidence: f64) -> Self {
Self {
language,
confidence,
alternatives: Vec::new(),
}
}
pub fn with_alternatives(mut self, alternatives: Vec<(LanguageId, f64)>) -> Self {
self.alternatives = alternatives;
self
}
}
#[derive(Debug, Clone)]
pub struct LanguageDetector {
configs: Vec<LanguageConfig>,
}
impl LanguageDetector {
pub fn new() -> Self {
Self {
configs: Vec::new(),
}
}
pub fn add_language(&mut self, config: LanguageConfig) {
self.configs.push(config);
}
pub fn from_configs(configs: Vec<LanguageConfig>) -> Self {
Self { configs }
}
pub fn detect_word(&self, word: &str) -> DetectionResult {
let mut scores: Vec<(LanguageId, f64)> = Vec::new();
for config in &self.configs {
let mut score = config.prior.ln();
if config.contains_word(word) {
score += config.word_log_prob(word);
} else {
score += config.unknown_word_prob;
}
scores.push((config.id.clone(), score));
}
if scores.is_empty() {
return DetectionResult::new(LanguageId::new("unknown"), 0.0);
}
let max_score = scores
.iter()
.map(|(_, s)| *s)
.fold(f64::NEG_INFINITY, f64::max);
let exp_scores: Vec<f64> = scores.iter().map(|(_, s)| (s - max_score).exp()).collect();
let sum: f64 = exp_scores.iter().sum();
let probs: Vec<(LanguageId, f64)> = scores
.iter()
.zip(exp_scores.iter())
.map(|((id, _), exp)| (id.clone(), exp / sum))
.collect();
let mut sorted = probs.clone();
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let (best_lang, best_prob) = sorted.remove(0);
DetectionResult::new(best_lang, best_prob).with_alternatives(sorted)
}
pub fn detect_sequence(&self, words: &[&str]) -> DetectionResult {
if words.is_empty() {
return DetectionResult::new(LanguageId::new("unknown"), 0.0);
}
let mut total_scores: HashMap<LanguageId, f64> = HashMap::new();
for word in words {
let result = self.detect_word(word);
*total_scores.entry(result.language).or_insert(0.0) += result.confidence;
for (lang, score) in result.alternatives {
*total_scores.entry(lang).or_insert(0.0) += score;
}
}
let total: f64 = total_scores.values().sum();
let normalized: Vec<(LanguageId, f64)> = total_scores
.into_iter()
.map(|(id, score)| (id, score / total))
.collect();
let mut sorted = normalized;
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let (best_lang, best_prob) = sorted.remove(0);
DetectionResult::new(best_lang, best_prob).with_alternatives(sorted)
}
pub fn languages(&self) -> impl Iterator<Item = &LanguageConfig> {
self.configs.iter()
}
}
impl Default for LanguageDetector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_language_id() {
let id = LanguageId::new("en");
assert_eq!(id.code(), "en");
assert_eq!(format!("{}", id), "en");
}
#[test]
fn test_language_id_presets() {
assert_eq!(LanguageId::english().code(), "en");
assert_eq!(LanguageId::spanish().code(), "es");
assert_eq!(LanguageId::french().code(), "fr");
}
#[test]
fn test_language_config() {
let config = LanguageConfig::new("en")
.with_prior(0.7)
.add_words(vec!["hello", "world", "the"]);
assert_eq!(config.id.code(), "en");
assert!((config.prior - 0.7).abs() < f64::EPSILON);
assert!(config.contains_word("hello"));
assert!(config.contains_word("HELLO")); assert!(!config.contains_word("hola"));
}
#[test]
fn test_language_config_word_probs() {
let config = LanguageConfig::new("en")
.add_word_prob("the", -1.0)
.add_word_prob("a", -2.0);
assert!((config.word_log_prob("the") - (-1.0)).abs() < f64::EPSILON);
assert!((config.word_log_prob("a") - (-2.0)).abs() < f64::EPSILON);
assert!((config.word_log_prob("xyz") - config.unknown_word_prob).abs() < f64::EPSILON);
}
#[test]
fn test_simple_language_model() {
let mut lm = SimpleLanguageModel::new("en").with_unknown_prob(-15.0);
lm.add_prob("hello", -2.0);
lm.add_prob("world", -3.0);
assert_eq!(lm.language().code(), "en");
assert!((lm.word_log_prob("hello") - (-2.0)).abs() < f64::EPSILON);
assert!((lm.word_log_prob("unknown") - (-15.0)).abs() < f64::EPSILON);
assert_eq!(lm.vocabulary_size(), 2);
assert!(lm.in_vocabulary("hello"));
assert!(!lm.in_vocabulary("xyz"));
}
#[test]
fn test_language_model_from_counts() {
let mut counts = HashMap::new();
counts.insert("the".to_string(), 100);
counts.insert("a".to_string(), 50);
counts.insert("an".to_string(), 25);
let lm = SimpleLanguageModel::from_counts("en", &counts);
assert_eq!(lm.vocabulary_size(), 3);
assert!(lm.word_log_prob("the") > lm.word_log_prob("a"));
assert!(lm.word_log_prob("a") > lm.word_log_prob("an"));
}
#[test]
fn test_detection_result() {
let result = DetectionResult::new(LanguageId::english(), 0.8).with_alternatives(vec![
(LanguageId::spanish(), 0.15),
(LanguageId::french(), 0.05),
]);
assert_eq!(result.language.code(), "en");
assert!((result.confidence - 0.8).abs() < f64::EPSILON);
assert_eq!(result.alternatives.len(), 2);
}
#[test]
fn test_language_detector_single_word() {
let mut detector = LanguageDetector::new();
detector.add_language(
LanguageConfig::new("en")
.with_prior(0.5)
.add_words(vec!["hello", "world", "the"]),
);
detector.add_language(
LanguageConfig::new("es")
.with_prior(0.5)
.add_words(vec!["hola", "mundo", "el"]),
);
let result = detector.detect_word("hello");
assert_eq!(result.language.code(), "en");
assert!(result.confidence > 0.5);
let result2 = detector.detect_word("hola");
assert_eq!(result2.language.code(), "es");
assert!(result2.confidence > 0.5);
}
#[test]
fn test_language_detector_sequence() {
let mut detector = LanguageDetector::new();
detector.add_language(
LanguageConfig::new("en")
.with_prior(0.5)
.add_words(vec!["hello", "world", "the", "is", "good"]),
);
detector.add_language(
LanguageConfig::new("es")
.with_prior(0.5)
.add_words(vec!["hola", "mundo", "el", "es", "bueno"]),
);
let result = detector.detect_sequence(&["hello", "world", "is", "good"]);
assert_eq!(result.language.code(), "en");
let result2 = detector.detect_sequence(&["hola", "mundo", "es", "bueno"]);
assert_eq!(result2.language.code(), "es");
}
#[test]
fn test_script_default() {
assert_eq!(Script::default(), Script::Latin);
}
#[test]
fn test_language_config_rtl() {
let config = LanguageConfig::new("ar").rtl().with_script(Script::Arabic);
assert!(config.rtl);
assert_eq!(config.script, Script::Arabic);
}
}