Skip to main content

feanor_math/matrix/
owned.rs

1use std::{alloc::{Allocator, Global}, fmt::{Debug, Formatter, Result}};
2
3use self::submatrix::{AsFirstElement, Submatrix, SubmatrixMut};
4
5use super::*;
6
7///
8/// A matrix that owns its elements.
9/// 
10/// To pass it to algorithms, use the `.data()` and `.data_mut()` functions.
11/// 
12/// # Example
13/// ```rust
14/// #![feature(allocator_api)]
15/// # use std::alloc::*;
16/// # use feanor_math::ring::*;
17/// # use feanor_math::primitive_int::*;
18/// # use feanor_math::matrix::*;
19/// # use feanor_math::algorithms::linsolve::*;
20/// let mut A = OwnedMatrix::identity(2, 2, StaticRing::<i32>::RING);
21/// let mut B = OwnedMatrix::identity(2, 2, StaticRing::<i32>::RING);
22/// let mut C = OwnedMatrix::identity(2, 2, StaticRing::<i32>::RING);
23/// StaticRing::<i32>::RING.get_ring().solve_right(A.data_mut(), B.data_mut(), C.data_mut(), Global).assert_solved();
24/// ```
25/// 
26pub struct OwnedMatrix<T, A: Allocator = Global> {
27    data: Vec<T, A>,
28    col_count: usize,
29    row_count: usize
30}
31
32impl<T> OwnedMatrix<T> {
33
34    ///
35    /// Creates the `row_count x col_count` [`OwnedMatrix`] whose `(i, j)`-th entry
36    /// is the output of the given function on `(i, j)`.
37    /// 
38    pub fn from_fn<F>(row_count: usize, col_count: usize, f: F) -> Self
39        where F: FnMut(usize, usize) -> T
40    {
41        Self::from_fn_in(row_count, col_count, f, Global)
42    }
43    
44    ///
45    /// Creates the `row_count x col_count` zero matrix over the given ring.
46    /// 
47    pub fn zero<R: RingStore>(row_count: usize, col_count: usize, ring: R) -> Self
48        where R::Type: RingBase<Element = T>
49    {
50        Self::zero_in(row_count, col_count, ring, Global)
51    }
52
53    ///
54    /// Creates the `row_count x col_count` identity matrix over the given ring.
55    /// 
56    pub fn identity<R: RingStore>(row_count: usize, col_count: usize, ring: R) -> Self
57        where R::Type: RingBase<Element = T>
58    {
59        Self::identity_in(row_count, col_count, ring, Global)
60    }
61}
62
63impl<T, A: Allocator> OwnedMatrix<T, A> {
64
65    ///
66    /// Creates the `row_count x col_count` [`OwnedMatrix`] matrix, whose entries are
67    /// taken from the given vector, interpreted as a row-major matrix. The number of
68    /// rows is `row_count = data.len() / col_count`.
69    /// 
70    /// If `col_count` is zero, this will panic. If that can happen, consider
71    /// using [`OwnedMatrix::new_with_shape()`].
72    /// 
73    pub fn new(data: Vec<T, A>, col_count: usize) -> Self {
74        let row_count = data.len() / col_count;
75        Self::new_with_shape(data, row_count, col_count)
76    }
77
78    ///
79    /// Creates the `row_count x col_count` [`OwnedMatrix`] matrix, whose entries are
80    /// taken from the given vector, interpreted as a row-major matrix.
81    /// 
82    /// # Example
83    /// ```
84    /// # use feanor_math::matrix::*;
85    /// let matrix = OwnedMatrix::new_with_shape(vec![1, 2, 3, 4, 5, 6], 3, 2);
86    /// assert_eq!(3, *matrix.at(1, 0));
87    /// assert_eq!(6, *matrix.at(2, 1));
88    /// ```
89    /// 
90    pub fn new_with_shape(data: Vec<T, A>, row_count: usize, col_count: usize) -> Self {
91        assert_eq!(row_count * col_count, data.len());
92        Self { data, col_count, row_count }
93    }
94
95    ///
96    /// Creates the `row_count x col_count` [`OwnedMatrix`] whose `(i, j)`-th entry
97    /// is the output of the given function on `(i, j)`.
98    /// 
99    #[stability::unstable(feature = "enable")]
100    pub fn from_fn_in<F>(row_count: usize, col_count: usize, mut f: F, allocator: A) -> Self
101        where F: FnMut(usize, usize) -> T
102    {
103        let mut data = Vec::with_capacity_in(row_count * col_count, allocator);
104        for i in 0..row_count {
105            for j in 0..col_count {
106                data.push(f(i, j));
107            }
108        }
109        return Self::new_with_shape(data, row_count, col_count);
110    }
111
112    ///
113    /// Returns a [`Submatrix`] view on the data of this matrix.
114    /// 
115    pub fn data<'a>(&'a self) -> Submatrix<'a, AsFirstElement<T>, T> {
116        Submatrix::<AsFirstElement<_>, _>::from_1d(&self.data, self.row_count(), self.col_count())
117    }
118
119    ///
120    /// Returns a [`SubmatrixMut`] view on the data of this matrix.
121    /// 
122    pub fn data_mut<'a>(&'a mut self) -> SubmatrixMut<'a, AsFirstElement<T>, T> {
123        let row_count = self.row_count();
124        let col_count = self.col_count();
125        SubmatrixMut::<AsFirstElement<_>, _>::from_1d(&mut self.data, row_count, col_count)
126    }
127
128    ///
129    /// Returns a reference to the `(i, j)`-th entry of this matrix.
130    /// 
131    pub fn at(&self, i: usize, j: usize) -> &T {
132        &self.data[i * self.col_count + j]
133    }
134
135    ///
136    /// Returns a mutable reference to the `(i, j)`-th entry of this matrix.
137    /// 
138    pub fn at_mut(&mut self, i: usize, j: usize) -> &mut T {
139        &mut self.data[i * self.col_count + j]
140    }
141
142    ///
143    /// Returns the number of rows of this matrix.
144    /// 
145    pub fn row_count(&self) -> usize {
146        self.row_count
147    }
148    
149    ////
150    /// Returns the number of columns of this matrix.
151    /// 
152    pub fn col_count(&self) -> usize {
153        self.col_count
154    }
155
156    ///
157    /// Creates the `row_count x col_count` zero matrix over the given ring.
158    /// 
159    #[stability::unstable(feature = "enable")]
160    pub fn zero_in<R: RingStore>(row_count: usize, col_count: usize, ring: R, allocator: A) -> Self
161        where R::Type: RingBase<Element = T>
162    {
163        let mut result = Vec::with_capacity_in(row_count * col_count, allocator);
164        for _ in 0..row_count {
165            for _ in 0..col_count {
166                result.push(ring.zero());
167            }
168        }
169        return Self::new_with_shape(result, row_count, col_count);
170    }
171
172    ///
173    /// Creates the `row_count x col_count` identity matrix over the given ring.
174    /// 
175    #[stability::unstable(feature = "enable")]
176    pub fn identity_in<R: RingStore>(row_count: usize, col_count: usize, ring: R, allocator: A) -> Self
177        where R::Type: RingBase<Element = T>
178    {
179        let mut result = Vec::with_capacity_in(row_count * col_count, allocator);
180        for i in 0..row_count {
181            for j in 0..col_count {
182                if i != j {
183                    result.push(ring.zero());
184                } else {
185                    result.push(ring.one());
186                }
187            }
188        }
189        return Self::new_with_shape(result, row_count, col_count);
190    }
191
192    #[stability::unstable(feature = "enable")]
193    pub fn clone_matrix<R: RingStore>(&self, ring: R) -> Self
194        where R::Type: RingBase<Element = T>,
195            A: Clone
196    {
197        let mut result = Vec::with_capacity_in(self.row_count() * self.col_count(), self.data.allocator().clone());
198        for i in 0..self.row_count() {
199            for j in 0..self.col_count() {
200                result.push(ring.clone_el(self.at(i, j)));
201            }
202        }
203        return Self::new_with_shape(result, self.row_count(), self.col_count());
204    }
205
206    #[stability::unstable(feature = "enable")]
207    pub fn set_row_count<F>(&mut self, new_count: usize, new_entries: F)
208        where F: FnMut() -> T
209    {
210        self.data.resize_with(new_count * self.col_count(), new_entries);
211    }
212}
213
214impl<T: Debug, A: Allocator> Debug for OwnedMatrix<T, A> {
215
216    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
217        self.data().fmt(f)
218    }
219}
220
221#[cfg(test)]
222use crate::primitive_int::*;
223
224#[test]
225fn test_zero_col_matrix() {
226    let A: OwnedMatrix<i64> = OwnedMatrix::new_with_shape(Vec::new(), 10, 0);
227    assert_eq!(0, A.col_count());
228    assert_eq!(10, A.row_count());
229
230    let B: OwnedMatrix<i64> = OwnedMatrix::zero(11, 0, StaticRing::<i64>::RING);
231    assert_eq!(0, B.col_count());
232    assert_eq!(11, B.row_count());
233}