hyperreal 0.10.6

Exact rational and computable real arithmetic in Rust
Documentation
use std::cell::Cell;
use std::collections::BTreeMap;
use std::sync::{Mutex, OnceLock};

use num::{BigUint, One, Zero};

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct DispatchCount {
    pub layer: &'static str,
    pub operation: &'static str,
    pub path: &'static str,
    pub count: u64,
}

#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct CommonFactorBuckets {
    pub none: u64,
    pub power_of_two: u64,
    pub small: u64,
    pub medium: u64,
    pub large: u64,
}

#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct RationalTraceStats {
    pub temporary_rationals: u64,
    pub reductions: u64,
    pub gcds: u64,
    pub common_factors: CommonFactorBuckets,
    pub peak_operand_bits: u64,
}

#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
struct DispatchKey {
    layer: &'static str,
    operation: &'static str,
    path: &'static str,
}

static COUNTS: OnceLock<Mutex<BTreeMap<DispatchKey, u64>>> = OnceLock::new();
static RATIONAL_STATS: OnceLock<Mutex<RationalTraceStats>> = OnceLock::new();

thread_local! {
    static RECORDING: Cell<bool> = const { Cell::new(false) };
}

fn counts() -> &'static Mutex<BTreeMap<DispatchKey, u64>> {
    COUNTS.get_or_init(|| Mutex::new(BTreeMap::new()))
}

fn rational_stats() -> &'static Mutex<RationalTraceStats> {
    RATIONAL_STATS.get_or_init(|| Mutex::new(RationalTraceStats::default()))
}

fn is_recording() -> bool {
    RECORDING.with(Cell::get)
}

pub struct RecordingGuard {
    previous: bool,
}

impl Drop for RecordingGuard {
    fn drop(&mut self) {
        RECORDING.with(|recording| recording.set(self.previous));
    }
}

pub fn reset() {
    counts()
        .lock()
        .expect("dispatch trace lock poisoned")
        .clear();
    reset_rational_stats();
}

pub fn recording_scope() -> RecordingGuard {
    let previous = RECORDING.with(|recording| {
        let previous = recording.get();
        recording.set(true);
        previous
    });
    RecordingGuard { previous }
}

pub fn with_recording<T>(f: impl FnOnce() -> T) -> T {
    let _guard = recording_scope();
    f()
}

pub fn record(layer: &'static str, operation: &'static str, path: &'static str) {
    if !is_recording() {
        return;
    }
    let key = DispatchKey {
        layer,
        operation,
        path,
    };
    *counts()
        .lock()
        .expect("dispatch trace lock poisoned")
        .entry(key)
        .or_insert(0) += 1;
}

fn update_peak(stats: &mut RationalTraceStats, value: &BigUint) {
    stats.peak_operand_bits = stats.peak_operand_bits.max(value.bits());
}

fn record_common_factor(stats: &mut RationalTraceStats, divisor: &BigUint) {
    if divisor.is_zero() || divisor.is_one() {
        stats.common_factors.none += 1;
    } else if divisor.trailing_zeros() == Some(divisor.bits() - 1) {
        stats.common_factors.power_of_two += 1;
    } else {
        match divisor.bits() {
            0..=8 => stats.common_factors.small += 1,
            9..=64 => stats.common_factors.medium += 1,
            _ => stats.common_factors.large += 1,
        }
    }
}

pub fn record_rational_temporary() {
    if !is_recording() {
        return;
    }
    rational_stats()
        .lock()
        .expect("rational trace lock poisoned")
        .temporary_rationals += 1;
}

pub fn record_rational_reduction(numerator: &BigUint, denominator: &BigUint) {
    if !is_recording() {
        return;
    }
    let mut stats = rational_stats()
        .lock()
        .expect("rational trace lock poisoned");
    stats.reductions += 1;
    update_peak(&mut stats, numerator);
    update_peak(&mut stats, denominator);
}

pub fn record_rational_gcd(left: &BigUint, right: &BigUint, divisor: &BigUint) {
    if !is_recording() {
        return;
    }
    let mut stats = rational_stats()
        .lock()
        .expect("rational trace lock poisoned");
    stats.gcds += 1;
    update_peak(&mut stats, left);
    update_peak(&mut stats, right);
    update_peak(&mut stats, divisor);
    record_common_factor(&mut stats, divisor);
}

pub fn record_rational_power_of_two_common_factor(shift: u64) {
    if !is_recording() {
        return;
    }
    let mut stats = rational_stats()
        .lock()
        .expect("rational trace lock poisoned");
    if shift == 0 {
        stats.common_factors.none += 1;
    } else {
        stats.common_factors.power_of_two += 1;
    }
}

pub fn reset_rational_stats() {
    *rational_stats()
        .lock()
        .expect("rational trace lock poisoned") = RationalTraceStats::default();
}

pub fn snapshot_rational_stats() -> RationalTraceStats {
    *rational_stats()
        .lock()
        .expect("rational trace lock poisoned")
}

pub fn take_rational_stats() -> RationalTraceStats {
    let mut stats = rational_stats()
        .lock()
        .expect("rational trace lock poisoned");
    let snapshot = *stats;
    *stats = RationalTraceStats::default();
    snapshot
}

pub fn snapshot() -> Vec<DispatchCount> {
    counts()
        .lock()
        .expect("dispatch trace lock poisoned")
        .iter()
        .map(|(key, count)| DispatchCount {
            layer: key.layer,
            operation: key.operation,
            path: key.path,
            count: *count,
        })
        .collect()
}

pub fn take() -> Vec<DispatchCount> {
    let mut counts = counts().lock().expect("dispatch trace lock poisoned");
    let snapshot = counts
        .iter()
        .map(|(key, count)| DispatchCount {
            layer: key.layer,
            operation: key.operation,
            path: key.path,
            count: *count,
        })
        .collect();
    counts.clear();
    snapshot
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn dispatch_trace_records_only_inside_scope() {
        reset();
        record("real", "sin", "ignored");
        assert!(snapshot().is_empty());

        with_recording(|| {
            record("real", "sin", "path");
            record("real", "sin", "path");
            record("computable", "sin", "other");
        });

        let counts = take();
        assert_eq!(counts.len(), 2);
        assert!(counts.iter().any(|entry| {
            entry.layer == "real"
                && entry.operation == "sin"
                && entry.path == "path"
                && entry.count == 2
        }));
        assert!(snapshot().is_empty());
    }

    #[test]
    fn rational_trace_records_reductions_and_gcds() {
        use crate::Rational;

        reset();
        with_recording(|| {
            let left = Rational::fraction(6, 8).unwrap();
            let right = Rational::fraction(9, 10).unwrap();
            let _ = left + right;
        });

        let stats = take_rational_stats();
        assert!(stats.temporary_rationals > 0);
        assert!(stats.reductions > 0);
        assert!(stats.gcds > 0);
        assert!(stats.peak_operand_bits > 0);
    }
}