use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::OnceLock;
const WEIGHTS_JSON: &str = include_str!("../data/tagger/weights.json");
const CLASSES_TXT: &str = include_str!("../data/tagger/classes.txt");
const TAGS_JSON: &str = include_str!("../data/tagger/tags.json");
static TAGGER: OnceLock<PerceptronTagger> = OnceLock::new();
pub fn global_tagger() -> &'static PerceptronTagger {
TAGGER.get_or_init(|| PerceptronTagger::new(WEIGHTS_JSON, CLASSES_TXT, TAGS_JSON))
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AveragedPerceptron {
pub feature_weights: HashMap<String, HashMap<String, f32>>,
pub classes: Vec<String>,
}
impl AveragedPerceptron {
pub fn new(weights_json: &str, classes_txt: &str) -> Self {
let feature_weights: HashMap<String, HashMap<String, f32>> =
serde_json::from_str(weights_json).expect("Failed to parse tagger weights.json");
let classes: Vec<String> = classes_txt
.lines()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
Self {
feature_weights,
classes,
}
}
pub fn predict(&self, word_features: HashMap<String, usize>) -> (&str, f32) {
let mut scores: HashMap<&str, f32> = HashMap::new();
for (feature, value) in &word_features {
if *value == 0 {
continue;
}
if let Some(weights) = self.feature_weights.get(feature.as_str()) {
let v = *value as f32;
for (label, weight) in weights {
*scores.entry(label.as_str()).or_insert(0.0) += weight * v;
}
}
}
let class = self
.classes
.iter()
.max_by(|a, b| {
let sa = scores.get(a.as_str()).unwrap_or(&0.0);
let sb = scores.get(b.as_str()).unwrap_or(&0.0);
sa.partial_cmp(sb).unwrap()
})
.expect("classes list must not be empty");
let max_score = *scores.get(class.as_str()).unwrap_or(&0.0);
(class.as_str(), max_score)
}
}
pub struct PerceptronTagger {
model: AveragedPerceptron,
tags: HashMap<String, String>,
}
impl PerceptronTagger {
pub fn new(weights_json: &str, classes_txt: &str, tags_json: &str) -> Self {
let tags: HashMap<String, String> =
serde_json::from_str(tags_json).expect("Failed to parse tagger tags.json");
Self {
model: AveragedPerceptron::new(weights_json, classes_txt),
tags,
}
}
pub fn tag<'a>(&self, words: &[&'a str]) -> Vec<Tag<'a>> {
let mut prev = "-START-".to_string();
let mut prev2 = "-START2-".to_string();
let mut output = Vec::with_capacity(words.len());
let mut context: Vec<&str> = Vec::with_capacity(words.len() + 4);
context.push("-START-");
context.push("-START2-");
for &token in words {
context.push(Self::normalize(token));
}
context.push("-END-");
context.push("-END2-");
for (i, &token) in words.iter().enumerate() {
if let Some(tag) = self.tags.get(token) {
output.push(Tag {
word: token,
tag: tag.clone(),
conf: 1.0,
});
prev2 = prev;
prev = tag.clone();
} else {
let features = Self::get_features(i + 2, token, &context, &prev, &prev2);
let (tag, conf) = self.model.predict(features);
let tag_string = tag.to_string();
output.push(Tag {
word: token,
tag: tag_string.clone(),
conf,
});
prev2 = prev;
prev = tag_string;
}
}
output
}
fn normalize(token: &str) -> &str {
if token.contains('-') && !token.starts_with('-') {
"!HYPHEN"
} else if token.len() == 4 && token.parse::<usize>().is_ok() {
"!YEAR"
} else if token.chars().next().is_some_and(|c| c.is_ascii_digit()) {
"!DIGITS"
} else {
token
}
}
fn get_features(
i: usize,
word: &str,
context: &[&str],
prev: &str,
prev2: &str,
) -> HashMap<String, usize> {
let mut features = HashMap::with_capacity(14);
features.insert("bias".to_string(), 1);
let suffix = suffix3(word);
features.insert(format!("i suffix {}", suffix), 1);
let pref1: String = word.chars().take(1).collect();
features.insert(format!("i pref1 {}", pref1), 1);
features.insert(format!("i-1 tag {}", prev), 1);
features.insert(format!("i-2 tag {}", prev2), 1);
features.insert(format!("i tag+i-2 tag {} {}", prev, prev2), 1);
features.insert(format!("i word {}", context[i]), 1);
features.insert(format!("i-1 tag+i word {} {}", prev, context[i]), 1);
features.insert(format!("i-1 word {}", context[i - 1]), 1);
features.insert(format!("i-2 word {}", context[i - 2]), 1);
features.insert(format!("i+1 word {}", context[i + 1]), 1);
features.insert(format!("i+2 word {}", context[i + 2]), 1);
features.insert(format!("i+1 suffix {}", suffix3(context[i + 1])), 1);
features.insert(format!("i-1 suffix {}", suffix3(context[i - 1])), 1);
features
}
}
fn suffix3(s: &str) -> String {
let chars: Vec<char> = s.chars().collect();
let start = chars.len().saturating_sub(3);
chars[start..].iter().collect()
}
pub struct Tag<'a> {
pub word: &'a str,
pub tag: String,
pub conf: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_global_tagger_loads() {
let tagger = global_tagger();
assert!(!tagger.model.classes.is_empty());
assert!(!tagger.tags.is_empty());
}
#[test]
fn test_tag_simple_sentence() {
let tagger = global_tagger();
let words = vec!["The", "cat", "sat", "on", "the", "mat"];
let tags = tagger.tag(&words);
assert_eq!(tags.len(), words.len());
for tag in &tags {
assert!(!tag.tag.is_empty());
}
}
#[test]
fn test_suffix3() {
assert_eq!(suffix3("hello"), "llo");
assert_eq!(suffix3("hi"), "hi");
assert_eq!(suffix3("a"), "a");
assert_eq!(suffix3(""), "");
}
}