1use crate::error::OpResult;
5use crate::error_helpers::try_from_numeric;
6use crate::ndarray;
7
8use crate::Float;
9
10pub type NdArray<T> = scirs2_core::ndarray::Array<T, scirs2_core::ndarray::IxDyn>;
12
13pub type NdArrayView<'a, T> = scirs2_core::ndarray::ArrayView<'a, T, scirs2_core::ndarray::IxDyn>;
15
16pub type RawNdArrayView<T> = scirs2_core::ndarray::RawArrayView<T, scirs2_core::ndarray::IxDyn>;
18
19pub type RawNdArrayViewMut<T> =
21 scirs2_core::ndarray::RawArrayViewMut<T, scirs2_core::ndarray::IxDyn>;
22
23pub type NdArrayViewMut<'a, T> =
25 scirs2_core::ndarray::ArrayViewMut<'a, T, scirs2_core::ndarray::IxDyn>;
26
27#[inline]
28pub(crate) fn asshape<T: Float>(x: &NdArrayView<T>) -> Vec<usize> {
30 x.iter().map(|a| a.to_usize().unwrap_or(0)).collect()
31}
32
33#[inline]
34pub(crate) fn expand_dims<T: Float>(x: NdArray<T>, axis: usize) -> NdArray<T> {
35 let mut shape = x.shape().to_vec();
36 shape.insert(axis, 1);
37 x.into_shape_with_order(shape)
38 .expect("Shape conversion failed - this is a bug")
39}
40
41#[inline]
42pub(crate) fn roll_axis<T: Float>(
43 arg: &mut NdArray<T>,
44 to: scirs2_core::ndarray::Axis,
45 from: scirs2_core::ndarray::Axis,
46) {
47 let i = to.index();
48 let mut j = from.index();
49 if j > i {
50 while i != j {
51 arg.swap_axes(i, j);
52 j -= 1;
53 }
54 } else {
55 while i != j {
56 arg.swap_axes(i, j);
57 j += 1;
58 }
59 }
60}
61
62#[inline]
63pub(crate) fn normalize_negative_axis(axis: isize, ndim: usize) -> usize {
64 if axis < 0 {
65 (ndim as isize + axis) as usize
66 } else {
67 axis as usize
68 }
69}
70
71#[inline]
72pub(crate) fn normalize_negative_axes<T: Float>(axes: &NdArrayView<T>, ndim: usize) -> Vec<usize> {
73 let mut axes_ret: Vec<usize> = Vec::with_capacity(axes.len());
74 for &axis in axes.iter() {
75 let axis = if axis < T::zero() {
76 (T::from(ndim).unwrap_or_else(|| T::zero()) + axis)
77 .to_usize()
78 .unwrap_or(0)
79 } else {
80 axis.to_usize().unwrap_or(0)
81 };
82 axes_ret.push(axis);
83 }
84 axes_ret
85}
86
87#[inline]
88pub(crate) fn sparse_to_dense<T: Float>(arr: &NdArrayView<T>) -> Vec<usize> {
89 let mut axes: Vec<usize> = vec![];
90 for (i, &a) in arr.iter().enumerate() {
91 if a == T::one() {
92 axes.push(i);
93 }
94 }
95 axes
96}
97
98#[allow(unused)]
99#[inline]
100pub(crate) fn is_fully_transposed(strides: &[scirs2_core::ndarray::Ixs]) -> bool {
101 let mut ret = true;
102 for w in strides.windows(2) {
103 if w[0] > w[1] {
104 ret = false;
105 break;
106 }
107 }
108 ret
109}
110
111#[inline]
113#[allow(dead_code)]
114pub fn zeros<T: Float>(shape: &[usize]) -> NdArray<T> {
115 NdArray::<T>::zeros(shape)
116}
117
118#[inline]
120#[allow(dead_code)]
121pub fn ones<T: Float>(shape: &[usize]) -> NdArray<T> {
122 NdArray::<T>::ones(shape)
123}
124
125#[inline]
127#[allow(dead_code)]
128pub fn constant<T: Float>(value: T, shape: &[usize]) -> NdArray<T> {
129 NdArray::<T>::from_elem(shape, value)
130}
131
132use scirs2_core::random::{ChaCha8Rng, Rng, RngExt, SeedableRng, TryRng};
133
134#[derive(Clone)]
138pub struct ArrayRng<A> {
139 rng: ChaCha8Rng,
140 _phantom: std::marker::PhantomData<A>,
141}
142
143impl<A> TryRng for ArrayRng<A> {
146 type Error = std::convert::Infallible;
147
148 fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
149 Ok(self.rng.next_u32())
150 }
151
152 fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
153 Ok(self.rng.next_u64())
154 }
155
156 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
157 self.rng.fill_bytes(dest);
158 Ok(())
159 }
160}
161
162impl<A: Float> ArrayRng<A> {
163 pub fn new() -> Self {
165 Self::from_seed(0)
166 }
167
168 pub fn from_seed(seed: u64) -> Self {
170 let rng = ChaCha8Rng::seed_from_u64(seed);
171 Self {
172 rng,
173 _phantom: std::marker::PhantomData,
174 }
175 }
176
177 pub fn as_rng(&self) -> &ChaCha8Rng {
179 &self.rng
180 }
181
182 pub fn as_rng_mut(&mut self) -> &mut ChaCha8Rng {
184 &mut self.rng
185 }
186
187 pub fn random(&mut self, shape: &[usize]) -> NdArray<A> {
190 let len = shape.iter().product();
191 let mut data = Vec::with_capacity(len);
192 for _ in 0..len {
193 data.push(
194 A::from(self.rng.random::<f64>()).expect("Shape conversion failed - this is a bug"),
195 );
196 }
197 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
198 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
199 }
200
201 pub fn normal(&mut self, shape: &[usize], mean: f64, std: f64) -> NdArray<A> {
204 use scirs2_core::random::{Distribution, Normal};
205 let normal = Normal::new(mean, std)
206 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
207 let len = shape.iter().product();
208 let mut data = Vec::with_capacity(len);
209 for _ in 0..len {
210 data.push(
211 A::from(normal.sample(&mut self.rng))
212 .expect("Shape conversion failed - this is a bug"),
213 );
214 }
215 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
216 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
217 }
218
219 pub fn uniform(&mut self, shape: &[usize], low: f64, high: f64) -> NdArray<A> {
222 use scirs2_core::random::{Distribution, Uniform};
223 let uniform = Uniform::new(low, high)
224 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
225 let len = shape.iter().product();
226 let mut data = Vec::with_capacity(len);
227 for _ in 0..len {
228 data.push(
229 A::from(uniform.sample(&mut self.rng))
230 .expect("Shape conversion failed - this is a bug"),
231 );
232 }
233 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
234 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
235 }
236
237 pub fn glorot_uniform(&mut self, shape: &[usize]) -> NdArray<A> {
241 assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
242 let fan_in = shape[shape.len() - 2];
243 let fan_out = shape[shape.len() - 1];
244 let scale = (6.0 / (fan_in + fan_out) as f64).sqrt();
245 self.uniform(shape, -scale, scale)
246 }
247
248 pub fn glorot_normal(&mut self, shape: &[usize]) -> NdArray<A> {
252 assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
253 let fan_in = shape[shape.len() - 2];
254 let fan_out = shape[shape.len() - 1];
255 let scale = (2.0 / (fan_in + fan_out) as f64).sqrt();
256 self.normal(shape, 0.0, scale)
257 }
258
259 pub fn he_uniform(&mut self, shape: &[usize]) -> NdArray<A> {
263 assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
264 let fan_in = shape[shape.len() - 2];
265 let scale = (6.0 / fan_in as f64).sqrt();
266 self.uniform(shape, -scale, scale)
267 }
268
269 pub fn he_normal(&mut self, shape: &[usize]) -> NdArray<A> {
273 assert!(shape.len() >= 2, "shape must have at least 2 dimensions");
274 let fan_in = shape[shape.len() - 2];
275 let scale = (2.0 / fan_in as f64).sqrt();
276 self.normal(shape, 0.0, scale)
277 }
278
279 pub fn standard_normal(&mut self, shape: &[usize]) -> NdArray<A> {
281 self.normal(shape, 0.0, 1.0)
282 }
283
284 pub fn standard_uniform(&mut self, shape: &[usize]) -> NdArray<A> {
286 self.uniform(shape, 0.0, 1.0)
287 }
288
289 pub fn bernoulli(&mut self, shape: &[usize], p: f64) -> NdArray<A> {
291 use scirs2_core::random::{Bernoulli, Distribution};
292 let bernoulli =
293 Bernoulli::new(p).unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
294 let len = shape.iter().product();
295 let mut data = Vec::with_capacity(len);
296 for _ in 0..len {
297 let val = if bernoulli.sample(&mut self.rng) {
298 A::one()
299 } else {
300 A::zero()
301 };
302 data.push(val);
303 }
304 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
305 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
306 }
307
308 pub fn exponential(&mut self, shape: &[usize], lambda: f64) -> NdArray<A> {
310 use scirs2_core::random::{Distribution, Exp};
311 let exp =
312 Exp::new(lambda).unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
313 let len = shape.iter().product();
314 let mut data = Vec::with_capacity(len);
315 for _ in 0..len {
316 data.push(
317 A::from(exp.sample(&mut self.rng))
318 .expect("Shape conversion failed - this is a bug"),
319 );
320 }
321 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
322 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
323 }
324
325 pub fn log_normal(&mut self, shape: &[usize], mean: f64, stddev: f64) -> NdArray<A> {
327 use scirs2_core::random::{Distribution, LogNormal};
328 let log_normal = LogNormal::new(mean, stddev)
329 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
330 let len = shape.iter().product();
331 let mut data = Vec::with_capacity(len);
332 for _ in 0..len {
333 data.push(
334 A::from(log_normal.sample(&mut self.rng))
335 .expect("Shape conversion failed - this is a bug"),
336 );
337 }
338 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
339 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
340 }
341
342 pub fn gamma(&mut self, shape: &[usize], shape_param: f64, scale: f64) -> NdArray<A> {
344 use scirs2_core::random::{Distribution, Gamma};
345 let gamma = Gamma::new(shape_param, scale)
346 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"));
347 let len = shape.iter().product();
348 let mut data = Vec::with_capacity(len);
349 for _ in 0..len {
350 data.push(
351 A::from(gamma.sample(&mut self.rng))
352 .expect("Shape conversion failed - this is a bug"),
353 );
354 }
355 NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(shape), data)
356 .unwrap_or_else(|_| panic!("Shape conversion failed - this is a bug"))
357 }
358}
359
360impl<A: Float> Default for ArrayRng<A> {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366#[inline]
368#[allow(dead_code)]
369pub fn is_scalarshape(shape: &[usize]) -> bool {
370 shape.is_empty() || (shape.len() == 1 && shape[0] == 1)
371}
372
373#[inline]
375#[allow(dead_code)]
376pub fn scalarshape() -> Vec<usize> {
377 vec![]
378}
379
380#[inline]
382#[allow(dead_code)]
383pub fn from_scalar<T: Float>(value: T) -> NdArray<T> {
384 NdArray::<T>::from_elem(scirs2_core::ndarray::IxDyn(&[1]), value)
385}
386
387#[inline]
389#[allow(dead_code)]
390pub fn shape_of_view<T>(view: &NdArrayView<'_, T>) -> Vec<usize> {
391 view.shape().to_vec()
392}
393
394#[inline]
396#[allow(dead_code)]
397pub fn shape_of<T>(array: &NdArray<T>) -> Vec<usize> {
398 array.shape().to_vec()
399}
400
401#[inline]
403#[allow(dead_code)]
404pub fn get_default_rng<A: Float>() -> ArrayRng<A> {
405 ArrayRng::<A>::default()
406}
407
408#[inline]
410#[allow(dead_code)]
411pub fn deep_copy<T: Float + Clone>(array: &NdArrayView<'_, T>) -> NdArray<T> {
412 array.to_owned()
413}
414
415#[inline]
417#[allow(dead_code)]
418pub fn select<T: Float + Clone>(
419 array: &NdArrayView<'_, T>,
420 axis: scirs2_core::ndarray::Axis,
421 indices: &[usize],
422) -> NdArray<T> {
423 let mut shape = array.shape().to_vec();
424 shape[axis.index()] = indices.len();
425
426 let mut result = NdArray::<T>::zeros(scirs2_core::ndarray::IxDyn(&shape));
427
428 for (i, &idx) in indices.iter().enumerate() {
429 let slice = array.index_axis(axis, idx);
430 result.index_axis_mut(axis, i).assign(&slice);
431 }
432
433 result
434}
435
436#[inline]
438#[allow(dead_code)]
439pub fn are_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
440 let len1 = shape1.len();
441 let len2 = shape2.len();
442 let min_len = std::cmp::min(len1, len2);
443
444 for i in 0..min_len {
445 let dim1 = shape1[len1 - 1 - i];
446 let dim2 = shape2[len2 - 1 - i];
447 if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
448 return false;
449 }
450 }
451 true
452}
453
454#[inline]
456#[allow(dead_code)]
457pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
458 if !are_broadcast_compatible(shape1, shape2) {
459 return None;
460 }
461
462 let len1 = shape1.len();
463 let len2 = shape2.len();
464 let result_len = std::cmp::max(len1, len2);
465 let mut result = Vec::with_capacity(result_len);
466
467 for i in 0..result_len {
468 let dim1 = if i < len1 { shape1[len1 - 1 - i] } else { 1 };
469 let dim2 = if i < len2 { shape2[len2 - 1 - i] } else { 1 };
470 result.push(std::cmp::max(dim1, dim2));
471 }
472
473 result.reverse();
474 Some(result)
475}
476
477pub mod array_gen {
479 use super::*;
480
481 #[inline]
483 pub fn zeros<T: Float>(shape: &[usize]) -> NdArray<T> {
484 NdArray::<T>::zeros(shape)
485 }
486
487 #[inline]
489 pub fn ones<T: Float>(shape: &[usize]) -> NdArray<T> {
490 NdArray::<T>::ones(shape)
491 }
492
493 #[inline]
495 pub fn eye<T: Float>(n: usize) -> NdArray<T> {
496 let mut result = NdArray::<T>::zeros(scirs2_core::ndarray::IxDyn(&[n, n]));
497 for i in 0..n {
498 result[[i, i]] = T::one();
499 }
500 result
501 }
502
503 #[inline]
505 pub fn constant<T: Float>(value: T, shape: &[usize]) -> NdArray<T> {
506 NdArray::<T>::from_elem(shape, value)
507 }
508
509 pub fn random<T: Float>(shape: &[usize]) -> NdArray<T> {
511 let mut rng = ArrayRng::<T>::default();
512 rng.random(shape)
513 }
514
515 pub fn randn<T: Float>(shape: &[usize]) -> NdArray<T> {
517 let mut rng = ArrayRng::<T>::default();
518 rng.normal(shape, 0.0, 1.0)
519 }
520
521 pub fn glorot_uniform<T: Float>(shape: &[usize]) -> NdArray<T> {
523 let mut rng = ArrayRng::<T>::default();
524 rng.glorot_uniform(shape)
525 }
526
527 pub fn glorot_normal<T: Float>(shape: &[usize]) -> NdArray<T> {
529 let mut rng = ArrayRng::<T>::default();
530 rng.glorot_normal(shape)
531 }
532
533 pub fn he_uniform<T: Float>(shape: &[usize]) -> NdArray<T> {
535 let mut rng = ArrayRng::<T>::default();
536 rng.he_uniform(shape)
537 }
538
539 pub fn he_normal<T: Float>(shape: &[usize]) -> NdArray<T> {
541 let mut rng = ArrayRng::<T>::default();
542 rng.he_normal(shape)
543 }
544
545 pub fn linspace<T: Float>(start: T, end: T, num: usize) -> NdArray<T> {
547 if num <= 1 {
548 return if num == 0 {
549 NdArray::<T>::zeros(scirs2_core::ndarray::IxDyn(&[0]))
550 } else {
551 NdArray::<T>::from_elem(scirs2_core::ndarray::IxDyn(&[1]), start)
552 };
553 }
554
555 let step = (end - start) / T::from(num - 1).unwrap_or_else(|| T::one());
556 let mut data = Vec::with_capacity(num);
557
558 for i in 0..num {
559 data.push(start + step * T::from(i).unwrap_or_else(|| T::zero()));
560 }
561
562 NdArray::<T>::from_shape_vec(scirs2_core::ndarray::IxDyn(&[num]), data)
563 .expect("Shape conversion failed - this is a bug")
564 }
565
566 pub fn arange<T: Float>(start: T, end: T, step: T) -> NdArray<T> {
568 let size = ((end - start) / step).to_f64().unwrap_or(0.0).ceil() as usize;
569 let mut data = Vec::with_capacity(size);
570
571 let mut current = start;
572 while current < end {
573 data.push(current);
574 current += step;
575 }
576
577 NdArray::<T>::from_shape_vec(scirs2_core::ndarray::IxDyn(&[data.len()]), data)
578 .expect("Shape conversion failed - this is a bug")
579 }
580}