1use crate::op::*;
11use crate::shape::Shape;
12
13pub const REGION_META_INPUT_WORDS: usize = 16;
14pub const REGION_META_CHAIN_WORDS: usize = 128;
15pub const REGION_META_TAIL_WORDS: usize = 6;
16pub const REGION_META_WORDS: usize =
17 REGION_META_INPUT_WORDS + REGION_META_CHAIN_WORDS + REGION_META_TAIL_WORDS;
18
19pub const FK_BATCH_SINGLE_KERNEL_MAX: usize = 64;
21
22pub fn fk_batch_single_kernel_enabled() -> bool {
24 crate::env::flag("RLX_FK_BATCH_SINGLE_KERNEL")
25}
26
27pub fn fk_batch_use_single_launch(num_batch: usize, prologue: RegionPrologue) -> bool {
29 fk_batch_single_kernel_enabled()
30 && prologue == RegionPrologue::None
31 && num_batch <= FK_BATCH_SINGLE_KERNEL_MAX
32}
33
34pub const REGION_PROLOGUE_NONE: u32 = 0;
35pub const REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW: u32 = 1;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct RegionNchwDims {
40 pub n: u32,
41 pub c: u32,
42 pub h: u32,
43 pub w: u32,
44}
45
46impl RegionNchwDims {
47 pub fn from_shape(shape: &Shape) -> Option<Self> {
48 if shape.rank() != 4 {
49 return None;
50 }
51 Some(Self {
52 n: shape.dim(0).unwrap_static() as u32,
53 c: shape.dim(1).unwrap_static() as u32,
54 h: shape.dim(2).unwrap_static() as u32,
55 w: shape.dim(3).unwrap_static() as u32,
56 })
57 }
58
59 pub fn num_elements(self) -> u32 {
61 self.n * self.c * self.h * self.w
62 }
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub struct PrologueLaunchGrid {
68 pub width: u32,
69 pub height: u32,
70 pub depth: u32,
71}
72
73impl PrologueLaunchGrid {
74 pub fn from_output_shape(shape: &Shape) -> Option<Self> {
75 let d = RegionNchwDims::from_shape(shape)?;
76 Some(Self {
77 width: d.w,
78 height: d.h,
79 depth: d.n * d.c,
80 })
81 }
82}
83
84pub fn encode_chain_operand(op: &ChainOperand) -> u32 {
86 match *op {
87 ChainOperand::Input(i) => i & 0x7FFF_FFFFu32,
88 ChainOperand::Step(i) => 0x8000_0000u32 | (i & 0x7FFF_FFFFu32),
89 }
90}
91
92pub fn activation_sub(a: Activation) -> u32 {
93 match a {
94 Activation::Gelu => 0,
95 Activation::GeluApprox => 1,
96 Activation::Silu => 2,
97 Activation::Relu => 3,
98 Activation::Sigmoid => 4,
99 Activation::Tanh => 5,
100 Activation::Exp => 6,
101 Activation::Log => 7,
102 Activation::Sqrt => 8,
103 Activation::Rsqrt => 9,
104 Activation::Neg => 10,
105 Activation::Abs => 11,
106 Activation::Round => 12,
107 Activation::Sin => 13,
108 Activation::Cos => 14,
109 Activation::Tan => 15,
110 Activation::Atan => 16,
111 }
112}
113
114pub fn binary_sub(b: BinaryOp) -> u32 {
115 match b {
116 BinaryOp::Add => 0,
117 BinaryOp::Sub => 1,
118 BinaryOp::Mul => 2,
119 BinaryOp::Div => 3,
120 BinaryOp::Max => 4,
121 BinaryOp::Min => 5,
122 BinaryOp::Pow => 6,
123 }
124}
125
126pub fn compare_sub(c: CmpOp) -> u32 {
127 match c {
128 CmpOp::Eq => 0,
129 CmpOp::Ne => 1,
130 CmpOp::Lt => 2,
131 CmpOp::Le => 3,
132 CmpOp::Gt => 4,
133 CmpOp::Ge => 5,
134 }
135}
136
137pub fn encode_chain_steps(chain: &[ChainStep]) -> [u32; REGION_META_CHAIN_WORDS] {
139 let mut chain_enc = [0u32; REGION_META_CHAIN_WORDS];
140 for (k, step) in chain.iter().enumerate() {
141 let base = k * 4;
142 let (kind, sub, lhs, rhs) = match step {
143 ChainStep::Activation(a, src) => {
144 (0u32, activation_sub(*a), encode_chain_operand(src), 0u32)
145 }
146 ChainStep::Cast(_, src) => (1u32, 0, encode_chain_operand(src), 0u32),
147 ChainStep::Binary(op, l, r) => (
148 2u32,
149 binary_sub(*op),
150 encode_chain_operand(l),
151 encode_chain_operand(r),
152 ),
153 ChainStep::Compare(op, l, r) => (
154 3u32,
155 compare_sub(*op),
156 encode_chain_operand(l),
157 encode_chain_operand(r),
158 ),
159 ChainStep::Where(c, t, f) => (
160 4u32,
161 encode_chain_operand(c),
162 encode_chain_operand(t),
163 encode_chain_operand(f),
164 ),
165 };
166 chain_enc[base] = kind;
167 chain_enc[base + 1] = sub;
168 chain_enc[base + 2] = lhs;
169 chain_enc[base + 3] = rhs;
170 }
171 chain_enc
172}
173
174pub fn encode_prologue_tail(
176 prologue: RegionPrologue,
177 out_shape: &Shape,
178 prologue_input: u32,
179) -> [u32; REGION_META_TAIL_WORDS] {
180 let mut tail = [0u32; REGION_META_TAIL_WORDS];
181 match prologue {
182 RegionPrologue::None => {}
183 RegionPrologue::ResizeNearest2x => {
184 if let Some(d) = RegionNchwDims::from_shape(out_shape) {
185 tail[0] = REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW;
186 tail[1] = d.n;
187 tail[2] = d.c;
188 tail[3] = d.h;
189 tail[4] = d.w;
190 }
191 }
192 }
193 tail[5] = prologue_input.min(15);
194 tail
195}
196
197pub fn batch_region_slice_shape(batch_out: &Shape) -> Shape {
199 if batch_out.rank() >= 1 {
200 batch_out.clone().with_dim(0, crate::shape::Dim::Static(1))
201 } else {
202 batch_out.clone()
203 }
204}
205
206pub fn batch_region_slice_elems(batch_out: &Shape, num_batch: usize) -> Option<u32> {
208 let total = batch_out.num_elements()?;
209 let n = num_batch.max(1);
210 Some((total / n) as u32)
211}
212
213pub fn batch_region_slice_dst_off_f32(base_dst_off: u32, slice_elems: u32, index: usize) -> u32 {
215 base_dst_off.saturating_add(index as u32 * slice_elems)
216}
217
218pub fn encode_elementwise_region_meta(
220 input_offs: &[u32; REGION_META_INPUT_WORDS],
221 chain: &[ChainStep],
222 prologue: RegionPrologue,
223 out_shape: &Shape,
224 prologue_input: u32,
225) -> [u32; REGION_META_WORDS] {
226 let mut meta = [0u32; REGION_META_WORDS];
227 meta[..REGION_META_INPUT_WORDS].copy_from_slice(input_offs);
228 meta[REGION_META_INPUT_WORDS..REGION_META_INPUT_WORDS + REGION_META_CHAIN_WORDS]
229 .copy_from_slice(&encode_chain_steps(chain));
230 let tail = encode_prologue_tail(prologue, out_shape, prologue_input);
231 let tail_start = REGION_META_INPUT_WORDS + REGION_META_CHAIN_WORDS;
232 meta[tail_start..tail_start + REGION_META_TAIL_WORDS].copy_from_slice(&tail);
233 meta
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use crate::DType;
240
241 #[test]
242 fn meta_word_count_matches_layout() {
243 assert_eq!(REGION_META_WORDS, 150);
244 }
245
246 #[test]
247 fn batch_slice_elems_and_dst_off() {
248 let shape = Shape::new(&[2, 3, 8, 8], DType::F32);
249 assert_eq!(batch_region_slice_elems(&shape, 2), Some(192));
250 assert_eq!(batch_region_slice_dst_off_f32(100, 192, 1), 100 + 192);
251 }
252
253 #[test]
254 fn resize_prologue_tail_packed() {
255 let shape = Shape::new(&[1, 3, 16, 16], DType::F32);
256 let tail = encode_prologue_tail(RegionPrologue::ResizeNearest2x, &shape, 0);
257 assert_eq!(tail[0], REGION_PROLOGUE_RESIZE_NEAREST_2X_NCHW);
258 assert_eq!((tail[1], tail[2], tail[3], tail[4]), (1, 3, 16, 16));
259 assert_eq!(tail[5], 0);
260 let tail1 = encode_prologue_tail(RegionPrologue::ResizeNearest2x, &shape, 1);
261 assert_eq!(tail1[5], 1);
262 }
263
264 #[test]
265 fn fk_batch_single_kernel_cap() {
266 assert_eq!(FK_BATCH_SINGLE_KERNEL_MAX, 64);
267 }
268
269 #[test]
270 fn fk_batch_use_single_launch_gating() {
271 assert!(!fk_batch_use_single_launch(2, RegionPrologue::None));
272 assert!(!fk_batch_use_single_launch(
273 FK_BATCH_SINGLE_KERNEL_MAX + 1,
274 RegionPrologue::None,
275 ));
276 assert!(!fk_batch_use_single_launch(
277 2,
278 RegionPrologue::ResizeNearest2x
279 ));
280 }
281}