use std::sync::Arc;
use smallvec::SmallVec;
use crate::Result;
use crate::op::Op;
use crate::types::ReduceOp;
use crate::uop::UOp;
impl UOp {
pub fn try_reduce_axis(self: &Arc<Self>, reduce_op: ReduceOp, axes: Vec<usize>) -> Result<Arc<Self>> {
use crate::SInt;
if let Some(src_shape) = self.shape()? {
Self::validate_reduce_axes(&axes, src_shape.len())?;
let active_axes: Vec<usize> = axes
.iter()
.filter(|&&axis| src_shape.get(axis).map(|dim| !matches!(dim, SInt::Const(1))).unwrap_or(false))
.copied()
.collect();
if active_axes.is_empty() {
return Ok(self.clone());
}
let dtype = self.dtype();
return Ok(Self::new(Op::ReduceAxis { src: self.clone(), reduce_op, axes: active_axes }, dtype));
}
let dtype = self.dtype();
Ok(Self::new(Op::ReduceAxis { src: self.clone(), reduce_op, axes }, dtype))
}
pub fn reduce(self: &Arc<Self>, ranges: SmallVec<[Arc<Self>; 4]>, reduce_op: ReduceOp) -> Arc<Self> {
let dtype = self.dtype();
Self::new(Op::Reduce { src: self.clone(), ranges, reduce_op }, dtype)
}
pub fn allreduce(src: Arc<Self>, device: Arc<Self>, reduce_op: ReduceOp) -> Arc<Self> {
let dtype = src.dtype();
Self::new(Op::AllReduce { src, device, reduce_op }, dtype)
}
}