Skip to main content

rlx_fusion/
limits.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Per-backend caps for fused IR (elementwise region chains, etc.).
17
18use std::cell::Cell;
19
20/// Hardware / encoder limits for fusion passes.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct FusionLimits {
23    /// Max steps in one `Op::ElementwiseRegion` chain (Metal/wgpu/CUDA: 32).
24    pub max_elementwise_steps: u32,
25    /// Max distinct external inputs in one region (Metal/wgpu/CUDA: 16).
26    pub max_elementwise_inputs: u32,
27}
28
29impl FusionLimits {
30    /// Caps shared by native elementwise-region kernels today.
31    pub const GPU_NATIVE: Self = Self {
32        max_elementwise_steps: 32,
33        max_elementwise_inputs: 16,
34    };
35
36    /// No practical cap — used when regions are unfused to primitives (CPU).
37    pub const UNBOUNDED: Self = Self {
38        max_elementwise_steps: u32::MAX,
39        max_elementwise_inputs: u32::MAX,
40    };
41}
42
43impl Default for FusionLimits {
44    fn default() -> Self {
45        Self::GPU_NATIVE
46    }
47}
48
49thread_local! {
50    static ACTIVE_LIMITS: Cell<FusionLimits> = Cell::new(FusionLimits::default());
51}
52
53/// Limits used by [`crate::fusion::MarkElementwiseRegions`] during this compile.
54pub fn active_fusion_limits() -> FusionLimits {
55    ACTIVE_LIMITS.with(|c| c.get())
56}
57
58/// Run `f` with `limits` installed for mark/clip passes (single-threaded compile).
59pub fn with_fusion_limits<T>(limits: FusionLimits, f: impl FnOnce() -> T) -> T {
60    ACTIVE_LIMITS.with(|c| {
61        let prev = c.get();
62        c.set(limits);
63        let out = f();
64        c.set(prev);
65        out
66    })
67}