Skip to main content

rlx_ir/
region_encode.rs

1// RLX - versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7
8//! Shared [`Op::ElementwiseRegion`] metadata encoding for GPU region kernels.
9
10use crate::op::*;
11use crate::shape::Shape;
12
13pub const REGION_META_INPUT_WORDS: usize = 16;
14pub const REGION_META_CHAIN_WORDS: usize = 128;
15pub const REGION_META_TAIL_WORDS: usize = 6;
16pub const REGION_META_WORDS: usize =
17    REGION_META_INPUT_WORDS + REGION_META_CHAIN_WORDS + REGION_META_TAIL_WORDS;
18
19/// Max batch slices for the single-launch batch region kernel (CUDA/ROCm/Metal/wgpu).
20pub const FK_BATCH_SINGLE_KERNEL_MAX: usize = 64;
21
22/// `RLX_FK_BATCH_SINGLE_KERNEL=1` at compile time.
23pub fn fk_batch_single_kernel_enabled() -> bool {
24    crate::env::flag("RLX_FK_BATCH_SINGLE_KERNEL")
25}
26
27/// Whether `BatchElementwiseRegion` should use one batch-region launch.
28pub fn fk_batch_use_single_launch(num_batch: usize, prologue: RegionPrologue) -> bool {
29    fk_batch_single_kernel_enabled()
30        && prologue == RegionPrologue::None
31        && num_batch <= FK_BATCH_SINGLE_KERNEL_MAX
32}
33
34pub const REGION_PROLOGUE_NONE: u32 = 0;
35pub const REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW: u32 = 1;
36
37/// NCHW output dimensions for prologue kernels (`n,c,h,w`).
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct RegionNchwDims {
40    pub n: u32,
41    pub c: u32,
42    pub h: u32,
43    pub w: u32,
44}
45
46impl RegionNchwDims {
47    pub fn from_shape(shape: &Shape) -> Option<Self> {
48        if shape.rank() != 4 {
49            return None;
50        }
51        Some(Self {
52            n: shape.dim(0).unwrap_static() as u32,
53            c: shape.dim(1).unwrap_static() as u32,
54            h: shape.dim(2).unwrap_static() as u32,
55            w: shape.dim(3).unwrap_static() as u32,
56        })
57    }
58
59    /// Linear element count for an NCHW tensor.
60    pub fn num_elements(self) -> u32 {
61        self.n * self.c * self.h * self.w
62    }
63}
64
65/// 3D launch grid for resize-prologue region kernels (width x height x N*C).
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub struct PrologueLaunchGrid {
68    pub width: u32,
69    pub height: u32,
70    pub depth: u32,
71}
72
73impl PrologueLaunchGrid {
74    pub fn from_output_shape(shape: &Shape) -> Option<Self> {
75        let d = RegionNchwDims::from_shape(shape)?;
76        Some(Self {
77            width: d.w,
78            height: d.h,
79            depth: d.n * d.c,
80        })
81    }
82}
83
84/// Encode operand for region chain steps (shared across CUDA / Metal / wgpu).
85pub fn encode_chain_operand(op: &ChainOperand) -> u32 {
86    match *op {
87        ChainOperand::Input(i) => i & 0x7FFF_FFFFu32,
88        ChainOperand::Step(i) => 0x8000_0000u32 | (i & 0x7FFF_FFFFu32),
89    }
90}
91
92pub fn activation_sub(a: Activation) -> u32 {
93    match a {
94        Activation::Gelu => 0,
95        Activation::GeluApprox => 1,
96        Activation::Silu => 2,
97        Activation::Relu => 3,
98        Activation::Sigmoid => 4,
99        Activation::Tanh => 5,
100        Activation::Exp => 6,
101        Activation::Log => 7,
102        Activation::Sqrt => 8,
103        Activation::Rsqrt => 9,
104        Activation::Neg => 10,
105        Activation::Abs => 11,
106        Activation::Round => 12,
107        Activation::Sin => 13,
108        Activation::Cos => 14,
109        Activation::Tan => 15,
110        Activation::Atan => 16,
111    }
112}
113
114pub fn binary_sub(b: BinaryOp) -> u32 {
115    match b {
116        BinaryOp::Add => 0,
117        BinaryOp::Sub => 1,
118        BinaryOp::Mul => 2,
119        BinaryOp::Div => 3,
120        BinaryOp::Max => 4,
121        BinaryOp::Min => 5,
122        BinaryOp::Pow => 6,
123    }
124}
125
126pub fn compare_sub(c: CmpOp) -> u32 {
127    match c {
128        CmpOp::Eq => 0,
129        CmpOp::Ne => 1,
130        CmpOp::Lt => 2,
131        CmpOp::Le => 3,
132        CmpOp::Gt => 4,
133        CmpOp::Ge => 5,
134    }
135}
136
137/// Pack chain steps into 128 u32 words.
138pub fn encode_chain_steps(chain: &[ChainStep]) -> [u32; REGION_META_CHAIN_WORDS] {
139    let mut chain_enc = [0u32; REGION_META_CHAIN_WORDS];
140    for (k, step) in chain.iter().enumerate() {
141        let base = k * 4;
142        let (kind, sub, lhs, rhs) = match step {
143            ChainStep::Activation(a, src) => {
144                (0u32, activation_sub(*a), encode_chain_operand(src), 0u32)
145            }
146            ChainStep::Cast(_, src) => (1u32, 0, encode_chain_operand(src), 0u32),
147            ChainStep::Binary(op, l, r) => (
148                2u32,
149                binary_sub(*op),
150                encode_chain_operand(l),
151                encode_chain_operand(r),
152            ),
153            ChainStep::Compare(op, l, r) => (
154                3u32,
155                compare_sub(*op),
156                encode_chain_operand(l),
157                encode_chain_operand(r),
158            ),
159            ChainStep::Where(c, t, f) => (
160                4u32,
161                encode_chain_operand(c),
162                encode_chain_operand(t),
163                encode_chain_operand(f),
164            ),
165        };
166        chain_enc[base] = kind;
167        chain_enc[base + 1] = sub;
168        chain_enc[base + 2] = lhs;
169        chain_enc[base + 3] = rhs;
170    }
171    chain_enc
172}
173
174/// Prologue tag + NCHW shape + external input index for the region output tensor.
175pub fn encode_prologue_tail(
176    prologue: RegionPrologue,
177    out_shape: &Shape,
178    prologue_input: u32,
179) -> [u32; REGION_META_TAIL_WORDS] {
180    let mut tail = [0u32; REGION_META_TAIL_WORDS];
181    match prologue {
182        RegionPrologue::None => {}
183        RegionPrologue::ResizeNearest2x => {
184            if let Some(d) = RegionNchwDims::from_shape(out_shape) {
185                tail[0] = REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW;
186                tail[1] = d.n;
187                tail[2] = d.c;
188                tail[3] = d.h;
189                tail[4] = d.w;
190            }
191        }
192    }
193    tail[5] = prologue_input.min(15);
194    tail
195}
196
197/// Per-slice output shape for [`Op::BatchElementwiseRegion`] (batch axis 0 ? 1).
198pub fn batch_region_slice_shape(batch_out: &Shape) -> Shape {
199    if batch_out.rank() >= 1 {
200        batch_out.clone().with_dim(0, crate::shape::Dim::Static(1))
201    } else {
202        batch_out.clone()
203    }
204}
205
206/// Element count of one batch slice in a contiguous batch output tensor.
207pub fn batch_region_slice_elems(batch_out: &Shape, num_batch: usize) -> Option<u32> {
208    let total = batch_out.num_elements()?;
209    let n = num_batch.max(1);
210    Some((total / n) as u32)
211}
212
213/// f32-linear offset of batch slice `index` within a packed output buffer.
214pub fn batch_region_slice_dst_off_f32(base_dst_off: u32, slice_elems: u32, index: usize) -> u32 {
215    base_dst_off.saturating_add(index as u32 * slice_elems)
216}
217
218/// Full device metadata buffer for [`Op::ElementwiseRegion`].
219pub fn encode_elementwise_region_meta(
220    input_offs: &[u32; REGION_META_INPUT_WORDS],
221    chain: &[ChainStep],
222    prologue: RegionPrologue,
223    out_shape: &Shape,
224    prologue_input: u32,
225) -> [u32; REGION_META_WORDS] {
226    let mut meta = [0u32; REGION_META_WORDS];
227    meta[..REGION_META_INPUT_WORDS].copy_from_slice(input_offs);
228    meta[REGION_META_INPUT_WORDS..REGION_META_INPUT_WORDS + REGION_META_CHAIN_WORDS]
229        .copy_from_slice(&encode_chain_steps(chain));
230    let tail = encode_prologue_tail(prologue, out_shape, prologue_input);
231    let tail_start = REGION_META_INPUT_WORDS + REGION_META_CHAIN_WORDS;
232    meta[tail_start..tail_start + REGION_META_TAIL_WORDS].copy_from_slice(&tail);
233    meta
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use crate::DType;
240
241    #[test]
242    fn meta_word_count_matches_layout() {
243        assert_eq!(REGION_META_WORDS, 150);
244    }
245
246    #[test]
247    fn batch_slice_elems_and_dst_off() {
248        let shape = Shape::new(&[2, 3, 8, 8], DType::F32);
249        assert_eq!(batch_region_slice_elems(&shape, 2), Some(192));
250        assert_eq!(batch_region_slice_dst_off_f32(100, 192, 1), 100 + 192);
251    }
252
253    #[test]
254    fn resize_prologue_tail_packed() {
255        let shape = Shape::new(&[1, 3, 16, 16], DType::F32);
256        let tail = encode_prologue_tail(RegionPrologue::ResizeNearest2x, &shape, 0);
257        assert_eq!(tail[0], REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW);
258        assert_eq!((tail[1], tail[2], tail[3], tail[4]), (1, 3, 16, 16));
259        assert_eq!(tail[5], 0);
260        let tail1 = encode_prologue_tail(RegionPrologue::ResizeNearest2x, &shape, 1);
261        assert_eq!(tail1[5], 1);
262    }
263
264    #[test]
265    fn fk_batch_single_kernel_cap() {
266        assert_eq!(FK_BATCH_SINGLE_KERNEL_MAX, 64);
267    }
268
269    #[test]
270    fn fk_batch_use_single_launch_gating() {
271        assert!(!fk_batch_use_single_launch(2, RegionPrologue::None));
272        assert!(!fk_batch_use_single_launch(
273            FK_BATCH_SINGLE_KERNEL_MAX + 1,
274            RegionPrologue::None,
275        ));
276        assert!(!fk_batch_use_single_launch(
277            2,
278            RegionPrologue::ResizeNearest2x
279        ));
280    }
281}