rlx_fusion/limits.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Per-backend caps for fused IR (elementwise region chains, etc.).
17
18use std::cell::Cell;
19
20/// Hardware / encoder limits for fusion passes.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct FusionLimits {
23 /// Max steps in one `Op::ElementwiseRegion` chain (Metal/wgpu/CUDA: 32).
24 pub max_elementwise_steps: u32,
25 /// Max distinct external inputs in one region (Metal/wgpu/CUDA: 16).
26 pub max_elementwise_inputs: u32,
27}
28
29impl FusionLimits {
30 /// Caps shared by native elementwise-region kernels today.
31 pub const GPU_NATIVE: Self = Self {
32 max_elementwise_steps: 32,
33 max_elementwise_inputs: 16,
34 };
35
36 /// No practical cap — used when regions are unfused to primitives (CPU).
37 pub const UNBOUNDED: Self = Self {
38 max_elementwise_steps: u32::MAX,
39 max_elementwise_inputs: u32::MAX,
40 };
41}
42
43impl Default for FusionLimits {
44 fn default() -> Self {
45 Self::GPU_NATIVE
46 }
47}
48
49thread_local! {
50 static ACTIVE_LIMITS: Cell<FusionLimits> = Cell::new(FusionLimits::default());
51}
52
53/// Limits used by [`crate::fusion::MarkElementwiseRegions`] during this compile.
54pub fn active_fusion_limits() -> FusionLimits {
55 ACTIVE_LIMITS.with(|c| c.get())
56}
57
58/// Run `f` with `limits` installed for mark/clip passes (single-threaded compile).
59pub fn with_fusion_limits<T>(limits: FusionLimits, f: impl FnOnce() -> T) -> T {
60 ACTIVE_LIMITS.with(|c| {
61 let prev = c.get();
62 c.set(limits);
63 let out = f();
64 c.set(prev);
65 out
66 })
67}