ring_math/
matrix2d.rs

1use scalarff::FieldElement;
2
3use super::vector::Vector;
4
5/// A two dimensional matrix implementation
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7#[derive(Clone, PartialEq)]
8pub struct Matrix2D<T: FieldElement> {
9    pub dimensions: (usize, usize), // (rows, cols)
10    pub values: Vec<T>,
11}
12
13impl<T: FieldElement> Matrix2D<T> {
14    pub const JL_PROJECTION_SIZE: usize = 256;
15
16    /// Create a new 2 dimensional matrix of specified
17    /// rows and columns
18    pub fn new(rows: usize, columns: usize) -> Self {
19        Self {
20            dimensions: (rows, columns),
21            values: vec![T::zero(); rows * columns],
22        }
23    }
24
25    /// Return an identity matrix of size `n`
26    pub fn identity(n: usize) -> Self {
27        let mut values: Vec<T> = Vec::new();
28        for x in 0..n {
29            let mut row = vec![T::zero(); n];
30            row[x] = T::one();
31            values.append(&mut row);
32        }
33        Matrix2D {
34            dimensions: (n, n),
35            values,
36        }
37    }
38
39    /// Return a zero matrix of the specified dimensions
40    pub fn zero(rows: usize, cols: usize) -> Self {
41        Matrix2D {
42            dimensions: (rows, cols),
43            values: vec![T::zero(); rows * cols],
44        }
45    }
46
47    /// Retrieve a column by index. Panics if the index is greater than
48    /// or equal to the number of columns.
49    pub fn column(&self, index: usize) -> Vector<T> {
50        if index >= self.dimensions.1 {
51            panic!("attempt to retrieve column outside of matrix dimensions. Requested column {index}, number of columns {}", self.dimensions.1);
52        }
53        let mut out = Vec::new();
54        let (m_rows, m_cols) = self.dimensions;
55        for i in 0..m_rows {
56            let column_element = &self.values[i * m_cols + index];
57            out.push(column_element.clone());
58        }
59        Vector::from_vec(out)
60    }
61
62    /// Retrieve a row by index. Panics if the index is greater than
63    /// or equal to the number of rows.
64    pub fn row(&self, index: usize) -> Vector<T> {
65        let (rows, cols) = self.dimensions;
66        if index >= rows {
67            panic!("attempt to retrieve a row outside of matrix dimensions. Requested row {index}, number of rows {rows}");
68        }
69        Vector::from_vec(self.values[index * cols..(index + 1) * cols].to_vec())
70    }
71
72    /// Take the matrix and split it into 2 matrices vertically.
73    /// e.g. take the first m1_height rows and return them as a matrix,
74    /// and return the remaining rows as the m2 matrix.
75    pub fn split_vertical(&self, m1_height: usize, m2_height: usize) -> (Matrix2D<T>, Matrix2D<T>) {
76        assert_eq!(
77            self.dimensions.0,
78            m1_height + m2_height,
79            "matrix vertical split height mismatch"
80        );
81        let (_, cols) = self.dimensions;
82        let mid_offset = m1_height * cols;
83        (
84            Matrix2D {
85                dimensions: (m1_height, cols),
86                values: self.values[..mid_offset].to_vec(),
87            },
88            Matrix2D {
89                dimensions: (m2_height, cols),
90                values: self.values[mid_offset..].to_vec(),
91            },
92        )
93    }
94
95    /// Compose the matrix self with another matrix vertically.
96    pub fn compose_vertical(&self, other: Self) -> Self {
97        assert_eq!(
98            self.dimensions.1, other.dimensions.1,
99            "horizontal size mismatch in vertical composition"
100        );
101        Self {
102            dimensions: (self.dimensions.0 + other.dimensions.0, self.dimensions.1),
103            values: self
104                .values
105                .iter()
106                .chain(other.values.iter())
107                .cloned()
108                .collect(),
109        }
110    }
111
112    /// Compose the matrix self with another matrix horizontally.
113    pub fn compose_horizontal(&self, other: Self) -> Self {
114        let mut values = vec![];
115        let (m1_rows, m1_cols) = self.dimensions;
116        let (m2_rows, m2_cols) = other.dimensions;
117        assert_eq!(
118            m1_rows, m2_rows,
119            "vertical size mismatch in horizontal composition"
120        );
121        for i in 0..m1_rows {
122            values.append(&mut self.values[i * m1_cols..(i + 1) * m1_cols].to_vec());
123            values.append(&mut other.values[i * m2_cols..(i + 1) * m2_cols].to_vec());
124        }
125        Self {
126            dimensions: (self.dimensions.0, self.dimensions.1 + other.dimensions.1),
127            values,
128        }
129    }
130
131    /// Sample a uniform random matrix of the specified dimensions
132    /// from the underlying field.
133    #[cfg(feature = "rand")]
134    pub fn sample_uniform<R: rand::Rng>(rows: usize, columns: usize, rng: &mut R) -> Self {
135        Self {
136            dimensions: (rows, columns),
137            values: Vector::sample_uniform(rows * columns, rng).to_vec(),
138        }
139    }
140
141    /// Build a johnson-lindenstrauss projection matrix
142    /// with an input vector size of `input_dimension`.
143    /// Returns a matrix of dimension `Matrix2d::JL_PROJECTION_SIZE x input_dimension`.
144    ///
145    /// Implemented as defined in [LaBRADOR](https://eprint.iacr.org/2022/1341.pdf)
146    /// section 4 (bottom of page 9).
147    #[cfg(feature = "rand")]
148    pub fn sample_jl<R: rand::Rng>(input_dimension: usize, rng: &mut R) -> Self {
149        let mut values = vec![];
150        // the matrix needs to be sampled randomly with
151        // each element being 0 with probabiltiy 1/2,
152        // 1 with probability 1/4 and -1 with probability 1/4
153        for _ in 0..(input_dimension * Self::JL_PROJECTION_SIZE) {
154            // TODO: don't fork on this logic
155            let v = rng.gen_range(0..=3);
156            match v {
157                0 => values.push(T::one()),
158                1 => values.push(-T::one()),
159                _ => values.push(T::zero()),
160            }
161        }
162        Self {
163            dimensions: (Self::JL_PROJECTION_SIZE, input_dimension),
164            values,
165        }
166    }
167}
168
169impl<T: FieldElement> std::fmt::Display for Matrix2D<T> {
170    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
171        let (rows, cols) = self.dimensions;
172        writeln!(f, "[")?;
173        for i in 0..rows {
174            write!(f, "  [ ")?;
175            for j in 0..cols {
176                write!(f, "{}, ", self.values[i * cols + j])?;
177            }
178            writeln!(f, "],")?;
179            writeln!(f, "]")?;
180        }
181        Ok(())
182    }
183}
184
185impl<T: FieldElement> std::ops::Add for Matrix2D<T> {
186    type Output = Matrix2D<T>;
187
188    fn add(self, other: Matrix2D<T>) -> Matrix2D<T> {
189        assert_eq!(
190            self.dimensions, other.dimensions,
191            "matrix addition dimensions mismatch"
192        );
193        Matrix2D {
194            dimensions: self.dimensions,
195            values: self
196                .values
197                .iter()
198                .zip(other.values.iter())
199                .map(|(a, b)| a.clone() + b.clone())
200                .collect(),
201        }
202    }
203}
204
205impl<T: FieldElement> std::ops::Mul<T> for Matrix2D<T> {
206    type Output = Matrix2D<T>;
207
208    /// We'll assume any provided vector is a column vector and
209    /// multiply column-wise by the matrix.
210    fn mul(self, other: T) -> Matrix2D<T> {
211        Matrix2D {
212            dimensions: self.dimensions,
213            values: self
214                .values
215                .iter()
216                .map(|v| v.clone() * other.clone())
217                .collect(),
218        }
219    }
220}
221
222impl<T: FieldElement> std::ops::Mul<Vector<T>> for Matrix2D<T> {
223    type Output = Vector<T>;
224
225    fn mul(self, other: Vector<T>) -> Vector<T> {
226        let mut out = Vec::new();
227        let (m_rows, m_cols) = self.dimensions;
228        for i in 0..m_rows {
229            let row = self.values[i * m_cols..(i + 1) * m_cols].to_vec();
230
231            out.push(
232                (other.clone() * Vector::from_vec(row))
233                    .iter()
234                    .fold(T::zero(), |acc, v| acc + v.clone()),
235            );
236        }
237        Vector::from_vec(out)
238    }
239}
240
241#[cfg(test)]
242mod test {
243    use scalarff::BigUint;
244    use scalarff::OxfoiFieldElement;
245
246    use super::Matrix2D;
247
248    #[test]
249    #[cfg(feature = "rand")]
250    fn test_jl_projection() {
251        let input_size = 64;
252        let projection_size = Matrix2D::<OxfoiFieldElement>::JL_PROJECTION_SIZE;
253        for _ in 0..100 {
254            let mut rng = rand::thread_rng();
255            let m = Matrix2D::<OxfoiFieldElement>::sample_jl(input_size, &mut rng);
256            assert_eq!(m.dimensions, (projection_size, input_size));
257            let input = super::Vector::sample_uniform(input_size, &mut rng);
258
259            // the floored value of sqrt(128)
260            let root_128_approx = BigUint::from(11u32);
261            let out = m * input.clone();
262            assert_eq!(out.len(), projection_size);
263            // we'll then check the l2 norm of the matrix multiplied
264            // by the input vector
265            // println!("{} {}", out.norm_l2(), root_128_approx * input.norm_l2());
266            assert!(out.norm_l2() < root_128_approx * input.norm_l2());
267        }
268    }
269}