Skip to main content

entrenar/
sovereign_array.rs

1//! Sovereign array types — Vec-backed replacements for ndarray Array1/Array2.
2//!
3//! These types provide the same API surface as ndarray but use plain Vec<f32>
4//! storage, eliminating the ndarray dependency while maintaining identical behavior.
5
6use std::ops::{Add, Div, Index, IndexMut, Mul, Sub};
7
8/// 1-D array backed by `Vec<f32>`. Drop-in replacement for `ndarray::Array1<f32>`.
9#[derive(Clone, Debug, PartialEq)]
10pub struct Array1<T = f32> {
11    data: Vec<T>,
12}
13
14impl<T: Clone + Default> Array1<T> {
15    pub fn zeros(n: usize) -> Self
16    where
17        T: From<f32>,
18    {
19        Self { data: vec![T::from(0.0f32); n] }
20    }
21
22    pub fn from_vec(v: Vec<T>) -> Self {
23        Self { data: v }
24    }
25}
26
27// Specialization for f32
28impl Array1<f32> {
29    pub fn ones(n: usize) -> Self {
30        Self { data: vec![1.0; n] }
31    }
32
33    pub fn zeros_f32(n: usize) -> Self {
34        Self { data: vec![0.0; n] }
35    }
36
37    pub fn mapv<F: Fn(f32) -> f32>(&self, f: F) -> Self {
38        Self { data: self.data.iter().map(|&x| f(x)).collect() }
39    }
40
41    pub fn mapv_inplace<F: Fn(f32) -> f32>(&mut self, f: F) {
42        for v in &mut self.data {
43            *v = f(*v);
44        }
45    }
46
47    pub fn sum(&self) -> f32 {
48        self.data.iter().sum()
49    }
50
51    pub fn dot(&self, other: &Self) -> f32 {
52        self.data.iter().zip(other.data.iter()).map(|(&a, &b)| a * b).sum()
53    }
54
55    pub fn mean(&self) -> Option<f32> {
56        if self.data.is_empty() {
57            None
58        } else {
59            Some(self.sum() / self.data.len() as f32)
60        }
61    }
62
63    pub fn as_slice(&self) -> &[f32] {
64        &self.data
65    }
66
67    pub fn as_slice_mut(&mut self) -> &mut [f32] {
68        &mut self.data
69    }
70
71    pub fn to_vec(&self) -> Vec<f32> {
72        self.data.clone()
73    }
74
75    pub fn into_raw_vec(self) -> Vec<f32> {
76        self.data
77    }
78
79    /// Vector-matrix multiply: self [m] dot other [m, n] -> [n]
80    pub fn dot_mat(&self, mat: &Array2<f32>) -> Array1<f32> {
81        assert_eq!(
82            self.data.len(),
83            mat.nrows(),
84            "dot_mat: vector length {} != matrix rows {}",
85            self.data.len(),
86            mat.nrows()
87        );
88        let n = mat.ncols();
89        let m = mat.nrows();
90        let mut out = vec![0.0f32; n];
91        for j in 0..n {
92            let mut sum = 0.0f32;
93            for i in 0..m {
94                sum += self.data[i] * mat[[i, j]];
95            }
96            out[j] = sum;
97        }
98        Array1::from(out)
99    }
100}
101
102impl<T> Array1<T> {
103    pub fn len(&self) -> usize {
104        self.data.len()
105    }
106    pub fn is_empty(&self) -> bool {
107        self.data.is_empty()
108    }
109    pub fn iter(&self) -> std::slice::Iter<'_, T> {
110        self.data.iter()
111    }
112    pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
113        self.data.iter_mut()
114    }
115}
116
117impl Array1<f32> {
118    pub fn assign(&mut self, other: &Array1<f32>) {
119        self.data.clear();
120        self.data.extend_from_slice(&other.data);
121    }
122}
123
124impl<T> From<Vec<T>> for Array1<T> {
125    fn from(v: Vec<T>) -> Self {
126        Self { data: v }
127    }
128}
129
130impl<T: Clone> From<&[T]> for Array1<T> {
131    fn from(s: &[T]) -> Self {
132        Self { data: s.to_vec() }
133    }
134}
135
136// FromIterator for .collect()
137impl std::iter::FromIterator<f32> for Array1<f32> {
138    fn from_iter<I: IntoIterator<Item = f32>>(iter: I) -> Self {
139        Self { data: iter.into_iter().collect() }
140    }
141}
142
143// += for Array1
144impl std::ops::AddAssign<&Array1<f32>> for Array1<f32> {
145    fn add_assign(&mut self, rhs: &Array1<f32>) {
146        for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) {
147            *a += b;
148        }
149    }
150}
151
152impl<T> Index<usize> for Array1<T> {
153    type Output = T;
154    fn index(&self, i: usize) -> &T {
155        &self.data[i]
156    }
157}
158
159impl<T> IndexMut<usize> for Array1<T> {
160    fn index_mut(&mut self, i: usize) -> &mut T {
161        &mut self.data[i]
162    }
163}
164
165// &Array1 + &Array1
166impl<'b> Add<&'b Array1<f32>> for &Array1<f32> {
167    type Output = Array1<f32>;
168    fn add(self, rhs: &'b Array1<f32>) -> Array1<f32> {
169        Array1 { data: self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a + b).collect() }
170    }
171}
172
173// &Array1 - &Array1
174impl<'b> Sub<&'b Array1<f32>> for &Array1<f32> {
175    type Output = Array1<f32>;
176    fn sub(self, rhs: &'b Array1<f32>) -> Array1<f32> {
177        Array1 { data: self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a - b).collect() }
178    }
179}
180
181// &Array1 * &Array1
182impl<'b> Mul<&'b Array1<f32>> for &Array1<f32> {
183    type Output = Array1<f32>;
184    fn mul(self, rhs: &'b Array1<f32>) -> Array1<f32> {
185        Array1 { data: self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a * b).collect() }
186    }
187}
188
189// &Array1 / &Array1
190impl<'b> Div<&'b Array1<f32>> for &Array1<f32> {
191    type Output = Array1<f32>;
192    fn div(self, rhs: &'b Array1<f32>) -> Array1<f32> {
193        Array1 { data: self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a / b).collect() }
194    }
195}
196
197// Array1 + f32 (scalar broadcast)
198impl Add<f32> for Array1<f32> {
199    type Output = Array1<f32>;
200    fn add(self, rhs: f32) -> Array1<f32> {
201        Array1 { data: self.data.iter().map(|&a| a + rhs).collect() }
202    }
203}
204
205impl Add<f32> for &Array1<f32> {
206    type Output = Array1<f32>;
207    fn add(self, rhs: f32) -> Array1<f32> {
208        Array1 { data: self.data.iter().map(|&a| a + rhs).collect() }
209    }
210}
211
212// Array1 * f32 (scalar broadcast)
213impl Mul<f32> for Array1<f32> {
214    type Output = Array1<f32>;
215    fn mul(self, rhs: f32) -> Array1<f32> {
216        Array1 { data: self.data.iter().map(|&a| a * rhs).collect() }
217    }
218}
219
220impl Mul<f32> for &Array1<f32> {
221    type Output = Array1<f32>;
222    fn mul(self, rhs: f32) -> Array1<f32> {
223        Array1 { data: self.data.iter().map(|&a| a * rhs).collect() }
224    }
225}
226
227// Array1 / f32 (scalar broadcast)
228impl Div<f32> for Array1<f32> {
229    type Output = Array1<f32>;
230    fn div(self, rhs: f32) -> Array1<f32> {
231        Array1 { data: self.data.iter().map(|&a| a / rhs).collect() }
232    }
233}
234
235// Array1 - f32 (scalar broadcast)
236impl Sub<f32> for Array1<f32> {
237    type Output = Array1<f32>;
238    fn sub(self, rhs: f32) -> Array1<f32> {
239        Array1 { data: self.data.iter().map(|&a| a - rhs).collect() }
240    }
241}
242
243impl Sub<f32> for &Array1<f32> {
244    type Output = Array1<f32>;
245    fn sub(self, rhs: f32) -> Array1<f32> {
246        Array1 { data: self.data.iter().map(|&a| a - rhs).collect() }
247    }
248}
249
250// f32 * &Array1 (scalar on left)
251impl Mul<&Array1<f32>> for f32 {
252    type Output = Array1<f32>;
253    fn mul(self, rhs: &Array1<f32>) -> Array1<f32> {
254        Array1 { data: rhs.data.iter().map(|&a| self * a).collect() }
255    }
256}
257
258// IntoIterator support
259impl<T> IntoIterator for Array1<T> {
260    type Item = T;
261    type IntoIter = std::vec::IntoIter<T>;
262    fn into_iter(self) -> Self::IntoIter {
263        self.data.into_iter()
264    }
265}
266
267impl<'a, T> IntoIterator for &'a Array1<T> {
268    type Item = &'a T;
269    type IntoIter = std::slice::Iter<'a, T>;
270    fn into_iter(self) -> Self::IntoIter {
271        self.data.iter()
272    }
273}
274
275/// Free function: create Array1 from slice (replaces `ndarray::arr1`)
276pub fn arr1(data: &[f32]) -> Array1<f32> {
277    Array1 { data: data.to_vec() }
278}
279
280/// Axis for Array2 operations (replaces `ndarray::Axis`)
281#[derive(Clone, Copy, Debug)]
282pub struct Axis(pub usize);
283
284/// 2-D array backed by `Vec<T>` in row-major order.
285/// Drop-in replacement for `ndarray::Array2<T>`.
286#[derive(Clone, Debug, PartialEq)]
287pub struct Array2<T = f32> {
288    data: Vec<T>,
289    rows: usize,
290    cols: usize,
291}
292
293impl<T: Clone + Default> Array2<T> {
294    pub fn zeros(shape: (usize, usize)) -> Self {
295        Self { data: vec![T::default(); shape.0 * shape.1], rows: shape.0, cols: shape.1 }
296    }
297
298    pub fn from_elem(shape: (usize, usize), val: T) -> Self {
299        Self { data: vec![val; shape.0 * shape.1], rows: shape.0, cols: shape.1 }
300    }
301}
302
303impl Array2<f32> {
304    pub fn ones(shape: (usize, usize)) -> Self {
305        Self { data: vec![1.0; shape.0 * shape.1], rows: shape.0, cols: shape.1 }
306    }
307
308    pub fn from_shape_fn<F: Fn((usize, usize)) -> f32>(shape: (usize, usize), f: F) -> Self {
309        let mut data = Vec::with_capacity(shape.0 * shape.1);
310        for r in 0..shape.0 {
311            for c in 0..shape.1 {
312                data.push(f((r, c)));
313            }
314        }
315        Self { data, rows: shape.0, cols: shape.1 }
316    }
317
318    pub fn from_shape_vec(
319        shape: (usize, usize),
320        data: Vec<f32>,
321    ) -> std::result::Result<Self, String> {
322        if data.len() != shape.0 * shape.1 {
323            return Err(format!(
324                "shape mismatch: expected {} elements, got {}",
325                shape.0 * shape.1,
326                data.len()
327            ));
328        }
329        Ok(Self { data, rows: shape.0, cols: shape.1 })
330    }
331
332    pub fn nrows(&self) -> usize {
333        self.rows
334    }
335
336    pub fn ncols(&self) -> usize {
337        self.cols
338    }
339
340    pub fn shape(&self) -> [usize; 2] {
341        [self.rows, self.cols]
342    }
343
344    pub fn row(&self, r: usize) -> ArrayView1<'_> {
345        let start = r * self.cols;
346        ArrayView1 { data: &self.data[start..start + self.cols] }
347    }
348
349    pub fn mapv<F: Fn(f32) -> f32>(&self, f: F) -> Self {
350        Self { data: self.data.iter().map(|&x| f(x)).collect(), rows: self.rows, cols: self.cols }
351    }
352
353    pub fn sum(&self) -> f32 {
354        self.data.iter().sum()
355    }
356
357    pub fn mean(&self) -> Option<f32> {
358        if self.data.is_empty() {
359            None
360        } else {
361            Some(self.data.iter().sum::<f32>() / self.data.len() as f32)
362        }
363    }
364
365    pub fn t(&self) -> Self {
366        let mut result = vec![0.0f32; self.data.len()];
367        for r in 0..self.rows {
368            for c in 0..self.cols {
369                result[c * self.rows + r] = self.data[r * self.cols + c];
370            }
371        }
372        Self { data: result, rows: self.cols, cols: self.rows }
373    }
374
375    /// Matrix multiply: self [m,k] dot other [k,n] -> [m,n]
376    pub fn dot(&self, other: &Self) -> Self {
377        assert_eq!(
378            self.cols, other.rows,
379            "dot: incompatible shapes [{},{}] x [{},{}]",
380            self.rows, self.cols, other.rows, other.cols
381        );
382        let m = self.rows;
383        let k = self.cols;
384        let n = other.cols;
385        let mut out = vec![0.0f32; m * n];
386        for i in 0..m {
387            for j in 0..n {
388                let mut sum = 0.0f32;
389                for p in 0..k {
390                    sum += self.data[i * k + p] * other.data[p * n + j];
391                }
392                out[i * n + j] = sum;
393            }
394        }
395        Self { data: out, rows: m, cols: n }
396    }
397
398    /// Matrix-vector multiply: self [m,k] dot vec [k] -> [m]
399    pub fn dot_vec(&self, vec: &Array1<f32>) -> Array1<f32> {
400        assert_eq!(
401            self.cols,
402            vec.len(),
403            "dot_vec: matrix cols {} != vector length {}",
404            self.cols,
405            vec.len()
406        );
407        let m = self.rows;
408        let k = self.cols;
409        let mut out = vec![0.0f32; m];
410        for i in 0..m {
411            let mut sum = 0.0f32;
412            for j in 0..k {
413                sum += self.data[i * k + j] * vec[j];
414            }
415            out[i] = sum;
416        }
417        Array1::from(out)
418    }
419
420    /// Iterate rows along Axis(0)
421    pub fn axis_iter(&self, _axis: Axis) -> AxisIter<'_> {
422        AxisIter { data: &self.data, cols: self.cols, row: 0, total_rows: self.rows }
423    }
424
425    /// Iterate mutable rows along Axis(0)
426    pub fn axis_iter_mut(&mut self, _axis: Axis) -> impl Iterator<Item = ArrayViewMut1<'_>> {
427        let cols = self.cols;
428        self.data.chunks_mut(cols).map(move |chunk| ArrayViewMut1 { data: chunk })
429    }
430
431    /// Sum along an axis
432    pub fn sum_axis(&self, axis: Axis) -> Array1<f32> {
433        match axis.0 {
434            0 => {
435                // Sum each column -> Array1 of length cols
436                let mut result = vec![0.0f32; self.cols];
437                for r in 0..self.rows {
438                    for c in 0..self.cols {
439                        result[c] += self.data[r * self.cols + c];
440                    }
441                }
442                Array1::from(result)
443            }
444            1 => {
445                // Sum each row -> Array1 of length rows
446                let mut result = vec![0.0f32; self.rows];
447                for r in 0..self.rows {
448                    for c in 0..self.cols {
449                        result[r] += self.data[r * self.cols + c];
450                    }
451                }
452                Array1::from(result)
453            }
454            _ => panic!("Axis({}) not supported for 2D array", axis.0),
455        }
456    }
457
458    pub fn as_slice(&self) -> &[f32] {
459        &self.data
460    }
461
462    #[allow(clippy::iter_without_into_iter)]
463    pub fn iter(&self) -> std::slice::Iter<'_, f32> {
464        self.data.iter()
465    }
466
467    pub fn to_vec(&self) -> Vec<f32> {
468        self.data.clone()
469    }
470
471    pub fn row_mut(&mut self, r: usize) -> &mut [f32] {
472        let start = r * self.cols;
473        &mut self.data[start..start + self.cols]
474    }
475
476    pub fn rows_mut(&mut self) -> impl Iterator<Item = &mut [f32]> {
477        let cols = self.cols;
478        self.data.chunks_mut(cols)
479    }
480}
481
482impl Index<[usize; 2]> for Array2<f32> {
483    type Output = f32;
484    fn index(&self, idx: [usize; 2]) -> &f32 {
485        &self.data[idx[0] * self.cols + idx[1]]
486    }
487}
488
489impl IndexMut<[usize; 2]> for Array2<f32> {
490    fn index_mut(&mut self, idx: [usize; 2]) -> &mut f32 {
491        &mut self.data[idx[0] * self.cols + idx[1]]
492    }
493}
494
495// &Array2 + &Array2
496impl<'b> Add<&'b Array2<f32>> for &Array2<f32> {
497    type Output = Array2<f32>;
498    fn add(self, rhs: &'b Array2<f32>) -> Array2<f32> {
499        assert_eq!(self.rows, rhs.rows);
500        assert_eq!(self.cols, rhs.cols);
501        Array2 {
502            data: self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a + b).collect(),
503            rows: self.rows,
504            cols: self.cols,
505        }
506    }
507}
508
509// &Array2 - &Array2
510impl<'b> Sub<&'b Array2<f32>> for &Array2<f32> {
511    type Output = Array2<f32>;
512    fn sub(self, rhs: &'b Array2<f32>) -> Array2<f32> {
513        assert_eq!(self.rows, rhs.rows);
514        assert_eq!(self.cols, rhs.cols);
515        Array2 {
516            data: self.data.iter().zip(rhs.data.iter()).map(|(&a, &b)| a - b).collect(),
517            rows: self.rows,
518            cols: self.cols,
519        }
520    }
521}
522
523// Array2 * f32
524impl Mul<f32> for &Array2<f32> {
525    type Output = Array2<f32>;
526    fn mul(self, rhs: f32) -> Array2<f32> {
527        Array2 {
528            data: self.data.iter().map(|&a| a * rhs).collect(),
529            rows: self.rows,
530            cols: self.cols,
531        }
532    }
533}
534
535// Array2 / f32
536impl Div<f32> for Array2<f32> {
537    type Output = Array2<f32>;
538    fn div(self, rhs: f32) -> Array2<f32> {
539        Array2 {
540            data: self.data.iter().map(|&a| a / rhs).collect(),
541            rows: self.rows,
542            cols: self.cols,
543        }
544    }
545}
546
547/// Read-only 1-D view (returned by `Array2::row`)
548pub struct ArrayView1<'a> {
549    data: &'a [f32],
550}
551
552impl ArrayView1<'_> {
553    pub fn to_owned(&self) -> Array1<f32> {
554        Array1::from(self.data.to_vec())
555    }
556
557    pub fn len(&self) -> usize {
558        self.data.len()
559    }
560
561    pub fn is_empty(&self) -> bool {
562        self.data.is_empty()
563    }
564
565    #[allow(clippy::iter_without_into_iter)]
566    pub fn iter(&self) -> std::slice::Iter<'_, f32> {
567        self.data.iter()
568    }
569
570    pub fn sum(&self) -> f32 {
571        self.data.iter().sum()
572    }
573
574    pub fn mapv<F: Fn(f32) -> f32>(&self, f: F) -> Array1<f32> {
575        Array1 { data: self.data.iter().map(|&x| f(x)).collect() }
576    }
577
578    pub fn to_vec(&self) -> Vec<f32> {
579        self.data.to_vec()
580    }
581}
582
583impl Index<usize> for ArrayView1<'_> {
584    type Output = f32;
585    fn index(&self, i: usize) -> &f32 {
586        &self.data[i]
587    }
588}
589
590/// Mutable 1-D view (returned by `axis_iter_mut`)
591pub struct ArrayViewMut1<'a> {
592    data: &'a mut [f32],
593}
594
595impl ArrayViewMut1<'_> {
596    pub fn len(&self) -> usize {
597        self.data.len()
598    }
599
600    pub fn is_empty(&self) -> bool {
601        self.data.is_empty()
602    }
603
604    #[allow(clippy::iter_without_into_iter)]
605    pub fn iter(&self) -> std::slice::Iter<'_, f32> {
606        self.data.iter()
607    }
608
609    #[allow(clippy::iter_without_into_iter)]
610    pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, f32> {
611        self.data.iter_mut()
612    }
613
614    pub fn mapv_inplace<F: Fn(f32) -> f32>(&mut self, f: F) {
615        for v in self.data.iter_mut() {
616            *v = f(*v);
617        }
618    }
619
620    pub fn sum(&self) -> f32 {
621        self.data.iter().sum()
622    }
623
624    /// Assign values from an Array1
625    pub fn assign(&mut self, src: &Array1<f32>) {
626        self.data.copy_from_slice(src.as_slice());
627    }
628}
629
630impl Index<usize> for ArrayViewMut1<'_> {
631    type Output = f32;
632    fn index(&self, i: usize) -> &f32 {
633        &self.data[i]
634    }
635}
636
637impl IndexMut<usize> for ArrayViewMut1<'_> {
638    fn index_mut(&mut self, i: usize) -> &mut f32 {
639        &mut self.data[i]
640    }
641}
642
643pub struct AxisIter<'a> {
644    data: &'a [f32],
645    cols: usize,
646    row: usize,
647    total_rows: usize,
648}
649
650impl<'a> Iterator for AxisIter<'a> {
651    type Item = ArrayView1<'a>;
652    fn next(&mut self) -> Option<Self::Item> {
653        if self.row >= self.total_rows {
654            return None;
655        }
656        let start = self.row * self.cols;
657        self.row += 1;
658        Some(ArrayView1 { data: &self.data[start..start + self.cols] })
659    }
660}
661
662/// Macro to create Array2 from nested arrays (replaces `ndarray::array!`)
663#[macro_export]
664macro_rules! array {
665    [$([$($val:expr),* $(,)?]),* $(,)?] => {{
666        let rows_data: Vec<Vec<f32>> = vec![$(vec![$($val as f32),*]),*];
667        let rows = rows_data.len();
668        let cols = if rows > 0 { rows_data[0].len() } else { 0 };
669        let flat: Vec<f32> = rows_data.into_iter().flatten().collect();
670        $crate::sovereign_array::Array2::from_shape_vec((rows, cols), flat).unwrap()
671    }};
672}
673
674impl Array2<u32> {
675    pub fn nrows(&self) -> usize {
676        self.rows
677    }
678    pub fn ncols(&self) -> usize {
679        self.cols
680    }
681    pub fn shape(&self) -> [usize; 2] {
682        [self.rows, self.cols]
683    }
684    pub fn row(&self, r: usize) -> &[u32] {
685        let start = r * self.cols;
686        &self.data[start..start + self.cols]
687    }
688}
689
690impl Array2<u8> {
691    pub fn nrows(&self) -> usize {
692        self.rows
693    }
694    pub fn ncols(&self) -> usize {
695        self.cols
696    }
697    pub fn shape(&self) -> [usize; 2] {
698        [self.rows, self.cols]
699    }
700    pub fn row(&self, r: usize) -> &[u8] {
701        let start = r * self.cols;
702        &self.data[start..start + self.cols]
703    }
704}
705
706macro_rules! impl_array2_index {
707    ($T:ty) => {
708        impl Index<[usize; 2]> for Array2<$T> {
709            type Output = $T;
710            fn index(&self, idx: [usize; 2]) -> &$T {
711                &self.data[idx[0] * self.cols + idx[1]]
712            }
713        }
714        impl IndexMut<[usize; 2]> for Array2<$T> {
715            fn index_mut(&mut self, idx: [usize; 2]) -> &mut $T {
716                &mut self.data[idx[0] * self.cols + idx[1]]
717            }
718        }
719    };
720}
721impl_array2_index!(u32);
722impl_array2_index!(u8);
723
724impl<'a, T> IntoIterator for &'a mut Array1<T> {
725    type Item = &'a mut T;
726    type IntoIter = std::slice::IterMut<'a, T>;
727    fn into_iter(self) -> Self::IntoIter {
728        self.data.iter_mut()
729    }
730}
731
732/// Generate delegating operator impls that forward owned args to &-& impl.
733macro_rules! delegate_binop {
734    ($Op:ident, $method:ident, $Lhs:ty, $Rhs:ty, $Out:ty) => {
735        impl $Op<$Rhs> for $Lhs {
736            type Output = $Out;
737            fn $method(self, rhs: $Rhs) -> $Out {
738                (&self).$method(&rhs)
739            }
740        }
741    };
742    (lref, $Op:ident, $method:ident, $Lhs:ty, $Rhs:ty, $Out:ty) => {
743        impl $Op<$Rhs> for $Lhs {
744            type Output = $Out;
745            fn $method(self, rhs: $Rhs) -> $Out {
746                self.$method(&rhs)
747            }
748        }
749    };
750    (rref, $Op:ident, $method:ident, $Lhs:ty, $Rhs:ty, $Out:ty) => {
751        impl $Op<$Rhs> for $Lhs {
752            type Output = $Out;
753            fn $method(self, rhs: $Rhs) -> $Out {
754                (&self).$method(rhs)
755            }
756        }
757    };
758}
759delegate_binop!(Add, add, Array2<f32>, Array2<f32>, Array2<f32>);
760delegate_binop!(rref, Add, add, Array2<f32>, &Array2<f32>, Array2<f32>);
761delegate_binop!(lref, Add, add, &Array1<f32>, Array1<f32>, Array1<f32>);
762delegate_binop!(rref, Add, add, Array1<f32>, &Array1<f32>, Array1<f32>);
763delegate_binop!(Add, add, Array1<f32>, Array1<f32>, Array1<f32>);
764delegate_binop!(rref, Sub, sub, Array1<f32>, &Array1<f32>, Array1<f32>);
765delegate_binop!(Sub, sub, Array1<f32>, Array1<f32>, Array1<f32>);
766
767impl Div<f32> for &Array2<f32> {
768    type Output = Array2<f32>;
769    fn div(self, rhs: f32) -> Array2<f32> {
770        Array2 {
771            data: self.data.iter().map(|&a| a / rhs).collect(),
772            rows: self.rows,
773            cols: self.cols,
774        }
775    }
776}
777
778#[cfg(test)]
779mod tests {
780    use super::*;
781
782    #[test]
783    fn array1_zeros() {
784        let a = Array1::<f32>::zeros(4);
785        assert_eq!(a.len(), 4);
786        assert!(a.as_slice().iter().all(|&x| x == 0.0));
787    }
788
789    #[test]
790    fn array1_from_vec_and_index() {
791        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
792        assert_eq!(a[0], 1.0);
793        assert_eq!(a[2], 3.0);
794    }
795
796    #[test]
797    fn array1_arithmetic() {
798        let a = Array1::from_vec(vec![1.0, 2.0]);
799        let b = Array1::from_vec(vec![3.0, 4.0]);
800        assert_eq!((&a + &b)[0], 4.0);
801        assert_eq!((&a - &b)[0], -2.0);
802    }
803
804    #[test]
805    fn array1_dot() {
806        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
807        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
808        assert_eq!(a.dot(&b), 32.0);
809    }
810
811    #[test]
812    fn array2_zeros_and_shape() {
813        let m = Array2::<f32>::zeros((2, 3));
814        assert_eq!(m.shape(), [2, 3]);
815        assert_eq!(m.nrows(), 2);
816        assert_eq!(m.ncols(), 3);
817    }
818
819    #[test]
820    fn array2_index_and_row() {
821        let mut m = Array2::<f32>::zeros((2, 2));
822        m[[0, 1]] = 5.0;
823        assert_eq!(m[[0, 1]], 5.0);
824        let m2 = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
825        assert_eq!(m2.row(0).len(), 3);
826    }
827
828    #[test]
829    fn array2_transpose() {
830        let m = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
831        let t = m.t();
832        assert_eq!(t.shape(), [3, 2]);
833        assert_eq!(t[[1, 0]], 2.0);
834        assert_eq!(t[[0, 1]], 4.0);
835    }
836}