Skip to main content

cuda_rust_wasm/runtime/
kernel_fusion.rs

1//! Kernel Fusion Engine
2//!
3//! Automatically detects and fuses element-wise / pointwise kernel sequences
4//! to eliminate intermediate memory allocations and round-trips. This mirrors
5//! the kernel fusion passes in TensorRT, XLA, and TVM.
6//!
7//! Fusion rules:
8//! 1. Element-wise ops (add, mul, relu, etc.) can always fuse.
9//! 2. Reduction followed by broadcast can fuse (vertical fusion).
10//! 3. Producer-consumer pairs with matching shapes can fuse (horizontal).
11
12use std::fmt;
13use std::collections::HashMap;
14
15/// An operation that can be part of a fused kernel.
16#[derive(Debug, Clone, PartialEq)]
17pub enum FusableOp {
18    /// Element-wise: output[i] = f(input[i])
19    Unary(UnaryOp),
20    /// Element-wise: output[i] = f(a[i], b[i])
21    Binary(BinaryOp),
22    /// Reduction over a dimension
23    Reduce(ReduceOp),
24    /// Memory operation
25    MemoryOp(MemOp),
26}
27
28/// Unary element-wise operations.
29#[derive(Debug, Clone, Copy, PartialEq)]
30pub enum UnaryOp {
31    Relu, Sigmoid, Tanh, Gelu, Sqrt, Rsqrt, Exp, Log, Neg, Abs,
32    Cast(PrecisionType, PrecisionType), // from, to
33}
34
35/// Binary element-wise operations.
36#[derive(Debug, Clone, Copy, PartialEq)]
37pub enum BinaryOp {
38    Add, Sub, Mul, Div, Max, Min, Pow,
39}
40
41/// Reduction operations.
42#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum ReduceOp {
44    Sum, Max, Min, Mean,
45}
46
47/// Memory operations.
48#[derive(Debug, Clone, Copy, PartialEq)]
49pub enum MemOp {
50    Load, Store, Copy,
51}
52
53/// Precision types for cast operations.
54#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum PrecisionType {
56    Fp16, Bf16, Fp32, Fp64, Int8, Int32,
57}
58
59/// A node in the fusion graph.
60#[derive(Debug, Clone)]
61pub struct FusionNode {
62    pub id: usize,
63    pub op: FusableOp,
64    /// Shape of the output tensor.
65    pub shape: Vec<usize>,
66    /// Input node IDs.
67    pub inputs: Vec<usize>,
68}
69
70/// A fused kernel — a sequence of operations executed as one kernel.
71#[derive(Debug, Clone)]
72pub struct FusedKernel {
73    pub id: usize,
74    /// Nodes in execution order (topological).
75    pub nodes: Vec<FusionNode>,
76    /// Input node IDs (external inputs to the fused kernel).
77    pub external_inputs: Vec<usize>,
78    /// Output node IDs (nodes whose results are needed externally).
79    pub external_outputs: Vec<usize>,
80    /// Estimated memory saved by fusion (bytes).
81    pub memory_saved: usize,
82}
83
84impl FusedKernel {
85    /// Execute the fused kernel on f32 data.
86    ///
87    /// `inputs` maps external input IDs to their data.
88    pub fn execute(&self, inputs: &HashMap<usize, Vec<f32>>) -> crate::Result<HashMap<usize, Vec<f32>>> {
89        let mut buffers: HashMap<usize, Vec<f32>> = HashMap::new();
90
91        // Copy external inputs
92        for (&id, data) in inputs {
93            buffers.insert(id, data.clone());
94        }
95
96        // Execute each node
97        for node in &self.nodes {
98            let result = match &node.op {
99                FusableOp::Unary(op) => {
100                    let input = buffers.get(&node.inputs[0])
101                        .ok_or_else(|| crate::error::CudaRustError::RuntimeError(
102                            format!("Missing input {} for node {}", node.inputs[0], node.id)))?;
103                    apply_unary(op, input)
104                }
105                FusableOp::Binary(op) => {
106                    let a = buffers.get(&node.inputs[0])
107                        .ok_or_else(|| crate::error::CudaRustError::RuntimeError("Missing input A".into()))?;
108                    let b = buffers.get(&node.inputs[1])
109                        .ok_or_else(|| crate::error::CudaRustError::RuntimeError("Missing input B".into()))?;
110                    apply_binary(op, a, b)
111                }
112                FusableOp::Reduce(op) => {
113                    let input = buffers.get(&node.inputs[0])
114                        .ok_or_else(|| crate::error::CudaRustError::RuntimeError("Missing reduce input".into()))?;
115                    Ok(apply_reduce(op, input))
116                }
117                FusableOp::MemoryOp(_) => {
118                    // Pass-through
119                    let input = buffers.get(&node.inputs[0])
120                        .ok_or_else(|| crate::error::CudaRustError::RuntimeError("Missing mem input".into()))?;
121                    Ok(input.clone())
122                }
123            }?;
124            buffers.insert(node.id, result);
125        }
126
127        // Collect external outputs
128        let mut outputs = HashMap::new();
129        for &id in &self.external_outputs {
130            if let Some(data) = buffers.get(&id) {
131                outputs.insert(id, data.clone());
132            }
133        }
134        Ok(outputs)
135    }
136
137    /// Number of intermediate buffers eliminated by fusion.
138    pub fn buffers_eliminated(&self) -> usize {
139        let total_nodes = self.nodes.len();
140        let external = self.external_inputs.len() + self.external_outputs.len();
141        if total_nodes > external { total_nodes - external } else { 0 }
142    }
143}
144
145fn apply_unary(op: &UnaryOp, input: &[f32]) -> crate::Result<Vec<f32>> {
146    Ok(input.iter().map(|&x| match op {
147        UnaryOp::Relu => x.max(0.0),
148        UnaryOp::Sigmoid => 1.0 / (1.0 + (-x).exp()),
149        UnaryOp::Tanh => x.tanh(),
150        UnaryOp::Gelu => x * 0.5 * (1.0 + (0.7978845608 * (x + 0.044715 * x * x * x)).tanh()),
151        UnaryOp::Sqrt => x.sqrt(),
152        UnaryOp::Rsqrt => 1.0 / x.sqrt(),
153        UnaryOp::Exp => x.exp(),
154        UnaryOp::Log => x.ln(),
155        UnaryOp::Neg => -x,
156        UnaryOp::Abs => x.abs(),
157        UnaryOp::Cast(_, _) => x, // f32→f32 is identity
158    }).collect())
159}
160
161fn apply_binary(op: &BinaryOp, a: &[f32], b: &[f32]) -> crate::Result<Vec<f32>> {
162    if a.len() != b.len() {
163        return Err(crate::error::CudaRustError::RuntimeError(
164            format!("Binary op shape mismatch: {} vs {}", a.len(), b.len()),
165        ));
166    }
167    Ok(a.iter().zip(b.iter()).map(|(&x, &y)| match op {
168        BinaryOp::Add => x + y,
169        BinaryOp::Sub => x - y,
170        BinaryOp::Mul => x * y,
171        BinaryOp::Div => x / y,
172        BinaryOp::Max => x.max(y),
173        BinaryOp::Min => x.min(y),
174        BinaryOp::Pow => x.powf(y),
175    }).collect())
176}
177
178fn apply_reduce(op: &ReduceOp, input: &[f32]) -> Vec<f32> {
179    if input.is_empty() {
180        return vec![0.0];
181    }
182    let result = match op {
183        ReduceOp::Sum => input.iter().sum(),
184        ReduceOp::Max => input.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
185        ReduceOp::Min => input.iter().cloned().fold(f32::INFINITY, f32::min),
186        ReduceOp::Mean => input.iter().sum::<f32>() / input.len() as f32,
187    };
188    vec![result]
189}
190
191/// Fusion analysis engine that detects fusable patterns.
192pub struct FusionAnalyzer {
193    nodes: Vec<FusionNode>,
194    next_id: usize,
195}
196
197impl FusionAnalyzer {
198    /// Create a new analyzer.
199    pub fn new() -> Self {
200        Self { nodes: Vec::new(), next_id: 0 }
201    }
202
203    /// Add an operation node.
204    pub fn add_node(&mut self, op: FusableOp, shape: Vec<usize>, inputs: Vec<usize>) -> usize {
205        let id = self.next_id;
206        self.next_id += 1;
207        self.nodes.push(FusionNode { id, op, shape, inputs });
208        id
209    }
210
211    /// Analyze the graph and produce fused kernels.
212    pub fn fuse(&self) -> FusionResult {
213        let mut fused_kernels = Vec::new();
214        let mut visited = vec![false; self.nodes.len()];
215        let mut total_memory_saved = 0usize;
216
217        // Build consumer map
218        let mut consumers: HashMap<usize, Vec<usize>> = HashMap::new();
219        for node in &self.nodes {
220            for &input_id in &node.inputs {
221                consumers.entry(input_id).or_default().push(node.id);
222            }
223        }
224
225        // Greedy fusion: chain element-wise ops
226        for i in 0..self.nodes.len() {
227            if visited[i] {
228                continue;
229            }
230
231            let node = &self.nodes[i];
232            if !is_element_wise(&node.op) {
233                visited[i] = true;
234                fused_kernels.push(FusedKernel {
235                    id: fused_kernels.len(),
236                    nodes: vec![node.clone()],
237                    external_inputs: node.inputs.clone(),
238                    external_outputs: vec![node.id],
239                    memory_saved: 0,
240                });
241                continue;
242            }
243
244            // Start a fusion chain
245            let mut chain = vec![node.clone()];
246            visited[i] = true;
247            let mut current_id = node.id;
248
249            // Extend chain forward while next consumer is a single element-wise op
250            loop {
251                let next_consumers = consumers.get(&current_id);
252                if let Some(cons) = next_consumers {
253                    if cons.len() == 1 {
254                        let next_id = cons[0];
255                        if !visited[next_id] && next_id < self.nodes.len() {
256                            let next_node = &self.nodes[next_id];
257                            if is_element_wise(&next_node.op) && shapes_match(&node.shape, &next_node.shape) {
258                                chain.push(next_node.clone());
259                                visited[next_id] = true;
260                                current_id = next_id;
261                                continue;
262                            }
263                        }
264                    }
265                }
266                break;
267            }
268
269            let shape = &chain[0].shape;
270            let elem_size = 4; // f32
271            let elems: usize = shape.iter().product();
272            let intermediates = if chain.len() > 1 { chain.len() - 1 } else { 0 };
273            let saved = intermediates * elems * elem_size;
274            total_memory_saved += saved;
275
276            // Determine external inputs and outputs
277            let chain_ids: Vec<usize> = chain.iter().map(|n| n.id).collect();
278            let external_inputs: Vec<usize> = chain.iter()
279                .flat_map(|n| n.inputs.iter())
280                .filter(|id| !chain_ids.contains(id))
281                .copied()
282                .collect();
283            let last_id = chain.last().unwrap().id;
284
285            fused_kernels.push(FusedKernel {
286                id: fused_kernels.len(),
287                nodes: chain,
288                external_inputs,
289                external_outputs: vec![last_id],
290                memory_saved: saved,
291            });
292        }
293
294        FusionResult {
295            fused_kernels,
296            total_memory_saved,
297            original_kernel_count: self.nodes.len(),
298        }
299    }
300}
301
302fn is_element_wise(op: &FusableOp) -> bool {
303    matches!(op, FusableOp::Unary(_) | FusableOp::Binary(_))
304}
305
306fn shapes_match(a: &[usize], b: &[usize]) -> bool {
307    a == b
308}
309
310/// Result of fusion analysis.
311#[derive(Debug)]
312pub struct FusionResult {
313    pub fused_kernels: Vec<FusedKernel>,
314    pub total_memory_saved: usize,
315    pub original_kernel_count: usize,
316}
317
318impl FusionResult {
319    /// Number of kernels after fusion.
320    pub fn fused_kernel_count(&self) -> usize {
321        self.fused_kernels.len()
322    }
323
324    /// Reduction in kernel count.
325    pub fn kernel_reduction(&self) -> f64 {
326        if self.original_kernel_count == 0 { return 0.0; }
327        1.0 - (self.fused_kernel_count() as f64 / self.original_kernel_count as f64)
328    }
329}
330
331impl fmt::Display for FusionResult {
332    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333        write!(f, "Fusion: {} → {} kernels ({:.0}% reduction), {:.1}KB memory saved",
334            self.original_kernel_count,
335            self.fused_kernel_count(),
336            self.kernel_reduction() * 100.0,
337            self.total_memory_saved as f64 / 1024.0)
338    }
339}
340
341// ── Tests ──────────────────────────────────────────────────────────
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_unary_ops() {
349        let input = vec![-1.0, 0.0, 1.0, 2.0];
350        let relu = apply_unary(&UnaryOp::Relu, &input).unwrap();
351        assert_eq!(relu, vec![0.0, 0.0, 1.0, 2.0]);
352
353        let neg = apply_unary(&UnaryOp::Neg, &input).unwrap();
354        assert_eq!(neg, vec![1.0, 0.0, -1.0, -2.0]);
355
356        let abs_r = apply_unary(&UnaryOp::Abs, &input).unwrap();
357        assert_eq!(abs_r, vec![1.0, 0.0, 1.0, 2.0]);
358    }
359
360    #[test]
361    fn test_binary_ops() {
362        let a = vec![1.0, 2.0, 3.0];
363        let b = vec![4.0, 5.0, 6.0];
364        let add = apply_binary(&BinaryOp::Add, &a, &b).unwrap();
365        assert_eq!(add, vec![5.0, 7.0, 9.0]);
366
367        let mul = apply_binary(&BinaryOp::Mul, &a, &b).unwrap();
368        assert_eq!(mul, vec![4.0, 10.0, 18.0]);
369    }
370
371    #[test]
372    fn test_reduce_ops() {
373        let input = vec![1.0, 2.0, 3.0, 4.0];
374        assert_eq!(apply_reduce(&ReduceOp::Sum, &input), vec![10.0]);
375        assert_eq!(apply_reduce(&ReduceOp::Max, &input), vec![4.0]);
376        assert_eq!(apply_reduce(&ReduceOp::Min, &input), vec![1.0]);
377        assert_eq!(apply_reduce(&ReduceOp::Mean, &input), vec![2.5]);
378    }
379
380    #[test]
381    fn test_fusion_chain() {
382        let mut analyzer = FusionAnalyzer::new();
383        // Chain: input → relu → sigmoid → exp
384        let input_id = analyzer.add_node(
385            FusableOp::Unary(UnaryOp::Relu), vec![1024], vec![]
386        );
387        let relu_id = analyzer.add_node(
388            FusableOp::Unary(UnaryOp::Sigmoid), vec![1024], vec![input_id]
389        );
390        let _exp_id = analyzer.add_node(
391            FusableOp::Unary(UnaryOp::Exp), vec![1024], vec![relu_id]
392        );
393
394        let result = analyzer.fuse();
395        // Should fuse all 3 into 1 kernel
396        assert_eq!(result.fused_kernel_count(), 1);
397        assert!(result.total_memory_saved > 0);
398        assert!(result.kernel_reduction() > 0.5);
399    }
400
401    #[test]
402    fn test_fusion_with_reduction_break() {
403        let mut analyzer = FusionAnalyzer::new();
404        let relu_id = analyzer.add_node(
405            FusableOp::Unary(UnaryOp::Relu), vec![1024], vec![]
406        );
407        // Reduction breaks the chain
408        let reduce_id = analyzer.add_node(
409            FusableOp::Reduce(ReduceOp::Sum), vec![1], vec![relu_id]
410        );
411        let _exp_id = analyzer.add_node(
412            FusableOp::Unary(UnaryOp::Exp), vec![1], vec![reduce_id]
413        );
414
415        let result = analyzer.fuse();
416        // Relu alone, reduce alone, exp alone (reduce breaks fusion)
417        assert!(result.fused_kernel_count() >= 2);
418    }
419
420    #[test]
421    fn test_fused_kernel_execute() {
422        // Manually build a fused kernel: relu → add
423        let fused = FusedKernel {
424            id: 0,
425            nodes: vec![
426                FusionNode { id: 1, op: FusableOp::Unary(UnaryOp::Relu), shape: vec![4], inputs: vec![0] },
427                FusionNode { id: 2, op: FusableOp::Binary(BinaryOp::Add), shape: vec![4], inputs: vec![1, 3] },
428            ],
429            external_inputs: vec![0, 3],
430            external_outputs: vec![2],
431            memory_saved: 16,
432        };
433
434        let mut inputs = HashMap::new();
435        inputs.insert(0, vec![-1.0, 0.0, 1.0, 2.0]);
436        inputs.insert(3, vec![10.0, 10.0, 10.0, 10.0]);
437
438        let outputs = fused.execute(&inputs).unwrap();
439        let result = outputs.get(&2).unwrap();
440        // relu([-1, 0, 1, 2]) = [0, 0, 1, 2], then + [10, 10, 10, 10] = [10, 10, 11, 12]
441        assert_eq!(result, &vec![10.0, 10.0, 11.0, 12.0]);
442    }
443
444    #[test]
445    fn test_buffers_eliminated() {
446        let fused = FusedKernel {
447            id: 0,
448            nodes: vec![
449                FusionNode { id: 0, op: FusableOp::Unary(UnaryOp::Relu), shape: vec![1024], inputs: vec![] },
450                FusionNode { id: 1, op: FusableOp::Unary(UnaryOp::Sigmoid), shape: vec![1024], inputs: vec![0] },
451                FusionNode { id: 2, op: FusableOp::Unary(UnaryOp::Exp), shape: vec![1024], inputs: vec![1] },
452            ],
453            external_inputs: vec![],
454            external_outputs: vec![2],
455            memory_saved: 8192,
456        };
457        assert_eq!(fused.buffers_eliminated(), 2); // 3 nodes - 0 inputs - 1 output
458    }
459
460    #[test]
461    fn test_gelu_sigmoid_fusion() {
462        let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
463        let gelu = apply_unary(&UnaryOp::Gelu, &input).unwrap();
464        let sigmoid = apply_unary(&UnaryOp::Sigmoid, &input).unwrap();
465        // Both should produce valid results
466        assert!(gelu.iter().all(|v| v.is_finite()));
467        assert!(sigmoid.iter().all(|v| *v >= 0.0 && *v <= 1.0));
468    }
469
470    #[test]
471    fn test_fusion_display() {
472        let result = FusionResult {
473            fused_kernels: vec![],
474            total_memory_saved: 65536,
475            original_kernel_count: 10,
476        };
477        let s = format!("{}", result);
478        assert!(s.contains("10"));
479        assert!(s.contains("64.0KB"));
480    }
481}