use super::coordinator::ShardCoordinator;
use crate::ngram::vocabulary::{
decode_ngram_key_bytes, encode_indices_to_key_bytes, ngram_order_bytes,
};
use rayon::prelude::*;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use thiserror::Error;
use xxhash_rust::xxh3::Xxh3DefaultBuilder;
type XxHashMap<K, V> = HashMap<K, V, Xxh3DefaultBuilder>;
type XxHashSet<T> = HashSet<T, Xxh3DefaultBuilder>;
#[derive(Error, Debug)]
pub enum MknError {
#[error("Shard error: {0}")]
Shard(#[from] super::shard::ShardError),
#[error("Coordinator error: {0}")]
Coordinator(#[from] super::coordinator::CoordinatorError),
#[error("Insufficient data: {0}")]
InsufficientData(String),
#[error("Computation error: {0}")]
Computation(String),
}
pub type MknResult<T> = Result<T, MknError>;
#[derive(Clone, Debug, Default)]
pub struct FrequencyCounts {
pub n1: u64,
pub n2: u64,
pub n3: u64,
pub n4: u64,
pub total_unique: u64,
pub total_count: u64,
}
#[inline]
fn checked_count_add(lhs: u64, rhs: u64, field: &str) -> u64 {
lhs.checked_add(rhs)
.unwrap_or_else(|| panic!("FrequencyCounts overflow in {field}"))
}
#[inline]
fn atomic_count_add(counter: &AtomicU64, delta: u64, field: &str) {
if delta == 0 {
return;
}
counter
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
current.checked_add(delta)
})
.unwrap_or_else(|_| panic!("AtomicFrequencyCounts overflow in {field}"));
}
#[inline]
fn clamp_discount(value: f64, min: f64, max: f64) -> f64 {
assert!(value.is_finite(), "non-finite MKN discount intermediate");
debug_assert!(min.is_finite() && max.is_finite() && min <= max);
if value < min {
min
} else if value > max {
max
} else {
value
}
}
impl FrequencyCounts {
pub fn merge(&mut self, other: &FrequencyCounts) {
self.n1 = checked_count_add(self.n1, other.n1, "n1");
self.n2 = checked_count_add(self.n2, other.n2, "n2");
self.n3 = checked_count_add(self.n3, other.n3, "n3");
self.n4 = checked_count_add(self.n4, other.n4, "n4");
self.total_unique =
checked_count_add(self.total_unique, other.total_unique, "total_unique");
self.total_count = checked_count_add(self.total_count, other.total_count, "total_count");
}
}
#[derive(Debug, Default)]
pub struct AtomicFrequencyCounts {
pub n1: AtomicU64,
pub n2: AtomicU64,
pub n3: AtomicU64,
pub n4: AtomicU64,
pub total_unique: AtomicU64,
pub total_count: AtomicU64,
}
impl AtomicFrequencyCounts {
pub fn observe(&self, count: u64) {
match count {
1 => atomic_count_add(&self.n1, 1, "n1"),
2 => atomic_count_add(&self.n2, 1, "n2"),
3 => atomic_count_add(&self.n3, 1, "n3"),
4 => atomic_count_add(&self.n4, 1, "n4"),
_ => {}
};
atomic_count_add(&self.total_unique, 1, "total_unique");
atomic_count_add(&self.total_count, count, "total_count");
}
pub fn into_counts(self) -> FrequencyCounts {
FrequencyCounts {
n1: self.n1.into_inner(),
n2: self.n2.into_inner(),
n3: self.n3.into_inner(),
n4: self.n4.into_inner(),
total_unique: self.total_unique.into_inner(),
total_count: self.total_count.into_inner(),
}
}
pub fn load(&self) -> FrequencyCounts {
FrequencyCounts {
n1: self.n1.load(Ordering::Relaxed),
n2: self.n2.load(Ordering::Relaxed),
n3: self.n3.load(Ordering::Relaxed),
n4: self.n4.load(Ordering::Relaxed),
total_unique: self.total_unique.load(Ordering::Relaxed),
total_count: self.total_count.load(Ordering::Relaxed),
}
}
}
#[derive(Clone, Debug)]
pub struct DiscountParams {
pub d1: f64,
pub d2: f64,
pub d3_plus: f64,
pub y: f64,
}
impl Default for DiscountParams {
fn default() -> Self {
Self {
d1: 0.5,
d2: 0.75,
d3_plus: 0.9,
y: 0.5,
}
}
}
impl DiscountParams {
pub fn from_counts(counts: &FrequencyCounts) -> Self {
if counts.n1 == 0 || counts.n2 == 0 {
log::warn!(
"Insufficient data for MKN discounts (n1={}, n2={}), using defaults",
counts.n1,
counts.n2
);
return Self::default();
}
let n1 = counts.n1 as f64;
let n2 = counts.n2 as f64;
let n3 = counts.n3.max(1) as f64; let n4 = counts.n4.max(1) as f64;
let y = n1 / (n1 + 2.0 * n2);
let d1 = clamp_discount(1.0 - 2.0 * y * (n2 / n1), 0.0, 1.0);
let d2 = clamp_discount(2.0 - 3.0 * y * (n3 / n2), 0.0, 2.0);
let d3_plus = clamp_discount(3.0 - 4.0 * y * (n4 / n3), 0.0, 3.0);
Self { d1, d2, d3_plus, y }
}
pub fn discount_for(&self, count: u64) -> f64 {
match count {
0 => 0.0,
1 => self.d1,
2 => self.d2,
_ => self.d3_plus,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct ContinuationCounts {
pub predecessor_counts: XxHashMap<Vec<u8>, u64>,
pub successor_counts: XxHashMap<Vec<u8>, u64>,
pub total_contexts: u64,
}
impl ContinuationCounts {
pub fn merge(&mut self, other: ContinuationCounts) {
for (k, v) in other.predecessor_counts {
*self.predecessor_counts.entry(k).or_default() += v;
}
for (k, v) in other.successor_counts {
*self.successor_counts.entry(k).or_default() += v;
}
self.total_contexts += other.total_contexts;
}
}
#[derive(Clone, Debug)]
pub struct MknStats {
pub frequency_counts: Vec<FrequencyCounts>,
pub discounts: Vec<DiscountParams>,
pub continuation_counts: Vec<ContinuationCounts>,
pub max_order: u8,
}
impl MknStats {
pub fn new(max_order: u8) -> Self {
let size = (max_order + 1) as usize;
Self {
frequency_counts: vec![FrequencyCounts::default(); size],
discounts: vec![DiscountParams::default(); size],
continuation_counts: vec![ContinuationCounts::default(); size],
max_order,
}
}
pub fn get_discount(&self, order: u8, count: u64) -> f64 {
if order as usize >= self.discounts.len() {
return 0.0;
}
self.discounts[order as usize].discount_for(count)
}
}
pub struct MknAggregator<'a> {
coordinator: &'a ShardCoordinator,
compute_continuations: bool,
cancellation_flag: Option<&'a AtomicBool>,
}
impl<'a> MknAggregator<'a> {
pub fn new(coordinator: &'a ShardCoordinator) -> Self {
Self {
coordinator,
compute_continuations: false,
cancellation_flag: None,
}
}
pub fn with_continuations(mut self) -> Self {
self.compute_continuations = true;
self
}
pub fn with_cancellation_flag(mut self, flag: &'a AtomicBool) -> Self {
self.cancellation_flag = Some(flag);
self
}
pub fn compute_frequency_counts(&self) -> MknResult<Vec<FrequencyCounts>> {
let max_order = 5u8;
let counts: Vec<AtomicFrequencyCounts> = (0..=max_order)
.map(|_| AtomicFrequencyCounts::default())
.collect();
let cancelled = AtomicBool::new(false);
let shard_files = self
.coordinator
.discover_shard_files()
.map_err(|e| MknError::Coordinator(e))?;
let shard_keys: Vec<_> = shard_files.into_iter().map(|(key, _)| key).collect();
shard_keys.par_iter().for_each(|key| {
if cancelled.load(Ordering::Relaxed) {
return;
}
if let Some(flag) = self.cancellation_flag {
if flag.load(Ordering::Relaxed) {
cancelled.store(true, Ordering::Relaxed);
return;
}
}
if let Ok(shard) = self.coordinator.get_or_create_shard(key) {
let guard = shard.read();
match guard.iter_with_counts() {
Ok(iter) => {
for (ngram, count) in iter {
if ngram.starts_with(&[0x00]) {
continue;
}
let order = ngram_order_bytes(&ngram);
if order <= max_order {
counts[order as usize].observe(count);
}
}
}
Err(e) => {
log::warn!("Failed to iterate shard {}: {}", key, e);
}
}
}
});
if cancelled.load(Ordering::Relaxed) {
return Err(MknError::Computation("Cancelled".to_string()));
}
Ok(counts.into_iter().map(|c| c.into_counts()).collect())
}
pub fn compute_discounts(&self, freq_counts: &[FrequencyCounts]) -> Vec<DiscountParams> {
freq_counts
.iter()
.map(DiscountParams::from_counts)
.collect()
}
pub fn compute_continuation_counts(&self) -> MknResult<Vec<ContinuationCounts>> {
let max_order = 5u8;
let mut all_counts: Vec<ContinuationCounts> = (0..=max_order)
.map(|_| ContinuationCounts::default())
.collect();
for order in 2..=max_order {
let counts = self.compute_continuation_counts_for_order(order)?;
all_counts[order as usize] = counts;
}
Ok(all_counts)
}
fn compute_continuation_counts_for_order(&self, order: u8) -> MknResult<ContinuationCounts> {
let mut predecessor_sets: XxHashMap<Vec<u8>, XxHashSet<u64>> =
HashMap::with_hasher(Xxh3DefaultBuilder);
let mut successor_sets: XxHashMap<Vec<u8>, XxHashSet<u64>> =
HashMap::with_hasher(Xxh3DefaultBuilder);
let shard_files = self
.coordinator
.discover_shard_files()
.map_err(|e| MknError::Coordinator(e))?;
let shard_keys: Vec<_> = shard_files.into_iter().map(|(key, _)| key).collect();
for key in &shard_keys {
if let Some(flag) = self.cancellation_flag {
if flag.load(Ordering::Relaxed) {
return Err(MknError::Computation("Cancelled".to_string()));
}
}
if let Ok(shard) = self.coordinator.get_or_create_shard(key) {
let guard = shard.read();
let iter = match guard.iter_with_counts() {
Ok(iter) => iter,
Err(e) => {
log::warn!("Failed to iterate shard {}: {}", key, e);
continue;
}
};
for (ngram, _count) in iter {
if ngram.starts_with(&[0x00]) {
continue;
}
let indices = decode_ngram_key_bytes(&ngram);
if indices.len() as u8 != order || indices.len() < 2 {
continue;
}
let predecessor = indices[0];
let pred_context = encode_indices_to_key_bytes(&indices[1..]);
predecessor_sets
.entry(pred_context)
.or_insert_with(|| HashSet::with_hasher(Xxh3DefaultBuilder))
.insert(predecessor);
let successor = indices[indices.len() - 1];
let succ_context = encode_indices_to_key_bytes(&indices[..indices.len() - 1]);
successor_sets
.entry(succ_context)
.or_insert_with(|| HashSet::with_hasher(Xxh3DefaultBuilder))
.insert(successor);
}
}
}
let predecessor_counts: XxHashMap<Vec<u8>, u64> = predecessor_sets
.into_iter()
.map(|(k, v)| (k, v.len() as u64))
.collect();
let successor_counts: XxHashMap<Vec<u8>, u64> = successor_sets
.into_iter()
.map(|(k, v)| (k, v.len() as u64))
.collect();
let total_contexts = predecessor_counts.len() as u64 + successor_counts.len() as u64;
Ok(ContinuationCounts {
predecessor_counts,
successor_counts,
total_contexts,
})
}
pub fn compute_all(&self) -> MknResult<MknStats> {
log::info!("Computing MKN frequency counts...");
let frequency_counts = self.compute_frequency_counts()?;
log::info!("Computing discount parameters...");
let discounts = self.compute_discounts(&frequency_counts);
let continuation_counts = if self.compute_continuations {
log::info!("Computing continuation counts...");
self.compute_continuation_counts()?
} else {
vec![ContinuationCounts::default(); 6]
};
let max_order = 5;
Ok(MknStats {
frequency_counts,
discounts,
continuation_counts,
max_order,
})
}
pub fn compute_discounts_only(&self) -> MknResult<Vec<DiscountParams>> {
let frequency_counts = self.compute_frequency_counts()?;
Ok(self.compute_discounts(&frequency_counts))
}
}
#[derive(Clone, Debug)]
pub struct MknSummary {
pub orders: Vec<OrderSummary>,
}
#[derive(Clone, Debug)]
pub struct OrderSummary {
pub order: u8,
pub unique_ngrams: u64,
pub total_count: u64,
pub discounts: DiscountParams,
}
impl MknStats {
pub fn summary(&self) -> MknSummary {
let orders = (1..=self.max_order)
.map(|order| {
let idx = order as usize;
OrderSummary {
order,
unique_ngrams: self.frequency_counts[idx].total_unique,
total_count: self.frequency_counts[idx].total_count,
discounts: self.discounts[idx].clone(),
}
})
.collect();
MknSummary { orders }
}
pub fn format_table(&self) -> String {
let mut lines = vec![
"MKN Statistics Summary".to_string(),
"======================".to_string(),
format!(
"{:>5} {:>12} {:>15} {:>8} {:>8} {:>8}",
"Order", "Unique", "Total Count", "D1", "D2", "D3+"
),
"-".repeat(60),
];
for order in 1..=self.max_order {
let idx = order as usize;
let fc = &self.frequency_counts[idx];
let d = &self.discounts[idx];
lines.push(format!(
"{:>5} {:>12} {:>15} {:>8.4} {:>8.4} {:>8.4}",
order, fc.total_unique, fc.total_count, d.d1, d.d2, d.d3_plus
));
}
lines.join("\n")
}
}
#[cfg(test)]
mod tests {
use super::super::config::{ShardConfig, ShardGranularity};
use super::*;
use crate::ngram::vocabulary::{
create_vocabulary, encode_indices_to_key_bytes, SharedVocabARTrie,
};
use proptest::prelude::*;
use tempfile::TempDir;
fn create_test_coordinator() -> (TempDir, ShardCoordinator, SharedVocabARTrie) {
let dir = TempDir::new().expect("Failed to create temp dir");
let config =
ShardConfig::new(dir.path().join("shards")).with_granularity(ShardGranularity::TwoChar);
let coordinator = ShardCoordinator::new(config).expect("Failed to create coordinator");
let vocab_path = dir.path().join("vocab.artrie");
let vocab = create_vocabulary(&vocab_path).expect("Failed to create vocab");
let encode = |words: &[&str]| -> String {
let mut buf = Vec::with_capacity(words.len() * 2);
let guard = vocab.write();
for word in words {
let idx = guard.insert(word).expect("test vocab insert");
crate::ngram::vocabulary::encode_varint(idx, &mut buf);
}
String::from_utf8(buf).expect("varint bytes should be valid UTF-8 for small indices")
};
coordinator
.store_ngram(&encode(&["the"]), 100)
.expect("store");
coordinator.store_ngram(&encode(&["a"]), 50).expect("store");
coordinator.store_ngram(&encode(&["an"]), 1).expect("store"); coordinator.store_ngram(&encode(&["is"]), 2).expect("store"); coordinator.store_ngram(&encode(&["at"]), 3).expect("store"); coordinator.store_ngram(&encode(&["in"]), 4).expect("store");
coordinator
.store_ngram(&encode(&["the", "quick"]), 10)
.expect("store");
coordinator
.store_ngram(&encode(&["the", "slow"]), 5)
.expect("store");
coordinator
.store_ngram(&encode(&["a", "big"]), 1)
.expect("store"); coordinator
.store_ngram(&encode(&["a", "small"]), 2)
.expect("store"); coordinator
.store_ngram(&encode(&["is", "very"]), 3)
.expect("store"); coordinator
.store_ngram(&encode(&["in", "the"]), 4)
.expect("store");
coordinator
.store_ngram(&encode(&["the", "quick", "brown"]), 5)
.expect("store");
coordinator
.store_ngram(&encode(&["the", "quick", "red"]), 1)
.expect("store"); coordinator
.store_ngram(&encode(&["the", "slow", "green"]), 2)
.expect("store");
(dir, coordinator, vocab)
}
fn assert_valid_discounts(discounts: &DiscountParams) {
assert!(discounts.y.is_finite());
assert!(discounts.d1.is_finite());
assert!(discounts.d2.is_finite());
assert!(discounts.d3_plus.is_finite());
assert!((0.0..=1.0).contains(&discounts.y));
assert!((0.0..=1.0).contains(&discounts.d1));
assert!((0.0..=2.0).contains(&discounts.d2));
assert!((0.0..=3.0).contains(&discounts.d3_plus));
}
#[test]
fn test_frequency_counts() {
let (_dir, coordinator, _vocab) = create_test_coordinator();
let aggregator = MknAggregator::new(&coordinator);
let counts = aggregator.compute_frequency_counts().expect("compute");
assert_eq!(counts[1].total_unique, 6);
assert!(counts[1].n1 >= 1); assert!(counts[1].n2 >= 1);
assert_eq!(counts[2].total_unique, 6);
assert_eq!(counts[3].total_unique, 3);
}
#[test]
fn test_discount_computation() {
let counts = FrequencyCounts {
n1: 100,
n2: 50,
n3: 30,
n4: 20,
total_unique: 200,
total_count: 1000,
};
let discounts = DiscountParams::from_counts(&counts);
assert_valid_discounts(&discounts);
assert!((discounts.y - 0.5).abs() < 0.01);
assert!((discounts.d1 - 0.5).abs() < 0.01);
assert!((discounts.d2 - 1.1).abs() < 0.01);
assert!((discounts.d3_plus - 1.667).abs() < 0.01);
}
#[test]
fn test_discount_default_on_insufficient_data() {
let counts = FrequencyCounts {
n1: 0,
n2: 0,
n3: 0,
n4: 0,
total_unique: 0,
total_count: 0,
};
let discounts = DiscountParams::from_counts(&counts);
assert_eq!(discounts.d1, 0.5);
assert_eq!(discounts.d2, 0.75);
assert_eq!(discounts.d3_plus, 0.9);
}
#[test]
fn test_discount_computation_extreme_counts_are_finite_and_bounded() {
let near_exact_limit = 1u64 << 53;
let cases = [
(1, 1, 0, 0),
(1, u64::MAX, u64::MAX, u64::MAX),
(u64::MAX, 1, 0, u64::MAX),
(u64::MAX, u64::MAX, u64::MAX, u64::MAX),
(
near_exact_limit - 1,
near_exact_limit,
near_exact_limit + 1,
0,
),
(
near_exact_limit + 1,
near_exact_limit - 1,
0,
near_exact_limit,
),
];
for (n1, n2, n3, n4) in cases {
let counts = FrequencyCounts {
n1,
n2,
n3,
n4,
total_unique: 0,
total_count: 0,
};
assert_valid_discounts(&DiscountParams::from_counts(&counts));
}
}
proptest! {
#[test]
fn test_discount_computation_positive_counts_are_finite_and_bounded(
n1 in 1u64..=u64::MAX,
n2 in 1u64..=u64::MAX,
n3 in any::<u64>(),
n4 in any::<u64>(),
) {
let counts = FrequencyCounts {
n1,
n2,
n3,
n4,
total_unique: 0,
total_count: 0,
};
let discounts = DiscountParams::from_counts(&counts);
prop_assert!(discounts.y.is_finite());
prop_assert!(discounts.d1.is_finite());
prop_assert!(discounts.d2.is_finite());
prop_assert!(discounts.d3_plus.is_finite());
prop_assert!((0.0..=1.0).contains(&discounts.y));
prop_assert!((0.0..=1.0).contains(&discounts.d1));
prop_assert!((0.0..=2.0).contains(&discounts.d2));
prop_assert!((0.0..=3.0).contains(&discounts.d3_plus));
}
}
#[test]
fn test_continuation_counts() {
let (_dir, coordinator, vocab) = create_test_coordinator();
let aggregator = MknAggregator::new(&coordinator).with_continuations();
let counts = aggregator
.compute_continuation_counts_for_order(2)
.expect("compute");
let quick_idx = vocab
.read()
.get_index("quick")
.expect("quick should be in vocab");
let quick_key = encode_indices_to_key_bytes(&[quick_idx]);
assert!(
counts.predecessor_counts.contains_key(&quick_key),
"Should have predecessor count for context 'quick' (encoded as {:?})",
quick_key
);
let the_idx = vocab
.read()
.get_index("the")
.expect("the should be in vocab");
let the_key = encode_indices_to_key_bytes(&[the_idx]);
assert!(
counts.successor_counts.contains_key(&the_key),
"Should have successor count for context 'the' (encoded as {:?})",
the_key
);
assert_eq!(*counts.successor_counts.get(&the_key).unwrap(), 2);
}
#[test]
fn test_compute_all() {
let (_dir, coordinator, _vocab) = create_test_coordinator();
let aggregator = MknAggregator::new(&coordinator);
let stats = aggregator.compute_all().expect("compute");
assert_eq!(stats.max_order, 5);
assert_eq!(stats.frequency_counts.len(), 6); assert_eq!(stats.discounts.len(), 6);
assert!(stats.frequency_counts[1].total_unique > 0);
assert!(stats.frequency_counts[2].total_unique > 0);
assert!(stats.frequency_counts[3].total_unique > 0);
}
#[test]
fn test_format_table() {
let (_dir, coordinator, _vocab) = create_test_coordinator();
let aggregator = MknAggregator::new(&coordinator);
let stats = aggregator.compute_all().expect("compute");
let table = stats.format_table();
assert!(table.contains("MKN Statistics Summary"));
assert!(table.contains("Order"));
assert!(table.contains("Unique"));
}
#[test]
fn test_discount_for_count() {
let discounts = DiscountParams {
d1: 0.5,
d2: 1.1,
d3_plus: 1.7,
y: 0.5,
};
assert_eq!(discounts.discount_for(0), 0.0);
assert_eq!(discounts.discount_for(1), 0.5);
assert_eq!(discounts.discount_for(2), 1.1);
assert_eq!(discounts.discount_for(3), 1.7);
assert_eq!(discounts.discount_for(100), 1.7);
}
#[test]
fn test_atomic_frequency_counts() {
let atomic = AtomicFrequencyCounts::default();
atomic.observe(1);
atomic.observe(1);
atomic.observe(2);
atomic.observe(3);
atomic.observe(4);
atomic.observe(100);
let counts = atomic.into_counts();
assert_eq!(counts.n1, 2);
assert_eq!(counts.n2, 1);
assert_eq!(counts.n3, 1);
assert_eq!(counts.n4, 1);
assert_eq!(counts.total_unique, 6);
assert_eq!(counts.total_count, 1 + 1 + 2 + 3 + 4 + 100);
}
#[test]
fn test_mkn_compute_all_cancellable() {
let (_dir, coordinator, _vocab) = create_test_coordinator();
let flag = std::sync::atomic::AtomicBool::new(true);
let aggregator = MknAggregator::new(&coordinator).with_cancellation_flag(&flag);
let result = aggregator.compute_frequency_counts();
match result {
Err(MknError::Computation(msg)) => {
assert!(
msg.contains("Cancelled"),
"expected 'Cancelled' in error message, got: {}",
msg
);
}
Ok(_) => panic!("compute_frequency_counts should have been cancelled"),
Err(e) => panic!(
"expected MknError::Computation(\"Cancelled\"), got: {:?}",
e
),
}
}
#[test]
fn test_mkn_continuation_cancellable() {
let (_dir, coordinator, _vocab) = create_test_coordinator();
let flag = std::sync::atomic::AtomicBool::new(true);
let aggregator = MknAggregator::new(&coordinator)
.with_continuations()
.with_cancellation_flag(&flag);
let result = aggregator.compute_continuation_counts();
match result {
Err(MknError::Computation(msg)) => {
assert!(
msg.contains("Cancelled"),
"expected 'Cancelled' in error message, got: {}",
msg
);
}
Ok(_) => panic!("compute_continuation_counts should have been cancelled"),
Err(e) => panic!(
"expected MknError::Computation(\"Cancelled\"), got: {:?}",
e
),
}
}
}