use crate::types::{Symbol, SymbolId, SymbolKind};
use parking_lot::RwLock;
use std::collections::HashMap;
const SYMBOL_OVERHEAD_BYTES: usize = 32;
#[derive(Debug, Clone, Copy)]
pub struct ThresholdConfig {
pub overhead_factor: f64,
pub min_overhead: usize,
pub max_per_block: usize,
}
impl ThresholdConfig {
#[inline]
#[must_use]
pub const fn new(overhead_factor: f64, min_overhead: usize, max_per_block: usize) -> Self {
Self {
overhead_factor,
min_overhead,
max_per_block,
}
}
}
impl Default for ThresholdConfig {
#[inline]
fn default() -> Self {
Self {
overhead_factor: 1.02,
min_overhead: 0,
max_per_block: 0,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct BlockProgress {
pub sbn: u8,
pub source_symbols: usize,
pub repair_symbols: usize,
pub k: Option<u16>,
pub threshold_reached: bool,
}
impl BlockProgress {
#[inline]
#[must_use]
pub const fn total(&self) -> usize {
self.source_symbols + self.repair_symbols
}
}
#[derive(Debug, Clone)]
pub enum InsertResult {
Inserted {
block_progress: BlockProgress,
threshold_reached: bool,
},
Duplicate,
MemoryLimitReached,
BlockLimitReached {
sbn: u8,
},
}
#[derive(Debug)]
pub struct SymbolSet {
symbols: HashMap<SymbolId, Symbol>,
block_counts: HashMap<u8, BlockProgress>,
total_count: usize,
total_bytes: usize,
memory_budget: Option<usize>,
memory_remaining: usize,
threshold_config: ThresholdConfig,
}
impl SymbolSet {
#[inline]
#[must_use]
pub fn new() -> Self {
Self::with_config(ThresholdConfig::default())
}
#[inline]
#[must_use]
pub fn with_config(config: ThresholdConfig) -> Self {
Self {
symbols: HashMap::new(),
block_counts: HashMap::new(),
total_count: 0,
total_bytes: 0,
memory_budget: None,
memory_remaining: 0,
threshold_config: config,
}
}
#[inline]
#[must_use]
pub fn with_memory_budget(config: ThresholdConfig, budget_bytes: usize) -> Self {
Self {
symbols: HashMap::new(),
block_counts: HashMap::new(),
total_count: 0,
total_bytes: 0,
memory_budget: Some(budget_bytes),
memory_remaining: budget_bytes,
threshold_config: config,
}
}
pub fn insert(&mut self, symbol: Symbol) -> InsertResult {
let id = symbol.id();
if self.symbols.contains_key(&id) {
return InsertResult::Duplicate;
}
let size = Self::estimate_symbol_size(&symbol);
if !self.try_allocate(size) {
return InsertResult::MemoryLimitReached;
}
let sbn = id.sbn();
let config = self.threshold_config;
let (limit_reached, progress_copy) = {
let progress = self.block_counts.entry(sbn).or_insert(BlockProgress {
sbn,
source_symbols: 0,
repair_symbols: 0,
k: None,
threshold_reached: false,
});
if config.max_per_block != 0 && progress.total() >= config.max_per_block {
(true, *progress)
} else {
match symbol.kind() {
SymbolKind::Source => progress.source_symbols += 1,
SymbolKind::Repair => progress.repair_symbols += 1,
}
progress.threshold_reached = Self::calculate_threshold(progress, &config);
(false, *progress)
}
};
if limit_reached {
self.deallocate(size);
return InsertResult::BlockLimitReached { sbn };
}
self.symbols.insert(id, symbol);
self.total_count += 1;
InsertResult::Inserted {
block_progress: progress_copy,
threshold_reached: progress_copy.threshold_reached,
}
}
pub fn insert_batch(&mut self, symbols: impl Iterator<Item = Symbol>) -> Vec<InsertResult> {
symbols.map(|symbol| self.insert(symbol)).collect()
}
pub fn set_block_k(&mut self, sbn: u8, k: u16) -> bool {
let config = self.threshold_config;
let progress = self.block_counts.entry(sbn).or_insert(BlockProgress {
sbn,
source_symbols: 0,
repair_symbols: 0,
k: None,
threshold_reached: false,
});
progress.k = Some(k);
progress.threshold_reached = Self::calculate_threshold(progress, &config);
progress.threshold_reached
}
#[inline]
#[must_use]
pub fn contains(&self, id: &SymbolId) -> bool {
self.symbols.contains_key(id)
}
#[inline]
#[must_use]
pub fn get(&self, id: &SymbolId) -> Option<&Symbol> {
self.symbols.get(id)
}
pub fn remove(&mut self, id: &SymbolId) -> Option<Symbol> {
let symbol = self.symbols.remove(id)?;
self.total_count = self.total_count.saturating_sub(1);
self.deallocate(Self::estimate_symbol_size(&symbol));
let sbn = id.sbn();
if let Some(progress) = self.block_counts.get_mut(&sbn) {
match symbol.kind() {
SymbolKind::Source => {
progress.source_symbols = progress.source_symbols.saturating_sub(1);
}
SymbolKind::Repair => {
progress.repair_symbols = progress.repair_symbols.saturating_sub(1);
}
}
progress.threshold_reached =
Self::calculate_threshold(progress, &self.threshold_config);
if progress.total() == 0 && progress.k.is_none() {
self.block_counts.remove(&sbn);
}
}
Some(symbol)
}
pub fn symbols_for_block(&self, sbn: u8) -> impl Iterator<Item = &Symbol> {
self.symbols
.values()
.filter(move |symbol| symbol.sbn() == sbn)
}
#[inline]
#[must_use]
pub fn block_progress(&self, sbn: u8) -> Option<&BlockProgress> {
self.block_counts.get(&sbn)
}
#[inline]
#[must_use]
pub fn threshold_reached(&self, sbn: u8) -> bool {
self.block_counts
.get(&sbn)
.is_some_and(|progress| progress.threshold_reached)
}
#[must_use]
pub fn ready_blocks(&self) -> Vec<u8> {
let mut ready: Vec<u8> = self
.block_counts
.iter()
.filter_map(|(sbn, progress)| {
if progress.threshold_reached {
Some(*sbn)
} else {
None
}
})
.collect();
ready.sort_unstable();
ready
}
#[inline]
#[must_use]
pub const fn len(&self) -> usize {
self.total_count
}
#[inline]
#[must_use]
pub const fn is_empty(&self) -> bool {
self.total_count == 0
}
#[inline]
#[must_use]
pub const fn memory_usage(&self) -> usize {
self.total_bytes
}
pub fn clear(&mut self) {
self.symbols.clear();
self.block_counts.clear();
self.total_count = 0;
self.total_bytes = 0;
if let Some(budget) = self.memory_budget {
self.memory_remaining = budget;
}
}
pub fn iter(&self) -> impl Iterator<Item = (&SymbolId, &Symbol)> {
self.symbols.iter()
}
pub fn drain(&mut self) -> impl Iterator<Item = (SymbolId, Symbol)> {
self.block_counts.clear();
self.total_count = 0;
self.total_bytes = 0;
if let Some(budget) = self.memory_budget {
self.memory_remaining = budget;
}
std::mem::take(&mut self.symbols).into_iter()
}
pub fn clear_block(&mut self, sbn: u8) {
let ids: Vec<SymbolId> = self
.symbols
.iter()
.filter_map(
|(id, symbol)| {
if symbol.sbn() == sbn { Some(*id) } else { None }
},
)
.collect();
for id in ids {
let _ = self.remove(&id);
}
if let Some(progress) = self.block_counts.get_mut(&sbn) {
progress.source_symbols = 0;
progress.repair_symbols = 0;
progress.threshold_reached =
Self::calculate_threshold(progress, &self.threshold_config);
if progress.k.is_none() {
self.block_counts.remove(&sbn);
}
}
}
fn calculate_threshold(progress: &BlockProgress, config: &ThresholdConfig) -> bool {
progress.k.is_some_and(|k| {
if k == 0 {
return false;
}
let k_usize = k as usize;
if progress.source_symbols >= k_usize {
return true;
}
let total = progress.total();
let raw = (f64::from(k) * config.overhead_factor).ceil();
let minimum_threshold = k_usize.saturating_add(config.min_overhead);
if raw.is_nan() {
return total >= minimum_threshold;
}
if raw.is_sign_positive() && !raw.is_finite() {
return false;
}
if raw.is_sign_negative() {
return total >= minimum_threshold;
}
#[allow(clippy::cast_sign_loss)]
let factor_threshold = raw as usize;
let threshold = factor_threshold.max(minimum_threshold);
total >= threshold
})
}
fn estimate_symbol_size(symbol: &Symbol) -> usize {
std::mem::size_of::<SymbolId>() + symbol.data().len() + SYMBOL_OVERHEAD_BYTES
}
fn try_allocate(&mut self, size: usize) -> bool {
if self.memory_budget.is_some() {
if size <= self.memory_remaining {
self.memory_remaining -= size;
} else {
return false;
}
}
self.total_bytes = self.total_bytes.saturating_add(size);
true
}
fn deallocate(&mut self, size: usize) {
self.total_bytes = self.total_bytes.saturating_sub(size);
if let Some(budget) = self.memory_budget {
self.memory_remaining = self.memory_remaining.saturating_add(size).min(budget);
}
}
}
impl Default for SymbolSet {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default)]
pub struct ConcurrentSymbolSet {
inner: RwLock<SymbolSet>,
}
impl ConcurrentSymbolSet {
#[must_use]
pub fn new() -> Self {
Self {
inner: RwLock::new(SymbolSet::new()),
}
}
pub fn insert(&self, symbol: Symbol) -> InsertResult {
self.inner.write().insert(symbol)
}
pub fn set_block_k(&self, sbn: u8, k: u16) -> bool {
self.inner.write().set_block_k(sbn, k)
}
#[must_use]
pub fn threshold_reached(&self, sbn: u8) -> bool {
self.inner.read().threshold_reached(sbn)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Symbol;
fn test_symbol(sbn: u8, esi: u32, data_len: usize) -> Symbol {
Symbol::new_source_for_test(1, sbn, esi, &vec![0u8; data_len])
}
fn test_repair_symbol(sbn: u8, esi: u32, data_len: usize) -> Symbol {
Symbol::new_repair_for_test(1, sbn, esi, &vec![0u8; data_len])
}
#[test]
fn insert_and_duplicate() {
let mut set = SymbolSet::new();
let symbol = test_symbol(0, 0, 4);
assert!(matches!(
set.insert(symbol.clone()),
InsertResult::Inserted { .. }
));
assert!(matches!(set.insert(symbol), InsertResult::Duplicate));
assert_eq!(set.len(), 1);
}
#[test]
fn threshold_tracking() {
let config = ThresholdConfig::new(1.0, 0, 0);
let mut set = SymbolSet::with_config(config);
assert!(!set.threshold_reached(0));
let _ = set.insert(test_symbol(0, 0, 4));
assert!(!set.threshold_reached(0));
set.set_block_k(0, 1);
assert!(set.threshold_reached(0));
}
#[test]
fn repair_symbols_increment_repair_progress() {
let mut set = SymbolSet::new();
let symbol = test_repair_symbol(0, 99, 4);
let InsertResult::Inserted { block_progress, .. } = set.insert(symbol) else {
panic!("repair symbol should insert");
};
assert_eq!(block_progress.source_symbols, 0);
assert_eq!(block_progress.repair_symbols, 1);
}
#[test]
fn block_limit_enforced() {
let config = ThresholdConfig::new(1.0, 0, 1);
let mut set = SymbolSet::with_config(config);
assert!(matches!(
set.insert(test_symbol(1, 0, 1)),
InsertResult::Inserted { .. }
));
assert!(matches!(
set.insert(test_symbol(1, 1, 1)),
InsertResult::BlockLimitReached { sbn: 1 }
));
}
#[test]
fn memory_budget_enforced() {
let config = ThresholdConfig::new(1.0, 0, 0);
let mut set = SymbolSet::with_memory_budget(config, 8);
let large = test_symbol(0, 0, 128);
assert!(matches!(
set.insert(large),
InsertResult::MemoryLimitReached
));
}
#[test]
fn clear_block_removes_only_block() {
let mut set = SymbolSet::new();
let _ = set.insert(test_symbol(0, 0, 4));
let _ = set.insert(test_symbol(1, 0, 4));
assert_eq!(set.len(), 2);
set.clear_block(0);
assert_eq!(set.len(), 1);
assert!(set.symbols_for_block(0).next().is_none());
assert!(set.symbols_for_block(1).next().is_some());
}
#[test]
fn remove_decrements_counts_and_memory() {
let mut set = SymbolSet::new();
let sym = test_symbol(0, 0, 16);
let id = sym.id();
let _ = set.insert(sym);
assert_eq!(set.len(), 1);
let mem_before = set.memory_usage();
assert!(mem_before > 0);
let removed = set.remove(&id);
assert!(removed.is_some());
assert_eq!(set.len(), 0);
assert!(set.is_empty());
assert_eq!(set.memory_usage(), 0);
}
#[test]
fn concurrent_symbol_set_insert_and_threshold() {
let css = ConcurrentSymbolSet::new();
let sym = test_symbol(0, 0, 4);
assert!(matches!(css.insert(sym), InsertResult::Inserted { .. }));
assert!(!css.threshold_reached(0));
css.set_block_k(0, 1);
assert!(css.threshold_reached(0));
}
#[test]
fn ready_blocks_returns_threshold_blocks() {
let config = ThresholdConfig::new(1.0, 0, 0);
let mut set = SymbolSet::with_config(config);
let _ = set.insert(test_symbol(0, 0, 4));
let _ = set.insert(test_symbol(1, 0, 4));
set.set_block_k(0, 1);
let ready = set.ready_blocks();
assert_eq!(ready.len(), 1);
assert!(ready.contains(&0));
}
#[test]
fn ready_blocks_are_sorted() {
let config = ThresholdConfig::new(1.0, 0, 0);
let mut set = SymbolSet::with_config(config);
let _ = set.insert(test_symbol(2, 0, 4));
let _ = set.insert(test_symbol(0, 0, 4));
let _ = set.insert(test_symbol(1, 0, 4));
assert!(set.set_block_k(2, 1));
assert!(set.set_block_k(0, 1));
assert!(set.set_block_k(1, 1));
assert_eq!(set.ready_blocks(), vec![0, 1, 2]);
}
#[test]
fn clear_resets_all_state() {
let config = ThresholdConfig::new(1.0, 0, 0);
let mut set = SymbolSet::with_memory_budget(config, 4096);
let _ = set.insert(test_symbol(0, 0, 4));
let _ = set.insert(test_symbol(0, 1, 4));
set.set_block_k(0, 2);
assert_eq!(set.len(), 2);
assert!(set.memory_usage() > 0);
set.clear();
assert!(set.is_empty());
assert_eq!(set.len(), 0);
assert_eq!(set.memory_usage(), 0);
assert!(!set.threshold_reached(0));
}
#[test]
fn block_progress_tracking() {
let mut set = SymbolSet::new();
assert!(set.block_progress(0).is_none());
let _ = set.insert(test_symbol(0, 0, 4)); let progress = set.block_progress(0).unwrap();
assert_eq!(
progress.total(),
progress.source_symbols + progress.repair_symbols
);
assert_eq!(progress.sbn, 0);
}
#[test]
fn iter_and_drain_symbols() {
let mut set = SymbolSet::new();
let _ = set.insert(test_symbol(0, 0, 4));
let _ = set.insert(test_symbol(0, 1, 4));
assert_eq!(set.iter().count(), 2);
assert_eq!(set.len(), 2);
let drained = set.drain().count();
assert_eq!(drained, 2);
assert!(set.is_empty());
assert_eq!(set.memory_usage(), 0);
}
#[test]
fn zero_k_never_reaches_threshold() {
let config = ThresholdConfig::new(1.0, 0, 0);
let mut set = SymbolSet::with_config(config);
let _ = set.insert(test_symbol(0, 0, 4));
assert!(!set.set_block_k(0, 0));
assert!(!set.threshold_reached(0));
}
#[test]
fn threshold_calculation_saturates_min_overhead() {
let config = ThresholdConfig::new(1.0, usize::MAX, 0);
let mut set = SymbolSet::with_config(config);
let _ = set.insert(test_symbol(0, 0, 4));
assert!(!set.set_block_k(0, 2));
assert!(!set.threshold_reached(0));
}
#[test]
fn threshold_reached_when_minimum_extra_dominates_without_double_counting() {
let config = ThresholdConfig::new(1.05, 3, 0);
let mut set = SymbolSet::with_config(config);
assert!(!set.set_block_k(0, 10));
for esi in 0..5 {
let _ = set.insert(test_symbol(0, esi, 4));
}
for esi in 5..12 {
let _ = set.insert(test_repair_symbol(0, esi, 4));
}
assert!(!set.threshold_reached(0));
let _ = set.insert(test_repair_symbol(0, 12, 4));
assert!(set.threshold_reached(0));
}
#[test]
fn threshold_reached_when_factor_dominates_without_extra_increment() {
let config = ThresholdConfig::new(1.5, 1, 0);
let mut set = SymbolSet::with_config(config);
assert!(!set.set_block_k(0, 10));
for esi in 0..5 {
let _ = set.insert(test_symbol(0, esi, 4));
}
for esi in 5..14 {
let _ = set.insert(test_repair_symbol(0, esi, 4));
}
assert!(!set.threshold_reached(0));
let _ = set.insert(test_repair_symbol(0, 14, 4));
assert!(set.threshold_reached(0));
}
#[test]
fn threshold_does_not_reach_with_infinite_overhead_before_all_sources_arrive() {
let config = ThresholdConfig::new(f64::INFINITY, 0, 0);
let mut set = SymbolSet::with_config(config);
assert!(!set.set_block_k(0, 10));
for esi in 0..9 {
let _ = set.insert(test_symbol(0, esi, 4));
}
for esi in 9..65 {
let _ = set.insert(test_repair_symbol(0, esi, 4));
}
assert!(!set.threshold_reached(0));
}
#[test]
fn threshold_config_debug_clone_copy() {
let cfg = ThresholdConfig::new(1.02, 0, 0);
let dbg = format!("{cfg:?}");
assert!(dbg.contains("ThresholdConfig"), "{dbg}");
let copied = cfg;
let cloned = cfg;
assert!((copied.overhead_factor - cloned.overhead_factor).abs() < f64::EPSILON);
}
#[test]
fn insert_result_debug_clone() {
let mut set = SymbolSet::new();
let result = set.insert(test_symbol(0, 0, 4));
let dbg = format!("{result:?}");
assert!(dbg.contains("Insert"), "{dbg}");
let cloned = result;
let dbg2 = format!("{cloned:?}");
assert_eq!(dbg, dbg2);
}
}