qfall_math/integer_mod_q/mat_polynomial_ring_zq/
tensor.rs

1// Copyright © 2025 Marcel Luca Schmidt
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module contains the implementation of the `tensor` product.
10
11use super::MatPolynomialRingZq;
12use crate::{
13    error::MathError,
14    integer::PolyOverZ,
15    traits::{CompareBase, MatrixDimensions, MatrixGetEntry, Tensor},
16};
17use flint_sys::{fmpz_poly_mat::fmpz_poly_mat_entry, fq::fq_mul};
18
19impl Tensor for MatPolynomialRingZq {
20    /// Computes the tensor product of `self` with `other`.
21    ///
22    /// Parameters:
23    /// - `other`: the value with which the tensor product is computed.
24    ///
25    /// Returns the tensor product of `self` with `other`.
26    ///
27    /// # Examples
28    /// ```
29    /// use qfall_math::integer_mod_q::MatPolynomialRingZq;
30    /// use qfall_math::traits::Tensor;
31    /// use std::str::FromStr;
32    ///
33    /// let mat_1 = MatPolynomialRingZq::from_str("[[1  1, 2  1 1]] / 3  1 2 3 mod 17").unwrap();
34    /// let mat_2 = MatPolynomialRingZq::from_str("[[1  1, 1  2]] / 3  1 2 3 mod 17").unwrap();
35    ///
36    /// let mat_ab = mat_1.tensor_product(&mat_2);
37    /// let mat_ba = mat_2.tensor_product(&mat_1);
38    ///
39    /// let res_ab = "[[1  1, 1  2, 2  1 1, 2  2 2]] / 3  1 2 3 mod 17";
40    /// let res_ba = "[[1  1, 2  1 1, 1  2, 2  2 2]] / 3  1 2 3 mod 17";
41    /// assert_eq!(mat_ab, MatPolynomialRingZq::from_str(res_ab).unwrap());
42    /// assert_eq!(mat_ba, MatPolynomialRingZq::from_str(res_ba).unwrap());
43    /// ```
44    ///
45    /// # Panics ...
46    /// - if the moduli of both matrices mismatch.
47    ///   Use [`tensor_product_safe`](crate::integer_mod_q::MatZq::tensor_product_safe) to get an error instead.
48    fn tensor_product(&self, other: &Self) -> Self {
49        self.tensor_product_safe(other).unwrap()
50    }
51}
52
53impl MatPolynomialRingZq {
54    /// Computes the tensor product of `self` with `other`.
55    ///
56    /// Parameters:
57    /// - `other`: the value with which the tensor product is computed.
58    ///
59    /// Returns the tensor product of `self` with `other` or an error if the
60    /// moduli of the provided matrices mismatch.
61    ///
62    /// # Examples
63    /// ```
64    /// use qfall_math::integer_mod_q::MatPolynomialRingZq;
65    /// use std::str::FromStr;
66    ///
67    /// let mat_1 = MatPolynomialRingZq::from_str("[[1  1, 2  1 1]] / 3  1 2 3 mod 17").unwrap();
68    /// let mat_2 = MatPolynomialRingZq::from_str("[[1  1, 1  2]] / 3  1 2 3 mod 17").unwrap();
69    ///
70    /// let mat_ab = mat_1.tensor_product_safe(&mat_2).unwrap();
71    /// let mat_ba = mat_2.tensor_product_safe(&mat_1).unwrap();
72    ///
73    /// let res_ab = "[[1  1, 1  2, 2  1 1, 2  2 2]] / 3  1 2 3 mod 17";
74    /// let res_ba = "[[1  1, 2  1 1, 1  2, 2  2 2]] / 3  1 2 3 mod 17";
75    /// assert_eq!(mat_ab, MatPolynomialRingZq::from_str(res_ab).unwrap());
76    /// assert_eq!(mat_ba, MatPolynomialRingZq::from_str(res_ba).unwrap());
77    /// ```
78    ///
79    /// # Errors and Failures
80    /// - Returns a [`MathError`] of type
81    ///   [`MismatchingModulus`](MathError::MismatchingModulus) if the
82    ///   moduli of the provided matrices mismatch.
83    pub fn tensor_product_safe(&self, other: &Self) -> Result<Self, MathError> {
84        if !self.compare_base(other) {
85            return Err(self.call_compare_base_error(other).unwrap());
86        }
87
88        let mut out = MatPolynomialRingZq::new(
89            self.get_num_rows() * other.get_num_rows(),
90            self.get_num_columns() * other.get_num_columns(),
91            self.get_mod(),
92        );
93
94        for i in 0..self.get_num_rows() {
95            for j in 0..self.get_num_columns() {
96                let entry: PolyOverZ = unsafe { self.get_entry_unchecked(i, j) };
97
98                if !entry.is_zero() {
99                    unsafe { set_matrix_window_mul(&mut out, i, j, entry, other) }
100                }
101            }
102        }
103
104        Ok(out)
105    }
106}
107
108/// This function sets a specific window of the provided matrix `out`
109/// according to the `scalar` multiple of `matrix`.
110///
111/// Sets the entries
112/// `[i*rows_other, j*columns_other]` up till `[row_left*(row_other +1), column_upper*(columns_other + 1)]`
113///
114/// Parameters:
115/// - `out`: the matrix in which the result is saved
116/// - `row_left`: defines the leftmost row of the set window
117/// - `column_upper`: defines the highest column of the set window
118/// - `scalar`: defines the value with which the part of the tensor product
119///   is calculated
120/// - `matrix`: the matrix with which the scalar is multiplied
121///   before setting the entries in `out`
122///
123/// Implicitly sets the entries of the matrix according to the definition
124/// of the tensor product.
125///
126/// # Security
127/// This function accesses memory directly without checking whether the memory is
128/// actually obtained by the matrix out.
129/// This means that this function should only be called wisely.
130/// If `row_left` or `row_upper` together with the length of the matrix exceeds the
131/// range of the matrix other memory could be overwritten.
132/// We included asserts to check whether this occurs, but we advise careful usage.
133unsafe fn set_matrix_window_mul(
134    out: &mut MatPolynomialRingZq,
135    row_left: i64,
136    column_upper: i64,
137    scalar: PolyOverZ,
138    matrix: &MatPolynomialRingZq,
139) {
140    let columns_other = matrix.get_num_columns();
141    let rows_other = matrix.get_num_rows();
142
143    assert!(row_left >= 0 && row_left + rows_other <= out.get_num_rows());
144    assert!(column_upper >= 0 && column_upper + columns_other <= out.get_num_columns());
145
146    for i_other in 0..rows_other {
147        for j_other in 0..columns_other {
148            unsafe {
149                fq_mul(
150                    fmpz_poly_mat_entry(
151                        &out.matrix.matrix,
152                        row_left * rows_other + i_other,
153                        column_upper * columns_other + j_other,
154                    ),
155                    &scalar.poly,
156                    fmpz_poly_mat_entry(&matrix.matrix.matrix, i_other, j_other),
157                    matrix.modulus.get_fq_ctx(),
158                )
159            }
160        }
161    }
162}
163
164#[cfg(test)]
165mod test_tensor {
166    use crate::{
167        integer_mod_q::{MatPolynomialRingZq, ModulusPolynomialRingZq},
168        traits::{MatrixDimensions, Tensor},
169    };
170    use std::str::FromStr;
171
172    /// Ensure that the dimensions of the tensor product are taken over correctly.
173    #[test]
174    fn dimensions_fit() {
175        let mod_poly = ModulusPolynomialRingZq::from_str("3  1 2 3 mod 17").unwrap();
176        let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly);
177        let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly);
178
179        let mat_3 = mat_1.tensor_product(&mat_2);
180
181        assert_eq!(51, mat_3.get_num_rows());
182        assert_eq!(52, mat_3.get_num_columns());
183    }
184
185    /// Ensure that the tensor works correctly with identity.
186    #[test]
187    fn identity() {
188        let mod_poly =
189            ModulusPolynomialRingZq::from_str(&format!("3  1 2 3 mod {}", u64::MAX)).unwrap();
190        let identity = MatPolynomialRingZq::identity(2, 2, &mod_poly);
191        let mat_1 = MatPolynomialRingZq::from_str(&format!(
192            "[[1  1, 1  {}, 1  1],[0, 1  {}, 1  -1]] / 3  1 2 3 mod {}",
193            i64::MAX,
194            i64::MIN,
195            u64::MAX
196        ))
197        .unwrap();
198
199        let mat_2 = identity.tensor_product(&mat_1);
200        let mat_3 = mat_1.tensor_product(&identity);
201
202        let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
203            "[[1  1, 1  {}, 1  1, 0, 0, 0],[0, 1  {}, 1  -1, 0, 0, 0],[0, 0, 0, 1  1, 1  {}, 1  1],[0, 0, 0, 0, 1  {}, 1  -1]] / 3  1 2 3 mod {}",
204            i64::MAX,
205            i64::MIN,
206            i64::MAX,
207            i64::MIN,
208            u64::MAX
209        ))
210        .unwrap();
211        let cmp_mat_3 = MatPolynomialRingZq::from_str(&format!(
212            "[[1  1, 0, 1  {}, 0, 1  1, 0],[0, 1  1, 0, 1  {}, 0, 1  1],[0, 0, 1  {}, 0, 1  -1, 0],[0, 0, 0, 1  {}, 0, 1  -1]] / 3  1 2 3 mod {}",
213            i64::MAX,
214            i64::MAX,
215            i64::MIN,
216            i64::MIN,
217            u64::MAX
218        ))
219        .unwrap();
220
221        assert_eq!(cmp_mat_2, mat_2);
222        assert_eq!(cmp_mat_3, mat_3);
223    }
224
225    /// Ensure the tensor product works where one is a vector and the other is a matrix.
226    #[test]
227    fn vector_matrix() {
228        let vector =
229            MatPolynomialRingZq::from_str(&format!("[[1  1],[1  -1]] / 3  1 2 3 mod {}", u64::MAX))
230                .unwrap();
231        let mat_1 = MatPolynomialRingZq::from_str(&format!(
232            "[[1  1, 1  {}, 1  1],[0, 1  {}, 1  -1]] / 3  1 2 3 mod {}",
233            i64::MAX,
234            i64::MAX,
235            u64::MAX
236        ))
237        .unwrap();
238
239        let mat_2 = vector.tensor_product(&mat_1);
240        let mat_3 = mat_1.tensor_product(&vector);
241
242        let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
243            "[[1  1, 1  {}, 1  1],[0, 1  {}, 1  -1],[1  -1, 1  -{}, 1  -1],[0, 1  -{}, 1  1]] / 3  1 2 3 mod {}",
244            i64::MAX,
245            i64::MAX,
246            i64::MAX,
247            i64::MAX,
248            u64::MAX
249        ))
250        .unwrap();
251        let cmp_mat_3 = MatPolynomialRingZq::from_str(&format!(
252            "[[1  1, 1  {}, 1  1],[1  -1, 1  -{}, 1  -1],[0, 1  {}, 1  -1],[0, 1  -{}, 1  1]] / 3  1 2 3 mod {}",
253            i64::MAX,
254            i64::MAX,
255            i64::MAX,
256            i64::MAX,
257            u64::MAX
258        ))
259        .unwrap();
260
261        assert_eq!(cmp_mat_2, mat_2);
262        assert_eq!(cmp_mat_3, mat_3);
263    }
264
265    /// Ensure that the tensor product works correctly with two vectors.
266    #[test]
267    fn vector_vector() {
268        let vec_1 =
269            MatPolynomialRingZq::from_str(&format!("[[1  2],[1  1]] / 3  1 2 3 mod {}", u64::MAX))
270                .unwrap();
271        let vec_2 = MatPolynomialRingZq::from_str(&format!(
272            "[[1  {}],[1  {}]] / 3  1 2 3 mod {}",
273            (u64::MAX - 1) / 2,
274            i64::MIN / 2,
275            u64::MAX
276        ))
277        .unwrap();
278
279        let vec_3 = vec_1.tensor_product(&vec_2);
280        let vec_4 = vec_2.tensor_product(&vec_1);
281
282        let cmp_vec_3 = MatPolynomialRingZq::from_str(&format!(
283            "[[1  {}],[1  {}],[1  {}],[1  {}]] / 3  1 2 3 mod {}",
284            u64::MAX - 1,
285            i64::MIN,
286            (u64::MAX - 1) / 2,
287            i64::MIN / 2,
288            u64::MAX
289        ))
290        .unwrap();
291        let cmp_vec_4 = MatPolynomialRingZq::from_str(&format!(
292            "[[1  {}],[1  {}],[1  {}],[1  {}]] / 3  1 2 3 mod {}",
293            u64::MAX - 1,
294            (u64::MAX - 1) / 2,
295            i64::MIN,
296            i64::MIN / 2,
297            u64::MAX
298        ))
299        .unwrap();
300
301        assert_eq!(cmp_vec_3, vec_3);
302        assert_eq!(cmp_vec_4, vec_4);
303    }
304
305    /// Ensures that the tensor product works for higher degree polynomials.
306    #[test]
307    fn higher_degree() {
308        let higher_degree = MatPolynomialRingZq::from_str(&format!(
309            "[[1  1, 2  0 1, 2  1 1]] / 3  1 2 3 mod {}",
310            u64::MAX
311        ))
312        .unwrap();
313        let mat_1 = MatPolynomialRingZq::from_str(&format!(
314            "[[1  1, 1  {}, 2  1 {}]] / 3  1 2 3 mod {}",
315            i64::MAX,
316            i64::MIN,
317            u64::MAX
318        ))
319        .unwrap();
320
321        let mat_2 = higher_degree.tensor_product(&mat_1);
322
323        let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
324            "[[1  1, 1  {}, 2  1 {}, 2  0 1, 2  0 {}, 3  0 1 {}, 2  1 1, 2  {} {}, 3  1 {} {}]] / 3  1 2 3 mod {}",
325            i64::MAX,
326            i64::MIN,
327            i64::MAX,
328            i64::MIN,
329            i64::MAX,
330            i64::MAX,
331            i64::MIN + 1,
332            i64::MIN,
333            u64::MAX
334        ))
335        .unwrap();
336
337        assert_eq!(cmp_mat_2, mat_2);
338    }
339
340    /// Ensure that the tensor product panics, if the moduli mismatch.
341    #[test]
342    #[should_panic]
343    fn moduli_mismatch_panic() {
344        let mod_poly_1 = ModulusPolynomialRingZq::from_str("3  1 2 3 mod 17").unwrap();
345        let mod_poly_2 = ModulusPolynomialRingZq::from_str("3  1 2 3 mod 16").unwrap();
346        let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly_1);
347        let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly_2);
348
349        let _ = mat_1.tensor_product(&mat_2);
350    }
351
352    /// Ensure that the safe version of the tensor product returns an error, if the moduli mismatch.
353    #[test]
354    fn moduli_mismatch_error() {
355        let mod_poly_1 = ModulusPolynomialRingZq::from_str("3  1 2 3 mod 17").unwrap();
356        let mod_poly_2 = ModulusPolynomialRingZq::from_str("3  1 2 3 mod 16").unwrap();
357        let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly_1);
358        let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly_2);
359
360        assert!(mat_1.tensor_product_safe(&mat_2).is_err());
361    }
362}