use ferray_core::Array;
use ferray_core::creation::{array, linspace, ones, zeros};
use ferray_core::dimension::broadcast::broadcast_shapes;
use ferray_core::dimension::{Ix1, Ix2, IxDyn};
use ferray_core::indexing::basic::SliceSpec;
use ferray_core::manipulation::{flatten, reshape, transpose};
use proptest::prelude::*;
fn shape_1d() -> impl Strategy<Value = usize> {
1usize..=50
}
fn shape_2d() -> impl Strategy<Value = (usize, usize)> {
(1usize..=20, 1usize..=20)
}
fn vec_f64(len: usize) -> impl Strategy<Value = Vec<f64>> {
proptest::collection::vec(-100.0f64..100.0, len)
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn prop_reshape_preserves_element_count(
rows in 1usize..=10,
cols in 1usize..=10,
) {
let n = rows * cols;
let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
let a = array(Ix2::new([rows, cols]), data).unwrap();
let b = reshape(&a, &[n]).unwrap();
prop_assert_eq!(b.size(), a.size());
if n >= 2 {
let mut factor = 1;
for f in 2..=n {
if n % f == 0 {
factor = f;
break;
}
}
let other = n / factor;
let c = reshape(&a, &[factor, other]).unwrap();
prop_assert_eq!(c.size(), a.size());
}
}
#[test]
fn prop_transpose_involutory((rows, cols) in shape_2d()) {
let n = rows * cols;
let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
let a = array(Ix2::new([rows, cols]), data).unwrap();
let t1 = transpose(&a, None).unwrap();
let t2 = transpose(&t1, None).unwrap();
prop_assert_eq!(t2.shape(), a.shape());
let orig: Vec<f64> = a.iter().copied().collect();
let back: Vec<f64> = t2.iter().copied().collect();
prop_assert_eq!(orig, back);
}
#[test]
fn prop_flatten_preserves_elements((rows, cols) in shape_2d()) {
let n = rows * cols;
let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
let a = array(Ix2::new([rows, cols]), data).unwrap();
let flat = flatten(&a).unwrap();
prop_assert_eq!(flat.size(), n);
prop_assert_eq!(flat.shape(), &[n]);
let orig: Vec<f64> = a.iter().copied().collect();
let flat_data: Vec<f64> = flat.iter().copied().collect();
prop_assert_eq!(orig, flat_data);
}
#[test]
fn prop_broadcast_shapes_commutative(
a_shape in proptest::collection::vec(1usize..=5, 1..=4),
b_shape in proptest::collection::vec(1usize..=5, 1..=4),
) {
let result_ab = broadcast_shapes(&a_shape, &b_shape);
let result_ba = broadcast_shapes(&b_shape, &a_shape);
match (result_ab, result_ba) {
(Ok(ab), Ok(ba)) => prop_assert_eq!(ab, ba),
(Err(_), Err(_)) => { }
_ => prop_assert!(false, "broadcast_shapes not commutative"),
}
}
#[test]
fn prop_zeros_all_zero(n in shape_1d()) {
let a = zeros::<f64, Ix1>(Ix1::new([n])).unwrap();
prop_assert_eq!(a.size(), n);
for &v in a.iter() {
prop_assert_eq!(v, 0.0);
}
}
#[test]
fn prop_ones_all_one(n in shape_1d()) {
let a = ones::<f64, Ix1>(Ix1::new([n])).unwrap();
prop_assert_eq!(a.size(), n);
for &v in a.iter() {
prop_assert_eq!(v, 1.0);
}
}
#[test]
fn prop_slice_returns_correct_count(
total in 2usize..=50,
end in 1usize..=50,
) {
let end = end.min(total);
let data: Vec<f64> = (0..total).map(|i| i as f64).collect();
let a = Array::<f64, Ix1>::from_vec(Ix1::new([total]), data).unwrap();
let spec = SliceSpec::new(0, end as isize);
let view = a.slice_axis(ferray_core::Axis(0), spec).unwrap();
prop_assert_eq!(view.size(), end);
}
#[test]
fn prop_reshape_incompatible_shape_fails(
n in 2usize..=50,
) {
let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
let a = array(Ix1::new([n]), data).unwrap();
let bad_size = n + 1;
let result = reshape(&a, &[bad_size]);
prop_assert!(result.is_err());
}
#[test]
fn prop_linspace_count(
num in 0usize..=100,
start in -100.0f64..100.0,
stop in -100.0f64..100.0,
) {
let a = linspace(start, stop, num, true).unwrap();
prop_assert_eq!(a.size(), num);
}
#[test]
fn prop_zeros_ones_same_shape((rows, cols) in shape_2d()) {
let z = zeros::<f64, Ix2>(Ix2::new([rows, cols])).unwrap();
let o = ones::<f64, Ix2>(Ix2::new([rows, cols])).unwrap();
prop_assert_eq!(z.shape(), o.shape());
}
#[test]
fn prop_reshape_flatten_roundtrip(
rows in 1usize..=10,
cols in 1usize..=10,
) {
let n = rows * cols;
let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
let a = array(Ix1::new([n]), data.clone()).unwrap();
let reshaped = reshape(&a, &[rows, cols]).unwrap();
let flat = flatten(&reshaped).unwrap();
let flat_data: Vec<f64> = flat.iter().copied().collect();
prop_assert_eq!(data, flat_data);
}
#[test]
fn prop_from_vec_preserves_data(data in vec_f64(20)) {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([20]), data.clone()).unwrap();
let stored: Vec<f64> = a.iter().copied().collect();
prop_assert_eq!(data, stored);
}
#[test]
fn prop_dynamic_rank_preserves(
rows in 1usize..=10,
cols in 1usize..=10,
) {
let n = rows * cols;
let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
let a = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[rows, cols]), data.clone()).unwrap();
prop_assert_eq!(a.shape(), &[rows, cols]);
prop_assert_eq!(a.ndim(), 2);
let stored: Vec<f64> = a.iter().copied().collect();
prop_assert_eq!(data, stored);
}
}