bitbelay_cli/commands/
chi_squared.rs

1//! A command for running the chi-squared test suite.
2
3use std::hash::BuildHasher;
4use std::num::NonZeroUsize;
5
6use anyhow::anyhow;
7use anyhow::bail;
8use bitbelay_providers::Provider;
9use bitbelay_report::Config;
10use bitbelay_suites::chi_squared::suite::Builder;
11use bitbelay_suites::r#trait::Suite as _;
12use tracing::Level;
13
14/// The default number of iterations per bucket.
15///
16/// NOTE: if this changes, update the argument documentation for `iterations`.
17const DEFAULT_ITERATIONS_PER_BUCKET: usize = 1000;
18
19/// Arguments for the chi-squared command.
20#[derive(clap::Args, Debug)]
21pub struct Args {
22    /// The number of buckets.
23    #[arg(short, long, default_value_t = 64)]
24    buckets: usize,
25
26    /// The number of iterations to test.
27    ///
28    /// If no number is given, then the number will be 100 * the number of
29    /// buckets to ensure enough samples are taken.
30    #[arg(short, long)]
31    iterations: Option<usize>,
32
33    /// The threshold of statistical significance.
34    #[arg(long, default_value_t = 0.05)]
35    threshold: f64,
36}
37
38/// The main function for the chi-squared command.
39pub fn main<H: BuildHasher>(
40    args: Args,
41    build_hasher: H,
42    provider: Box<dyn Provider>,
43) -> anyhow::Result<()> {
44    tracing::info!("Starting chi-squared test suite.");
45
46    let buckets =
47        NonZeroUsize::try_from(args.buckets).map_err(|_| anyhow!("--buckets must be non-zero!"))?;
48
49    let iterations = NonZeroUsize::try_from(
50        args.iterations
51            .unwrap_or(args.buckets * DEFAULT_ITERATIONS_PER_BUCKET),
52    )
53    .map_err(|_| anyhow!("--iterations must be non-zero!"))?;
54
55    if !(0.0..=1.0).contains(&args.threshold) {
56        bail!("--threshold must be between 0.0 and 1.0!");
57    }
58
59    tracing::info!(
60        "Running chi-squared test with {} buckets for {} iterations.",
61        args.buckets,
62        iterations
63    );
64
65    let mut suite = Builder::default()
66        .buckets(buckets)
67        .unwrap()
68        .build_hasher(&build_hasher)
69        .unwrap()
70        .try_build()
71        .unwrap();
72
73    suite.run_goodness_of_fit(provider, iterations, args.threshold);
74
75    if tracing::enabled!(Level::TRACE) {
76        // SAFETY: we know there must be one test because we just ran it above!
77        let test = suite.tests().last().unwrap();
78        for (i, entries) in test
79            .as_goodness_of_fit_test()
80            // SAFETY: we also know that the last test was a goodness of fit test.
81            .unwrap()
82            .buckets()
83            .iter()
84            .enumerate()
85        {
86            tracing::trace!("[Bucket {}] => {}", i + 1, entries);
87        }
88    }
89
90    suite
91        .report()
92        .write_to(&mut std::io::stderr(), &Config::default())?;
93
94    Ok(())
95}