lak/types.rs
1// types.rs
2
3use std::ops::Range;
4
5/// enum for transpose ops
6/// * [Transpose::NoTranspose] for no-transpose ops
7/// * [Transpose::Transpose] for transpose ops
8#[derive(Clone, Copy, Debug)]
9pub enum Transpose {
10 NoTranspose,
11 Transpose,
12}
13
14/// enum for triangular ops
15/// * [Triangular::Upper] for upper-triangular ops
16/// * [Triangular::Lower] for lower-triangular ops
17#[derive(Clone, Copy, Debug)]
18pub enum Triangular {
19 Upper,
20 Lower,
21}
22
23/// immutable vector type
24#[derive(Clone, Copy, Debug)]
25pub struct VecRef<'a, T> {
26 buffer: &'a [T],
27}
28
29/// mutable vector type
30#[derive(Debug)]
31pub struct VecMut<'a, T> {
32 buffer: &'a mut [T],
33}
34
35/// immutable matrix type
36/// column major
37#[derive(Clone, Copy, Debug)]
38pub struct MatRef<'a, T> {
39 buffer: &'a [T],
40 dimension: (usize, usize),
41}
42
43/// mutable matrix type
44/// column major
45#[derive(Debug)]
46pub struct MatMut<'a, T> {
47 buffer: &'a mut [T],
48 dimension: (usize, usize),
49}
50
51impl<'a, T> VecRef<'a, T> {
52 /// constructs [VecRef] with given slice
53 pub fn new(buffer: &'a [T]) -> Self {
54 Self { buffer }
55 }
56
57 /// returns length of internal slice
58 pub fn length(&self) -> usize {
59 self.buffer.len()
60 }
61
62 /// accesses full internal immutable slice
63 pub fn as_slice(&self) -> &[T] {
64 self.buffer
65 }
66
67 /// accesses internal immutable slice over a given index range
68 pub fn slice(&self, range: Range<usize>) -> &[T] {
69 &self.buffer[range.start..range.end]
70 }
71
72 /// checks whether internal length is equal to given length parameter
73 pub fn has_equal_length(&self, length: usize) -> bool {
74 self.buffer.len() == length
75 }
76}
77
78impl<'a, T> VecMut<'a, T> {
79 /// constructs [VecMut] with given slice
80 pub fn new(buffer: &'a mut [T]) -> Self {
81 Self { buffer }
82 }
83
84 /// returns length of internal slice
85 pub fn length(&self) -> usize {
86 self.buffer.len()
87 }
88
89 /// accesses full internal slice as immutable
90 pub fn as_slice(&self) -> &[T] {
91 self.buffer
92 }
93
94 /// accesses internal immutable slice over a given index range
95 pub fn slice(&self, range: Range<usize>) -> &[T] {
96 &self.buffer[range.start..range.end]
97 }
98
99 /// accesses internal mutable slice over a given index range
100 pub fn slice_mut(&mut self, range: Range<usize>) -> &mut [T] {
101 &mut self.buffer[range.start..range.end]
102 }
103
104 /// accesses full internal slice as mutable
105 pub fn as_slice_mut(&mut self) -> &mut [T] {
106 self.buffer
107 }
108
109 /// checks whether internal length is equal to given length parameter
110 pub fn has_equal_length(&self, length: usize) -> bool {
111 self.buffer.len() == length
112 }
113
114 /// used for calling routines over and over again
115 /// on the stored internal mutable slice
116 ///
117 /// borrows self mutably
118 ///
119 /// example:
120 /// ```
121 /// use lak::l1::scal;
122 /// use lak::types::VecMut;
123 ///
124 /// let mut x = [1.0, 2.0, 3.0];
125 /// let mut x = VecMut::new(&mut x);
126 ///
127 /// scal(2.0, x.reborrow());
128 /// scal(3.0, x.reborrow());
129 /// ```
130 pub fn reborrow(&mut self) -> VecMut<'_, T> {
131 VecMut::new(self.as_slice_mut())
132 }
133}
134
135
136impl<'a, T> MatRef<'a, T> {
137 /// constructs [MatRef] with given slice and (n_rows, n_cols) dimension
138 ///
139 /// example:
140 ///
141 /// ```
142 /// use lak::types::MatRef;
143 ///
144 /// // col-major 2 x 3 matrix:
145 /// // [1 3 5]
146 /// // [2 4 6]
147 /// let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
148 /// let a = MatRef::new(&a, (2, 3));
149 /// ```
150 pub fn new(buffer: &'a [T], dimension: (usize, usize)) -> Self {
151 let i = dimension.0;
152 let j = dimension.1;
153 let buffer_length = buffer.len();
154 let matrix_length = i * j;
155
156 assert_eq!(
157 matrix_length,
158 buffer_length,
159 "matrix has invalid dimensions: buffer length is {buffer_length}, \
160 but dimensions are {i} x {j} = {matrix_length}",
161 );
162
163 Self { buffer, dimension }
164 }
165
166 /// accesses internal immutable slice
167 pub fn as_slice(&self) -> &[T] {
168 self.buffer
169 }
170
171 /// accesses matrix dimension
172 /// (n_rows, n_cols)
173 pub fn dimension(&self) -> (usize, usize) {
174 self.dimension
175 }
176
177 /// accesses matrix number of rows
178 pub fn n_rows(&self) -> usize {
179 self.dimension.0
180 }
181
182 /// accesses matrix number of cols
183 pub fn n_cols(&self) -> usize {
184 self.dimension.1
185 }
186
187 /// return a [VecRef] for a given column in Self
188 pub fn col(&self, j: usize) -> VecRef<'_, T> {
189 let n_rows = self.n_rows();
190 let beg_idx = n_rows * j;
191 let end_idx = n_rows * (j + 1);
192
193 let slice = &self.buffer[beg_idx..end_idx];
194 VecRef::new(slice)
195 }
196
197 /// return a [MatRef] for a contiguous column panel of Self
198 ///
199 /// example:
200 /// ```
201 /// use lak::types::MatRef;
202 ///
203 /// // [1 3 5]
204 /// // [2 4 6]
205 /// let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
206 /// let a = MatRef::new(&a, (2, 3));
207 ///
208 /// // MatRef of columns 1..3
209 /// // [3 5]
210 /// // [4 6]
211 /// let panel = a.col_panel(1..3);
212 /// ```
213 pub fn col_panel(&self, cols: Range<usize>) -> MatRef<'_, T> {
214 debug_assert!(
215 cols.start < cols.end,
216 "start of col range must be < end of col range"
217 );
218 debug_assert!(
219 cols.end <= self.dimension.1,
220 "end of col range must be <= number cols in Self"
221 );
222
223 let n_rows = self.n_rows();
224 let n_cols = cols.end - cols.start;
225 let beg_idx = n_rows * cols.start;
226 let end_idx = n_rows * cols.end;
227
228 MatRef::new(
229 &self.buffer[beg_idx..end_idx],
230 (n_rows, n_cols)
231 )
232 }
233
234 /// returns an [Iterator] over [MatRef]s containing column panels that
235 /// span Self.
236 ///
237 /// each panel holds nc columns, and the last item is the leftover
238 /// panel with column width < nc
239 ///
240 /// args:
241 /// * nc: [usize] - # cols in panel
242 ///
243 /// returns:
244 /// * [Iterator] - over ([Range] of column idxs used in panel, [MatRef] of panel itself)
245 pub fn col_panels(&self, nc: usize) -> impl DoubleEndedIterator<Item = (Range<usize>, MatRef<'_, T>)> {
246 debug_assert!(nc > 0, "nc must be > 0");
247
248 let n_cols = self.n_cols();
249 (0..n_cols).step_by(nc).map(move |j0| {
250 let j1 = usize::min(j0 + nc, n_cols);
251
252 (Range {start: j0, end: j1}, self.col_panel(j0..j1))
253 })
254 }
255
256 /// checks whether matrix n_cols equals given length
257 pub fn has_len_equal_to_n_cols(&self, length: usize) -> bool {
258 self.dimension.1 == length
259 }
260
261 /// checks whether matrix n_rows equals given length
262 pub fn has_len_equal_to_n_rows(&self, length: usize) -> bool {
263 self.dimension.0 == length
264 }
265}
266
267impl<'a, T> MatMut<'a, T> {
268 /// constructs [MatMut] with given slice and (n_rows, n_cols) dimension
269 ///
270 /// example:
271 /// ```
272 /// use lak::types::MatMut;
273 ///
274 /// // col-major 2 x 3 matrix:
275 /// // [1 3 5]
276 /// // [2 4 6]
277 /// let mut a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
278 /// let a = MatMut::new(&mut a, (2, 3));
279 /// ```
280 pub fn new(buffer: &'a mut [T], dimension: (usize, usize)) -> Self {
281 let i = dimension.0;
282 let j = dimension.1;
283 let buffer_length = buffer.len();
284 let matrix_length = i * j;
285
286 assert_eq!(
287 matrix_length,
288 buffer_length,
289 "matrix has invalid dimensions: buffer length is {buffer_length}, \
290 but dimensions are {i} x {j} = {matrix_length}"
291 );
292
293 Self { buffer, dimension }
294 }
295
296 /// accesses full internal immutable slice
297 pub fn as_slice(&self) -> &[T] {
298 self.buffer
299 }
300
301 /// accesses full internal slice as mutable
302 pub fn as_slice_mut(&mut self) -> &mut [T] {
303 self.buffer
304 }
305
306 /// accesses matrix dimension
307 /// (n_rows, n_cols)
308 pub fn dimension(&self) -> (usize, usize) {
309 self.dimension
310 }
311
312 /// accesses matrix number of rows
313 pub fn n_rows(&self) -> usize {
314 self.dimension.0
315 }
316
317 /// accesses matrix number of cols
318 pub fn n_cols(&self) -> usize {
319 self.dimension.1
320 }
321
322 /// return a [VecRef] for a given column in Self
323 pub fn col(&self, j: usize) -> VecRef<'_, T> {
324 let n_rows = self.n_rows();
325 let beg_idx = n_rows * j;
326 let end_idx = n_rows * (j + 1);
327
328 let slice = &self.buffer[beg_idx..end_idx];
329 VecRef::new(slice)
330 }
331
332 /// return a [VecMut] for a given column in Self
333 pub fn col_mut(&mut self, j: usize) -> VecMut<'_, T> {
334 let n_rows = self.n_rows();
335 let beg_idx = n_rows * j;
336 let end_idx = n_rows * (j + 1);
337
338 let slice = &mut self.buffer[beg_idx..end_idx];
339 VecMut::new(slice)
340 }
341
342 /// return a [MatRef] for a contiguous column panel of Self
343 ///
344 /// contains full columns over a given a range of column indices.
345 ///
346 /// example:
347 /// ```
348 /// use lak::types::{MatRef, MatMut};
349 ///
350 /// // [1 3 5]
351 /// // [2 4 6]
352 /// let mut a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
353 /// let a = MatMut::new(&mut a, (2, 3));
354 ///
355 /// // MatRef of columns 1..3
356 /// // [3 5]
357 /// // [4 6]
358 /// let panel = a.col_panel(1..3);
359 /// ```
360 pub fn col_panel(&self, cols: Range<usize>) -> MatRef<'_, T> {
361 debug_assert!(
362 cols.start < cols.end,
363 "start of col range must be < end of col range"
364 );
365 debug_assert!(
366 cols.end <= self.dimension.1,
367 "end of col range must be <= number cols in Self"
368 );
369
370 let n_rows = self.n_rows();
371 let n_cols = cols.end - cols.start;
372 let beg_idx = n_rows * cols.start;
373 let end_idx = n_rows * cols.end;
374
375 MatRef::new(
376 &self.buffer[beg_idx..end_idx],
377 (n_rows, n_cols)
378 )
379 }
380
381 /// returns a [MatMut] for a contiguous column panel of Self
382 ///
383 /// contains full columns over a given a range of column indices.
384 ///
385 /// example:
386 /// ```
387 /// use lak::types::MatMut;
388 ///
389 /// // [1 3 5]
390 /// // [2 4 6]
391 /// let mut a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
392 /// let mut a = MatMut::new(&mut a, (2, 3));
393 ///
394 /// // MatMut of columns 1..3
395 /// // [3 5]
396 /// // [4 6]
397 /// let panel = a.col_panel_mut(1..3);
398 /// ```
399 pub fn col_panel_mut(&mut self, cols: Range<usize>) -> MatMut<'_, T> {
400 debug_assert!(
401 cols.start < cols.end,
402 "start of col range must be < end of col range"
403 );
404 debug_assert!(
405 cols.end <= self.dimension.1,
406 "end of col range must be <= number cols in Self"
407 );
408
409 let n_rows = self.n_rows();
410 let n_cols = cols.end - cols.start;
411 let beg_idx = n_rows * cols.start;
412 let end_idx = n_rows * cols.end;
413
414 MatMut::new(
415 &mut self.buffer[beg_idx..end_idx],
416 (n_rows, n_cols)
417 )
418 }
419
420 /// return an [Iterator] over [MatRef]s chunks containing column
421 /// panels that span Self.
422 ///
423 /// each chunk holds nc columns, and the last item is the leftover
424 /// column panel with n_cols < nc
425 ///
426 /// args:
427 /// * nc: [usize] - # cols in panel
428 ///
429 /// returns:
430 /// * [Iterator] - over ([Range] of column idxs used in panel, [MatRef] of panel itself)
431 pub fn col_panels(&self, nc: usize) -> impl DoubleEndedIterator<Item = (Range<usize>, MatRef<'_, T>)> {
432 debug_assert!(nc > 0, "nc must be > 0");
433
434 let n_cols = self.n_cols();
435 (0..n_cols).step_by(nc).map(move |j0| {
436 let j1 = usize::min(j0 + nc, n_cols);
437
438 (Range {start: j0, end: j1}, self.col_panel(j0..j1))
439 })
440 }
441
442 /// checks whether matrix n_cols equals given length
443 pub fn has_len_equal_to_n_cols(&self, length: usize) -> bool {
444 self.dimension.1 == length
445 }
446
447 /// checks whether matrix n_rows equals given length
448 pub fn has_len_equal_to_n_rows(&self, length: usize) -> bool {
449 self.dimension.0 == length
450 }
451
452 /// used for calling routines over and over again
453 /// on the stored internal mutable slice
454 ///
455 /// borrows self mutably
456 ///
457 /// example:
458 /// ```
459 /// use lak::l2::ger;
460 /// use lak::types::{MatMut, VecRef};
461 ///
462 /// let x = [1.0, 2.0];
463 /// let y = [3.0, 4.0];
464 /// let mut a = [0.0; 4];
465 ///
466 /// let x = VecRef::new(&x);
467 /// let y = VecRef::new(&y);
468 /// let mut a = MatMut::new(&mut a, (2, 2));
469 ///
470 /// ger(1.0, a.reborrow(), x, y);
471 /// ger(1.0, a.reborrow(), x, y);
472 /// ```
473 pub fn reborrow(&mut self) -> MatMut<'_, T> {
474 let (n_rows, n_cols) = self.dimension();
475 MatMut::new(self.as_slice_mut(), (n_rows, n_cols))
476 }
477}
478
479
480/// asserts two [VecRef]/[VecMut] have equal length buffers
481#[macro_export]
482macro_rules! assert_length_eq {
483 ($x:expr, $y:expr) => {
484 assert!(
485 $x.has_equal_length($y.length()),
486 "number of elements must be equal"
487 );
488 };
489}
490
491
492/// asserts the length of a [VecRef]/[VecMut] buffer
493/// equals the number of cols in a [MatRef]/[MatMut]
494///
495/// a.assert_length_eq_n_cols(x);
496#[macro_export]
497macro_rules! assert_length_eq_n_cols {
498 ($a:expr, $x:expr) => {
499 assert!(
500 $a.has_len_equal_to_n_cols($x.length()),
501 "number of cols in a does not match length of x"
502 );
503 };
504}
505
506/// asserts the length of a [VecRef]/[VecMut] buffer
507/// equals the number of rows in a [MatRef]/[MatMut]
508///
509/// a.assert_length_eq_n_rows(x);
510#[macro_export]
511macro_rules! assert_length_eq_n_rows {
512 ($a:expr, $x:expr) => {
513 assert!(
514 $a.has_len_equal_to_n_rows($x.length()),
515 "number of rows in a does not match length of x"
516 );
517 };
518}
519