concision_core/utils/
tensor.rs1pub use self::{gen::*, stack::*};
6use nd::*;
7use num::traits::{NumAssign, Zero};
8
9pub fn concat_iter<D, T>(axis: usize, iter: impl IntoIterator<Item = Array<T, D>>) -> Array<T, D>
11where
12 D: RemoveAxis,
13 T: Clone,
14{
15 let mut arr = iter.into_iter().collect::<Vec<_>>();
16 let mut out = arr.pop().unwrap();
17 for i in arr {
18 out = concatenate!(Axis(axis), out, i);
19 }
20 out
21}
22
23pub fn inverse<T>(matrix: &Array2<T>) -> Option<Array2<T>>
24where
25 T: Copy + NumAssign + ScalarOperand,
26{
27 let (rows, cols) = matrix.dim();
28
29 if !matrix.is_square() {
30 return None; }
32
33 let identity = Array2::eye(rows);
34
35 let mut aug = Array2::zeros((rows, 2 * cols));
37 aug.slice_mut(s![.., ..cols]).assign(matrix);
38 aug.slice_mut(s![.., cols..]).assign(&identity);
39
40 for i in 0..rows {
42 let pivot = aug[[i, i]];
43
44 if pivot == T::zero() {
45 return None; }
47
48 aug.slice_mut(s![i, ..]).mapv_inplace(|x| x / pivot);
49
50 for j in 0..rows {
51 if i != j {
52 let am = aug.clone();
53 let factor = aug[[j, i]];
54 let rhs = am.slice(s![i, ..]);
55 aug.slice_mut(s![j, ..])
56 .zip_mut_with(&rhs, |x, &y| *x -= y * factor);
57 }
58 }
59 }
60
61 let inverted = aug.slice(s![.., cols..]);
63
64 Some(inverted.to_owned())
65}
66
67pub fn tril<T>(a: &Array2<T>) -> Array2<T>
69where
70 T: Clone + Zero,
71{
72 let mut out = a.clone();
73 for i in 0..a.shape()[0] {
74 for j in i + 1..a.shape()[1] {
75 out[[i, j]] = T::zero();
76 }
77 }
78 out
79}
80pub fn triu<T>(a: &Array2<T>) -> Array2<T>
82where
83 T: Clone + Zero,
84{
85 let mut out = a.clone();
86 for i in 0..a.shape()[0] {
87 for j in 0..i {
88 out[[i, j]] = T::zero();
89 }
90 }
91 out
92}
93
94pub(crate) mod gen {
95 use nd::{Array, Array1, Dimension, IntoDimension, ShapeError};
96 use num::traits::{Float, FromPrimitive, Num, NumCast};
97
98 pub fn genspace<T: NumCast>(features: usize) -> Array1<T> {
99 Array1::from_iter((0..features).map(|x| T::from(x).unwrap()))
100 }
101
102 pub fn linarr<A, D>(dim: impl Clone + IntoDimension<Dim = D>) -> Result<Array<A, D>, ShapeError>
103 where
104 A: Float,
105 D: Dimension,
106 {
107 let dim = dim.into_dimension();
108 let n = dim.size();
109 Array::linspace(A::zero(), A::from(n - 1).unwrap(), n).into_shape(dim)
110 }
111
112 pub fn linspace<T>(start: T, end: T, n: usize) -> Vec<T>
113 where
114 T: Copy + FromPrimitive + Num,
115 {
116 if n <= 1 {
117 panic!("linspace requires at least two points");
118 }
119
120 let step = (end - start) / T::from_usize(n - 1).unwrap();
121
122 (0..n)
123 .map(|i| start + step * T::from_usize(i).unwrap())
124 .collect()
125 }
126 pub fn rangespace<A, D>(dim: impl IntoDimension<Dim = D>) -> Array<A, D>
128 where
129 A: FromPrimitive,
130 D: Dimension,
131 {
132 let dim = dim.into_dimension();
133 let iter = (0..dim.size()).map(|i| A::from_usize(i).unwrap()).collect();
134 Array::from_shape_vec(dim, iter).unwrap()
135 }
136}
137
138pub(crate) mod stack {
139 use nd::{s, Array1, Array2};
140 use num::Num;
141 pub fn stack_iter<T>(iter: impl IntoIterator<Item = Array1<T>>) -> Array2<T>
143 where
144 T: Clone + Num,
145 {
146 let mut iter = iter.into_iter();
147 let first = iter.next().unwrap();
148 let shape = [iter.size_hint().0 + 1, first.len()];
149 let mut res = Array2::<T>::zeros(shape);
150 res.slice_mut(s![0, ..]).assign(&first);
151 for (i, s) in iter.enumerate() {
152 res.slice_mut(s![i + 1, ..]).assign(&s);
153 }
154 res
155 }
156 pub fn hstack<T>(iter: impl IntoIterator<Item = Array1<T>>) -> Array2<T>
158 where
159 T: Clone + Num,
160 {
161 let iter = Vec::from_iter(iter);
162 let mut res = Array2::<T>::zeros((iter.first().unwrap().len(), iter.len()));
163 for (i, s) in iter.iter().enumerate() {
164 res.slice_mut(s![.., i]).assign(s);
165 }
166 res
167 }
168 pub fn vstack<T>(iter: impl IntoIterator<Item = Array1<T>>) -> Array2<T>
170 where
171 T: Clone + Num,
172 {
173 let iter = Vec::from_iter(iter);
174 let mut res = Array2::<T>::zeros((iter.len(), iter.first().unwrap().len()));
175 for (i, s) in iter.iter().enumerate() {
176 res.slice_mut(s![i, ..]).assign(s);
177 }
178 res
179 }
180}