#![expect(clippy::float_arithmetic, reason = "computing probabilities")]
use core::{error::Error, fmt};
use rand::prelude::*;
use crate::level_generator::LevelGenerator;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[expect(
clippy::module_name_repetitions,
reason = "Using 'Error' would be too generic and may cause confusion."
)]
#[non_exhaustive]
pub enum GeometricError {
ZeroMax,
MaxTooLarge,
InvalidProbability,
}
impl fmt::Display for GeometricError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ZeroMax => write!(f, "max must be non-zero."),
Self::MaxTooLarge => write!(f, "max must be less than i32::MAX."),
Self::InvalidProbability => write!(f, "q must be in (0, 1)."),
}
}
}
impl Error for GeometricError {}
#[derive(Debug, Clone)]
pub struct Geometric {
total: usize,
total_inclusive: i32,
q: f64,
rng: SmallRng,
}
impl Geometric {
#[inline]
pub fn new(total: usize, q: f64) -> Result<Self, GeometricError> {
if total == 0 {
return Err(GeometricError::ZeroMax);
}
let Some(total_inclusive) = i32::try_from(total).ok().and_then(|i| i.checked_add(1)) else {
return Err(GeometricError::MaxTooLarge);
};
if !(0.0 < q && q < 1.0) {
return Err(GeometricError::InvalidProbability);
}
Ok(Geometric {
total,
total_inclusive,
q,
rng: SmallRng::from_rng(&mut rand::rng()),
})
}
#[inline]
pub fn new_with_seed(total: usize, q: f64, seed: u64) -> Result<Self, GeometricError> {
if total == 0 {
return Err(GeometricError::ZeroMax);
}
let Some(total_inclusive) = i32::try_from(total).ok().and_then(|i| i.checked_add(1)) else {
return Err(GeometricError::MaxTooLarge);
};
if !(0.0 < q && q < 1.0) {
return Err(GeometricError::InvalidProbability);
}
Ok(Geometric {
total,
total_inclusive,
q,
rng: SmallRng::seed_from_u64(seed),
})
}
}
impl Default for Geometric {
#[inline]
fn default() -> Self {
#[expect(
clippy::expect_used,
reason = "16 levels and q = 0.5 are compile-time constants whose \
validity is guaranteed by the Geometric invariants"
)]
Geometric::new(16, 0.5)
.expect("16 levels and q = 0.5 are always valid Geometric parameters")
}
}
impl LevelGenerator for Geometric {
#[inline]
fn total(&self) -> usize {
self.total
}
#[inline]
#[expect(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
reason = "CDF domain is [0, total] so the cast is safe after clamping"
)]
#[expect(clippy::as_conversions, reason = "No other way to do this")]
fn level(&mut self) -> usize {
let u = self.rng.random::<f64>();
((1.0 + (self.q.powi(self.total_inclusive) - 1.0) * u)
.log(self.q)
.floor() as usize)
.min(self.total)
}
}
#[cfg(test)]
mod tests {
use anyhow::{Result, anyhow};
use pretty_assertions::assert_eq;
use rstest::rstest;
use super::{Geometric, LevelGenerator};
use crate::level_generator::geometric::GeometricError;
#[test]
fn invalid_max() {
assert_eq!(Geometric::new(0, 0.5).err(), Some(GeometricError::ZeroMax));
}
#[test]
fn invalid_p() {
assert_eq!(
Geometric::new(1, 0.0).err(),
Some(GeometricError::InvalidProbability)
);
assert_eq!(
Geometric::new(1, 1.0).err(),
Some(GeometricError::InvalidProbability)
);
}
#[rstest]
fn total_is_correct(
#[values(1, 2, 4, 8, 128, 512, 1024)] n: usize,
#[values(0.01, 0.1, 0.5, 0.99)] q: f64,
) -> Result<()> {
let generator = Geometric::new_with_seed(n, q, 42)?;
assert_eq!(generator.total(), n);
Ok(())
}
#[rstest]
fn generates_level_zero(
#[values(1, 2, 4, 8, 128, 512, 1024)] n: usize,
#[values(0.01, 0.1, 0.2, 0.5, 0.8, 0.99)] q: f64,
) -> Result<()> {
const MAX: usize = if cfg!(miri) { 50 } else { 10_000_000 };
let mut generator = Geometric::new_with_seed(n, q, 42)?;
let found = (0..MAX).any(|_| {
let level = generator.level();
assert!(
(0..=n).contains(&level),
"level {level} out of range 0..={n}"
);
level == 0
});
if !cfg!(miri) {
assert!(
found,
"Failed to generate a level-0 node after {MAX} attempts"
);
}
Ok(())
}
#[rstest]
fn generates_max_level_small_n(
#[values(1, 2, 4, 8)] n: usize,
#[values(0.2, 0.5, 0.8, 0.9, 0.99)] q: f64,
) -> Result<()> {
const MAX: usize = if cfg!(miri) { 50 } else { 10_000_000 };
let mut generator = Geometric::new_with_seed(n, q, 42)?;
let found = (0..MAX).any(|_| {
let level = generator.level();
assert!(
(0..=n).contains(&level),
"level {level} out of range 0..={n}"
);
level == n
});
if !cfg!(miri) {
assert!(
found,
"Failed to generate a level-{n} node after {MAX} attempts"
);
}
Ok(())
}
#[rstest]
fn generates_max_level_large_n(
#[values(32, 64)] n: usize,
#[values(0.99, 0.999)] q: f64,
) -> Result<()> {
const MAX: usize = if cfg!(miri) { 50 } else { 10_000_000 };
let mut generator = Geometric::new_with_seed(n, q, 42)?;
let found = (0..MAX).any(|_| {
let level = generator.level();
assert!(
(0..=n).contains(&level),
"level {level} out of range 0..={n}"
);
level == n
});
if !cfg!(miri) {
assert!(
found,
"Failed to generate a level-{n} node after {MAX} attempts"
);
}
Ok(())
}
#[rstest]
fn distribution_ratio(
#[values(4, 8, 16)] n: usize,
#[values(0.1, 0.2, 0.5, 0.8, 0.9)] q: f64,
) -> Result<()> {
const SAMPLES: usize = if cfg!(miri) { 50 } else { 10_000_000 };
const MIN_COUNT: u32 = 1_000;
const TOLERANCE: f64 = 0.05;
let mut counts = vec![0_u32; n.strict_add(1)];
let mut generator = Geometric::new_with_seed(n, q, 42)?;
for _ in 0..SAMPLES {
if let Some(count) = counts.get_mut(generator.level()) {
*count = count.strict_add(1);
} else {
panic!("Generated level {} out of range 0..={n}", generator.level());
}
}
if cfg!(miri) {
return Ok(());
}
for k in 0..n {
let next_k = k.strict_add(1);
let count_k = counts
.get(k)
.copied()
.ok_or_else(|| anyhow!("invalid count bin"))?;
let count_next_k = counts
.get(next_k)
.copied()
.ok_or_else(|| anyhow!("invalid count bin"))?;
if count_k < MIN_COUNT || count_next_k < MIN_COUNT {
break;
}
let ratio = f64::from(count_next_k) / f64::from(count_k);
let relative_err = (ratio - q).abs() / q;
assert!(
relative_err < TOLERANCE,
"level {k}→{next_k}: count[{k}]={count_k}, count[{next_k}]={count_next_k}, \
ratio={ratio:.4}, expected q={q:.4} (err {:.1}%)",
relative_err * 100.0,
);
}
Ok(())
}
}