easy_ml/matrices/views/
partitions.rs

1#[allow(unused)] // used in doc links
2use crate::matrices::Matrix;
3use crate::matrices::views::{DataLayout, MatrixMut, MatrixRef, MatrixView, NoInteriorMutability};
4use crate::matrices::{Column, Row};
5
6/**
7 * A mutably borrowed part of a matrix.
8 *
9 * Rust's borrow checker does not not permit overlapping exclusive references, so you cannot
10 * simply construct multiple views into a [Matrix](Matrix) by creating each one sequentially as
11 * you can for immutable/shared references to a Matrix.
12 *
13 * ```
14 * use easy_ml::matrices::Matrix;
15 * use easy_ml::matrices::views::MatrixRange;
16 * let matrix = Matrix::row(vec![1, 2, 3]);
17 * let one = MatrixRange::from(&matrix, 0..1, 0..1);
18 * let two = MatrixRange::from(&matrix, 0..1, 1..2);
19 * let three = MatrixRange::from(&matrix, 0..1, 2..3);
20 * let four = MatrixRange::from(&matrix, 0..1, 0..3);
21 * ```
22 *
23 * MatrixPart instead holds only a mutable reference to a slice into a Matrix's buffer. It does
24 * not borrow the entire Matrix, and thus is used as the container for Matrix APIs which partition
25 * a Matrix into multiple non overlapping parts. The Matrix can then be independently mutated
26 * by each of the MatrixParts.
27 *
28 * See [`Matrix::partition_quadrants`](Matrix::partition_quadrants)
29 */
30#[derive(Debug)]
31pub struct MatrixPart<'source, T> {
32    data: Vec<&'source mut [T]>,
33    rows: Row,
34    columns: Column,
35}
36
37impl<'a, T> MatrixPart<'a, T> {
38    pub(crate) fn new(data: Vec<&'a mut [T]>, rows: Row, columns: Column) -> MatrixPart<'a, T> {
39        MatrixPart {
40            data,
41            rows,
42            columns,
43        }
44    }
45}
46
47// # Safety
48//
49// We don't implement interior mutability and we can't be resized anyway since our
50// buffer is not owned.
51/**
52 * A MatrixPart implements MatrixRef.
53 */
54unsafe impl<'a, T> MatrixRef<T> for MatrixPart<'a, T> {
55    fn try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
56        if row >= self.rows || column >= self.columns {
57            return None;
58        }
59        Some(&self.data[row][column])
60    }
61
62    fn view_rows(&self) -> Row {
63        self.rows
64    }
65
66    fn view_columns(&self) -> Column {
67        self.columns
68    }
69
70    unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &T {
71        unsafe { self.data.get_unchecked(row).get_unchecked(column) }
72    }
73
74    fn data_layout(&self) -> DataLayout {
75        DataLayout::RowMajor
76    }
77}
78
79// # Safety
80//
81// We don't implement interior mutability and we can't be resized anyway since our
82// buffer is not owned.
83/**
84 * A MatrixPart implements MatrixMut.
85 */
86unsafe impl<'a, T> MatrixMut<T> for MatrixPart<'a, T> {
87    fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
88        if row >= self.rows || column >= self.columns {
89            return None;
90        }
91        Some(&mut self.data[row][column])
92    }
93
94    unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut T {
95        unsafe { self.data.get_unchecked_mut(row).get_unchecked_mut(column) }
96    }
97}
98
99// # Safety
100//
101// We don't implement interior mutability and we can't be resized anyway since our
102// buffer is not owned.
103/**
104 * A MatrixPart implements NoInteriorMutability.
105 */
106unsafe impl<'a, T> NoInteriorMutability for MatrixPart<'a, T> {}
107
108/**
109 * Four [parts](MatrixPart) of a Matrix which can be mutated individually.
110 *
111 * See [`Matrix::partition_quadrants`](crate::matrices::Matrix::partition_quadrants).
112 */
113#[derive(Debug)]
114pub struct MatrixQuadrants<'source, T> {
115    pub top_left: MatrixView<T, MatrixPart<'source, T>>,
116    pub top_right: MatrixView<T, MatrixPart<'source, T>>,
117    pub bottom_left: MatrixView<T, MatrixPart<'source, T>>,
118    pub bottom_right: MatrixView<T, MatrixPart<'source, T>>,
119}
120
121impl<'a, T> std::fmt::Display for MatrixQuadrants<'a, T>
122where
123    T: std::fmt::Display,
124{
125    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
126        write!(
127            f,
128            "Top Left:\n{}\nTop Right:\n{}\nBottom Left:\n{}\nBottom Right:\n{}\n",
129            self.top_left, self.top_right, self.bottom_left, self.bottom_right
130        )
131    }
132}