Skip to main content

scirs2_autograd/
ndarray_ext.rs

1//! A small extension of [ndarray](https://github.com/rust-ndarray/ndarray)
2//!
3//! Mainly provides `array_gen`, which is a collection of array generator functions.
4use crate::error::OpResult;
5use crate::error_helpers::try_from_numeric;
6use crate::ndarray;
7
8use crate::Float;
9
10/// alias for `scirs2_core::ndarray::Array<T, IxDyn>`
11pub type NdArray<T> = scirs2_core::ndarray::Array<T, scirs2_core::ndarray::IxDyn>;
12
13/// alias for `scirs2_core::ndarray::ArrayView<T, IxDyn>`
14pub type NdArrayView<'a, T> = scirs2_core::ndarray::ArrayView<'a, T, scirs2_core::ndarray::IxDyn>;
15
16/// alias for `scirs2_core::ndarray::RawArrayView<T, IxDyn>`
17pub type RawNdArrayView<T> = scirs2_core::ndarray::RawArrayView<T, scirs2_core::ndarray::IxDyn>;
18
19/// alias for `scirs2_core::ndarray::RawArrayViewMut<T, IxDyn>`
20pub type RawNdArrayViewMut<T> =
21    scirs2_core::ndarray::RawArrayViewMut<T, scirs2_core::ndarray::IxDyn>;
22
23/// alias for `scirs2_core::ndarray::ArrayViewMut<T, IxDyn>`
24pub type NdArrayViewMut<'a, T> =
25    scirs2_core::ndarray::ArrayViewMut<'a, T, scirs2_core::ndarray::IxDyn>;
26
27#[inline]
28/// This works well only for small arrays
29pub(crate) fn asshape<T: Float>(x: &NdArrayView<T>) -> Vec<usize> {
30    x.iter().map(|a| a.to_usize().unwrap_or(0)).collect()
31}
32
33#[inline]
34pub(crate) fn expand_dims<T: Float>(x: NdArray<T>, axis: usize) -> NdArray<T> {
35    let mut shape = x.shape().to_vec();
36    shape.insert(axis, 1);
37    x.into_shape_with_order(shape)
38        .expect("Shape conversion failed - this is a bug")
39}
40
41#[inline]
42pub(crate) fn roll_axis<T: Float>(
43    arg: &mut NdArray<T>,
44    to: scirs2_core::ndarray::Axis,
45    from: scirs2_core::ndarray::Axis,
46) {
47    let i = to.index();
48    let mut j = from.index();
49    if j > i {
50        while i != j {
51            arg.swap_axes(i, j);
52            j -= 1;
53        }
54    } else {
55        while i != j {
56            arg.swap_axes(i, j);
57            j += 1;
58        }
59    }
60}
61
62#[inline]
63pub(crate) fn normalize_negative_axis(axis: isize, ndim: usize) -> usize {
64    if axis < 0 {
65        (ndim as isize + axis) as usize
66    } else {
67        axis as usize
68    }
69}
70
71#[inline]
72pub(crate) fn normalize_negative_axes<T: Float>(axes: &NdArrayView<T>, ndim: usize) -> Vec<usize> {
73    let mut axes_ret: Vec<usize> = Vec::with_capacity(axes.len());
74    for &axis in axes.iter() {
75        let axis = if axis < T::zero() {
76            (T::from(ndim).unwrap_or_else(|| T::zero()) + axis)
77                .to_usize()
78                .unwrap_or(0)
79        } else {
80            axis.to_usize().unwrap_or(0)
81        };
82        axes_ret.push(axis);
83    }
84    axes_ret
85}
86
87#[inline]
88pub(crate) fn sparse_to_dense<T: Float>(arr: &NdArrayView<T>) -> Vec<usize> {
89    let mut axes: Vec<usize> = vec![];
90    for (i, &a) in arr.iter().enumerate() {
91        if a == T::one() {
92            axes.push(i);
93        }
94    }
95    axes
96}
97
98#[allow(unused)]
99#[inline]
100pub(crate) fn is_fully_transposed(strides: &[scirs2_core::ndarray::Ixs]) -> bool {
101    let mut ret = true;
102    for w in strides.windows(2) {
103        if w[0] > w[1] {
104            ret = false;
105            break;
106        }
107    }
108    ret
109}
110
111/// Creates a zero array in the specified shape.
112#[inline]
113#[allow(dead_code)]
114pub fn zeros<T: Float>(shape: &[usize]) -> NdArray<T> {
115    NdArray::<T>::zeros(shape)
116}
117
118/// Creates a one array in the specified shape.
119#[inline]
120#[allow(dead_code)]
121pub fn ones<T: Float>(shape: &[usize]) -> NdArray<T> {
122    NdArray::<T>::ones(shape)
123}
124
125/// Creates a constant array in the specified shape.
126#[inline]
127#[allow(dead_code)]
128pub fn constant<T: Float>(value: T, shape: &[usize]) -> NdArray<T> {
129    NdArray::<T>::from_elem(shape, value)
130}
131
132use scirs2_core::random::{ChaCha8Rng, Rng, RngExt, SeedableRng, TryRng};
133
134/// Random number generator for ndarray
135///
136/// Uses ChaCha8Rng internally because StdRng does not implement Clone in rand 0.10.
137#[derive(Clone)]
138pub struct ArrayRng<A> {
139    rng: ChaCha8Rng,
140    _phantom: std::marker::PhantomData<A>,
141}
142
143// Implement TryRng for ArrayRng by delegating to the internal ChaCha8Rng.
144// In rand_core 0.10, TryRng<Error=Infallible> auto-provides Rng and (deprecated) RngCore.
145impl<A> TryRng for ArrayRng<A> {
146    type Error = std::convert::Infallible;
147
148    fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
149        Ok(self.rng.next_u32())
150    }
151
152    fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
153        Ok(self.rng.next_u64())
154    }
155
156    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
157        self.rng.fill_bytes(dest);
158        Ok(())
159    }
160}
161
162impl<A: Float> ArrayRng<A> {
163    /// Creates a new random number generator with the default seed.
164    pub fn new() -> Self {
165        Self::from_seed(0)
166    }
167
168    /// Creates a new random number generator with the specified seed.
169    pub fn from_seed(seed: u64) -> Self {
170        let rng = ChaCha8Rng::seed_from_u64(seed);
171        Self {
172            rng,
173            _phantom: std::marker::PhantomData,
174        }
175    }
176
177    /// Returns a reference to the internal RNG
178    pub fn as_rng(&self) -> &ChaCha8Rng {
179        &self.rng
180    }
181
182    /// Returns a mutable reference to the internal RNG
183    pub fn as_rng_mut(&mut self) -> &mut ChaCha8Rng {
184        &mut self.rng
185    }
186
187    /// Creates a uniform random array in the specified shape.
188    /// Values are in the range [0, 1)
189    pub fn random(&mut self, shape: &[usize]) -> NdArray<A> {
190        let len = shape.iter().product();
191        let mut data = Vec::with_capacity(len);
192        for _ in 0..len {
193            data.push(
194                A::from(self.rng.random::<f64>()).expect("Shape conversion failed - this is a bug"),
195            );
196        }
197        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
198            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
199    }
200
201    /// Creates a normal random array in the specified shape.
202    /// Values are drawn from a normal distribution with the specified mean and standard deviation.
203    pub fn normal(&mut self, shape: &[usize], mean: f64, std: f64) -> NdArray<A> {
204        use scirs2_core::random::{Distribution, Normal};
205        let normal = Normal::new(mean, std)
206            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
207        let len = shape.iter().product();
208        let mut data = Vec::with_capacity(len);
209        for _ in 0..len {
210            data.push(
211                A::from(normal.sample(&mut self.rng))
212                    .expect("Shape conversion failed - this is a bug"),
213            );
214        }
215        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
216            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
217    }
218
219    /// Creates a uniform random array in the specified shape.
220    /// Values are in the range [low, high).
221    pub fn uniform(&mut self, shape: &[usize], low: f64, high: f64) -> NdArray<A> {
222        use scirs2_core::random::{Distribution, Uniform};
223        let uniform = Uniform::new(low, high)
224            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
225        let len = shape.iter().product();
226        let mut data = Vec::with_capacity(len);
227        for _ in 0..len {
228            data.push(
229                A::from(uniform.sample(&mut self.rng))
230                    .expect("Shape conversion failed - this is a bug"),
231            );
232        }
233        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
234            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
235    }
236
237    /// Creates a random array with Glorot/Xavier uniform initialization.
238    /// For a tensor with shape (in_features, out_features),
239    /// samples are drawn from Uniform(-sqrt(6/(in_features+out_features)), sqrt(6/(in_features+out_features))).
240    pub fn glorot_uniform(&mut self, shape: &[usize]) -> NdArray<A> {
241        assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
242        let fan_in = shape[shape.len() - 2];
243        let fan_out = shape[shape.len() - 1];
244        let scale = (6.0 / (fan_in + fan_out) as f64).sqrt();
245        self.uniform(shape, -scale, scale)
246    }
247
248    /// Creates a random array with Glorot/Xavier normal initialization.
249    /// For a tensor with shape (in_features, out_features),
250    /// samples are drawn from Normal(0, sqrt(2/(in_features+out_features))).
251    pub fn glorot_normal(&mut self, shape: &[usize]) -> NdArray<A> {
252        assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
253        let fan_in = shape[shape.len() - 2];
254        let fan_out = shape[shape.len() - 1];
255        let scale = (2.0 / (fan_in + fan_out) as f64).sqrt();
256        self.normal(shape, 0.0, scale)
257    }
258
259    /// Creates a random array with He/Kaiming uniform initialization.
260    /// For a tensor with shape (in_features, out_features),
261    /// samples are drawn from Uniform(-sqrt(6/in_features), sqrt(6/in_features)).
262    pub fn he_uniform(&mut self, shape: &[usize]) -> NdArray<A> {
263        assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
264        let fan_in = shape[shape.len() - 2];
265        let scale = (6.0 / fan_in as f64).sqrt();
266        self.uniform(shape, -scale, scale)
267    }
268
269    /// Creates a random array with He/Kaiming normal initialization.
270    /// For a tensor with shape (in_features, out_features),
271    /// samples are drawn from Normal(0, sqrt(2/in_features)).
272    pub fn he_normal(&mut self, shape: &[usize]) -> NdArray<A> {
273        assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
274        let fan_in = shape[shape.len() - 2];
275        let scale = (2.0 / fan_in as f64).sqrt();
276        self.normal(shape, 0.0, scale)
277    }
278
279    /// Creates a random array from the standard normal distribution.
280    pub fn standard_normal(&mut self, shape: &[usize]) -> NdArray<A> {
281        self.normal(shape, 0.0, 1.0)
282    }
283
284    /// Creates a random array from the standard uniform distribution.
285    pub fn standard_uniform(&mut self, shape: &[usize]) -> NdArray<A> {
286        self.uniform(shape, 0.0, 1.0)
287    }
288
289    /// Creates a random array from the bernoulli distribution.
290    pub fn bernoulli(&mut self, shape: &[usize], p: f64) -> NdArray<A> {
291        use scirs2_core::random::{Bernoulli, Distribution};
292        let bernoulli =
293            Bernoulli::new(p).unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
294        let len = shape.iter().product();
295        let mut data = Vec::with_capacity(len);
296        for _ in 0..len {
297            let val = if bernoulli.sample(&mut self.rng) {
298                A::one()
299            } else {
300                A::zero()
301            };
302            data.push(val);
303        }
304        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
305            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
306    }
307
308    /// Creates a random array from the exponential distribution.
309    pub fn exponential(&mut self, shape: &[usize], lambda: f64) -> NdArray<A> {
310        use scirs2_core::random::{Distribution, Exp};
311        let exp =
312            Exp::new(lambda).unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
313        let len = shape.iter().product();
314        let mut data = Vec::with_capacity(len);
315        for _ in 0..len {
316            data.push(
317                A::from(exp.sample(&mut self.rng))
318                    .expect("Shape conversion failed - this is a bug"),
319            );
320        }
321        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
322            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
323    }
324
325    /// Creates a random array from the log-normal distribution.
326    pub fn log_normal(&mut self, shape: &[usize], mean: f64, stddev: f64) -> NdArray<A> {
327        use scirs2_core::random::{Distribution, LogNormal};
328        let log_normal = LogNormal::new(mean, stddev)
329            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
330        let len = shape.iter().product();
331        let mut data = Vec::with_capacity(len);
332        for _ in 0..len {
333            data.push(
334                A::from(log_normal.sample(&mut self.rng))
335                    .expect("Shape conversion failed - this is a bug"),
336            );
337        }
338        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
339            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
340    }
341
342    /// Creates a random array from the gamma distribution.
343    pub fn gamma(&mut self, shape: &[usize], shape_param: f64, scale: f64) -> NdArray<A> {
344        use scirs2_core::random::{Distribution, Gamma};
345        let gamma = Gamma::new(shape_param, scale)
346            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
347        let len = shape.iter().product();
348        let mut data = Vec::with_capacity(len);
349        for _ in 0..len {
350            data.push(
351                A::from(gamma.sample(&mut self.rng))
352                    .expect("Shape conversion failed - this is a bug"),
353            );
354        }
355        NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
356            .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
357    }
358}
359
360impl<A: Float> Default for ArrayRng<A> {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366/// Check if a shape represents a scalar value (empty or `[1]` shape)
367#[inline]
368#[allow(dead_code)]
369pub fn is_scalarshape(shape: &[usize]) -> bool {
370    shape.is_empty() || (shape.len() == 1 && shape[0] == 1)
371}
372
373/// Create a scalar shape (empty shape)
374#[inline]
375#[allow(dead_code)]
376pub fn scalarshape() -> Vec<usize> {
377    vec![]
378}
379
380/// Create an array from a scalar value
381#[inline]
382#[allow(dead_code)]
383pub fn from_scalar<T: Float>(value: T) -> NdArray<T> {
384    NdArray::<T>::from_elem(scirs2_core::ndarray::IxDyn(&[1]), value)
385}
386
387/// Get shape of an ndarray view
388#[inline]
389#[allow(dead_code)]
390pub fn shape_of_view<T>(view: &NdArrayView<'_, T>) -> Vec<usize> {
391    view.shape().to_vec()
392}
393
394/// Get shape of an ndarray
395#[inline]
396#[allow(dead_code)]
397pub fn shape_of<T>(array: &NdArray<T>) -> Vec<usize> {
398    array.shape().to_vec()
399}
400
401/// Get default random number generator
402#[inline]
403#[allow(dead_code)]
404pub fn get_default_rng<A: Float>() -> ArrayRng<A> {
405    ArrayRng::<A>::default()
406}
407
408/// Create a deep copy of an ndarray
409#[inline]
410#[allow(dead_code)]
411pub fn deep_copy<T: Float + Clone>(array: &NdArrayView<'_, T>) -> NdArray<T> {
412    array.to_owned()
413}
414
415/// Select elements from an array along an axis
416#[inline]
417#[allow(dead_code)]
418pub fn select<T: Float + Clone>(
419    array: &NdArrayView<'_, T>,
420    axis: scirs2_core::ndarray::Axis,
421    indices: &[usize],
422) -> NdArray<T> {
423    let mut shape = array.shape().to_vec();
424    shape[axis.index()] = indices.len();
425
426    let mut result = NdArray::<T>::zeros(scirs2_core::ndarray::IxDyn(&shape));
427
428    for (i, &idx) in indices.iter().enumerate() {
429        let slice = array.index_axis(axis, idx);
430        result.index_axis_mut(axis, i).assign(&slice);
431    }
432
433    result
434}
435
436/// Check if two shapes are compatible for broadcasting
437#[inline]
438#[allow(dead_code)]
439pub fn are_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
440    let len1 = shape1.len();
441    let len2 = shape2.len();
442    let min_len = std::cmp::min(len1, len2);
443
444    for i in 0..min_len {
445        let dim1 = shape1[len1 - 1 - i];
446        let dim2 = shape2[len2 - 1 - i];
447        if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
448            return false;
449        }
450    }
451    true
452}
453
454/// Compute the shape resulting from broadcasting two shapes together
455#[inline]
456#[allow(dead_code)]
457pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
458    if !are_broadcast_compatible(shape1, shape2) {
459        return None;
460    }
461
462    let len1 = shape1.len();
463    let len2 = shape2.len();
464    let result_len = std::cmp::max(len1, len2);
465    let mut result = Vec::with_capacity(result_len);
466
467    for i in 0..result_len {
468        let dim1 = if i < len1 { shape1[len1 - 1 - i] } else { 1 };
469        let dim2 = if i < len2 { shape2[len2 - 1 - i] } else { 1 };
470        result.push(std::cmp::max(dim1, dim2));
471    }
472
473    result.reverse();
474    Some(result)
475}
476
477/// Array generation functions
478pub mod array_gen {
479    use super::*;
480
481    /// Creates a zero array in the specified shape.
482    #[inline]
483    pub fn zeros<T: Float>(shape: &[usize]) -> NdArray<T> {
484        NdArray::<T>::zeros(shape)
485    }
486
487    /// Creates a one array in the specified shape.
488    #[inline]
489    pub fn ones<T: Float>(shape: &[usize]) -> NdArray<T> {
490        NdArray::<T>::ones(shape)
491    }
492
493    /// Creates a 2D identity matrix of the specified size.
494    #[inline]
495    pub fn eye<T: Float>(n: usize) -> NdArray<T> {
496        let mut result = NdArray::<T>::zeros(scirs2_core::ndarray::IxDyn(&[n, n]));
497        for i in 0..n {
498            result[[i, i]] = T::one();
499        }
500        result
501    }
502
503    /// Creates a constant array in the specified shape.
504    #[inline]
505    pub fn constant<T: Float>(value: T, shape: &[usize]) -> NdArray<T> {
506        NdArray::<T>::from_elem(shape, value)
507    }
508
509    /// Generates a random array in the specified shape with values between 0 and 1.
510    pub fn random<T: Float>(shape: &[usize]) -> NdArray<T> {
511        let mut rng = ArrayRng::<T>::default();
512        rng.random(shape)
513    }
514
515    /// Generates a random normal array in the specified shape.
516    pub fn randn<T: Float>(shape: &[usize]) -> NdArray<T> {
517        let mut rng = ArrayRng::<T>::default();
518        rng.normal(shape, 0.0, 1.0)
519    }
520
521    /// Creates a Glorot/Xavier uniform initialized array in the specified shape.
522    pub fn glorot_uniform<T: Float>(shape: &[usize]) -> NdArray<T> {
523        let mut rng = ArrayRng::<T>::default();
524        rng.glorot_uniform(shape)
525    }
526
527    /// Creates a Glorot/Xavier normal initialized array in the specified shape.
528    pub fn glorot_normal<T: Float>(shape: &[usize]) -> NdArray<T> {
529        let mut rng = ArrayRng::<T>::default();
530        rng.glorot_normal(shape)
531    }
532
533    /// Creates a He/Kaiming uniform initialized array in the specified shape.
534    pub fn he_uniform<T: Float>(shape: &[usize]) -> NdArray<T> {
535        let mut rng = ArrayRng::<T>::default();
536        rng.he_uniform(shape)
537    }
538
539    /// Creates a He/Kaiming normal initialized array in the specified shape.
540    pub fn he_normal<T: Float>(shape: &[usize]) -> NdArray<T> {
541        let mut rng = ArrayRng::<T>::default();
542        rng.he_normal(shape)
543    }
544
545    /// Creates an array with a linearly spaced sequence from start to end.
546    pub fn linspace<T: Float>(start: T, end: T, num: usize) -> NdArray<T> {
547        if num <= 1 {
548            return if num == 0 {
549                NdArray::<T>::zeros(scirs2_core::ndarray::IxDyn(&[0]))
550            } else {
551                NdArray::<T>::from_elem(scirs2_core::ndarray::IxDyn(&[1]), start)
552            };
553        }
554
555        let step = (end - start) / T::from(num - 1).unwrap_or_else(|| T::one());
556        let mut data = Vec::with_capacity(num);
557
558        for i in 0..num {
559            data.push(start + step * T::from(i).unwrap_or_else(|| T::zero()));
560        }
561
562        NdArray::<T>::from_shape_vec(scirs2_core::ndarray::IxDyn(&[num]), data)
563            .expect("Shape conversion failed - this is a bug")
564    }
565
566    /// Creates an array of evenly spaced values within a given interval.
567    pub fn arange<T: Float>(start: T, end: T, step: T) -> NdArray<T> {
568        let size = ((end - start) / step).to_f64().unwrap_or(0.0).ceil() as usize;
569        let mut data = Vec::with_capacity(size);
570
571        let mut current = start;
572        while current < end {
573            data.push(current);
574            current += step;
575        }
576
577        NdArray::<T>::from_shape_vec(scirs2_core::ndarray::IxDyn(&[data.len()]), data)
578            .expect("Shape conversion failed - this is a bug")
579    }
580}