use std::cell::RefCell;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Determinism {
Strict,
Relaxed,
}
impl Determinism {
pub fn as_str(self) -> &'static str {
match self {
Determinism::Strict => "strict",
Determinism::Relaxed => "relaxed",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"strict" => Some(Determinism::Strict),
"relaxed" => Some(Determinism::Relaxed),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NumericMode {
Kahan,
Binned,
FixedTree,
}
impl NumericMode {
pub fn as_str(self) -> &'static str {
match self {
NumericMode::Kahan => "kahan",
NumericMode::Binned => "binned",
NumericMode::FixedTree => "fixed-tree",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"kahan" => Some(NumericMode::Kahan),
"binned" => Some(NumericMode::Binned),
"fixed-tree" | "fixedtree" | "fixed_tree" => Some(NumericMode::FixedTree),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AuditMode {
Summary,
Full,
Forensic,
}
impl AuditMode {
pub fn as_str(self) -> &'static str {
match self {
AuditMode::Summary => "summary",
AuditMode::Full => "full",
AuditMode::Forensic => "forensic",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"summary" => Some(AuditMode::Summary),
"full" => Some(AuditMode::Full),
"forensic" => Some(AuditMode::Forensic),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ThermalMode {
Cool,
Balanced,
MaxPerf,
}
impl ThermalMode {
pub fn as_str(self) -> &'static str {
match self {
ThermalMode::Cool => "cool",
ThermalMode::Balanced => "balanced",
ThermalMode::MaxPerf => "max-perf",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"cool" => Some(ThermalMode::Cool),
"balanced" => Some(ThermalMode::Balanced),
"max-perf" | "maxperf" | "max_perf" => Some(ThermalMode::MaxPerf),
_ => None,
}
}
pub fn preset_batch_size(self) -> usize {
match self {
ThermalMode::Cool => 32,
ThermalMode::Balanced => 128,
ThermalMode::MaxPerf => 512,
}
}
pub fn preset_audit_mode(self) -> AuditMode {
match self {
ThermalMode::Cool => AuditMode::Summary,
ThermalMode::Balanced => AuditMode::Full,
ThermalMode::MaxPerf => AuditMode::Summary,
}
}
}
pub const ENERGY_PER_FLOP_JOULES: f64 = 1.0e-10;
pub const ENERGY_PER_BYTE_JOULES: f64 = 1.0e-10;
pub fn energy_estimate_joules(flops: i64, bytes: i64) -> f64 {
let f = flops.max(0) as f64;
let b = bytes.max(0) as f64;
let flop_energy = f * ENERGY_PER_FLOP_JOULES;
let byte_energy = b * ENERGY_PER_BYTE_JOULES;
flop_energy + byte_energy
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct RuntimePolicy {
pub determinism: Determinism,
pub numeric_mode: NumericMode,
pub thermal_mode: ThermalMode,
pub max_threads: usize,
pub batch_size: usize,
pub audit_mode: AuditMode,
pub adaptive: bool,
}
impl RuntimePolicy {
pub fn for_thermal_mode(mode: ThermalMode) -> Self {
Self {
determinism: Determinism::Strict,
numeric_mode: NumericMode::Kahan,
thermal_mode: mode,
max_threads: 0,
batch_size: mode.preset_batch_size(),
audit_mode: mode.preset_audit_mode(),
adaptive: true,
}
}
pub fn summary(&self) -> String {
format!(
"runtime_policy: thermal={} threads={} batch={} audit={} numeric={} determinism={} adaptive={}",
self.thermal_mode.as_str(),
effective_threads(self, detect_cores()),
self.batch_size,
self.audit_mode.as_str(),
self.numeric_mode.as_str(),
self.determinism.as_str(),
self.adaptive,
)
}
}
impl Default for RuntimePolicy {
fn default() -> Self {
Self::for_thermal_mode(ThermalMode::Balanced)
}
}
pub fn detect_cores() -> usize {
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
}
pub fn effective_threads(policy: &RuntimePolicy, detected_cores: usize) -> usize {
let cores = detected_cores.max(1);
if policy.max_threads > 0 {
policy.max_threads.min(cores)
} else {
match policy.thermal_mode {
ThermalMode::Cool => (cores / 4).max(1),
ThermalMode::Balanced => (cores / 2).max(1),
ThermalMode::MaxPerf => cores,
}
}
}
#[cfg(feature = "parallel")]
pub fn apply_thread_cap(n: usize) -> usize {
let full = detect_cores();
if n > 0 && n < full {
let _ = capped_pool(n); }
rayon::current_num_threads()
}
#[cfg(not(feature = "parallel"))]
pub fn apply_thread_cap(_n: usize) -> usize {
1
}
#[cfg(feature = "parallel")]
const SUSTAIN_WINDOW: std::time::Duration = std::time::Duration::from_millis(2000);
#[cfg(feature = "parallel")]
const IDLE_RESET: std::time::Duration = std::time::Duration::from_millis(500);
#[cfg(feature = "parallel")]
#[derive(Default)]
struct AdaptiveState {
burst_start: Option<std::time::Instant>,
last_op: Option<std::time::Instant>,
}
#[cfg(feature = "parallel")]
thread_local! {
static ADAPTIVE: RefCell<AdaptiveState> = RefCell::new(AdaptiveState::default());
}
#[cfg(feature = "parallel")]
fn decide_sustained(
state: &mut AdaptiveState,
now: std::time::Instant,
window: std::time::Duration,
idle: std::time::Duration,
) -> bool {
if let Some(last) = state.last_op {
if now.duration_since(last) > idle {
state.burst_start = None;
}
}
let start = *state.burst_start.get_or_insert(now);
state.last_op = Some(now);
now.duration_since(start) >= window
}
#[cfg(feature = "parallel")]
fn is_sustained_now() -> bool {
let now = std::time::Instant::now();
ADAPTIVE.with(|s| decide_sustained(&mut s.borrow_mut(), now, SUSTAIN_WINDOW, IDLE_RESET))
}
#[cfg(feature = "parallel")]
fn reset_adaptive_state() {
ADAPTIVE.with(|s| *s.borrow_mut() = AdaptiveState::default());
}
#[cfg(not(feature = "parallel"))]
fn reset_adaptive_state() {}
#[cfg(feature = "parallel")]
static CAPPED_POOL: std::sync::OnceLock<Option<rayon::ThreadPool>> = std::sync::OnceLock::new();
#[cfg(feature = "parallel")]
fn capped_pool(cap: usize) -> Option<&'static rayon::ThreadPool> {
CAPPED_POOL
.get_or_init(|| rayon::ThreadPoolBuilder::new().num_threads(cap).build().ok())
.as_ref()
}
#[cfg(feature = "parallel")]
pub fn run_parallel<R, F>(work: F) -> R
where
R: Send,
F: FnOnce() -> R + Send,
{
if rayon::current_thread_index().is_some() {
return work();
}
let policy = get();
let cap = effective_threads(&policy, detect_cores());
if cap >= detect_cores() {
return work(); }
let throttle = if policy.adaptive { is_sustained_now() } else { true };
if !throttle {
return work(); }
match capped_pool(cap) {
Some(pool) => pool.install(work),
None => work(),
}
}
#[cfg(not(feature = "parallel"))]
pub fn run_parallel<R, F>(work: F) -> R
where
F: FnOnce() -> R,
{
work()
}
thread_local! {
pub(crate) static POLICY: RefCell<RuntimePolicy> = RefCell::new(RuntimePolicy::default());
}
pub fn get() -> RuntimePolicy {
POLICY.with(|c| *c.borrow())
}
pub fn reset() {
POLICY.with(|c| *c.borrow_mut() = RuntimePolicy::default());
reset_adaptive_state();
}
pub fn set_thermal_mode(mode: ThermalMode) {
POLICY.with(|c| {
let mut p = c.borrow_mut();
p.thermal_mode = mode;
p.batch_size = mode.preset_batch_size();
p.audit_mode = mode.preset_audit_mode();
p.max_threads = 0;
});
}
pub fn set_threads(n: usize) {
POLICY.with(|c| c.borrow_mut().max_threads = n);
}
pub fn set_batch_size(n: usize) {
POLICY.with(|c| c.borrow_mut().batch_size = n);
}
pub fn set_audit_mode(mode: AuditMode) {
POLICY.with(|c| c.borrow_mut().audit_mode = mode);
}
pub fn set_numeric_mode(mode: NumericMode) {
POLICY.with(|c| c.borrow_mut().numeric_mode = mode);
}
pub fn set_determinism(d: Determinism) {
POLICY.with(|c| c.borrow_mut().determinism = d);
}
pub fn set_adaptive(on: bool) {
POLICY.with(|c| c.borrow_mut().adaptive = on);
}
pub fn current_effective_threads() -> usize {
effective_threads(&get(), detect_cores())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_is_balanced() {
let p = RuntimePolicy::default();
assert_eq!(p.thermal_mode, ThermalMode::Balanced);
assert_eq!(p.determinism, Determinism::Strict);
assert_eq!(p.numeric_mode, NumericMode::Kahan);
assert_eq!(p.max_threads, 0);
assert_eq!(p.batch_size, 128);
assert_eq!(p.audit_mode, AuditMode::Full);
assert!(p.adaptive, "adaptive (race-to-idle) is on by default");
}
#[test]
fn set_adaptive_round_trip() {
reset();
assert!(get().adaptive);
set_adaptive(false);
assert!(!get().adaptive);
reset();
assert!(get().adaptive, "reset restores adaptive default");
}
#[cfg(feature = "parallel")]
#[test]
fn decide_sustained_burst_then_throttle() {
use std::time::{Duration, Instant};
let mut st = AdaptiveState::default();
let t0 = Instant::now();
let win = Duration::from_millis(2000);
let idle = Duration::from_millis(500);
let at = |ms: u64| t0 + Duration::from_millis(ms);
assert!(!decide_sustained(&mut st, t0, win, idle), "burst starts");
let mut ms = 300;
while ms < 2000 {
assert!(!decide_sustained(&mut st, at(ms), win, idle), "still burst at {ms}ms");
ms += 300;
}
assert!(decide_sustained(&mut st, at(2100), win, idle), "sustained past window");
}
#[cfg(feature = "parallel")]
#[test]
fn decide_sustained_idle_resets_burst() {
use std::time::{Duration, Instant};
let mut st = AdaptiveState::default();
let t0 = Instant::now();
let win = Duration::from_millis(2000);
let idle = Duration::from_millis(500);
let at = |ms: u64| t0 + Duration::from_millis(ms);
assert!(!decide_sustained(&mut st, t0, win, idle));
let mut ms = 300;
while ms <= 2100 {
decide_sustained(&mut st, at(ms), win, idle);
ms += 300;
}
assert!(decide_sustained(&mut st, at(2400), win, idle), "is sustained");
assert!(
!decide_sustained(&mut st, at(3000), win, idle),
"idle gap should reset the burst → full speed again"
);
}
#[cfg(feature = "parallel")]
#[test]
fn run_parallel_preserves_value_under_throttle() {
reset();
set_thermal_mode(ThermalMode::Cool);
set_adaptive(false); let throttled: f64 = run_parallel(|| (0..1000).map(|i| i as f64).sum());
set_adaptive(true);
reset_adaptive_state(); let burst: f64 = run_parallel(|| (0..1000).map(|i| i as f64).sum());
assert_eq!(throttled, burst);
reset();
}
#[test]
fn thermal_presets_distinct() {
assert_eq!(ThermalMode::Cool.preset_batch_size(), 32);
assert_eq!(ThermalMode::Balanced.preset_batch_size(), 128);
assert_eq!(ThermalMode::MaxPerf.preset_batch_size(), 512);
assert_eq!(ThermalMode::Cool.preset_audit_mode(), AuditMode::Summary);
assert_eq!(ThermalMode::Balanced.preset_audit_mode(), AuditMode::Full);
assert_eq!(ThermalMode::MaxPerf.preset_audit_mode(), AuditMode::Summary);
}
#[test]
fn effective_threads_monotonic_in_thermal_mode() {
let cores = 8;
let cool = effective_threads(&RuntimePolicy::for_thermal_mode(ThermalMode::Cool), cores);
let bal = effective_threads(&RuntimePolicy::for_thermal_mode(ThermalMode::Balanced), cores);
let max = effective_threads(&RuntimePolicy::for_thermal_mode(ThermalMode::MaxPerf), cores);
assert!(cool <= bal, "cool {cool} should not exceed balanced {bal}");
assert!(bal <= max, "balanced {bal} should not exceed max-perf {max}");
assert_eq!(cool, 2);
assert_eq!(bal, 4);
assert_eq!(max, 8);
}
#[test]
fn effective_threads_always_in_range() {
for cores in [1usize, 2, 3, 7, 16, 64] {
for mode in [ThermalMode::Cool, ThermalMode::Balanced, ThermalMode::MaxPerf] {
let t = effective_threads(&RuntimePolicy::for_thermal_mode(mode), cores);
assert!(t >= 1, "threads must be >= 1 (cores={cores}, mode={mode:?})");
assert!(t <= cores, "threads {t} must be <= cores {cores}");
}
}
}
#[test]
fn explicit_thread_cap_wins_and_clamps() {
let mut p = RuntimePolicy::for_thermal_mode(ThermalMode::MaxPerf);
p.max_threads = 3;
assert_eq!(effective_threads(&p, 8), 3, "explicit cap should be honored");
p.max_threads = 100;
assert_eq!(effective_threads(&p, 8), 8, "cap clamps to detected cores");
}
#[test]
fn effective_threads_zero_cores_is_one() {
let p = RuntimePolicy::default();
assert_eq!(effective_threads(&p, 0), 1);
}
#[test]
fn energy_is_non_negative_and_zero_at_zero() {
assert_eq!(energy_estimate_joules(0, 0), 0.0);
assert!(energy_estimate_joules(-5, -7) >= 0.0);
assert_eq!(energy_estimate_joules(-5, -7), 0.0, "negatives clamp to zero");
}
#[test]
fn energy_is_monotonic() {
let a = energy_estimate_joules(1000, 1000);
let b = energy_estimate_joules(2000, 1000);
let c = energy_estimate_joules(2000, 2000);
assert!(b > a, "more flops => more energy");
assert!(c > b, "more bytes => more energy");
}
#[test]
fn energy_is_additive_in_components() {
let flop_only = energy_estimate_joules(1_000_000, 0);
let byte_only = energy_estimate_joules(0, 1_000_000);
let both = energy_estimate_joules(1_000_000, 1_000_000);
assert_eq!(both, flop_only + byte_only);
}
#[test]
fn energy_is_deterministic_across_calls() {
let first = energy_estimate_joules(123_456, 789);
for _ in 0..1000 {
assert_eq!(energy_estimate_joules(123_456, 789), first);
}
}
#[test]
fn enum_round_trips() {
for m in [ThermalMode::Cool, ThermalMode::Balanced, ThermalMode::MaxPerf] {
assert_eq!(ThermalMode::from_str(m.as_str()), Some(m));
}
for m in [NumericMode::Kahan, NumericMode::Binned, NumericMode::FixedTree] {
assert_eq!(NumericMode::from_str(m.as_str()), Some(m));
}
for m in [AuditMode::Summary, AuditMode::Full, AuditMode::Forensic] {
assert_eq!(AuditMode::from_str(m.as_str()), Some(m));
}
for m in [Determinism::Strict, Determinism::Relaxed] {
assert_eq!(Determinism::from_str(m.as_str()), Some(m));
}
}
#[test]
fn invalid_mode_strings_return_none() {
assert_eq!(ThermalMode::from_str("blazing"), None);
assert_eq!(NumericMode::from_str(""), None);
assert_eq!(AuditMode::from_str("paranoid"), None);
assert_eq!(Determinism::from_str("yolo"), None);
}
#[test]
fn set_get_round_trip_and_reset() {
reset();
set_thermal_mode(ThermalMode::Cool);
let p = get();
assert_eq!(p.thermal_mode, ThermalMode::Cool);
assert_eq!(p.batch_size, 32);
assert_eq!(p.audit_mode, AuditMode::Summary);
set_threads(2);
set_batch_size(64);
set_audit_mode(AuditMode::Forensic);
set_numeric_mode(NumericMode::Binned);
let p = get();
assert_eq!(p.max_threads, 2);
assert_eq!(p.batch_size, 64);
assert_eq!(p.audit_mode, AuditMode::Forensic);
assert_eq!(p.numeric_mode, NumericMode::Binned);
reset();
assert_eq!(get(), RuntimePolicy::default());
}
#[test]
fn profile_then_override_precedence() {
reset();
set_thermal_mode(ThermalMode::MaxPerf);
assert_eq!(get().batch_size, 512);
set_batch_size(16);
assert_eq!(get().batch_size, 16, "explicit override wins over profile preset");
assert_eq!(get().thermal_mode, ThermalMode::MaxPerf, "mode unchanged by batch override");
reset();
}
#[test]
fn apply_thread_cap_never_panics_and_is_positive() {
let n = apply_thread_cap(2);
assert!(n >= 1);
}
#[test]
fn summary_is_stable() {
reset();
let s1 = get().summary();
let s2 = get().summary();
assert_eq!(s1, s2);
assert!(s1.contains("thermal=balanced"));
assert!(s1.contains("determinism=strict"));
reset();
}
}