1use super::MatPolynomialRingZq;
12use crate::{
13 error::MathError,
14 integer::PolyOverZ,
15 traits::{CompareBase, MatrixDimensions, MatrixGetEntry, Tensor},
16};
17use flint_sys::{fmpz_poly::fmpz_poly_mul, fmpz_poly_mat::fmpz_poly_mat_entry};
18
19impl Tensor for MatPolynomialRingZq {
20 fn tensor_product(&self, other: &Self) -> Self {
49 self.tensor_product_safe(other).unwrap()
50 }
51}
52
53impl MatPolynomialRingZq {
54 pub fn tensor_product_safe(&self, other: &Self) -> Result<Self, MathError> {
84 if !self.compare_base(other) {
85 return Err(self.call_compare_base_error(other).unwrap());
86 }
87
88 let mut out = MatPolynomialRingZq::new(
89 self.get_num_rows() * other.get_num_rows(),
90 self.get_num_columns() * other.get_num_columns(),
91 self.get_mod(),
92 );
93
94 for i in 0..self.get_num_rows() {
95 for j in 0..self.get_num_columns() {
96 let entry: PolyOverZ = unsafe { self.get_entry_unchecked(i, j) };
97
98 if !entry.is_zero() {
99 unsafe { set_matrix_window_mul(&mut out, i, j, entry, other) }
100 }
101 }
102 }
103
104 Ok(out)
105 }
106}
107
108unsafe fn set_matrix_window_mul(
134 out: &mut MatPolynomialRingZq,
135 row_left: i64,
136 column_upper: i64,
137 scalar: PolyOverZ,
138 matrix: &MatPolynomialRingZq,
139) {
140 let columns_other = matrix.get_num_columns();
141 let rows_other = matrix.get_num_rows();
142
143 assert!(row_left >= 0 && row_left + rows_other <= out.get_num_rows());
144 assert!(column_upper >= 0 && column_upper + columns_other <= out.get_num_columns());
145
146 for i_other in 0..rows_other {
147 for j_other in 0..columns_other {
148 unsafe {
149 fmpz_poly_mul(
150 fmpz_poly_mat_entry(
151 &out.matrix.matrix,
152 row_left * rows_other + i_other,
153 column_upper * columns_other + j_other,
154 ),
155 &scalar.poly,
156 fmpz_poly_mat_entry(&matrix.matrix.matrix, i_other, j_other),
157 );
158 out.reduce_entry(
159 row_left * rows_other + i_other,
160 column_upper * columns_other + j_other,
161 );
162 }
163 }
164 }
165}
166
167#[cfg(test)]
168mod test_tensor {
169 use crate::{
170 integer_mod_q::{MatPolynomialRingZq, ModulusPolynomialRingZq},
171 traits::{MatrixDimensions, Tensor},
172 };
173 use std::str::FromStr;
174
175 #[test]
177 fn dimensions_fit() {
178 let mod_poly = ModulusPolynomialRingZq::from_str("3 1 2 1 mod 17").unwrap();
179 let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly);
180 let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly);
181
182 let mat_3 = mat_1.tensor_product(&mat_2);
183
184 assert_eq!(51, mat_3.get_num_rows());
185 assert_eq!(52, mat_3.get_num_columns());
186 }
187
188 #[test]
190 fn identity() {
191 let mod_poly =
192 ModulusPolynomialRingZq::from_str(&format!("3 1 2 1 mod {}", u64::MAX)).unwrap();
193 let identity = MatPolynomialRingZq::identity(2, 2, &mod_poly);
194 let mat_1 = MatPolynomialRingZq::from_str(&format!(
195 "[[1 1, 1 {}, 1 1],[0, 1 {}, 1 -1]] / 3 1 2 1 mod {}",
196 i64::MAX,
197 i64::MIN,
198 u64::MAX
199 ))
200 .unwrap();
201
202 let mat_2 = identity.tensor_product(&mat_1);
203 let mat_3 = mat_1.tensor_product(&identity);
204
205 let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
206 "[[1 1, 1 {}, 1 1, 0, 0, 0],[0, 1 {}, 1 -1, 0, 0, 0],[0, 0, 0, 1 1, 1 {}, 1 1],[0, 0, 0, 0, 1 {}, 1 -1]] / 3 1 2 1 mod {}",
207 i64::MAX,
208 i64::MIN,
209 i64::MAX,
210 i64::MIN,
211 u64::MAX
212 ))
213 .unwrap();
214 let cmp_mat_3 = MatPolynomialRingZq::from_str(&format!(
215 "[[1 1, 0, 1 {}, 0, 1 1, 0],[0, 1 1, 0, 1 {}, 0, 1 1],[0, 0, 1 {}, 0, 1 -1, 0],[0, 0, 0, 1 {}, 0, 1 -1]] / 3 1 2 1 mod {}",
216 i64::MAX,
217 i64::MAX,
218 i64::MIN,
219 i64::MIN,
220 u64::MAX
221 ))
222 .unwrap();
223
224 assert_eq!(cmp_mat_2, mat_2);
225 assert_eq!(cmp_mat_3, mat_3);
226 }
227
228 #[test]
230 fn vector_matrix() {
231 let vector =
232 MatPolynomialRingZq::from_str(&format!("[[1 1],[1 -1]] / 3 1 2 1 mod {}", u64::MAX))
233 .unwrap();
234 let mat_1 = MatPolynomialRingZq::from_str(&format!(
235 "[[1 1, 1 {}, 1 1],[0, 1 {}, 1 -1]] / 3 1 2 1 mod {}",
236 i64::MAX,
237 i64::MAX,
238 u64::MAX
239 ))
240 .unwrap();
241
242 let mat_2 = vector.tensor_product(&mat_1);
243 let mat_3 = mat_1.tensor_product(&vector);
244
245 let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
246 "[[1 1, 1 {}, 1 1],[0, 1 {}, 1 -1],[1 -1, 1 -{}, 1 -1],[0, 1 -{}, 1 1]] / 3 1 2 1 mod {}",
247 i64::MAX,
248 i64::MAX,
249 i64::MAX,
250 i64::MAX,
251 u64::MAX
252 ))
253 .unwrap();
254 let cmp_mat_3 = MatPolynomialRingZq::from_str(&format!(
255 "[[1 1, 1 {}, 1 1],[1 -1, 1 -{}, 1 -1],[0, 1 {}, 1 -1],[0, 1 -{}, 1 1]] / 3 1 2 1 mod {}",
256 i64::MAX,
257 i64::MAX,
258 i64::MAX,
259 i64::MAX,
260 u64::MAX
261 ))
262 .unwrap();
263
264 assert_eq!(cmp_mat_2, mat_2);
265 assert_eq!(cmp_mat_3, mat_3);
266 }
267
268 #[test]
270 fn vector_vector() {
271 let vec_1 =
272 MatPolynomialRingZq::from_str(&format!("[[1 2],[1 1]] / 3 1 2 1 mod {}", u64::MAX))
273 .unwrap();
274 let vec_2 = MatPolynomialRingZq::from_str(&format!(
275 "[[1 {}],[1 {}]] / 3 1 2 1 mod {}",
276 (u64::MAX - 1) / 2,
277 i64::MIN / 2,
278 u64::MAX
279 ))
280 .unwrap();
281
282 let vec_3 = vec_1.tensor_product(&vec_2);
283 let vec_4 = vec_2.tensor_product(&vec_1);
284
285 let cmp_vec_3 = MatPolynomialRingZq::from_str(&format!(
286 "[[1 {}],[1 {}],[1 {}],[1 {}]] / 3 1 2 1 mod {}",
287 u64::MAX - 1,
288 i64::MIN,
289 (u64::MAX - 1) / 2,
290 i64::MIN / 2,
291 u64::MAX
292 ))
293 .unwrap();
294 let cmp_vec_4 = MatPolynomialRingZq::from_str(&format!(
295 "[[1 {}],[1 {}],[1 {}],[1 {}]] / 3 1 2 1 mod {}",
296 u64::MAX - 1,
297 (u64::MAX - 1) / 2,
298 i64::MIN,
299 i64::MIN / 2,
300 u64::MAX
301 ))
302 .unwrap();
303
304 assert_eq!(cmp_vec_3, vec_3);
305 assert_eq!(cmp_vec_4, vec_4);
306 }
307
308 #[test]
310 fn higher_degree() {
311 let higher_degree = MatPolynomialRingZq::from_str(&format!(
312 "[[1 1, 2 0 1, 2 1 1]] / 3 1 2 1 mod {}",
313 u64::MAX
314 ))
315 .unwrap();
316 let mat_1 = MatPolynomialRingZq::from_str(&format!(
317 "[[1 1, 1 {}, 2 1 {}]] / 3 1 2 1 mod {}",
318 i64::MAX,
319 i64::MIN,
320 u64::MAX
321 ))
322 .unwrap();
323
324 let mat_2 = higher_degree.tensor_product(&mat_1);
325
326 let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
327 "[[1 1, 1 {}, 2 1 {}, 2 0 1, 2 0 {}, 3 0 1 {}, 2 1 1, 2 {} {}, 3 1 {} {}]] / 3 1 2 1 mod {}",
328 i64::MAX,
329 i64::MIN,
330 i64::MAX,
331 i64::MIN,
332 i64::MAX,
333 i64::MAX,
334 i64::MIN + 1,
335 i64::MIN,
336 u64::MAX
337 ))
338 .unwrap();
339
340 assert_eq!(cmp_mat_2, mat_2);
341 }
342
343 #[test]
345 #[should_panic]
346 fn moduli_mismatch_panic() {
347 let mod_poly_1 = ModulusPolynomialRingZq::from_str("3 1 2 1 mod 17").unwrap();
348 let mod_poly_2 = ModulusPolynomialRingZq::from_str("3 1 2 1 mod 16").unwrap();
349 let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly_1);
350 let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly_2);
351
352 let _ = mat_1.tensor_product(&mat_2);
353 }
354
355 #[test]
357 fn moduli_mismatch_error() {
358 let mod_poly_1 = ModulusPolynomialRingZq::from_str("3 1 2 1 mod 17").unwrap();
359 let mod_poly_2 = ModulusPolynomialRingZq::from_str("3 1 2 1 mod 16").unwrap();
360 let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly_1);
361 let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly_2);
362
363 assert!(mat_1.tensor_product_safe(&mat_2).is_err());
364 }
365}