morok_ir/uop/constructors/reduce.rs
1//! Reduction operations: reduce, allreduce.
2//!
3//! This module contains reduction and aggregation operations:
4//! - try_reduce_axis: Reduce along specified axes
5//! - reduce: Reduce across loop ranges
6//! - allreduce: All-reduce across multiple devices
7
8use std::sync::Arc;
9
10use smallvec::SmallVec;
11
12use crate::Result;
13use crate::op::Op;
14use crate::types::ReduceOp;
15use crate::uop::UOp;
16
17impl UOp {
18 /// Reduce along specified axes using reduce_op.
19 ///
20 /// Implements Tinygrad's early-return pattern: when all axes are reduced
21 /// or when all reduction axes have dimension 1, returns self instead of
22 /// creating a ReduceAxis operation.
23 ///
24 /// # Errors
25 /// Returns error if any axis is >= number of dimensions.
26 pub fn try_reduce_axis(self: &Arc<Self>, reduce_op: ReduceOp, axes: Vec<usize>) -> Result<Arc<Self>> {
27 use crate::SInt;
28
29 // Validate axes if source shape is known
30 if let Some(src_shape) = self.shape()? {
31 Self::validate_reduce_axes(&axes, src_shape.len())?;
32
33 // Filter out axes where dimension is 1 (no-op reductions)
34 let active_axes: Vec<usize> = axes
35 .iter()
36 .filter(|&&axis| src_shape.get(axis).map(|dim| !matches!(dim, SInt::Const(1))).unwrap_or(false))
37 .copied()
38 .collect();
39
40 // Tinygrad pattern: if no active axes remain, return self
41 // This prevents creating scalar ReduceAxis operations that would
42 // propagate empty shapes through the pipeline
43 if active_axes.is_empty() {
44 return Ok(self.clone());
45 }
46
47 // Create ReduceAxis only for non-trivial reductions
48 let dtype = self.dtype();
49 return Ok(Self::new(Op::ReduceAxis { src: self.clone(), reduce_op, axes: active_axes }, dtype));
50 }
51
52 // If shape is unknown, create ReduceAxis with original axes
53 let dtype = self.dtype();
54 Ok(Self::new(Op::ReduceAxis { src: self.clone(), reduce_op, axes }, dtype))
55 }
56
57 /// Reduce across loop ranges using reduce_op.
58 ///
59 /// Unlike `try_reduce_axis` (operates on tensor axes), this reduces
60 /// values accumulated across RANGE loop iterations.
61 pub fn reduce(self: &Arc<Self>, ranges: SmallVec<[Arc<Self>; 4]>, reduce_op: ReduceOp) -> Arc<Self> {
62 let dtype = self.dtype();
63 Self::new(Op::Reduce { src: self.clone(), ranges, reduce_op }, dtype)
64 }
65
66 /// All-reduce across multiple devices.
67 pub fn allreduce(src: Arc<Self>, device: Arc<Self>, reduce_op: ReduceOp) -> Arc<Self> {
68 let dtype = src.dtype();
69 Self::new(Op::AllReduce { src, device, reduce_op }, dtype)
70 }
71}