Skip to main content

lak/
types.rs

1// types.rs 
2
3use std::ops::Range;
4
5/// enum for transpose ops 
6/// * [Transpose::NoTranspose] for no-transpose ops
7/// * [Transpose::Transpose] for transpose ops
8#[derive(Clone, Copy, Debug)] 
9pub enum Transpose { 
10    NoTranspose, 
11    Transpose, 
12}
13
14/// enum for triangular ops 
15/// * [Triangular::Upper] for upper-triangular ops 
16/// * [Triangular::Lower] for lower-triangular ops
17#[derive(Clone, Copy, Debug)] 
18pub enum Triangular { 
19    Upper, 
20    Lower, 
21}
22
23/// immutable vector type 
24#[derive(Clone, Copy, Debug)]
25pub struct VecRef<'a, T> { 
26    buffer: &'a [T], 
27}
28
29/// mutable vector type
30#[derive(Debug)]
31pub struct VecMut<'a, T> { 
32    buffer: &'a mut [T], 
33}
34
35/// immutable matrix type 
36/// column major 
37#[derive(Clone, Copy, Debug)]
38pub struct MatRef<'a, T> { 
39    buffer: &'a [T], 
40    dimension: (usize, usize), 
41}
42
43/// mutable matrix type 
44/// column major 
45#[derive(Debug)]
46pub struct MatMut<'a, T> { 
47    buffer: &'a mut [T],
48    dimension: (usize, usize), 
49}
50
51impl<'a, T> VecRef<'a, T> { 
52    /// constructs [VecRef] with given slice 
53    pub fn new(buffer: &'a [T]) -> Self { 
54        Self { buffer }
55    }
56
57    /// returns length of internal slice 
58    pub fn length(&self) -> usize { 
59        self.buffer.len()
60    }
61
62    /// accesses full internal immutable slice 
63    pub fn as_slice(&self) -> &[T] { 
64        self.buffer
65    }
66
67    /// accesses internal immutable slice over a given index range 
68    pub fn slice(&self, range: Range<usize>) -> &[T] { 
69        &self.buffer[range.start..range.end]
70    }
71
72    /// checks whether internal length is equal to given length parameter
73    pub fn has_equal_length(&self, length: usize) -> bool { 
74        self.buffer.len() == length
75    }
76}
77
78impl<'a, T> VecMut<'a, T> { 
79    /// constructs [VecMut] with given slice 
80    pub fn new(buffer: &'a mut [T]) -> Self { 
81        Self { buffer }
82    }
83
84    /// returns length of internal slice 
85    pub fn length(&self) -> usize { 
86        self.buffer.len()
87    }
88
89    /// accesses full internal slice as immutable 
90    pub fn as_slice(&self) -> &[T] { 
91        self.buffer
92    }
93
94    /// accesses internal immutable slice over a given index range 
95    pub fn slice(&self, range: Range<usize>) -> &[T] { 
96        &self.buffer[range.start..range.end]
97    }
98
99    /// accesses internal mutable slice over a given index range 
100    pub fn slice_mut(&mut self, range: Range<usize>) -> &mut [T] { 
101        &mut self.buffer[range.start..range.end]
102    }
103
104    /// accesses full internal slice as mutable 
105    pub fn as_slice_mut(&mut self) -> &mut [T] { 
106        self.buffer
107    }
108
109    /// checks whether internal length is equal to given length parameter
110    pub fn has_equal_length(&self, length: usize) -> bool { 
111        self.buffer.len() == length
112    }
113
114    /// used for calling routines over and over again 
115    /// on the stored internal mutable slice 
116    ///
117    /// borrows self mutably
118    ///
119    /// example:
120    /// ```
121    /// use lak::l1::scal;
122    /// use lak::types::VecMut;
123    ///
124    /// let mut x = [1.0, 2.0, 3.0];
125    /// let mut x = VecMut::new(&mut x);
126    ///
127    /// scal(2.0, x.reborrow());
128    /// scal(3.0, x.reborrow());
129    /// ```
130    pub fn reborrow(&mut self) -> VecMut<'_, T> { 
131        VecMut::new(self.as_slice_mut())
132    }
133}
134
135
136impl<'a, T> MatRef<'a, T> { 
137    /// constructs [MatRef] with given slice and (n_rows, n_cols) dimension
138    ///
139    /// example: 
140    ///
141    /// ``` 
142    /// use lak::types::MatRef; 
143    ///
144    /// // col-major 2 x 3 matrix: 
145    /// // [1 3 5] 
146    /// // [2 4 6] 
147    /// let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; 
148    /// let a = MatRef::new(&a, (2, 3));
149    /// ```
150    pub fn new(buffer: &'a [T], dimension: (usize, usize)) -> Self { 
151        let i = dimension.0; 
152        let j = dimension.1; 
153        let buffer_length = buffer.len(); 
154        let matrix_length = i * j; 
155
156        assert_eq!(
157            matrix_length,
158            buffer_length,
159            "matrix has invalid dimensions: buffer length is {buffer_length}, \
160             but dimensions are {i} x {j} = {matrix_length}",
161        );
162
163        Self { buffer, dimension }
164    }
165
166    /// accesses internal immutable slice 
167    pub fn as_slice(&self) -> &[T] { 
168        self.buffer 
169    }
170
171    /// accesses matrix dimension 
172    /// (n_rows, n_cols)
173    pub fn dimension(&self) -> (usize, usize) { 
174        self.dimension
175    }
176
177    /// accesses matrix number of rows 
178    pub fn n_rows(&self) -> usize { 
179        self.dimension.0
180    }
181
182    /// accesses matrix number of cols 
183    pub fn n_cols(&self) -> usize { 
184        self.dimension.1
185    }
186
187    /// return a [VecRef] for a given column in Self 
188    pub fn col(&self, j: usize) -> VecRef<'_, T> {
189        let n_rows = self.n_rows(); 
190        let beg_idx = n_rows * j; 
191        let end_idx = n_rows * (j + 1); 
192
193        let slice = &self.buffer[beg_idx..end_idx]; 
194        VecRef::new(slice)
195    }
196
197    /// return a [MatRef] for a contiguous column panel of Self
198    ///
199    /// example: 
200    /// ``` 
201    /// use lak::types::MatRef; 
202    ///
203    /// // [1 3 5] 
204    /// // [2 4 6] 
205    /// let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; 
206    /// let a = MatRef::new(&a, (2, 3)); 
207    ///
208    /// // MatRef of columns 1..3 
209    /// // [3 5] 
210    /// // [4 6] 
211    /// let panel = a.col_panel(1..3); 
212    /// ```
213    pub fn col_panel(&self, cols: Range<usize>) -> MatRef<'_, T> { 
214        debug_assert!(
215            cols.start < cols.end,
216            "start of col range must be < end of col range"
217        );
218        debug_assert!(
219            cols.end <= self.dimension.1, 
220            "end of col range must be <= number cols in Self"
221        );
222
223        let n_rows = self.n_rows(); 
224        let n_cols = cols.end - cols.start; 
225        let beg_idx = n_rows * cols.start; 
226        let end_idx = n_rows * cols.end; 
227
228        MatRef::new(
229            &self.buffer[beg_idx..end_idx], 
230            (n_rows, n_cols)
231        )     
232    }
233
234    /// returns an [Iterator] over [MatRef]s containing column panels that 
235    /// span Self. 
236    ///
237    /// each panel holds nc columns, and the last item is the leftover 
238    /// panel with column width < nc 
239    ///
240    /// args: 
241    /// * nc: [usize] - # cols in panel 
242    ///
243    /// returns: 
244    /// * [Iterator] - over ([Range] of column idxs used in panel, [MatRef] of panel itself)
245    pub fn col_panels(&self, nc: usize) -> impl DoubleEndedIterator<Item = (Range<usize>, MatRef<'_, T>)> { 
246        debug_assert!(nc > 0, "nc must be > 0"); 
247
248        let n_cols = self.n_cols(); 
249        (0..n_cols).step_by(nc).map(move |j0| { 
250            let j1 = usize::min(j0 + nc, n_cols); 
251
252            (Range {start: j0, end: j1}, self.col_panel(j0..j1))
253        })
254    }
255
256    /// checks whether matrix n_cols equals given length
257    pub fn has_len_equal_to_n_cols(&self, length: usize) -> bool { 
258        self.dimension.1 == length 
259    }
260
261    /// checks whether matrix n_rows equals given length 
262    pub fn has_len_equal_to_n_rows(&self, length: usize) -> bool { 
263        self.dimension.0 == length
264    }
265}
266
267impl<'a, T> MatMut<'a, T> { 
268    /// constructs [MatMut] with given slice and (n_rows, n_cols) dimension
269    ///
270    /// example: 
271    /// ``` 
272    /// use lak::types::MatMut; 
273    ///
274    /// // col-major 2 x 3 matrix: 
275    /// // [1 3 5] 
276    /// // [2 4 6] 
277    /// let mut a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; 
278    /// let a = MatMut::new(&mut a, (2, 3));
279    /// ```
280    pub fn new(buffer: &'a mut [T], dimension: (usize, usize)) -> Self { 
281        let i = dimension.0; 
282        let j = dimension.1; 
283        let buffer_length = buffer.len(); 
284        let matrix_length = i * j; 
285
286        assert_eq!(
287            matrix_length,
288            buffer_length,
289            "matrix has invalid dimensions: buffer length is {buffer_length}, \
290             but dimensions are {i} x {j} = {matrix_length}"
291        );
292
293        Self { buffer, dimension }
294    }
295
296    /// accesses full internal immutable slice 
297    pub fn as_slice(&self) -> &[T] { 
298        self.buffer 
299    }
300
301    /// accesses full internal slice as mutable 
302    pub fn as_slice_mut(&mut self) -> &mut [T] { 
303        self.buffer
304    }
305
306    /// accesses matrix dimension 
307    /// (n_rows, n_cols)
308    pub fn dimension(&self) -> (usize, usize) { 
309        self.dimension
310    }
311
312    /// accesses matrix number of rows 
313    pub fn n_rows(&self) -> usize { 
314        self.dimension.0
315    }
316
317    /// accesses matrix number of cols 
318    pub fn n_cols(&self) -> usize { 
319        self.dimension.1
320    }
321
322    /// return a [VecRef] for a given column in Self 
323    pub fn col(&self, j: usize) -> VecRef<'_, T> { 
324        let n_rows = self.n_rows(); 
325        let beg_idx = n_rows * j; 
326        let end_idx = n_rows * (j + 1); 
327
328        let slice = &self.buffer[beg_idx..end_idx]; 
329        VecRef::new(slice)
330    }
331
332    /// return a [VecMut] for a given column in Self 
333    pub fn col_mut(&mut self, j: usize) -> VecMut<'_, T> { 
334        let n_rows = self.n_rows(); 
335        let beg_idx = n_rows * j; 
336        let end_idx = n_rows * (j + 1); 
337
338        let slice = &mut self.buffer[beg_idx..end_idx]; 
339        VecMut::new(slice)
340    }
341
342    /// return a [MatRef] for a contiguous column panel of Self 
343    /// 
344    /// contains full columns over a given a range of column indices.
345    ///
346    /// example: 
347    /// ``` 
348    /// use lak::types::{MatRef, MatMut}; 
349    ///
350    /// // [1 3 5] 
351    /// // [2 4 6] 
352    /// let mut a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; 
353    /// let a = MatMut::new(&mut a, (2, 3)); 
354    ///
355    /// // MatRef of columns 1..3 
356    /// // [3 5] 
357    /// // [4 6] 
358    /// let panel = a.col_panel(1..3); 
359    /// ```
360    pub fn col_panel(&self, cols: Range<usize>) -> MatRef<'_, T> { 
361        debug_assert!(
362            cols.start < cols.end,
363            "start of col range must be < end of col range"
364        );
365        debug_assert!(
366            cols.end <= self.dimension.1, 
367            "end of col range must be <= number cols in Self"
368        );
369
370        let n_rows = self.n_rows(); 
371        let n_cols = cols.end - cols.start; 
372        let beg_idx = n_rows * cols.start; 
373        let end_idx = n_rows * cols.end; 
374
375        MatRef::new(
376            &self.buffer[beg_idx..end_idx], 
377            (n_rows, n_cols)
378        )     
379    }
380
381    /// returns a [MatMut] for a contiguous column panel of Self 
382    /// 
383    /// contains full columns over a given a range of column indices. 
384    ///
385    /// example: 
386    /// ``` 
387    /// use lak::types::MatMut; 
388    ///
389    /// // [1 3 5] 
390    /// // [2 4 6] 
391    /// let mut a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; 
392    /// let mut a = MatMut::new(&mut a, (2, 3)); 
393    ///
394    /// // MatMut of columns 1..3 
395    /// // [3 5] 
396    /// // [4 6] 
397    /// let panel = a.col_panel_mut(1..3); 
398    /// ```
399    pub fn col_panel_mut(&mut self, cols: Range<usize>) -> MatMut<'_, T> { 
400        debug_assert!(
401            cols.start < cols.end,
402            "start of col range must be < end of col range"
403        );
404        debug_assert!(
405            cols.end <= self.dimension.1, 
406            "end of col range must be <= number cols in Self"
407        );
408
409        let n_rows = self.n_rows(); 
410        let n_cols = cols.end - cols.start; 
411        let beg_idx = n_rows * cols.start; 
412        let end_idx = n_rows * cols.end; 
413
414        MatMut::new(
415            &mut self.buffer[beg_idx..end_idx], 
416            (n_rows, n_cols)
417        )     
418    }
419
420    /// return an [Iterator] over [MatRef]s chunks containing column 
421    /// panels that span Self. 
422    ///
423    /// each chunk holds nc columns, and the last item is the leftover 
424    /// column panel with n_cols < nc 
425    ///
426    /// args: 
427    /// * nc: [usize] - # cols in panel 
428    ///
429    /// returns: 
430    /// * [Iterator] - over ([Range] of column idxs used in panel, [MatRef] of panel itself)
431    pub fn col_panels(&self, nc: usize) -> impl DoubleEndedIterator<Item = (Range<usize>, MatRef<'_, T>)> { 
432        debug_assert!(nc > 0, "nc must be > 0");         
433        
434        let n_cols = self.n_cols();
435        (0..n_cols).step_by(nc).map(move |j0| { 
436            let j1 = usize::min(j0 + nc, n_cols); 
437
438            (Range {start: j0, end: j1}, self.col_panel(j0..j1))
439        })
440    }
441
442    /// checks whether matrix n_cols equals given length
443    pub fn has_len_equal_to_n_cols(&self, length: usize) -> bool { 
444        self.dimension.1 == length 
445    }
446
447    /// checks whether matrix n_rows equals given length 
448    pub fn has_len_equal_to_n_rows(&self, length: usize) -> bool { 
449        self.dimension.0 == length
450    }
451
452    /// used for calling routines over and over again 
453    /// on the stored internal mutable slice 
454    ///
455    /// borrows self mutably
456    ///
457    /// example:
458    /// ```
459    /// use lak::l2::ger;
460    /// use lak::types::{MatMut, VecRef};
461    ///
462    /// let x = [1.0, 2.0];
463    /// let y = [3.0, 4.0];
464    /// let mut a = [0.0; 4];
465    ///
466    /// let x = VecRef::new(&x);
467    /// let y = VecRef::new(&y);
468    /// let mut a = MatMut::new(&mut a, (2, 2));
469    ///
470    /// ger(1.0, a.reborrow(), x, y);
471    /// ger(1.0, a.reborrow(), x, y);
472    /// ```
473    pub fn reborrow(&mut self) -> MatMut<'_, T> { 
474        let (n_rows, n_cols) = self.dimension();
475        MatMut::new(self.as_slice_mut(), (n_rows, n_cols))
476    }
477}
478
479
480/// asserts two [VecRef]/[VecMut] have equal length buffers
481#[macro_export]
482macro_rules! assert_length_eq {
483    ($x:expr, $y:expr) => {
484        assert!(
485            $x.has_equal_length($y.length()),
486            "number of elements must be equal"
487        );
488    };
489}
490
491
492/// asserts the length of a [VecRef]/[VecMut] buffer 
493/// equals the number of cols in a [MatRef]/[MatMut] 
494///
495/// a.assert_length_eq_n_cols(x); 
496#[macro_export]
497macro_rules! assert_length_eq_n_cols {
498    ($a:expr, $x:expr) => {
499        assert!(
500            $a.has_len_equal_to_n_cols($x.length()),
501            "number of cols in a does not match length of x"
502        );
503    };
504}
505
506/// asserts the length of a [VecRef]/[VecMut] buffer 
507/// equals the number of rows in a [MatRef]/[MatMut] 
508///
509/// a.assert_length_eq_n_rows(x); 
510#[macro_export]
511macro_rules! assert_length_eq_n_rows {
512    ($a:expr, $x:expr) => {
513        assert!(
514            $a.has_len_equal_to_n_rows($x.length()),
515            "number of rows in a does not match length of x"
516        );
517    };
518}
519