1use super::MatPolynomialRingZq;
12use crate::{
13 error::MathError,
14 integer::PolyOverZ,
15 traits::{CompareBase, MatrixDimensions, MatrixGetEntry, Tensor},
16};
17use flint_sys::{fmpz_poly_mat::fmpz_poly_mat_entry, fq::fq_mul};
18
19impl Tensor for MatPolynomialRingZq {
20 fn tensor_product(&self, other: &Self) -> Self {
49 self.tensor_product_safe(other).unwrap()
50 }
51}
52
53impl MatPolynomialRingZq {
54 pub fn tensor_product_safe(&self, other: &Self) -> Result<Self, MathError> {
84 if !self.compare_base(other) {
85 return Err(self.call_compare_base_error(other).unwrap());
86 }
87
88 let mut out = MatPolynomialRingZq::new(
89 self.get_num_rows() * other.get_num_rows(),
90 self.get_num_columns() * other.get_num_columns(),
91 self.get_mod(),
92 );
93
94 for i in 0..self.get_num_rows() {
95 for j in 0..self.get_num_columns() {
96 let entry: PolyOverZ = unsafe { self.get_entry_unchecked(i, j) };
97
98 if !entry.is_zero() {
99 unsafe { set_matrix_window_mul(&mut out, i, j, entry, other) }
100 }
101 }
102 }
103
104 Ok(out)
105 }
106}
107
108unsafe fn set_matrix_window_mul(
134 out: &mut MatPolynomialRingZq,
135 row_left: i64,
136 column_upper: i64,
137 scalar: PolyOverZ,
138 matrix: &MatPolynomialRingZq,
139) {
140 let columns_other = matrix.get_num_columns();
141 let rows_other = matrix.get_num_rows();
142
143 assert!(row_left >= 0 && row_left + rows_other <= out.get_num_rows());
144 assert!(column_upper >= 0 && column_upper + columns_other <= out.get_num_columns());
145
146 for i_other in 0..rows_other {
147 for j_other in 0..columns_other {
148 unsafe {
149 fq_mul(
150 fmpz_poly_mat_entry(
151 &out.matrix.matrix,
152 row_left * rows_other + i_other,
153 column_upper * columns_other + j_other,
154 ),
155 &scalar.poly,
156 fmpz_poly_mat_entry(&matrix.matrix.matrix, i_other, j_other),
157 matrix.modulus.get_fq_ctx(),
158 )
159 }
160 }
161 }
162}
163
164#[cfg(test)]
165mod test_tensor {
166 use crate::{
167 integer_mod_q::{MatPolynomialRingZq, ModulusPolynomialRingZq},
168 traits::{MatrixDimensions, Tensor},
169 };
170 use std::str::FromStr;
171
172 #[test]
174 fn dimensions_fit() {
175 let mod_poly = ModulusPolynomialRingZq::from_str("3 1 2 3 mod 17").unwrap();
176 let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly);
177 let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly);
178
179 let mat_3 = mat_1.tensor_product(&mat_2);
180
181 assert_eq!(51, mat_3.get_num_rows());
182 assert_eq!(52, mat_3.get_num_columns());
183 }
184
185 #[test]
187 fn identity() {
188 let mod_poly =
189 ModulusPolynomialRingZq::from_str(&format!("3 1 2 3 mod {}", u64::MAX)).unwrap();
190 let identity = MatPolynomialRingZq::identity(2, 2, &mod_poly);
191 let mat_1 = MatPolynomialRingZq::from_str(&format!(
192 "[[1 1, 1 {}, 1 1],[0, 1 {}, 1 -1]] / 3 1 2 3 mod {}",
193 i64::MAX,
194 i64::MIN,
195 u64::MAX
196 ))
197 .unwrap();
198
199 let mat_2 = identity.tensor_product(&mat_1);
200 let mat_3 = mat_1.tensor_product(&identity);
201
202 let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
203 "[[1 1, 1 {}, 1 1, 0, 0, 0],[0, 1 {}, 1 -1, 0, 0, 0],[0, 0, 0, 1 1, 1 {}, 1 1],[0, 0, 0, 0, 1 {}, 1 -1]] / 3 1 2 3 mod {}",
204 i64::MAX,
205 i64::MIN,
206 i64::MAX,
207 i64::MIN,
208 u64::MAX
209 ))
210 .unwrap();
211 let cmp_mat_3 = MatPolynomialRingZq::from_str(&format!(
212 "[[1 1, 0, 1 {}, 0, 1 1, 0],[0, 1 1, 0, 1 {}, 0, 1 1],[0, 0, 1 {}, 0, 1 -1, 0],[0, 0, 0, 1 {}, 0, 1 -1]] / 3 1 2 3 mod {}",
213 i64::MAX,
214 i64::MAX,
215 i64::MIN,
216 i64::MIN,
217 u64::MAX
218 ))
219 .unwrap();
220
221 assert_eq!(cmp_mat_2, mat_2);
222 assert_eq!(cmp_mat_3, mat_3);
223 }
224
225 #[test]
227 fn vector_matrix() {
228 let vector =
229 MatPolynomialRingZq::from_str(&format!("[[1 1],[1 -1]] / 3 1 2 3 mod {}", u64::MAX))
230 .unwrap();
231 let mat_1 = MatPolynomialRingZq::from_str(&format!(
232 "[[1 1, 1 {}, 1 1],[0, 1 {}, 1 -1]] / 3 1 2 3 mod {}",
233 i64::MAX,
234 i64::MAX,
235 u64::MAX
236 ))
237 .unwrap();
238
239 let mat_2 = vector.tensor_product(&mat_1);
240 let mat_3 = mat_1.tensor_product(&vector);
241
242 let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
243 "[[1 1, 1 {}, 1 1],[0, 1 {}, 1 -1],[1 -1, 1 -{}, 1 -1],[0, 1 -{}, 1 1]] / 3 1 2 3 mod {}",
244 i64::MAX,
245 i64::MAX,
246 i64::MAX,
247 i64::MAX,
248 u64::MAX
249 ))
250 .unwrap();
251 let cmp_mat_3 = MatPolynomialRingZq::from_str(&format!(
252 "[[1 1, 1 {}, 1 1],[1 -1, 1 -{}, 1 -1],[0, 1 {}, 1 -1],[0, 1 -{}, 1 1]] / 3 1 2 3 mod {}",
253 i64::MAX,
254 i64::MAX,
255 i64::MAX,
256 i64::MAX,
257 u64::MAX
258 ))
259 .unwrap();
260
261 assert_eq!(cmp_mat_2, mat_2);
262 assert_eq!(cmp_mat_3, mat_3);
263 }
264
265 #[test]
267 fn vector_vector() {
268 let vec_1 =
269 MatPolynomialRingZq::from_str(&format!("[[1 2],[1 1]] / 3 1 2 3 mod {}", u64::MAX))
270 .unwrap();
271 let vec_2 = MatPolynomialRingZq::from_str(&format!(
272 "[[1 {}],[1 {}]] / 3 1 2 3 mod {}",
273 (u64::MAX - 1) / 2,
274 i64::MIN / 2,
275 u64::MAX
276 ))
277 .unwrap();
278
279 let vec_3 = vec_1.tensor_product(&vec_2);
280 let vec_4 = vec_2.tensor_product(&vec_1);
281
282 let cmp_vec_3 = MatPolynomialRingZq::from_str(&format!(
283 "[[1 {}],[1 {}],[1 {}],[1 {}]] / 3 1 2 3 mod {}",
284 u64::MAX - 1,
285 i64::MIN,
286 (u64::MAX - 1) / 2,
287 i64::MIN / 2,
288 u64::MAX
289 ))
290 .unwrap();
291 let cmp_vec_4 = MatPolynomialRingZq::from_str(&format!(
292 "[[1 {}],[1 {}],[1 {}],[1 {}]] / 3 1 2 3 mod {}",
293 u64::MAX - 1,
294 (u64::MAX - 1) / 2,
295 i64::MIN,
296 i64::MIN / 2,
297 u64::MAX
298 ))
299 .unwrap();
300
301 assert_eq!(cmp_vec_3, vec_3);
302 assert_eq!(cmp_vec_4, vec_4);
303 }
304
305 #[test]
307 fn higher_degree() {
308 let higher_degree = MatPolynomialRingZq::from_str(&format!(
309 "[[1 1, 2 0 1, 2 1 1]] / 3 1 2 3 mod {}",
310 u64::MAX
311 ))
312 .unwrap();
313 let mat_1 = MatPolynomialRingZq::from_str(&format!(
314 "[[1 1, 1 {}, 2 1 {}]] / 3 1 2 3 mod {}",
315 i64::MAX,
316 i64::MIN,
317 u64::MAX
318 ))
319 .unwrap();
320
321 let mat_2 = higher_degree.tensor_product(&mat_1);
322
323 let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
324 "[[1 1, 1 {}, 2 1 {}, 2 0 1, 2 0 {}, 3 0 1 {}, 2 1 1, 2 {} {}, 3 1 {} {}]] / 3 1 2 3 mod {}",
325 i64::MAX,
326 i64::MIN,
327 i64::MAX,
328 i64::MIN,
329 i64::MAX,
330 i64::MAX,
331 i64::MIN + 1,
332 i64::MIN,
333 u64::MAX
334 ))
335 .unwrap();
336
337 assert_eq!(cmp_mat_2, mat_2);
338 }
339
340 #[test]
342 #[should_panic]
343 fn moduli_mismatch_panic() {
344 let mod_poly_1 = ModulusPolynomialRingZq::from_str("3 1 2 3 mod 17").unwrap();
345 let mod_poly_2 = ModulusPolynomialRingZq::from_str("3 1 2 3 mod 16").unwrap();
346 let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly_1);
347 let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly_2);
348
349 let _ = mat_1.tensor_product(&mat_2);
350 }
351
352 #[test]
354 fn moduli_mismatch_error() {
355 let mod_poly_1 = ModulusPolynomialRingZq::from_str("3 1 2 3 mod 17").unwrap();
356 let mod_poly_2 = ModulusPolynomialRingZq::from_str("3 1 2 3 mod 16").unwrap();
357 let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly_1);
358 let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly_2);
359
360 assert!(mat_1.tensor_product_safe(&mat_2).is_err());
361 }
362}