use std::collections::BTreeMap;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::LazyLock;
use std::sync::RwLock;
use crate::*;
static CDT_CACHE: LazyLock<RwLock<HashMap<(u128, u32), Arc<GaussianCDT>>>> =
LazyLock::new(|| RwLock::new(HashMap::default()));
static MIN_PRECISION: LazyLock<f64> = LazyLock::new(|| 10f64.powi(5) * f64::EPSILON);
pub struct GaussianCDT {
pub cardinality: u128,
pub sigma: f64,
pub displacements: BTreeMap<i32, f64>,
pub normalized_sum: f64,
pub tail_bounds: (i32, i32),
}
impl GaussianCDT {
pub fn identifier<E: FieldScalar>(sigma: f64) -> (u128, u32) {
let scale = 10f64.powi(5);
let sigma_key = sigma * scale;
assert!(
sigma_key - sigma_key.floor() < 1.0,
"CDT: sigma is too precise"
);
assert!(sigma_key <= u32::MAX as f64, "CDT: sigma is too large");
let sigma_key = sigma_key as u32;
(E::CARDINALITY, sigma_key)
}
fn new<E: FieldScalar>(sigma: f64) -> Self {
let std_devs = 15f64;
let tail = std_devs * sigma;
assert!(
tail < 2f64.powi(20),
"CDT refusing to build table larger than 2^20 tail (i32 limits)"
);
let tail = tail as i32;
let mut actual_tail = tail;
let mut displacements = BTreeMap::<i32, f64>::default();
log::debug!("CDT MIN_PRECISION: {}", *MIN_PRECISION);
for disp in -tail..=tail {
let prob_exp = (disp as f64).powi(2) / (2.0 * sigma * sigma);
let is_exp_precise =
disp == 0 || prob_exp.fract() == 0. || prob_exp.fract() > *MIN_PRECISION;
let prob = f64::exp(-prob_exp);
if prob < *MIN_PRECISION || !is_exp_precise {
assert_ne!(disp, 0);
if disp.is_negative() {
actual_tail = disp.abs() - 1;
continue;
}
if disp.is_positive() {
assert!(actual_tail < disp);
break;
}
unreachable!();
}
log::debug!("CDT displacement: {} pdf: {}", disp, prob);
displacements.insert(disp, prob);
}
if actual_tail < (7.0 * sigma) as i32 {
panic!(
"CDT refusing to build unlikely table with sigma: {} tail: {} stddevs: {}",
sigma,
actual_tail,
((actual_tail as f64 / sigma) * 10.0).floor() / 10.0
);
}
let tail_bounds = (-actual_tail, actual_tail);
let mut raw_sum = 0f64;
for disp in tail_bounds.0..=tail_bounds.1 {
let prob = displacements
.get(&disp)
.expect("CDT displacement {disp} does not exist in table");
raw_sum += prob;
log::debug!("CDT sigma {}, disp: {} prob: {}", sigma, disp, prob);
}
assert!(raw_sum < 10f64.powi(5));
let mut normalized_sum = 0f64;
for disp in tail_bounds.0..=tail_bounds.1 {
let prob = displacements
.get_mut(&disp)
.expect("CDT displacement {disp} does not exist in table");
*prob /= raw_sum;
normalized_sum += *prob;
*prob = normalized_sum;
}
Self {
cardinality: E::CARDINALITY,
sigma,
tail_bounds,
displacements,
normalized_sum,
}
}
pub fn cache_or_init<E: FieldScalar>(sigma: f64) -> Arc<Self> {
let identifier = Self::identifier::<E>(sigma);
if let Some(cdt) = CDT_CACHE.read().unwrap().get(&identifier) {
return cdt.clone();
}
let out = Arc::new(Self::new::<E>(sigma));
CDT_CACHE.write().unwrap().insert(identifier, out.clone());
out
}
fn displacements_iter(&self) -> impl Iterator<Item = (i32, f64)> {
(self.tail_bounds.0..=self.tail_bounds.1).map(|i| {
self.displacements
.get(&i)
.map(|prob| (i, *prob))
.expect("CDT did not find entry for displacement {i}")
})
}
pub fn sample<E: FieldScalar, R: Rng>(&self, rng: &mut R) -> E {
let r: f64 = rng.random_range(0.0..1.0);
let (min_i, _) = self
.displacements
.first_key_value()
.expect("CDT first disp did not exist");
let (max_i, _) = self
.displacements
.last_key_value()
.expect("CDT last disp did not exist");
let mut i: i32 = 0;
let mut last_i: i32 = 0;
loop {
let l = last_i;
last_i = i;
let next_disp = match self.displacements.get(&i) {
Some(disp) => disp,
None => return E::at_displacement(i - 1),
};
let prev_disp = match self.displacements.get(&(i - 1)) {
Some(disp) => disp,
None => return E::at_displacement(i),
};
if &r >= next_disp {
if i.signum() == 0 {
i = max_i / 2;
continue;
}
i += (l.abs_diff(i) / 2).max(1) as i32;
} else if &r < prev_disp {
if i.signum() == 0 {
i = min_i / 2;
continue;
}
i -= (l.abs_diff(i) / 2).max(1) as i32;
} else {
return E::at_displacement(i);
}
}
}
pub fn sample_arr<const N: usize, E: FieldScalar, R: Rng>(&self, rng: &mut R) -> [E; N] {
std::array::from_fn(|_| self.sample(rng))
}
pub fn sample_vec<E: FieldScalar, R: Rng>(&self, len: usize, rng: &mut R) -> Vector<E> {
let mut samples = BTreeMap::<usize, f64>::default();
for i in 0..len {
samples.insert(i, rng.random_range(0.0..1.0));
}
let mut out = BTreeMap::<usize, E>::default();
for (disp, prob) in self.displacements_iter() {
samples.retain(|i, sample| {
if *sample < prob {
out.insert(*i, E::at_displacement(disp));
return false;
}
true
});
if samples.is_empty() {
break;
}
}
assert!(samples.is_empty(), "CDT not all samples were matched");
assert_eq!(out.len(), len, "CDT outputting invalid sample len");
out.into_values().collect::<Vector<_>>()
}
pub(crate) fn prob(&self, disp: &i32) -> f64 {
let prev_prob = self
.displacements
.get(&(*disp - 1))
.copied()
.unwrap_or_default();
*self
.displacements
.get(&disp)
.expect("CDT requested probability of element not in table")
- prev_prob
}
}
#[cfg(test)]
mod test {
use super::*;
const SAMPLES_PER_CDT: usize = 100_000;
fn get_cdt_sample_pairs<E: FieldScalar, R: Rng>(
test_logic: fn(Arc<GaussianCDT>, HashMap<i32, usize>, rng: &mut R),
rng: &mut R,
) {
let range = 20..=100;
let sigma_iter = range.map(|i| (i as f64) / 10.0);
for sigma in sigma_iter.clone() {
let cdt = GaussianCDT::cache_or_init::<E>(sigma);
let mut samples = HashMap::<i32, usize>::default();
for _ in 0..SAMPLES_PER_CDT {
let disp = cdt.sample::<E, _>(rng).displacement();
*samples.entry(disp as i32).or_default() += 1;
}
test_logic(cdt, samples, rng);
}
for sigma in sigma_iter.clone() {
let cdt = GaussianCDT::cache_or_init::<E>(sigma);
let mut samples = HashMap::<i32, usize>::default();
let batch_size: usize = rand::random_range(1..30);
let mut sample_count = 0usize;
loop {
if sample_count >= SAMPLES_PER_CDT {
break;
}
let batch_size = batch_size.min(sample_count.abs_diff(SAMPLES_PER_CDT));
let disps: Vector<E> = cdt.sample_vec(batch_size, rng);
for disp in disps {
*samples.entry(disp.displacement() as i32).or_default() += 1;
}
sample_count += batch_size;
}
test_logic(cdt, samples, rng);
}
for sigma in sigma_iter.clone() {
let cdt = GaussianCDT::cache_or_init::<E>(sigma);
let mut samples = HashMap::<i32, usize>::default();
let mut sample_count = 0usize;
loop {
if sample_count >= SAMPLES_PER_CDT {
break;
}
for disp in cdt.sample_arr::<10, E, _>(rng) {
sample_count += 1;
*samples.entry(disp.displacement() as i32).or_default() += 1;
if sample_count >= SAMPLES_PER_CDT {
break;
}
}
}
test_logic(cdt, samples, rng);
}
}
macro_rules! as_f64 {
($($name: ident),*) => {
$(
let $name = ($name).clone() as f64;
)*
};
}
#[test]
fn cdt_mean() {
type Field = MilliScalarMont;
let rng = &mut rand::rng();
get_cdt_sample_pairs::<Field, _>(
|cdt, samples, _rng| {
let mean = samples
.iter()
.map(|(disp, count)| {
as_f64!(disp, count);
disp * count
})
.sum::<f64>()
/ SAMPLES_PER_CDT as f64;
let std_err = cdt.sigma / (SAMPLES_PER_CDT as f64).sqrt();
let tolerance = 3.5 * std_err;
assert!(mean.abs() < tolerance, "mean: {mean}");
},
rng,
);
}
#[test]
fn cdt_std_dev() {
type Field = MilliScalarMont;
let rng = &mut rand::rng();
get_cdt_sample_pairs::<Field, _>(
|cdt, samples, _rng| {
let mut sum = 0f64;
for (disp, count) in &samples {
as_f64!(disp, count);
sum += disp * count;
}
let mean = sum / SAMPLES_PER_CDT as f64;
let variance = samples
.iter()
.map(|(disp, count)| {
as_f64!(count, disp, mean);
count * (disp - mean).powi(2)
})
.sum::<f64>()
/ SAMPLES_PER_CDT as f64;
let std_dev = variance.sqrt();
let percent_diff = ((std_dev - cdt.sigma) / cdt.sigma).abs();
assert!(
percent_diff < 0.01,
"std_dev percent diff: {}%",
percent_diff * 100.
);
},
rng,
);
}
#[test]
fn cdt_symmetry() {
type Field = MilliScalarMont;
let rng = &mut rand::rng();
get_cdt_sample_pairs::<Field, _>(
|_cdt, samples, _rng| {
let mut total_neg = 0f64;
let mut total_pos = 0f64;
for (disp, count) in samples {
if disp < 0 {
total_neg += count as f64;
} else if disp > 0 {
total_pos += count as f64;
}
}
assert!(
(1.0 - total_neg / total_pos).abs() < 0.03,
"total_neg: {total_neg} total_pos: {total_pos}"
);
},
rng,
);
}
#[test]
fn cdt_chi_squared_fit() {
type Field = MilliScalarMont;
let rng = &mut rand::rng();
get_cdt_sample_pairs::<Field, _>(
|cdt, mut samples, _rng| {
let mut chi_sq = 0f64;
for disp in cdt.tail_bounds.0..cdt.tail_bounds.1 {
let count = samples.entry(disp).or_default();
let expected = cdt.prob(&disp) * SAMPLES_PER_CDT as f64;
if expected < 1.0 {
continue;
}
chi_sq += (*count as f64 - expected).powi(2) / expected;
}
let df = samples.len() - 1;
let expected = chi_sq_95(df);
assert!(
chi_sq < expected,
"{chi_sq} outside of bound 95% {expected}"
);
},
rng,
);
}
}