use super::super::entry::NgramEntry;
use super::super::trie::NgramTrie;
use liblevenshtein::dictionary::MutableMappedDictionary;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct KneserNeySmoothing {
d1: f64,
d2: f64,
d3_plus: f64,
}
impl KneserNeySmoothing {
pub fn new(_order: usize) -> Self {
Self::default_discounts()
}
pub fn from_counts(n1: u64, n2: u64, n3: u64, n4: u64) -> Self {
let n1 = n1.max(1) as f64;
let n2 = n2.max(1) as f64;
let n3 = n3.max(1) as f64;
let n4 = n4.max(1) as f64;
let y = n1 / (n1 + 2.0 * n2);
let d1 = (1.0 - 2.0 * y * (n2 / n1)).max(0.0).min(1.0);
let d2 = (2.0 - 3.0 * y * (n3 / n2)).max(0.0).min(2.0);
let d3_plus = (3.0 - 4.0 * y * (n4 / n3)).max(0.0).min(3.0);
Self { d1, d2, d3_plus }
}
pub fn default_discounts() -> Self {
Self {
d1: 0.75,
d2: 0.85,
d3_plus: 0.95,
}
}
#[inline]
fn discount(&self, count: u64) -> f64 {
match count {
0 => 0.0,
1 => self.d1,
2 => self.d2,
_ => self.d3_plus,
}
}
pub fn log_prob<D>(
&self,
word: &str,
context: &[&str],
trie: &NgramTrie<D>,
vocab_size: usize,
total_count: u64,
) -> f64
where
D: MutableMappedDictionary<Value = NgramEntry>,
{
let prob = self.prob_recursive(word, context, trie, vocab_size, total_count, true);
prob.ln()
}
fn prob_recursive<D>(
&self,
word: &str,
context: &[&str],
trie: &NgramTrie<D>,
vocab_size: usize,
total_count: u64,
is_highest_order: bool,
) -> f64
where
D: MutableMappedDictionary<Value = NgramEntry>,
{
if context.is_empty() {
return self.unigram_prob(word, trie, vocab_size, total_count, is_highest_order);
}
let mut ngram: Vec<&str> = context.to_vec();
ngram.push(word);
let ngram_count = trie.count(&ngram);
let context_count = trie.count(context);
if context_count == 0 {
return self.prob_recursive(word, &context[1..], trie, vocab_size, total_count, false);
}
let discount = self.discount(ngram_count);
let discounted_count = (ngram_count as f64 - discount).max(0.0);
let discounted_prob = discounted_count / context_count as f64;
let unique_continuations = trie
.get(context)
.map(|e| e.unique_continuations() as f64)
.unwrap_or(1.0);
let lambda = (discount * unique_continuations) / context_count as f64;
let backoff_prob =
self.prob_recursive(word, &context[1..], trie, vocab_size, total_count, false);
discounted_prob + lambda * backoff_prob
}
fn unigram_prob<D>(
&self,
word: &str,
trie: &NgramTrie<D>,
vocab_size: usize,
total_count: u64,
is_highest_order: bool,
) -> f64
where
D: MutableMappedDictionary<Value = NgramEntry>,
{
let entry = trie.get(&[word]);
if is_highest_order {
let count = entry.map(|e| e.count()).unwrap_or(0);
if count == 0 {
return 1.0 / vocab_size as f64;
}
count as f64 / total_count as f64
} else {
let continuation_count = entry.map(|e| e.continuation_count()).unwrap_or(0);
if continuation_count == 0 {
return 1.0 / vocab_size as f64;
}
continuation_count as f64 / vocab_size as f64
}
}
}
impl Default for KneserNeySmoothing {
fn default() -> Self {
Self::default_discounts()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_discount_computation() {
let smoothing = KneserNeySmoothing::from_counts(1000, 500, 300, 200);
assert!(smoothing.d1 > 0.0 && smoothing.d1 < 1.0);
assert!(smoothing.d2 > 0.0 && smoothing.d2 < 2.0);
assert!(smoothing.d3_plus > 0.0 && smoothing.d3_plus < 3.0);
}
#[test]
fn test_discount_by_count() {
let smoothing = KneserNeySmoothing::default_discounts();
assert_eq!(smoothing.discount(0), 0.0);
assert_eq!(smoothing.discount(1), smoothing.d1);
assert_eq!(smoothing.discount(2), smoothing.d2);
assert_eq!(smoothing.discount(3), smoothing.d3_plus);
assert_eq!(smoothing.discount(100), smoothing.d3_plus);
}
#[test]
fn test_default_discounts() {
let smoothing = KneserNeySmoothing::default_discounts();
assert_eq!(smoothing.d1, 0.75);
assert_eq!(smoothing.d2, 0.85);
assert_eq!(smoothing.d3_plus, 0.95);
}
}