use super::Error;
use crate::Config;
use commonware_math::fields::goldilocks::F;
use commonware_utils::BigRationalExt as _;
use num_rational::BigRational;
const SECURITY_BITS: usize = 126;
const LOG2_PRECISION: usize = SECURITY_BITS.next_power_of_two().trailing_zeros() as usize;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Topology {
pub data_bytes: usize,
pub data_cols: usize,
pub data_rows: usize,
pub encoded_rows: usize,
pub samples: usize,
pub column_samples: usize,
pub min_shards: usize,
pub total_shards: usize,
}
impl Topology {
const fn with_cols(data_bytes: usize, n: usize, k: usize, cols: usize) -> Self {
let data_els = F::bits_to_elements(8 * data_bytes);
let data_rows = data_els.div_ceil(cols);
let samples = data_rows.div_ceil(n);
Self {
data_bytes,
data_cols: cols,
data_rows,
encoded_rows: ((n + k) * samples).next_power_of_two(),
samples,
column_samples: 0,
min_shards: n,
total_shards: n + k,
}
}
pub(crate) fn required_samples(&self) -> usize {
let k = BigRational::from_usize(self.encoded_rows - self.data_rows);
let m = BigRational::from_usize(self.encoded_rows);
let fraction = (&k + BigRational::from_u64(1)) / (BigRational::from_usize(2) * &m);
let one_minus = BigRational::from_usize(1) - &fraction;
let log_term = one_minus.log2_ceil(LOG2_PRECISION);
if log_term >= BigRational::from_u64(0) {
return usize::MAX;
}
let required = BigRational::from_usize(SECURITY_BITS) / -log_term;
required.ceil_to_u128().unwrap_or(u128::MAX) as usize
}
fn correct_column_samples(&mut self) {
self.column_samples =
F::bits_to_elements(SECURITY_BITS) * self.required_samples().div_ceil(self.samples);
}
pub fn reckon(config: &Config, data_bytes: usize) -> Self {
let n = config.minimum_shards.get() as usize;
let k = config.extra_shards.get() as usize;
let corrected_data_bytes = data_bytes.max(1);
let mut out = Self::with_cols(corrected_data_bytes, n, k, 1);
loop {
let attempt = Self::with_cols(corrected_data_bytes, n, k, out.data_cols + 1);
let required_samples = attempt.required_samples();
if required_samples.saturating_mul(n + k) <= attempt.encoded_rows {
out = Self {
samples: required_samples.max(attempt.samples),
..attempt
};
} else {
break;
}
}
out.correct_column_samples();
out.data_bytes = data_bytes;
out
}
pub fn check_index(&self, i: u16) -> Result<(), Error> {
if (0..self.total_shards).contains(&(i as usize)) {
return Ok(());
}
Err(Error::InvalidIndex(i))
}
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_utils::NZU16;
#[test]
fn reckon_handles_small_extra_shards() {
let config = Config {
minimum_shards: NZU16!(3),
extra_shards: NZU16!(1),
};
let topology = Topology::reckon(&config, 16);
assert_eq!(topology.min_shards, 3);
assert_eq!(topology.total_shards, 4);
assert_eq!(topology.data_cols, 1);
let required = topology.required_samples();
let provided = topology.samples * (topology.column_samples / 2);
assert!(
provided >= required,
"security invariant violated: provided {provided} < required {required}"
);
}
}