use std::sync::OnceLock;
pub const DEFAULT_PARALLEL_THRESHOLD_BYTES: usize = 64 * 1024;
pub const ENV_THREADS: &str = "TENSOGRAM_THREADS";
fn env_threads() -> u32 {
static CACHED: OnceLock<u32> = OnceLock::new();
*CACHED.get_or_init(|| {
std::env::var(ENV_THREADS)
.ok()
.and_then(|s| s.trim().parse::<u32>().ok())
.unwrap_or(0)
})
}
#[inline]
pub(crate) fn resolve_budget(requested: u32) -> u32 {
if requested > 0 {
requested
} else {
env_threads()
}
}
#[inline]
pub(crate) fn should_parallelise(
budget: u32,
work_bytes: usize,
threshold_bytes: Option<usize>,
) -> bool {
if budget == 0 {
return false;
}
let threshold = threshold_bytes.unwrap_or(DEFAULT_PARALLEL_THRESHOLD_BYTES);
work_bytes >= threshold
}
#[inline]
pub(crate) fn is_axis_b_friendly(encoding: &str, filter: &str, compression: &str) -> bool {
matches!(compression, "blosc2" | "zstd")
|| matches!(encoding, "simple_packing")
|| matches!(filter, "shuffle")
}
#[inline]
pub(crate) fn use_axis_a(n_objects: usize, budget: u32, any_object_axis_b_friendly: bool) -> bool {
if budget <= 1 || n_objects <= 1 {
return false;
}
!any_object_axis_b_friendly
}
#[inline]
pub(crate) fn with_pool<F, R>(budget: u32, f: F) -> R
where
F: FnOnce() -> R + Send,
R: Send,
{
if budget <= 1 {
return f();
}
#[cfg(feature = "threads")]
{
match rayon::ThreadPoolBuilder::new()
.num_threads(budget as usize)
.thread_name(|i| format!("tensogram-worker-{i}"))
.build()
{
Ok(pool) => pool.install(f),
Err(e) => {
warn_pool_build_failure(&e.to_string());
f()
}
}
}
#[cfg(not(feature = "threads"))]
{
warn_threads_feature_disabled();
f()
}
}
#[inline]
pub(crate) fn run_maybe_pooled<F, R>(
budget: u32,
parallel: bool,
intra_codec_threads: u32,
f: F,
) -> R
where
F: FnOnce() -> R + Send,
R: Send,
{
if parallel && intra_codec_threads > 1 {
with_pool(budget, f)
} else {
f()
}
}
#[cfg(feature = "threads")]
fn warn_pool_build_failure(msg: &str) {
static WARNED: OnceLock<()> = OnceLock::new();
WARNED.get_or_init(|| {
tracing::warn!(
error = msg,
"failed to build rayon thread pool; falling back to sequential execution"
);
});
}
#[cfg(not(feature = "threads"))]
fn warn_threads_feature_disabled() {
static WARNED: OnceLock<()> = OnceLock::new();
WARNED.get_or_init(|| {
tracing::warn!(
"threads > 1 requested but the 'threads' cargo feature is disabled; \
falling back to sequential execution"
);
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_zero_with_no_env_returns_zero() {
if std::env::var(ENV_THREADS).is_err() {
assert_eq!(resolve_budget(0), 0);
}
}
#[test]
fn resolve_nonzero_ignores_env() {
assert_eq!(resolve_budget(4), 4);
}
#[test]
fn should_parallelise_below_threshold() {
assert!(!should_parallelise(8, 1024, None));
assert!(!should_parallelise(0, usize::MAX, None));
}
#[test]
fn should_parallelise_above_threshold() {
assert!(should_parallelise(2, 1024 * 1024, None));
}
#[test]
fn should_parallelise_custom_threshold() {
assert!(should_parallelise(2, 1, Some(0)));
assert!(!should_parallelise(2, 1024 * 1024, Some(usize::MAX)));
}
#[test]
fn should_parallelise_boundary_values() {
assert!(should_parallelise(
2,
DEFAULT_PARALLEL_THRESHOLD_BYTES,
None
));
assert!(!should_parallelise(
2,
DEFAULT_PARALLEL_THRESHOLD_BYTES - 1,
None
));
assert!(!should_parallelise(0, usize::MAX, Some(0)));
assert!(should_parallelise(1, 0, Some(0)));
assert!(!should_parallelise(
u32::MAX,
usize::MAX - 1,
Some(usize::MAX)
));
}
#[test]
fn with_pool_budget_zero_runs_inline() {
let result = with_pool(0, || 42);
assert_eq!(result, 42);
}
#[test]
fn with_pool_budget_one_runs_inline() {
let result = with_pool(1, || 42);
assert_eq!(result, 42);
}
#[test]
fn is_axis_b_friendly_reports_known_codecs() {
assert!(is_axis_b_friendly("none", "none", "blosc2"));
assert!(is_axis_b_friendly("none", "none", "zstd"));
assert!(is_axis_b_friendly("simple_packing", "none", "none"));
assert!(is_axis_b_friendly("none", "shuffle", "none"));
assert!(!is_axis_b_friendly("none", "none", "none"));
assert!(!is_axis_b_friendly("none", "none", "lz4"));
assert!(!is_axis_b_friendly("none", "none", "szip"));
assert!(!is_axis_b_friendly("none", "none", "zfp"));
}
#[test]
fn use_axis_a_single_object_is_always_false() {
assert!(!use_axis_a(1, 8, false));
assert!(!use_axis_a(1, 8, true));
assert!(!use_axis_a(0, 8, false));
}
#[test]
fn use_axis_a_low_budget_is_always_false() {
assert!(!use_axis_a(10, 0, false));
assert!(!use_axis_a(10, 1, false));
}
#[test]
fn use_axis_a_multi_object_b_friendly_prefers_b() {
assert!(!use_axis_a(10, 4, true));
}
#[test]
fn use_axis_a_multi_object_non_b_friendly_uses_a() {
assert!(use_axis_a(10, 4, false));
}
#[test]
fn run_maybe_pooled_no_budget_runs_inline() {
assert_eq!(run_maybe_pooled(0, false, 0, || 7), 7);
assert_eq!(run_maybe_pooled(4, false, 0, || 7), 7);
assert_eq!(run_maybe_pooled(4, true, 0, || 7), 7);
assert_eq!(run_maybe_pooled(4, true, 1, || 7), 7);
}
#[cfg(feature = "threads")]
#[test]
fn run_maybe_pooled_with_budget_installs_pool() {
let observed = run_maybe_pooled(4, true, 4, rayon::current_num_threads);
assert_eq!(observed, 4);
}
#[cfg(feature = "threads")]
#[test]
fn with_pool_budget_four_uses_pool() {
use std::sync::atomic::{AtomicUsize, Ordering};
let counter = AtomicUsize::new(0);
let sum: usize = with_pool(4, || {
use rayon::prelude::*;
(0..1000u64)
.into_par_iter()
.map(|i| {
counter.fetch_add(1, Ordering::Relaxed);
i as usize
})
.sum()
});
assert_eq!(sum, (0..1000).sum::<u64>() as usize);
assert_eq!(counter.load(Ordering::Relaxed), 1000);
}
}