1use ferray_core::FerrayError;
14
15pub trait IntoShape {
30 fn into_shape(self) -> Result<Vec<usize>, FerrayError>;
36}
37
38fn validate_shape(shape: Vec<usize>) -> Result<Vec<usize>, FerrayError> {
39 if shape.is_empty() {
40 return Err(FerrayError::invalid_value(
41 "shape must have at least one axis",
42 ));
43 }
44 Ok(shape)
45}
46
47impl IntoShape for usize {
48 fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
49 validate_shape(vec![self])
50 }
51}
52
53impl IntoShape for &[usize] {
54 fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
55 validate_shape(self.to_vec())
56 }
57}
58
59impl<const N: usize> IntoShape for [usize; N] {
60 fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
61 validate_shape(self.to_vec())
62 }
63}
64
65impl<const N: usize> IntoShape for &[usize; N] {
66 fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
67 validate_shape(self.to_vec())
68 }
69}
70
71impl IntoShape for Vec<usize> {
72 fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
73 validate_shape(self)
74 }
75}
76
77impl IntoShape for &Vec<usize> {
78 fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
79 validate_shape(self.clone())
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86
87 #[test]
88 fn shape_from_usize() {
89 assert_eq!(10usize.into_shape().unwrap(), vec![10]);
90 }
91
92 #[test]
93 fn shape_from_slice() {
94 let s: &[usize] = &[3, 4];
95 assert_eq!(s.into_shape().unwrap(), vec![3, 4]);
96 }
97
98 #[test]
99 fn shape_from_array() {
100 assert_eq!([2, 3, 4].into_shape().unwrap(), vec![2, 3, 4]);
101 }
102
103 #[test]
104 fn shape_from_array_ref() {
105 let a = [2, 3, 4];
106 assert_eq!((&a).into_shape().unwrap(), vec![2, 3, 4]);
107 }
108
109 #[test]
110 fn shape_from_vec() {
111 assert_eq!(vec![2, 3].into_shape().unwrap(), vec![2, 3]);
112 }
113
114 #[test]
115 fn empty_shape_rejected() {
116 let s: &[usize] = &[];
117 assert!(s.into_shape().is_err());
118 }
119
120 #[test]
121 fn zero_axis_allowed() {
122 assert_eq!(0usize.into_shape().unwrap(), vec![0]);
125 assert_eq!([3, 0, 4].into_shape().unwrap(), vec![3, 0, 4]);
126 }
127}