use std::cell::Cell;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FusionLimits {
pub max_elementwise_steps: u32,
pub max_elementwise_inputs: u32,
}
impl FusionLimits {
pub const GPU_NATIVE: Self = Self {
max_elementwise_steps: 32,
max_elementwise_inputs: 16,
};
pub const UNBOUNDED: Self = Self {
max_elementwise_steps: u32::MAX,
max_elementwise_inputs: u32::MAX,
};
}
impl Default for FusionLimits {
fn default() -> Self {
Self::GPU_NATIVE
}
}
thread_local! {
static ACTIVE_LIMITS: Cell<FusionLimits> = Cell::new(FusionLimits::default());
}
pub fn active_fusion_limits() -> FusionLimits {
ACTIVE_LIMITS.with(|c| c.get())
}
pub fn with_fusion_limits<T>(limits: FusionLimits, f: impl FnOnce() -> T) -> T {
ACTIVE_LIMITS.with(|c| {
let prev = c.get();
c.set(limits);
let out = f();
c.set(prev);
out
})
}