1use serde::{Deserialize, Serialize};
2use std::fmt::{self, Debug, Display};
3use std::iter::Sum;
4use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign};
5use std::str::FromStr;
6
7pub trait FloatData<T>:
10 Mul<Output = T>
11 + Display
12 + Add<Output = T>
13 + Div<Output = T>
14 + Neg<Output = T>
15 + Copy
16 + Debug
17 + PartialEq
18 + PartialOrd
19 + AddAssign
20 + Sub<Output = T>
21 + SubAssign
22 + Sum
23 + std::marker::Send
24 + std::marker::Sync
25{
26 const ZERO: T;
27 const ONE: T;
28 const MIN: T;
29 const MAX: T;
30 const NAN: T;
31 const INFINITY: T;
32 fn from_usize(v: usize) -> T;
33 fn from_u16(v: u16) -> T;
34 fn is_nan(self) -> bool;
35 fn ln(self) -> T;
36 fn exp(self) -> T;
37}
38impl FloatData<f64> for f64 {
39 const ZERO: f64 = 0.0;
40 const ONE: f64 = 1.0;
41 const MIN: f64 = f64::MIN;
42 const MAX: f64 = f64::MAX;
43 const NAN: f64 = f64::NAN;
44 const INFINITY: f64 = f64::INFINITY;
45
46 fn from_usize(v: usize) -> f64 {
47 v as f64
48 }
49 fn from_u16(v: u16) -> f64 {
50 f64::from(v)
51 }
52 fn is_nan(self) -> bool {
53 self.is_nan()
54 }
55 fn ln(self) -> f64 {
56 self.ln()
57 }
58 fn exp(self) -> f64 {
59 self.exp()
60 }
61}
62
63impl FloatData<f32> for f32 {
64 const ZERO: f32 = 0.0;
65 const ONE: f32 = 1.0;
66 const MIN: f32 = f32::MIN;
67 const MAX: f32 = f32::MAX;
68 const NAN: f32 = f32::NAN;
69 const INFINITY: f32 = f32::INFINITY;
70
71 fn from_usize(v: usize) -> f32 {
72 v as f32
73 }
74 fn from_u16(v: u16) -> f32 {
75 f32::from(v)
76 }
77 fn is_nan(self) -> bool {
78 self.is_nan()
79 }
80 fn ln(self) -> f32 {
81 self.ln()
82 }
83 fn exp(self) -> f32 {
84 self.exp()
85 }
86}
87
88pub struct Matrix<'a, T> {
92 pub data: &'a [T],
93 pub index: Vec<usize>,
94 pub rows: usize,
95 pub cols: usize,
96 stride1: usize,
97 stride2: usize,
98}
99
100impl<'a, T> Matrix<'a, T> {
101 pub fn new(data: &'a [T], rows: usize, cols: usize) -> Self {
103 Matrix {
104 data,
105 index: (0..rows).collect(),
106 rows,
107 cols,
108 stride1: rows,
109 stride2: 1,
110 }
111 }
112
113 pub fn get(&self, i: usize, j: usize) -> &T {
118 &self.data[self.item_index(i, j)]
119 }
120
121 fn item_index(&self, i: usize, j: usize) -> usize {
122 let mut idx = self.stride2 * i;
123 idx += j * self.stride1;
124 idx
125 }
126
127 pub fn get_row_iter(
129 &self,
130 row: usize,
131 ) -> std::iter::StepBy<std::iter::Skip<std::slice::Iter<T>>> {
132 self.data.iter().skip(row).step_by(self.rows)
133 }
134
135 pub fn get_col_slice(&self, col: usize, start_row: usize, end_row: usize) -> &[T] {
141 let i = self.item_index(start_row, col);
142 let j = self.item_index(end_row, col);
143 &self.data[i..j]
144 }
145
146 pub fn get_col(&self, col: usize) -> &[T] {
150 self.get_col_slice(col, 0, self.rows)
151 }
152}
153
154impl<'a, T> Matrix<'a, T>
155where
156 T: Copy,
157{
158 pub fn get_row(&self, row: usize) -> Vec<T> {
160 self.get_row_iter(row).copied().collect()
161 }
162}
163
164#[derive(Debug, Serialize, Deserialize)]
171pub struct RowMajorMatrix<T> {
172 pub data: Vec<T>,
173 pub rows: usize,
174 pub cols: usize,
175 stride1: usize,
176 stride2: usize,
177}
178
179impl<T> RowMajorMatrix<T> {
180 pub fn new(data: Vec<T>, rows: usize, cols: usize) -> Self {
182 RowMajorMatrix {
183 data,
184 rows,
185 cols,
186 stride1: 1,
187 stride2: cols,
188 }
189 }
190
191 pub fn get(&self, i: usize, j: usize) -> &T {
196 &self.data[self.item_index(i, j)]
197 }
198
199 fn item_index(&self, i: usize, j: usize) -> usize {
200 let mut idx = self.stride2 * i;
201 idx += j * self.stride1;
202 idx
203 }
204
205 pub fn append_row(&mut self, items: Vec<T>) {
208 assert!(items.len() % self.cols == 0);
209 let new_rows = items.len() / self.cols;
210 self.rows += new_rows;
211 self.data.extend(items);
212 }
213}
214
215impl<'a, T> fmt::Display for Matrix<'a, T>
216where
217 T: FromStr + std::fmt::Display,
218 <T as FromStr>::Err: 'static + std::error::Error,
219{
220 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
223 let mut val = String::new();
224 for i in 0..self.rows {
225 for j in 0..self.cols {
226 val.push_str(self.get(i, j).to_string().as_str());
227 if j == (self.cols - 1) {
228 val.push('\n');
229 } else {
230 val.push(' ');
231 }
232 }
233 }
234 write!(f, "{}", val)
235 }
236}
237
238#[derive(Debug, Deserialize, Serialize)]
240pub struct JaggedMatrix<T> {
241 pub data: Vec<T>,
243 pub ends: Vec<usize>,
245 pub cols: usize,
247 pub n_records: usize,
249}
250
251impl<T> JaggedMatrix<T>
252where
253 T: Copy,
254{
255 pub fn from_vecs(vecs: &[Vec<T>]) -> Self {
257 let mut data = Vec::new();
258 let mut ends = Vec::new();
259 let mut e = 0;
260 let mut n_records = 0;
261 for vec in vecs {
262 for v in vec {
263 data.push(*v);
264 }
265 e += vec.len();
266 ends.push(e);
267 n_records += e;
268 }
269 let cols = vecs.len();
270
271 JaggedMatrix {
272 data,
273 ends,
274 cols,
275 n_records,
276 }
277 }
278}
279
280impl<T> JaggedMatrix<T> {
281 pub fn new() -> Self {
283 JaggedMatrix {
284 data: Vec::new(),
285 ends: Vec::new(),
286 cols: 0,
287 n_records: 0,
288 }
289 }
290 pub fn get_col(&self, col: usize) -> &[T] {
292 assert!(col < self.ends.len());
293 let (i, j) = if col == 0 {
294 (0, self.ends[col])
295 } else {
296 (self.ends[col - 1], self.ends[col])
297 };
298 &self.data[i..j]
299 }
300
301 pub fn get_col_mut(&mut self, col: usize) -> &mut [T] {
303 assert!(col < self.ends.len());
304 let (i, j) = if col == 0 {
305 (0, self.ends[col])
306 } else {
307 (self.ends[col - 1], self.ends[col])
308 };
309 &mut self.data[i..j]
310 }
311}
312
313impl<T> Default for JaggedMatrix<T> {
314 fn default() -> Self {
315 Self::new()
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_rowmatrix_get() {
325 let v = vec![1, 2, 3, 5, 6, 7];
326 let m = RowMajorMatrix::new(v, 2, 3);
327 println!("{:?}", m);
328 assert_eq!(m.get(0, 0), &1);
329 assert_eq!(m.get(1, 0), &5);
330 assert_eq!(m.get(0, 2), &3);
331 assert_eq!(m.get(1, 1), &6);
332 }
333
334 #[test]
335 fn test_rowmatrix_append() {
336 let v = vec![1, 2, 3, 5, 6, 7];
337 let mut m = RowMajorMatrix::new(v, 2, 3);
338 m.append_row(vec![-1, -2, -3]);
339 assert_eq!(m.get(2, 1), &-2);
340 }
341
342 #[test]
343 fn test_matrix_get() {
344 let v = vec![1, 2, 3, 5, 6, 7];
345 let m = Matrix::new(&v, 2, 3);
346 println!("{}", m);
347 assert_eq!(m.get(0, 0), &1);
348 assert_eq!(m.get(1, 0), &2);
349 }
350 #[test]
351 fn test_matrix_get_col_slice() {
352 let v = vec![1, 2, 3, 5, 6, 7];
353 let m = Matrix::new(&v, 3, 2);
354 assert_eq!(m.get_col_slice(0, 0, 3), &vec![1, 2, 3]);
355 assert_eq!(m.get_col_slice(1, 0, 2), &vec![5, 6]);
356 assert_eq!(m.get_col_slice(1, 1, 3), &vec![6, 7]);
357 assert_eq!(m.get_col_slice(0, 1, 2), &vec![2]);
358 }
359
360 #[test]
361 fn test_matrix_get_col() {
362 let v = vec![1, 2, 3, 5, 6, 7];
363 let m = Matrix::new(&v, 3, 2);
364 assert_eq!(m.get_col(1), &vec![5, 6, 7]);
365 }
366
367 #[test]
368 fn test_matrix_row() {
369 let v = vec![1, 2, 3, 5, 6, 7];
370 let m = Matrix::new(&v, 3, 2);
371 assert_eq!(m.get_row(2), vec![3, 7]);
372 assert_eq!(m.get_row(0), vec![1, 5]);
373 assert_eq!(m.get_row(1), vec![2, 6]);
374 }
375
376 #[test]
377 fn test_jaggedmatrix_get_col() {
378 let vecs = vec![vec![0], vec![5, 4, 3, 2], vec![4, 5]];
379 let jmatrix = JaggedMatrix::from_vecs(&vecs);
380 assert_eq!(jmatrix.get_col(1), vec![5, 4, 3, 2]);
381 assert_eq!(jmatrix.get_col(0), vec![0]);
382 assert_eq!(jmatrix.get_col(2), vec![4, 5]);
383 }
384}