forust_ml/
data.rs

1use serde::{Deserialize, Serialize};
2use std::fmt::{self, Debug, Display};
3use std::iter::Sum;
4use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign};
5use std::str::FromStr;
6
7/// Data trait used throughout the package
8/// to control for floating point numbers.
9pub trait FloatData<T>:
10    Mul<Output = T>
11    + Display
12    + Add<Output = T>
13    + Div<Output = T>
14    + Neg<Output = T>
15    + Copy
16    + Debug
17    + PartialEq
18    + PartialOrd
19    + AddAssign
20    + Sub<Output = T>
21    + SubAssign
22    + Sum
23    + std::marker::Send
24    + std::marker::Sync
25{
26    const ZERO: T;
27    const ONE: T;
28    const MIN: T;
29    const MAX: T;
30    const NAN: T;
31    const INFINITY: T;
32    fn from_usize(v: usize) -> T;
33    fn from_u16(v: u16) -> T;
34    fn is_nan(self) -> bool;
35    fn ln(self) -> T;
36    fn exp(self) -> T;
37}
38impl FloatData<f64> for f64 {
39    const ZERO: f64 = 0.0;
40    const ONE: f64 = 1.0;
41    const MIN: f64 = f64::MIN;
42    const MAX: f64 = f64::MAX;
43    const NAN: f64 = f64::NAN;
44    const INFINITY: f64 = f64::INFINITY;
45
46    fn from_usize(v: usize) -> f64 {
47        v as f64
48    }
49    fn from_u16(v: u16) -> f64 {
50        f64::from(v)
51    }
52    fn is_nan(self) -> bool {
53        self.is_nan()
54    }
55    fn ln(self) -> f64 {
56        self.ln()
57    }
58    fn exp(self) -> f64 {
59        self.exp()
60    }
61}
62
63impl FloatData<f32> for f32 {
64    const ZERO: f32 = 0.0;
65    const ONE: f32 = 1.0;
66    const MIN: f32 = f32::MIN;
67    const MAX: f32 = f32::MAX;
68    const NAN: f32 = f32::NAN;
69    const INFINITY: f32 = f32::INFINITY;
70
71    fn from_usize(v: usize) -> f32 {
72        v as f32
73    }
74    fn from_u16(v: u16) -> f32 {
75        f32::from(v)
76    }
77    fn is_nan(self) -> bool {
78        self.is_nan()
79    }
80    fn ln(self) -> f32 {
81        self.ln()
82    }
83    fn exp(self) -> f32 {
84        self.exp()
85    }
86}
87
88/// Contigious Column major matrix data container. This is
89/// used throughout the crate, to house both the user provided data
90/// as well as the binned data.
91pub struct Matrix<'a, T> {
92    pub data: &'a [T],
93    pub index: Vec<usize>,
94    pub rows: usize,
95    pub cols: usize,
96    stride1: usize,
97    stride2: usize,
98}
99
100impl<'a, T> Matrix<'a, T> {
101    // Defaults to column major
102    pub fn new(data: &'a [T], rows: usize, cols: usize) -> Self {
103        Matrix {
104            data,
105            index: (0..rows).collect(),
106            rows,
107            cols,
108            stride1: rows,
109            stride2: 1,
110        }
111    }
112
113    /// Get a single reference to an item in the matrix.
114    ///
115    /// * `i` - The ith row of the data to get.
116    /// * `j` - the jth column of the data to get.
117    pub fn get(&self, i: usize, j: usize) -> &T {
118        &self.data[self.item_index(i, j)]
119    }
120
121    fn item_index(&self, i: usize, j: usize) -> usize {
122        let mut idx = self.stride2 * i;
123        idx += j * self.stride1;
124        idx
125    }
126
127    /// Get access to a row of the data, as an iterator.
128    pub fn get_row_iter(
129        &self,
130        row: usize,
131    ) -> std::iter::StepBy<std::iter::Skip<std::slice::Iter<T>>> {
132        self.data.iter().skip(row).step_by(self.rows)
133    }
134
135    /// Get a slice of a column in the matrix.
136    ///
137    /// * `col` - The index of the column to select.
138    /// * `start_row` - The index of the start of the slice.
139    /// * `end_row` - The index of the end of the slice of the column to select.
140    pub fn get_col_slice(&self, col: usize, start_row: usize, end_row: usize) -> &[T] {
141        let i = self.item_index(start_row, col);
142        let j = self.item_index(end_row, col);
143        &self.data[i..j]
144    }
145
146    /// Get an entire column in the matrix.
147    ///
148    /// * `col` - The index of the column to get.
149    pub fn get_col(&self, col: usize) -> &[T] {
150        self.get_col_slice(col, 0, self.rows)
151    }
152}
153
154impl<'a, T> Matrix<'a, T>
155where
156    T: Copy,
157{
158    /// Get a row of the data as a vector.
159    pub fn get_row(&self, row: usize) -> Vec<T> {
160        self.get_row_iter(row).copied().collect()
161    }
162}
163
164/// A lightweight row major matrix, this is primarily
165/// for returning data to the user, it is especially
166/// suited for appending rows to, such as when building
167/// up a matrix of data to return to the
168/// user, the added benefit is it will be even
169/// faster to return to numpy.
170#[derive(Debug, Serialize, Deserialize)]
171pub struct RowMajorMatrix<T> {
172    pub data: Vec<T>,
173    pub rows: usize,
174    pub cols: usize,
175    stride1: usize,
176    stride2: usize,
177}
178
179impl<T> RowMajorMatrix<T> {
180    // Defaults to column major
181    pub fn new(data: Vec<T>, rows: usize, cols: usize) -> Self {
182        RowMajorMatrix {
183            data,
184            rows,
185            cols,
186            stride1: 1,
187            stride2: cols,
188        }
189    }
190
191    /// Get a single reference to an item in the matrix.
192    ///
193    /// * `i` - The ith row of the data to get.
194    /// * `j` - the jth column of the data to get.
195    pub fn get(&self, i: usize, j: usize) -> &T {
196        &self.data[self.item_index(i, j)]
197    }
198
199    fn item_index(&self, i: usize, j: usize) -> usize {
200        let mut idx = self.stride2 * i;
201        idx += j * self.stride1;
202        idx
203    }
204
205    /// Add a rows to the matrix, this can be multiple
206    /// rows, if they are in sequential order in the items.
207    pub fn append_row(&mut self, items: Vec<T>) {
208        assert!(items.len() % self.cols == 0);
209        let new_rows = items.len() / self.cols;
210        self.rows += new_rows;
211        self.data.extend(items);
212    }
213}
214
215impl<'a, T> fmt::Display for Matrix<'a, T>
216where
217    T: FromStr + std::fmt::Display,
218    <T as FromStr>::Err: 'static + std::error::Error,
219{
220    // This trait requires `fmt` with this exact signature.
221    /// Format a Matrix.
222    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
223        let mut val = String::new();
224        for i in 0..self.rows {
225            for j in 0..self.cols {
226                val.push_str(self.get(i, j).to_string().as_str());
227                if j == (self.cols - 1) {
228                    val.push('\n');
229                } else {
230                    val.push(' ');
231                }
232            }
233        }
234        write!(f, "{}", val)
235    }
236}
237
238/// A jagged column aligned matrix, that owns it's data contents.
239#[derive(Debug, Deserialize, Serialize)]
240pub struct JaggedMatrix<T> {
241    /// The contents of the matrix.
242    pub data: Vec<T>,
243    /// The end index's of the matrix.
244    pub ends: Vec<usize>,
245    /// Number of columns in the matrix
246    pub cols: usize,
247    /// The number of elements in the matrix.
248    pub n_records: usize,
249}
250
251impl<T> JaggedMatrix<T>
252where
253    T: Copy,
254{
255    /// Generate a jagged array from a vector of vectors
256    pub fn from_vecs(vecs: &[Vec<T>]) -> Self {
257        let mut data = Vec::new();
258        let mut ends = Vec::new();
259        let mut e = 0;
260        let mut n_records = 0;
261        for vec in vecs {
262            for v in vec {
263                data.push(*v);
264            }
265            e += vec.len();
266            ends.push(e);
267            n_records += e;
268        }
269        let cols = vecs.len();
270
271        JaggedMatrix {
272            data,
273            ends,
274            cols,
275            n_records,
276        }
277    }
278}
279
280impl<T> JaggedMatrix<T> {
281    /// Create a new jagged matrix.
282    pub fn new() -> Self {
283        JaggedMatrix {
284            data: Vec::new(),
285            ends: Vec::new(),
286            cols: 0,
287            n_records: 0,
288        }
289    }
290    /// Get the column of a jagged array.
291    pub fn get_col(&self, col: usize) -> &[T] {
292        assert!(col < self.ends.len());
293        let (i, j) = if col == 0 {
294            (0, self.ends[col])
295        } else {
296            (self.ends[col - 1], self.ends[col])
297        };
298        &self.data[i..j]
299    }
300
301    /// Get a mutable reference to a column of the array.
302    pub fn get_col_mut(&mut self, col: usize) -> &mut [T] {
303        assert!(col < self.ends.len());
304        let (i, j) = if col == 0 {
305            (0, self.ends[col])
306        } else {
307            (self.ends[col - 1], self.ends[col])
308        };
309        &mut self.data[i..j]
310    }
311}
312
313impl<T> Default for JaggedMatrix<T> {
314    fn default() -> Self {
315        Self::new()
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_rowmatrix_get() {
325        let v = vec![1, 2, 3, 5, 6, 7];
326        let m = RowMajorMatrix::new(v, 2, 3);
327        println!("{:?}", m);
328        assert_eq!(m.get(0, 0), &1);
329        assert_eq!(m.get(1, 0), &5);
330        assert_eq!(m.get(0, 2), &3);
331        assert_eq!(m.get(1, 1), &6);
332    }
333
334    #[test]
335    fn test_rowmatrix_append() {
336        let v = vec![1, 2, 3, 5, 6, 7];
337        let mut m = RowMajorMatrix::new(v, 2, 3);
338        m.append_row(vec![-1, -2, -3]);
339        assert_eq!(m.get(2, 1), &-2);
340    }
341
342    #[test]
343    fn test_matrix_get() {
344        let v = vec![1, 2, 3, 5, 6, 7];
345        let m = Matrix::new(&v, 2, 3);
346        println!("{}", m);
347        assert_eq!(m.get(0, 0), &1);
348        assert_eq!(m.get(1, 0), &2);
349    }
350    #[test]
351    fn test_matrix_get_col_slice() {
352        let v = vec![1, 2, 3, 5, 6, 7];
353        let m = Matrix::new(&v, 3, 2);
354        assert_eq!(m.get_col_slice(0, 0, 3), &vec![1, 2, 3]);
355        assert_eq!(m.get_col_slice(1, 0, 2), &vec![5, 6]);
356        assert_eq!(m.get_col_slice(1, 1, 3), &vec![6, 7]);
357        assert_eq!(m.get_col_slice(0, 1, 2), &vec![2]);
358    }
359
360    #[test]
361    fn test_matrix_get_col() {
362        let v = vec![1, 2, 3, 5, 6, 7];
363        let m = Matrix::new(&v, 3, 2);
364        assert_eq!(m.get_col(1), &vec![5, 6, 7]);
365    }
366
367    #[test]
368    fn test_matrix_row() {
369        let v = vec![1, 2, 3, 5, 6, 7];
370        let m = Matrix::new(&v, 3, 2);
371        assert_eq!(m.get_row(2), vec![3, 7]);
372        assert_eq!(m.get_row(0), vec![1, 5]);
373        assert_eq!(m.get_row(1), vec![2, 6]);
374    }
375
376    #[test]
377    fn test_jaggedmatrix_get_col() {
378        let vecs = vec![vec![0], vec![5, 4, 3, 2], vec![4, 5]];
379        let jmatrix = JaggedMatrix::from_vecs(&vecs);
380        assert_eq!(jmatrix.get_col(1), vec![5, 4, 3, 2]);
381        assert_eq!(jmatrix.get_col(0), vec![0]);
382        assert_eq!(jmatrix.get_col(2), vec![4, 5]);
383    }
384}