use ferray_core::error::{FerrayError, FerrayResult};
#[inline]
pub(crate) fn normalize_axis(ndim: usize, axis: isize) -> FerrayResult<usize> {
let n = ndim as isize;
let normalized = if axis < 0 { axis + n } else { axis };
if normalized < 0 || normalized >= n {
Err(FerrayError::axis_out_of_bounds(axis as usize, ndim))
} else {
Ok(normalized as usize)
}
}
#[inline]
pub(crate) fn resolve_axis(ndim: usize, axis: Option<isize>) -> FerrayResult<usize> {
match axis {
Some(ax) => normalize_axis(ndim, ax),
None => {
if ndim == 0 {
Err(FerrayError::invalid_value(
"cannot compute FFT on a 0-dimensional array",
))
} else {
Ok(ndim - 1)
}
}
}
}
#[inline]
pub(crate) fn resolve_axes(ndim: usize, axes: Option<&[isize]>) -> FerrayResult<Vec<usize>> {
match axes {
Some(ax) => ax.iter().map(|&a| normalize_axis(ndim, a)).collect(),
None => Ok((0..ndim).collect()),
}
}
#[inline]
pub(crate) fn compute_strides(shape: &[usize]) -> Vec<isize> {
let ndim = shape.len();
let mut strides = vec![0isize; ndim];
if ndim == 0 {
return strides;
}
strides[ndim - 1] = 1;
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1] as isize;
}
strides
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_axis_defaults_to_last() {
assert_eq!(resolve_axis(3, None).unwrap(), 2);
assert_eq!(resolve_axis(1, None).unwrap(), 0);
}
#[test]
fn resolve_axis_zero_dim_errors() {
assert!(resolve_axis(0, None).is_err());
}
#[test]
fn resolve_axis_out_of_bounds() {
assert!(resolve_axis(2, Some(3)).is_err());
assert!(resolve_axis(2, Some(-3)).is_err());
}
#[test]
fn resolve_axis_negative() {
assert_eq!(resolve_axis(3, Some(-1)).unwrap(), 2);
assert_eq!(resolve_axis(3, Some(-2)).unwrap(), 1);
assert_eq!(resolve_axis(3, Some(-3)).unwrap(), 0);
}
#[test]
fn resolve_axes_defaults_to_all() {
assert_eq!(resolve_axes(3, None).unwrap(), vec![0, 1, 2]);
}
#[test]
fn resolve_axes_validates_each() {
assert!(resolve_axes(3, Some(&[0, 5])).is_err());
assert_eq!(resolve_axes(3, Some(&[0, 2])).unwrap(), vec![0, 2]);
}
#[test]
fn resolve_axes_negative_mix() {
assert_eq!(resolve_axes(4, Some(&[-2, -1])).unwrap(), vec![2, 3]);
assert_eq!(resolve_axes(4, Some(&[0, -1])).unwrap(), vec![0, 3]);
}
#[test]
fn normalize_axis_round_trip() {
for ndim in 1..=4usize {
for ax in 0..ndim {
assert_eq!(normalize_axis(ndim, ax as isize).unwrap(), ax);
let neg = (ax as isize) - (ndim as isize);
assert_eq!(normalize_axis(ndim, neg).unwrap(), ax);
}
}
}
#[test]
fn strides_1d() {
assert_eq!(compute_strides(&[8]), vec![1]);
}
#[test]
fn strides_2d() {
assert_eq!(compute_strides(&[3, 4]), vec![4, 1]);
}
#[test]
fn strides_3d() {
assert_eq!(compute_strides(&[2, 3, 4]), vec![12, 4, 1]);
}
#[test]
fn strides_empty() {
assert_eq!(compute_strides(&[]), Vec::<isize>::new());
}
}