use std::f64::consts::PI;
use smallvec::smallvec;
use morok_dtype::DType;
use crate::{SInt, UOp, error::Error, shape::Shape};
#[test]
fn test_reshape_basic() {
let new_shape: Shape = smallvec![SInt::from(1), SInt::from(1)];
let result = UOp::native_const(1.0f32).try_reshape(&new_shape).unwrap();
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_reshape_size_must_match() {
let val = UOp::native_const(1.0f32); let bad_shape: Shape = smallvec![SInt::from(2), SInt::from(3)];
let result = val.try_reshape(&bad_shape);
assert!(matches!(result, Err(Error::ReshapeSizeMismatch { input_size: 1, output_size: 6 })));
}
#[test]
fn test_permute_empty_on_scalar() {
let val = UOp::native_const(1.0f32);
let perm = vec![]; let result = val.try_permute(perm).unwrap();
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_permute_invalid_on_scalar() {
let val = UOp::native_const(1.0f32);
let bad_perm = vec![0, 1]; let result = val.try_permute(bad_perm);
assert!(matches!(result, Err(Error::PermuteInvalidPermutation { .. })));
}
#[test]
fn test_permute_duplicate_index() {
let _val = UOp::native_const(1.0f32);
}
#[test]
fn test_expand_dimension_mismatch() {
let val = UOp::native_const(1.0f32); let new_shape: Shape = smallvec![SInt::from(3), SInt::from(5)];
let result = val.try_expand(&new_shape);
assert!(matches!(result, Err(Error::ExpandDimensionMismatch { input_dims: 0, output_dims: 2 })));
}
#[test]
fn test_pad_dimension_mismatch() {
let val = UOp::native_const(1.0f32);
let padding = vec![(SInt::from(0), SInt::from(0)), (SInt::from(1), SInt::from(1))];
let result = val.try_pad(&padding);
assert!(matches!(result, Err(Error::PadDimensionMismatch { padding_dims: 2, shape_dims: 0 })));
}
#[test]
fn test_pad_empty_on_scalar() {
let val = UOp::native_const(1.0f32);
let padding = vec![];
let result = val.try_pad(&padding).unwrap();
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_shrink_empty_on_scalar() {
let val = UOp::native_const(1.0f32);
let ranges = vec![];
let result = val.try_shrink(&ranges).unwrap();
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_flip_dimension_mismatch() {
let val = UOp::native_const(1.0f32);
let flip_spec = vec![true, false];
let result = val.try_flip(flip_spec);
assert!(matches!(result, Err(Error::FlipInvalidSpec { expected_dims: 0, got_dims: 2 })));
}
#[test]
fn test_flip_empty_on_scalar() {
let val = UOp::native_const(1.0f32);
let flip_spec = vec![];
let result = val.try_flip(flip_spec).unwrap();
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_multi_basic() {
let val = UOp::native_const(1.0f32);
let result = UOp::multi(val, 0);
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_movement_ops_preserve_dtype() {
let val_int = UOp::native_const(42i64);
let shape: Shape = smallvec![SInt::from(1)];
let reshaped = val_int.try_reshape(&shape).unwrap();
assert_eq!(reshaped.dtype(), DType::Int64);
let val_float = UOp::native_const(PI);
let permuted = val_float.try_permute(vec![]).unwrap();
assert_eq!(permuted.dtype(), DType::Float64);
let val_bool = UOp::native_const(true);
let multi = UOp::multi(val_bool, 0);
assert_eq!(multi.dtype(), DType::Bool);
}