use ferray_core::FerrayError;
pub trait IntoShape {
fn into_shape(self) -> Result<Vec<usize>, FerrayError>;
}
fn validate_shape(shape: Vec<usize>) -> Result<Vec<usize>, FerrayError> {
if shape.is_empty() {
return Err(FerrayError::invalid_value(
"shape must have at least one axis",
));
}
Ok(shape)
}
impl IntoShape for usize {
fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
validate_shape(vec![self])
}
}
impl IntoShape for &[usize] {
fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
validate_shape(self.to_vec())
}
}
impl<const N: usize> IntoShape for [usize; N] {
fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
validate_shape(self.to_vec())
}
}
impl<const N: usize> IntoShape for &[usize; N] {
fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
validate_shape(self.to_vec())
}
}
impl IntoShape for Vec<usize> {
fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
validate_shape(self)
}
}
impl IntoShape for &Vec<usize> {
fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
validate_shape(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shape_from_usize() {
assert_eq!(10usize.into_shape().unwrap(), vec![10]);
}
#[test]
fn shape_from_slice() {
let s: &[usize] = &[3, 4];
assert_eq!(s.into_shape().unwrap(), vec![3, 4]);
}
#[test]
fn shape_from_array() {
assert_eq!([2, 3, 4].into_shape().unwrap(), vec![2, 3, 4]);
}
#[test]
fn shape_from_array_ref() {
let a = [2, 3, 4];
assert_eq!((&a).into_shape().unwrap(), vec![2, 3, 4]);
}
#[test]
fn shape_from_vec() {
assert_eq!(vec![2, 3].into_shape().unwrap(), vec![2, 3]);
}
#[test]
fn empty_shape_rejected() {
let s: &[usize] = &[];
assert!(s.into_shape().is_err());
}
#[test]
fn zero_axis_allowed() {
assert_eq!(0usize.into_shape().unwrap(), vec![0]);
assert_eq!([3, 0, 4].into_shape().unwrap(), vec![3, 0, 4]);
}
}