Skip to main content

morok_ir/uop/constructors/
reduce.rs

1//! Reduction operations: reduce, allreduce.
2//!
3//! This module contains reduction and aggregation operations:
4//! - try_reduce_axis: Reduce along specified axes
5//! - reduce: Reduce across loop ranges
6//! - allreduce: All-reduce across multiple devices
7
8use std::sync::Arc;
9
10use smallvec::SmallVec;
11
12use crate::Result;
13use crate::op::Op;
14use crate::types::ReduceOp;
15use crate::uop::UOp;
16
17impl UOp {
18    /// Reduce along specified axes using reduce_op.
19    ///
20    /// Implements Tinygrad's early-return pattern: when all axes are reduced
21    /// or when all reduction axes have dimension 1, returns self instead of
22    /// creating a ReduceAxis operation.
23    ///
24    /// # Errors
25    /// Returns error if any axis is >= number of dimensions.
26    pub fn try_reduce_axis(self: &Arc<Self>, reduce_op: ReduceOp, axes: Vec<usize>) -> Result<Arc<Self>> {
27        use crate::SInt;
28
29        // Validate axes if source shape is known
30        if let Some(src_shape) = self.shape()? {
31            Self::validate_reduce_axes(&axes, src_shape.len())?;
32
33            // Filter out axes where dimension is 1 (no-op reductions)
34            let active_axes: Vec<usize> = axes
35                .iter()
36                .filter(|&&axis| src_shape.get(axis).map(|dim| !matches!(dim, SInt::Const(1))).unwrap_or(false))
37                .copied()
38                .collect();
39
40            // Tinygrad pattern: if no active axes remain, return self
41            // This prevents creating scalar ReduceAxis operations that would
42            // propagate empty shapes through the pipeline
43            if active_axes.is_empty() {
44                return Ok(self.clone());
45            }
46
47            // Create ReduceAxis only for non-trivial reductions
48            let dtype = self.dtype();
49            return Ok(Self::new(Op::ReduceAxis { src: self.clone(), reduce_op, axes: active_axes }, dtype));
50        }
51
52        // If shape is unknown, create ReduceAxis with original axes
53        let dtype = self.dtype();
54        Ok(Self::new(Op::ReduceAxis { src: self.clone(), reduce_op, axes }, dtype))
55    }
56
57    /// Reduce across loop ranges using reduce_op.
58    ///
59    /// Unlike `try_reduce_axis` (operates on tensor axes), this reduces
60    /// values accumulated across RANGE loop iterations.
61    pub fn reduce(self: &Arc<Self>, ranges: SmallVec<[Arc<Self>; 4]>, reduce_op: ReduceOp) -> Arc<Self> {
62        let dtype = self.dtype();
63        Self::new(Op::Reduce { src: self.clone(), ranges, reduce_op }, dtype)
64    }
65
66    /// All-reduce across multiple devices.
67    pub fn allreduce(src: Arc<Self>, device: Arc<Self>, reduce_op: ReduceOp) -> Arc<Self> {
68        let dtype = src.dtype();
69        Self::new(Op::AllReduce { src, device, reduce_op }, dtype)
70    }
71}