use std::f64::consts::PI;
use morok_dtype::DType;
use crate::{ReduceOp, UOp};
#[test]
fn test_reduce_axis_basic() {
let result = UOp::native_const(1.0f32).try_reduce_axis(ReduceOp::Add, vec![]).unwrap();
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_reduce_axis_preserves_dtype() {
let val_int = UOp::native_const(42i32);
let result = val_int.try_reduce_axis(ReduceOp::Add, vec![]).unwrap();
assert_eq!(result.dtype(), DType::Int32);
let val_float = UOp::native_const(PI);
let result = val_float.try_reduce_axis(ReduceOp::Max, vec![]).unwrap();
assert_eq!(result.dtype(), DType::Float64);
}
#[test]
fn test_reduce_ops() {
let val = UOp::native_const(1.0f32);
let add = val.try_reduce_axis(ReduceOp::Add, vec![]).unwrap();
assert_eq!(add.dtype(), DType::Float32);
let max = val.try_reduce_axis(ReduceOp::Max, vec![]).unwrap();
assert_eq!(max.dtype(), DType::Float32);
let mul = val.try_reduce_axis(ReduceOp::Mul, vec![]).unwrap();
assert_eq!(mul.dtype(), DType::Float32);
let min = val.try_reduce_axis(ReduceOp::Min, vec![]).unwrap();
assert_eq!(min.dtype(), DType::Float32);
}
#[test]
fn test_reduce_axis_full_reduction_returns_self() {
use crate::SInt;
use crate::shape::shape_to_uop;
use smallvec::smallvec;
let shape = smallvec![SInt::Const(2), SInt::Const(3)];
let src = UOp::native_const(1.0f32);
let shaped = UOp::reshape(src, shape_to_uop(&shape));
let result = shaped.try_reduce_axis(ReduceOp::Add, vec![0, 1]).unwrap();
assert_eq!(result.dtype(), shaped.dtype());
}
#[test]
fn test_reduce_axis_size_one_dims_filtered() {
use crate::SInt;
use crate::shape::shape_to_uop;
use smallvec::smallvec;
let shape = smallvec![SInt::Const(1), SInt::Const(3), SInt::Const(1), SInt::Const(4)];
let src = UOp::native_const(1.0f32);
let shaped = UOp::reshape(src, shape_to_uop(&shape));
let result = shaped.try_reduce_axis(ReduceOp::Add, vec![0, 2]).unwrap();
assert_eq!(result.dtype(), shaped.dtype());
}
#[test]
fn test_reduce_axis_mixed_size_dims() {
use crate::SInt;
use crate::op::Op;
use crate::shape::shape_to_uop;
use smallvec::smallvec;
let shape = smallvec![SInt::Const(1), SInt::Const(3), SInt::Const(1), SInt::Const(4)];
let src = UOp::native_const(1.0f32);
let shaped = UOp::reshape(src, shape_to_uop(&shape));
let result = shaped.try_reduce_axis(ReduceOp::Add, vec![0, 1, 2]).unwrap();
if let Op::ReduceAxis { axes, .. } = result.op() {
assert_eq!(axes.len(), 1);
assert_eq!(axes[0], 1);
} else {
panic!("Expected ReduceAxis operation");
}
}