Skip to main content

ferray_random/
shape.rs

1// ferray-random: IntoShape trait for flexible shape arguments
2//
3// Mirrors NumPy's convention where `size=` accepts either an int or a
4// tuple of ints. Distribution methods take `shape: impl IntoShape` so
5// callers can write:
6//
7//     rng.random(10)           // 1-D, length 10
8//     rng.random([3, 4])       // 2-D, shape (3, 4)
9//     rng.random(&[2, 3, 4])   // 3-D, shape (2, 3, 4)
10//
11// See: https://github.com/dollspace-gay/ferray/issues/440
12
13use ferray_core::FerrayError;
14
15/// Convert a size argument into a concrete shape vector.
16///
17/// Implemented for:
18/// - `usize` — 1-D shape `[n]` (including `n == 0` for an empty array)
19/// - `&[usize]` — arbitrary-rank shape
20/// - `[usize; N]` (any N, via const generics)
21/// - `&[usize; N]`
22/// - `Vec<usize>`
23///
24/// Zero-axis shapes (e.g. `0usize`, `[3, 0, 4]`) are now permitted and
25/// produce an empty array, matching NumPy's `np.random.uniform(size=0)`
26/// behaviour (#264, #455). The only rejected shape is a totally
27/// rank-empty `&[]`, which would correspond to a 0-d scalar — that is
28/// not yet wired through the distribution machinery.
29pub trait IntoShape {
30    /// Consume `self` and return the shape as a `Vec<usize>`.
31    ///
32    /// # Errors
33    /// Returns `FerrayError::InvalidValue` if the resulting shape has
34    /// zero rank (i.e. an empty shape slice).
35    fn into_shape(self) -> Result<Vec<usize>, FerrayError>;
36}
37
38fn validate_shape(shape: Vec<usize>) -> Result<Vec<usize>, FerrayError> {
39    if shape.is_empty() {
40        return Err(FerrayError::invalid_value(
41            "shape must have at least one axis",
42        ));
43    }
44    Ok(shape)
45}
46
47impl IntoShape for usize {
48    fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
49        validate_shape(vec![self])
50    }
51}
52
53impl IntoShape for &[usize] {
54    fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
55        validate_shape(self.to_vec())
56    }
57}
58
59impl<const N: usize> IntoShape for [usize; N] {
60    fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
61        validate_shape(self.to_vec())
62    }
63}
64
65impl<const N: usize> IntoShape for &[usize; N] {
66    fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
67        validate_shape(self.to_vec())
68    }
69}
70
71impl IntoShape for Vec<usize> {
72    fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
73        validate_shape(self)
74    }
75}
76
77impl IntoShape for &Vec<usize> {
78    fn into_shape(self) -> Result<Vec<usize>, FerrayError> {
79        validate_shape(self.clone())
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    #[test]
88    fn shape_from_usize() {
89        assert_eq!(10usize.into_shape().unwrap(), vec![10]);
90    }
91
92    #[test]
93    fn shape_from_slice() {
94        let s: &[usize] = &[3, 4];
95        assert_eq!(s.into_shape().unwrap(), vec![3, 4]);
96    }
97
98    #[test]
99    fn shape_from_array() {
100        assert_eq!([2, 3, 4].into_shape().unwrap(), vec![2, 3, 4]);
101    }
102
103    #[test]
104    fn shape_from_array_ref() {
105        let a = [2, 3, 4];
106        assert_eq!((&a).into_shape().unwrap(), vec![2, 3, 4]);
107    }
108
109    #[test]
110    fn shape_from_vec() {
111        assert_eq!(vec![2, 3].into_shape().unwrap(), vec![2, 3]);
112    }
113
114    #[test]
115    fn empty_shape_rejected() {
116        let s: &[usize] = &[];
117        assert!(s.into_shape().is_err());
118    }
119
120    #[test]
121    fn zero_axis_allowed() {
122        // NumPy allows size=0 / size=(3, 0, 4) and returns an empty
123        // array. ferray now matches that (#264, #455).
124        assert_eq!(0usize.into_shape().unwrap(), vec![0]);
125        assert_eq!([3, 0, 4].into_shape().unwrap(), vec![3, 0, 4]);
126    }
127}