1use std::fmt;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum FusableOp {
18 Unary(UnaryOp),
20 Binary(BinaryOp),
22 Reduce(ReduceOp),
24 MemoryOp(MemOp),
26}
27
28#[derive(Debug, Clone, Copy, PartialEq)]
30pub enum UnaryOp {
31 Relu, Sigmoid, Tanh, Gelu, Sqrt, Rsqrt, Exp, Log, Neg, Abs,
32 Cast(PrecisionType, PrecisionType), }
34
35#[derive(Debug, Clone, Copy, PartialEq)]
37pub enum BinaryOp {
38 Add, Sub, Mul, Div, Max, Min, Pow,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum ReduceOp {
44 Sum, Max, Min, Mean,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq)]
49pub enum MemOp {
50 Load, Store, Copy,
51}
52
53#[derive(Debug, Clone, Copy, PartialEq)]
55pub enum PrecisionType {
56 Fp16, Bf16, Fp32, Fp64, Int8, Int32,
57}
58
59#[derive(Debug, Clone)]
61pub struct FusionNode {
62 pub id: usize,
63 pub op: FusableOp,
64 pub shape: Vec<usize>,
66 pub inputs: Vec<usize>,
68}
69
70#[derive(Debug, Clone)]
72pub struct FusedKernel {
73 pub id: usize,
74 pub nodes: Vec<FusionNode>,
76 pub external_inputs: Vec<usize>,
78 pub external_outputs: Vec<usize>,
80 pub memory_saved: usize,
82}
83
84impl FusedKernel {
85 pub fn execute(&self, inputs: &HashMap<usize, Vec<f32>>) -> crate::Result<HashMap<usize, Vec<f32>>> {
89 let mut buffers: HashMap<usize, Vec<f32>> = HashMap::new();
90
91 for (&id, data) in inputs {
93 buffers.insert(id, data.clone());
94 }
95
96 for node in &self.nodes {
98 let result = match &node.op {
99 FusableOp::Unary(op) => {
100 let input = buffers.get(&node.inputs[0])
101 .ok_or_else(|| crate::error::CudaRustError::RuntimeError(
102 format!("Missing input {} for node {}", node.inputs[0], node.id)))?;
103 apply_unary(op, input)
104 }
105 FusableOp::Binary(op) => {
106 let a = buffers.get(&node.inputs[0])
107 .ok_or_else(|| crate::error::CudaRustError::RuntimeError("Missing input A".into()))?;
108 let b = buffers.get(&node.inputs[1])
109 .ok_or_else(|| crate::error::CudaRustError::RuntimeError("Missing input B".into()))?;
110 apply_binary(op, a, b)
111 }
112 FusableOp::Reduce(op) => {
113 let input = buffers.get(&node.inputs[0])
114 .ok_or_else(|| crate::error::CudaRustError::RuntimeError("Missing reduce input".into()))?;
115 Ok(apply_reduce(op, input))
116 }
117 FusableOp::MemoryOp(_) => {
118 let input = buffers.get(&node.inputs[0])
120 .ok_or_else(|| crate::error::CudaRustError::RuntimeError("Missing mem input".into()))?;
121 Ok(input.clone())
122 }
123 }?;
124 buffers.insert(node.id, result);
125 }
126
127 let mut outputs = HashMap::new();
129 for &id in &self.external_outputs {
130 if let Some(data) = buffers.get(&id) {
131 outputs.insert(id, data.clone());
132 }
133 }
134 Ok(outputs)
135 }
136
137 pub fn buffers_eliminated(&self) -> usize {
139 let total_nodes = self.nodes.len();
140 let external = self.external_inputs.len() + self.external_outputs.len();
141 if total_nodes > external { total_nodes - external } else { 0 }
142 }
143}
144
145fn apply_unary(op: &UnaryOp, input: &[f32]) -> crate::Result<Vec<f32>> {
146 Ok(input.iter().map(|&x| match op {
147 UnaryOp::Relu => x.max(0.0),
148 UnaryOp::Sigmoid => 1.0 / (1.0 + (-x).exp()),
149 UnaryOp::Tanh => x.tanh(),
150 UnaryOp::Gelu => x * 0.5 * (1.0 + (0.7978845608 * (x + 0.044715 * x * x * x)).tanh()),
151 UnaryOp::Sqrt => x.sqrt(),
152 UnaryOp::Rsqrt => 1.0 / x.sqrt(),
153 UnaryOp::Exp => x.exp(),
154 UnaryOp::Log => x.ln(),
155 UnaryOp::Neg => -x,
156 UnaryOp::Abs => x.abs(),
157 UnaryOp::Cast(_, _) => x, }).collect())
159}
160
161fn apply_binary(op: &BinaryOp, a: &[f32], b: &[f32]) -> crate::Result<Vec<f32>> {
162 if a.len() != b.len() {
163 return Err(crate::error::CudaRustError::RuntimeError(
164 format!("Binary op shape mismatch: {} vs {}", a.len(), b.len()),
165 ));
166 }
167 Ok(a.iter().zip(b.iter()).map(|(&x, &y)| match op {
168 BinaryOp::Add => x + y,
169 BinaryOp::Sub => x - y,
170 BinaryOp::Mul => x * y,
171 BinaryOp::Div => x / y,
172 BinaryOp::Max => x.max(y),
173 BinaryOp::Min => x.min(y),
174 BinaryOp::Pow => x.powf(y),
175 }).collect())
176}
177
178fn apply_reduce(op: &ReduceOp, input: &[f32]) -> Vec<f32> {
179 if input.is_empty() {
180 return vec![0.0];
181 }
182 let result = match op {
183 ReduceOp::Sum => input.iter().sum(),
184 ReduceOp::Max => input.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
185 ReduceOp::Min => input.iter().cloned().fold(f32::INFINITY, f32::min),
186 ReduceOp::Mean => input.iter().sum::<f32>() / input.len() as f32,
187 };
188 vec![result]
189}
190
191pub struct FusionAnalyzer {
193 nodes: Vec<FusionNode>,
194 next_id: usize,
195}
196
197impl FusionAnalyzer {
198 pub fn new() -> Self {
200 Self { nodes: Vec::new(), next_id: 0 }
201 }
202
203 pub fn add_node(&mut self, op: FusableOp, shape: Vec<usize>, inputs: Vec<usize>) -> usize {
205 let id = self.next_id;
206 self.next_id += 1;
207 self.nodes.push(FusionNode { id, op, shape, inputs });
208 id
209 }
210
211 pub fn fuse(&self) -> FusionResult {
213 let mut fused_kernels = Vec::new();
214 let mut visited = vec![false; self.nodes.len()];
215 let mut total_memory_saved = 0usize;
216
217 let mut consumers: HashMap<usize, Vec<usize>> = HashMap::new();
219 for node in &self.nodes {
220 for &input_id in &node.inputs {
221 consumers.entry(input_id).or_default().push(node.id);
222 }
223 }
224
225 for i in 0..self.nodes.len() {
227 if visited[i] {
228 continue;
229 }
230
231 let node = &self.nodes[i];
232 if !is_element_wise(&node.op) {
233 visited[i] = true;
234 fused_kernels.push(FusedKernel {
235 id: fused_kernels.len(),
236 nodes: vec![node.clone()],
237 external_inputs: node.inputs.clone(),
238 external_outputs: vec![node.id],
239 memory_saved: 0,
240 });
241 continue;
242 }
243
244 let mut chain = vec![node.clone()];
246 visited[i] = true;
247 let mut current_id = node.id;
248
249 loop {
251 let next_consumers = consumers.get(¤t_id);
252 if let Some(cons) = next_consumers {
253 if cons.len() == 1 {
254 let next_id = cons[0];
255 if !visited[next_id] && next_id < self.nodes.len() {
256 let next_node = &self.nodes[next_id];
257 if is_element_wise(&next_node.op) && shapes_match(&node.shape, &next_node.shape) {
258 chain.push(next_node.clone());
259 visited[next_id] = true;
260 current_id = next_id;
261 continue;
262 }
263 }
264 }
265 }
266 break;
267 }
268
269 let shape = &chain[0].shape;
270 let elem_size = 4; let elems: usize = shape.iter().product();
272 let intermediates = if chain.len() > 1 { chain.len() - 1 } else { 0 };
273 let saved = intermediates * elems * elem_size;
274 total_memory_saved += saved;
275
276 let chain_ids: Vec<usize> = chain.iter().map(|n| n.id).collect();
278 let external_inputs: Vec<usize> = chain.iter()
279 .flat_map(|n| n.inputs.iter())
280 .filter(|id| !chain_ids.contains(id))
281 .copied()
282 .collect();
283 let last_id = chain.last().unwrap().id;
284
285 fused_kernels.push(FusedKernel {
286 id: fused_kernels.len(),
287 nodes: chain,
288 external_inputs,
289 external_outputs: vec![last_id],
290 memory_saved: saved,
291 });
292 }
293
294 FusionResult {
295 fused_kernels,
296 total_memory_saved,
297 original_kernel_count: self.nodes.len(),
298 }
299 }
300}
301
302fn is_element_wise(op: &FusableOp) -> bool {
303 matches!(op, FusableOp::Unary(_) | FusableOp::Binary(_))
304}
305
306fn shapes_match(a: &[usize], b: &[usize]) -> bool {
307 a == b
308}
309
310#[derive(Debug)]
312pub struct FusionResult {
313 pub fused_kernels: Vec<FusedKernel>,
314 pub total_memory_saved: usize,
315 pub original_kernel_count: usize,
316}
317
318impl FusionResult {
319 pub fn fused_kernel_count(&self) -> usize {
321 self.fused_kernels.len()
322 }
323
324 pub fn kernel_reduction(&self) -> f64 {
326 if self.original_kernel_count == 0 { return 0.0; }
327 1.0 - (self.fused_kernel_count() as f64 / self.original_kernel_count as f64)
328 }
329}
330
331impl fmt::Display for FusionResult {
332 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333 write!(f, "Fusion: {} → {} kernels ({:.0}% reduction), {:.1}KB memory saved",
334 self.original_kernel_count,
335 self.fused_kernel_count(),
336 self.kernel_reduction() * 100.0,
337 self.total_memory_saved as f64 / 1024.0)
338 }
339}
340
341#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_unary_ops() {
349 let input = vec![-1.0, 0.0, 1.0, 2.0];
350 let relu = apply_unary(&UnaryOp::Relu, &input).unwrap();
351 assert_eq!(relu, vec![0.0, 0.0, 1.0, 2.0]);
352
353 let neg = apply_unary(&UnaryOp::Neg, &input).unwrap();
354 assert_eq!(neg, vec![1.0, 0.0, -1.0, -2.0]);
355
356 let abs_r = apply_unary(&UnaryOp::Abs, &input).unwrap();
357 assert_eq!(abs_r, vec![1.0, 0.0, 1.0, 2.0]);
358 }
359
360 #[test]
361 fn test_binary_ops() {
362 let a = vec![1.0, 2.0, 3.0];
363 let b = vec![4.0, 5.0, 6.0];
364 let add = apply_binary(&BinaryOp::Add, &a, &b).unwrap();
365 assert_eq!(add, vec![5.0, 7.0, 9.0]);
366
367 let mul = apply_binary(&BinaryOp::Mul, &a, &b).unwrap();
368 assert_eq!(mul, vec![4.0, 10.0, 18.0]);
369 }
370
371 #[test]
372 fn test_reduce_ops() {
373 let input = vec![1.0, 2.0, 3.0, 4.0];
374 assert_eq!(apply_reduce(&ReduceOp::Sum, &input), vec![10.0]);
375 assert_eq!(apply_reduce(&ReduceOp::Max, &input), vec![4.0]);
376 assert_eq!(apply_reduce(&ReduceOp::Min, &input), vec![1.0]);
377 assert_eq!(apply_reduce(&ReduceOp::Mean, &input), vec![2.5]);
378 }
379
380 #[test]
381 fn test_fusion_chain() {
382 let mut analyzer = FusionAnalyzer::new();
383 let input_id = analyzer.add_node(
385 FusableOp::Unary(UnaryOp::Relu), vec![1024], vec![]
386 );
387 let relu_id = analyzer.add_node(
388 FusableOp::Unary(UnaryOp::Sigmoid), vec![1024], vec![input_id]
389 );
390 let _exp_id = analyzer.add_node(
391 FusableOp::Unary(UnaryOp::Exp), vec![1024], vec![relu_id]
392 );
393
394 let result = analyzer.fuse();
395 assert_eq!(result.fused_kernel_count(), 1);
397 assert!(result.total_memory_saved > 0);
398 assert!(result.kernel_reduction() > 0.5);
399 }
400
401 #[test]
402 fn test_fusion_with_reduction_break() {
403 let mut analyzer = FusionAnalyzer::new();
404 let relu_id = analyzer.add_node(
405 FusableOp::Unary(UnaryOp::Relu), vec![1024], vec![]
406 );
407 let reduce_id = analyzer.add_node(
409 FusableOp::Reduce(ReduceOp::Sum), vec![1], vec![relu_id]
410 );
411 let _exp_id = analyzer.add_node(
412 FusableOp::Unary(UnaryOp::Exp), vec![1], vec![reduce_id]
413 );
414
415 let result = analyzer.fuse();
416 assert!(result.fused_kernel_count() >= 2);
418 }
419
420 #[test]
421 fn test_fused_kernel_execute() {
422 let fused = FusedKernel {
424 id: 0,
425 nodes: vec![
426 FusionNode { id: 1, op: FusableOp::Unary(UnaryOp::Relu), shape: vec![4], inputs: vec![0] },
427 FusionNode { id: 2, op: FusableOp::Binary(BinaryOp::Add), shape: vec![4], inputs: vec![1, 3] },
428 ],
429 external_inputs: vec![0, 3],
430 external_outputs: vec![2],
431 memory_saved: 16,
432 };
433
434 let mut inputs = HashMap::new();
435 inputs.insert(0, vec![-1.0, 0.0, 1.0, 2.0]);
436 inputs.insert(3, vec![10.0, 10.0, 10.0, 10.0]);
437
438 let outputs = fused.execute(&inputs).unwrap();
439 let result = outputs.get(&2).unwrap();
440 assert_eq!(result, &vec![10.0, 10.0, 11.0, 12.0]);
442 }
443
444 #[test]
445 fn test_buffers_eliminated() {
446 let fused = FusedKernel {
447 id: 0,
448 nodes: vec![
449 FusionNode { id: 0, op: FusableOp::Unary(UnaryOp::Relu), shape: vec![1024], inputs: vec![] },
450 FusionNode { id: 1, op: FusableOp::Unary(UnaryOp::Sigmoid), shape: vec![1024], inputs: vec![0] },
451 FusionNode { id: 2, op: FusableOp::Unary(UnaryOp::Exp), shape: vec![1024], inputs: vec![1] },
452 ],
453 external_inputs: vec![],
454 external_outputs: vec![2],
455 memory_saved: 8192,
456 };
457 assert_eq!(fused.buffers_eliminated(), 2); }
459
460 #[test]
461 fn test_gelu_sigmoid_fusion() {
462 let input = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
463 let gelu = apply_unary(&UnaryOp::Gelu, &input).unwrap();
464 let sigmoid = apply_unary(&UnaryOp::Sigmoid, &input).unwrap();
465 assert!(gelu.iter().all(|v| v.is_finite()));
467 assert!(sigmoid.iter().all(|v| *v >= 0.0 && *v <= 1.0));
468 }
469
470 #[test]
471 fn test_fusion_display() {
472 let result = FusionResult {
473 fused_kernels: vec![],
474 total_memory_saved: 65536,
475 original_kernel_count: 10,
476 };
477 let s = format!("{}", result);
478 assert!(s.contains("10"));
479 assert!(s.contains("64.0KB"));
480 }
481}