#![expect(clippy::float_arithmetic, reason = "computing probabilities")]
use core::{error::Error, fmt};
use rand::prelude::*;
use crate::level_generator::LevelGenerator;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[expect(
clippy::module_name_repetitions,
reason = "Using 'Error' would be too generic and may cause confusion."
)]
#[non_exhaustive]
pub enum GeometricError {
ZeroMax,
MaxTooLarge,
InvalidProbability,
}
impl fmt::Display for GeometricError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ZeroMax => write!(f, "max must be non-zero."),
Self::MaxTooLarge => write!(f, "max must be less than i32::MAX."),
Self::InvalidProbability => write!(f, "p must be in (0, 1)."),
}
}
}
impl Error for GeometricError {}
#[derive(Debug, Clone)]
pub struct Geometric {
total: usize,
total_i32: i32,
q: f64,
rng: SmallRng,
}
impl Geometric {
#[inline]
pub fn new(total: usize, p: f64) -> Result<Self, GeometricError> {
if total == 0 {
return Err(GeometricError::ZeroMax);
}
let Ok(total_i32) = i32::try_from(total) else {
return Err(GeometricError::MaxTooLarge);
};
if !(0.0 < p && p < 1.0) {
return Err(GeometricError::InvalidProbability);
}
Ok(Geometric {
total,
total_i32,
q: 1.0 - p,
rng: SmallRng::from_rng(&mut rand::rng()),
})
}
}
impl Default for Geometric {
#[inline]
fn default() -> Self {
#[expect(
clippy::expect_used,
reason = "16 levels and p = 0.5 are compile-time constants whose \
validity is guaranteed by the Geometric invariants"
)]
Geometric::new(16, 0.5)
.expect("16 levels and p = 0.5 are always valid Geometric parameters")
}
}
impl LevelGenerator for Geometric {
#[inline]
fn total(&self) -> usize {
self.total
}
#[inline]
#[expect(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
reason = "CDF domain is [0, total] so the cast is safe after clamping"
)]
#[expect(clippy::as_conversions, reason = "No other way to do this")]
fn level(&mut self) -> usize {
let u = self.rng.random::<f64>();
((1.0 + (self.q.powi(self.total_i32) - 1.0) * u)
.log(self.q)
.floor() as usize)
.min(self.total.saturating_sub(1))
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
#[cfg(not(miri))]
use anyhow::bail;
use pretty_assertions::assert_eq;
use rstest::rstest;
use super::{Geometric, LevelGenerator};
use crate::level_generator::geometric::GeometricError;
#[test]
fn invalid_max() {
assert_eq!(Geometric::new(0, 0.5).err(), Some(GeometricError::ZeroMax));
}
#[test]
fn invalid_p() {
assert_eq!(
Geometric::new(1, 0.0).err(),
Some(GeometricError::InvalidProbability)
);
assert_eq!(
Geometric::new(1, 1.0).err(),
Some(GeometricError::InvalidProbability)
);
}
#[cfg(miri)]
#[rstest]
fn new_miri(
#[values(1, 2, 128, 1024)] n: usize,
#[values(0.01, 0.1, 0.5, 0.99)] p: f64,
) -> Result<()> {
const MAX: usize = 10;
let mut generator = Geometric::new(n, p)?;
assert_eq!(generator.total(), n);
for _ in 0..MAX {
let level = generator.level();
assert!((0..n).contains(&level));
}
Ok(())
}
#[cfg(not(miri))]
#[rstest]
fn new_small(
#[values(1, 2, 4, 8)] n: usize,
#[values(0.01, 0.1, 0.5, 0.8)] p: f64,
) -> Result<()> {
const MAX: usize = 10_000_000;
let mut generator = Geometric::new(n, p)?;
assert_eq!(generator.total(), n);
for _ in 0..1_000 {
let level = generator.level();
assert!((0..n).contains(&level));
}
let mut found = false;
for _ in 0..MAX {
let level = generator.level();
if level == 0 {
found = true;
break;
}
}
if !found {
bail!("Failed to generate a level-0 node.");
}
found = false;
for _ in 0..MAX {
let level = generator.level();
if level == n.checked_sub(1).expect("n is guaranteed to be > 0") {
found = true;
break;
}
}
if !found {
bail!(
"Failed to generate a level-{} node.",
n.checked_sub(1).expect("n is guaranteed to be > 0")
);
}
Ok(())
}
#[cfg(not(miri))]
#[rstest]
fn new_large(#[values(512, 1024)] n: usize, #[values(0.001, 0.01)] p: f64) -> Result<()> {
const MAX: usize = 10_000_000;
let mut generator = Geometric::new(n, p)?;
assert_eq!(generator.total(), n);
for _ in 0..1_000 {
let level = generator.level();
assert!((0..n).contains(&level));
}
let mut found = false;
for _ in 0..MAX {
let level = generator.level();
if level == 0 {
found = true;
break;
}
}
if !found {
bail!("Failed to generate a level-0 node.");
}
found = false;
for _ in 0..MAX {
let level = generator.level();
if level == n.checked_sub(1).expect("n is guaranteed to be > 0") {
found = true;
break;
}
}
if !found {
bail!(
"Failed to generate a level-{} node.",
n.checked_sub(1).expect("n is guaranteed to be > 0")
);
}
Ok(())
}
}