1use scalarff::FieldElement;
2
3use super::vector::Vector;
4
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7#[derive(Clone, PartialEq)]
8pub struct Matrix2D<T: FieldElement> {
9 pub dimensions: (usize, usize), pub values: Vec<T>,
11}
12
13impl<T: FieldElement> Matrix2D<T> {
14 pub const JL_PROJECTION_SIZE: usize = 256;
15
16 pub fn new(rows: usize, columns: usize) -> Self {
19 Self {
20 dimensions: (rows, columns),
21 values: vec![T::zero(); rows * columns],
22 }
23 }
24
25 pub fn identity(n: usize) -> Self {
27 let mut values: Vec<T> = Vec::new();
28 for x in 0..n {
29 let mut row = vec![T::zero(); n];
30 row[x] = T::one();
31 values.append(&mut row);
32 }
33 Matrix2D {
34 dimensions: (n, n),
35 values,
36 }
37 }
38
39 pub fn zero(rows: usize, cols: usize) -> Self {
41 Matrix2D {
42 dimensions: (rows, cols),
43 values: vec![T::zero(); rows * cols],
44 }
45 }
46
47 pub fn column(&self, index: usize) -> Vector<T> {
50 if index >= self.dimensions.1 {
51 panic!("attempt to retrieve column outside of matrix dimensions. Requested column {index}, number of columns {}", self.dimensions.1);
52 }
53 let mut out = Vec::new();
54 let (m_rows, m_cols) = self.dimensions;
55 for i in 0..m_rows {
56 let column_element = &self.values[i * m_cols + index];
57 out.push(column_element.clone());
58 }
59 Vector::from_vec(out)
60 }
61
62 pub fn row(&self, index: usize) -> Vector<T> {
65 let (rows, cols) = self.dimensions;
66 if index >= rows {
67 panic!("attempt to retrieve a row outside of matrix dimensions. Requested row {index}, number of rows {rows}");
68 }
69 Vector::from_vec(self.values[index * cols..(index + 1) * cols].to_vec())
70 }
71
72 pub fn split_vertical(&self, m1_height: usize, m2_height: usize) -> (Matrix2D<T>, Matrix2D<T>) {
76 assert_eq!(
77 self.dimensions.0,
78 m1_height + m2_height,
79 "matrix vertical split height mismatch"
80 );
81 let (_, cols) = self.dimensions;
82 let mid_offset = m1_height * cols;
83 (
84 Matrix2D {
85 dimensions: (m1_height, cols),
86 values: self.values[..mid_offset].to_vec(),
87 },
88 Matrix2D {
89 dimensions: (m2_height, cols),
90 values: self.values[mid_offset..].to_vec(),
91 },
92 )
93 }
94
95 pub fn compose_vertical(&self, other: Self) -> Self {
97 assert_eq!(
98 self.dimensions.1, other.dimensions.1,
99 "horizontal size mismatch in vertical composition"
100 );
101 Self {
102 dimensions: (self.dimensions.0 + other.dimensions.0, self.dimensions.1),
103 values: self
104 .values
105 .iter()
106 .chain(other.values.iter())
107 .cloned()
108 .collect(),
109 }
110 }
111
112 pub fn compose_horizontal(&self, other: Self) -> Self {
114 let mut values = vec![];
115 let (m1_rows, m1_cols) = self.dimensions;
116 let (m2_rows, m2_cols) = other.dimensions;
117 assert_eq!(
118 m1_rows, m2_rows,
119 "vertical size mismatch in horizontal composition"
120 );
121 for i in 0..m1_rows {
122 values.append(&mut self.values[i * m1_cols..(i + 1) * m1_cols].to_vec());
123 values.append(&mut other.values[i * m2_cols..(i + 1) * m2_cols].to_vec());
124 }
125 Self {
126 dimensions: (self.dimensions.0, self.dimensions.1 + other.dimensions.1),
127 values,
128 }
129 }
130
131 #[cfg(feature = "rand")]
134 pub fn sample_uniform<R: rand::Rng>(rows: usize, columns: usize, rng: &mut R) -> Self {
135 Self {
136 dimensions: (rows, columns),
137 values: Vector::sample_uniform(rows * columns, rng).to_vec(),
138 }
139 }
140
141 #[cfg(feature = "rand")]
148 pub fn sample_jl<R: rand::Rng>(input_dimension: usize, rng: &mut R) -> Self {
149 let mut values = vec![];
150 for _ in 0..(input_dimension * Self::JL_PROJECTION_SIZE) {
154 let v = rng.gen_range(0..=3);
156 match v {
157 0 => values.push(T::one()),
158 1 => values.push(-T::one()),
159 _ => values.push(T::zero()),
160 }
161 }
162 Self {
163 dimensions: (Self::JL_PROJECTION_SIZE, input_dimension),
164 values,
165 }
166 }
167}
168
169impl<T: FieldElement> std::fmt::Display for Matrix2D<T> {
170 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
171 let (rows, cols) = self.dimensions;
172 writeln!(f, "[")?;
173 for i in 0..rows {
174 write!(f, " [ ")?;
175 for j in 0..cols {
176 write!(f, "{}, ", self.values[i * cols + j])?;
177 }
178 writeln!(f, "],")?;
179 writeln!(f, "]")?;
180 }
181 Ok(())
182 }
183}
184
185impl<T: FieldElement> std::ops::Add for Matrix2D<T> {
186 type Output = Matrix2D<T>;
187
188 fn add(self, other: Matrix2D<T>) -> Matrix2D<T> {
189 assert_eq!(
190 self.dimensions, other.dimensions,
191 "matrix addition dimensions mismatch"
192 );
193 Matrix2D {
194 dimensions: self.dimensions,
195 values: self
196 .values
197 .iter()
198 .zip(other.values.iter())
199 .map(|(a, b)| a.clone() + b.clone())
200 .collect(),
201 }
202 }
203}
204
205impl<T: FieldElement> std::ops::Mul<T> for Matrix2D<T> {
206 type Output = Matrix2D<T>;
207
208 fn mul(self, other: T) -> Matrix2D<T> {
211 Matrix2D {
212 dimensions: self.dimensions,
213 values: self
214 .values
215 .iter()
216 .map(|v| v.clone() * other.clone())
217 .collect(),
218 }
219 }
220}
221
222impl<T: FieldElement> std::ops::Mul<Vector<T>> for Matrix2D<T> {
223 type Output = Vector<T>;
224
225 fn mul(self, other: Vector<T>) -> Vector<T> {
226 let mut out = Vec::new();
227 let (m_rows, m_cols) = self.dimensions;
228 for i in 0..m_rows {
229 let row = self.values[i * m_cols..(i + 1) * m_cols].to_vec();
230
231 out.push(
232 (other.clone() * Vector::from_vec(row))
233 .iter()
234 .fold(T::zero(), |acc, v| acc + v.clone()),
235 );
236 }
237 Vector::from_vec(out)
238 }
239}
240
241#[cfg(test)]
242mod test {
243 use scalarff::BigUint;
244 use scalarff::OxfoiFieldElement;
245
246 use super::Matrix2D;
247
248 #[test]
249 #[cfg(feature = "rand")]
250 fn test_jl_projection() {
251 let input_size = 64;
252 let projection_size = Matrix2D::<OxfoiFieldElement>::JL_PROJECTION_SIZE;
253 for _ in 0..100 {
254 let mut rng = rand::thread_rng();
255 let m = Matrix2D::<OxfoiFieldElement>::sample_jl(input_size, &mut rng);
256 assert_eq!(m.dimensions, (projection_size, input_size));
257 let input = super::Vector::sample_uniform(input_size, &mut rng);
258
259 let root_128_approx = BigUint::from(11u32);
261 let out = m * input.clone();
262 assert_eq!(out.len(), projection_size);
263 assert!(out.norm_l2() < root_128_approx * input.norm_l2());
267 }
268 }
269}