use smallvec::smallvec;
use morok_dtype::DType;
use crate::{ConstValue, SInt, UOp, shape::*};
#[test]
fn test_is_static() {
let static_shape = smallvec![SInt::from(3), SInt::from(4), SInt::from(5)];
assert!(is_static(&static_shape));
let also_static = smallvec![SInt::from(3), SInt::from(10)];
assert!(is_static(&also_static));
}
#[test]
fn test_to_static() {
let shape = smallvec![SInt::from(3), SInt::from(4)];
assert_eq!(to_static(&shape), Some(smallvec![3, 4]));
let shape2 = smallvec![SInt::from(5), SInt::from(6), SInt::from(7)];
assert_eq!(to_static(&shape2), Some(smallvec![5, 6, 7]));
}
#[test]
fn test_ndim() {
let shape: Shape = smallvec![SInt::from(3), SInt::from(4), SInt::from(5)];
assert_eq!(shape.len(), 3);
}
#[test]
fn test_shape_product() {
let shape: Shape = smallvec![SInt::from(2), SInt::from(3), SInt::from(4)];
let product = crate::sint_prod(&shape);
assert_eq!(product.as_const(), Some(24));
}
#[test]
fn test_validate_shape() {
assert!(validate_shape(&[1, 2, 3]).is_ok());
assert!(validate_shape(&[1, -2, 3]).is_err());
assert!(validate_shape(&[0]).is_ok()); assert!(validate_shape(&[1, 0, 3]).is_ok());
}
#[test]
fn test_shapes_equal() {
let shape1 = smallvec![SInt::from(3), SInt::from(4)];
let shape2 = smallvec![SInt::from(3), SInt::from(4)];
assert!(shapes_equal(&shape1, &shape2));
let shape3 = smallvec![SInt::from(3), SInt::from(5)];
assert!(!shapes_equal(&shape1, &shape3));
}
#[test]
fn test_all_shapes_equal() {
let shape1 = smallvec![SInt::from(3), SInt::from(4)];
let shape2 = smallvec![SInt::from(3), SInt::from(4)];
let shape3 = smallvec![SInt::from(3), SInt::from(4)];
assert!(all_shapes_equal(&[shape1, shape2, shape3]));
}
#[test]
fn test_align_shapes_left() {
let shape1 = smallvec![SInt::from(5)];
let shape2 = smallvec![SInt::from(3), SInt::from(5)];
let aligned = align_shapes_left(&[shape1, shape2]);
assert_eq!(aligned.len(), 2);
assert_eq!(aligned[0].len(), 2);
assert_eq!(aligned[0][0].as_const(), Some(1));
assert_eq!(aligned[0][1].as_const(), Some(5));
}
#[test]
fn test_can_broadcast() {
let shape1 = smallvec![SInt::from(1), SInt::from(5)];
let shape2 = smallvec![SInt::from(3), SInt::from(5)];
assert!(can_broadcast(&shape1, &shape2));
let shape3 = smallvec![SInt::from(3), SInt::from(4)];
assert!(!can_broadcast(&shape1, &shape3));
}
#[test]
fn test_broadcast_shape() {
let shape1 = smallvec![SInt::from(1), SInt::from(5)];
let shape2 = smallvec![SInt::from(3), SInt::from(5)];
let result = broadcast_shape(&shape1, &shape2).unwrap();
assert_eq!(result[0].as_const(), Some(3));
assert_eq!(result[1].as_const(), Some(5));
}
#[test]
fn test_broadcast_shape_error() {
let shape1 = smallvec![SInt::from(3), SInt::from(4)];
let shape2 = smallvec![SInt::from(3), SInt::from(5)];
assert!(broadcast_shape(&shape1, &shape2).is_err());
}
#[test]
fn test_broadcast_shapes_multiple() {
let shape1 = smallvec![SInt::from(1), SInt::from(5)];
let shape2 = smallvec![SInt::from(3), SInt::from(1)];
let shape3 = smallvec![SInt::from(3), SInt::from(5)];
let result = broadcast_shapes(&[shape1, shape2, shape3]).unwrap();
assert_eq!(result[0].as_const(), Some(3));
assert_eq!(result[1].as_const(), Some(5));
}
#[test]
fn test_infer_const_shape() {
let scalar = UOp::native_const(42.0f32);
let shape = scalar.shape().unwrap().expect("Const should have shape");
assert_eq!(shape.len(), 0); }
#[test]
fn test_infer_vconst_shape() {
let values = vec![ConstValue::Float(1.0), ConstValue::Float(2.0), ConstValue::Float(3.0), ConstValue::Float(4.0)];
let vec = UOp::new(crate::Op::VConst { values: values.clone() }, DType::Float32.vec(4));
assert!(vec.shape().unwrap().is_none(), "VConst should return None for shape");
}
#[test]
fn test_infer_unary_shape() {
let val = UOp::native_const(5.0f32);
let neg = val.neg();
let shape = neg.shape().unwrap().expect("Unary should have shape");
assert_eq!(shape.len(), 0); }
#[test]
fn test_infer_binary_shape() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add = a.try_add(&b).unwrap();
let shape = add.shape().unwrap().expect("Binary should have shape");
assert_eq!(shape.len(), 0); }
#[test]
fn test_infer_cast_shape() {
let val = UOp::native_const(1.5f32);
let cast = val.cast(DType::Int32);
let shape = cast.shape().unwrap().expect("Cast should preserve shape");
assert_eq!(shape.len(), 0);
}
#[test]
fn test_shape_caching() {
let val = UOp::native_const(1.0f32);
let shape1 = val.shape().unwrap().expect("Should have shape");
let shape2 = val.shape().unwrap().expect("Should have cached shape");
assert!(std::ptr::eq(shape1, shape2), "Shape should be cached");
}
#[test]
fn test_shape_to_uop_non_empty() {
use crate::op::Op;
let shape = smallvec![SInt::from(3), SInt::from(4)];
let shape_uop = shape_to_uop(&shape);
if let Op::Vectorize { elements } = shape_uop.op() {
assert_eq!(elements.len(), 2, "Shape [3, 4] should have 2 elements");
} else {
panic!("Expected Vectorize, got {:?}", shape_uop.op());
}
}