Skip to main content

qfall_math/integer_mod_q/mat_polynomial_ring_zq/
tensor.rs

1// Copyright © 2025 Marcel Luca Schmidt
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module contains the implementation of the `tensor` product.
10
11use super::MatPolynomialRingZq;
12use crate::{
13    error::MathError,
14    integer::PolyOverZ,
15    traits::{CompareBase, MatrixDimensions, MatrixGetEntry, Tensor},
16};
17use flint_sys::{fmpz_poly::fmpz_poly_mul, fmpz_poly_mat::fmpz_poly_mat_entry};
18
19impl Tensor for MatPolynomialRingZq {
20    /// Computes the tensor product of `self` with `other`.
21    ///
22    /// Parameters:
23    /// - `other`: the value with which the tensor product is computed.
24    ///
25    /// Returns the tensor product of `self` with `other`.
26    ///
27    /// # Examples
28    /// ```
29    /// use qfall_math::integer_mod_q::MatPolynomialRingZq;
30    /// use qfall_math::traits::Tensor;
31    /// use std::str::FromStr;
32    ///
33    /// let mat_1 = MatPolynomialRingZq::from_str("[[1  1, 2  1 1]] / 3  1 2 1 mod 17").unwrap();
34    /// let mat_2 = MatPolynomialRingZq::from_str("[[1  1, 1  2]] / 3  1 2 1 mod 17").unwrap();
35    ///
36    /// let mat_ab = mat_1.tensor_product(&mat_2);
37    /// let mat_ba = mat_2.tensor_product(&mat_1);
38    ///
39    /// let res_ab = "[[1  1, 1  2, 2  1 1, 2  2 2]] / 3  1 2 1 mod 17";
40    /// let res_ba = "[[1  1, 2  1 1, 1  2, 2  2 2]] / 3  1 2 1 mod 17";
41    /// assert_eq!(mat_ab, MatPolynomialRingZq::from_str(res_ab).unwrap());
42    /// assert_eq!(mat_ba, MatPolynomialRingZq::from_str(res_ba).unwrap());
43    /// ```
44    ///
45    /// # Panics ...
46    /// - if the moduli of both matrices mismatch.
47    ///   Use [`tensor_product_safe`](crate::integer_mod_q::MatZq::tensor_product_safe) to get an error instead.
48    fn tensor_product(&self, other: &Self) -> Self {
49        self.tensor_product_safe(other).unwrap()
50    }
51}
52
53impl MatPolynomialRingZq {
54    /// Computes the tensor product of `self` with `other`.
55    ///
56    /// Parameters:
57    /// - `other`: the value with which the tensor product is computed.
58    ///
59    /// Returns the tensor product of `self` with `other` or an error if the
60    /// moduli of the provided matrices mismatch.
61    ///
62    /// # Examples
63    /// ```
64    /// use qfall_math::integer_mod_q::MatPolynomialRingZq;
65    /// use std::str::FromStr;
66    ///
67    /// let mat_1 = MatPolynomialRingZq::from_str("[[1  1, 2  1 1]] / 3  1 2 1 mod 17").unwrap();
68    /// let mat_2 = MatPolynomialRingZq::from_str("[[1  1, 1  2]] / 3  1 2 1 mod 17").unwrap();
69    ///
70    /// let mat_ab = mat_1.tensor_product_safe(&mat_2).unwrap();
71    /// let mat_ba = mat_2.tensor_product_safe(&mat_1).unwrap();
72    ///
73    /// let res_ab = "[[1  1, 1  2, 2  1 1, 2  2 2]] / 3  1 2 1 mod 17";
74    /// let res_ba = "[[1  1, 2  1 1, 1  2, 2  2 2]] / 3  1 2 1 mod 17";
75    /// assert_eq!(mat_ab, MatPolynomialRingZq::from_str(res_ab).unwrap());
76    /// assert_eq!(mat_ba, MatPolynomialRingZq::from_str(res_ba).unwrap());
77    /// ```
78    ///
79    /// # Errors and Failures
80    /// - Returns a [`MathError`] of type
81    ///   [`MismatchingModulus`](MathError::MismatchingModulus) if the
82    ///   moduli of the provided matrices mismatch.
83    pub fn tensor_product_safe(&self, other: &Self) -> Result<Self, MathError> {
84        if !self.compare_base(other) {
85            return Err(self.call_compare_base_error(other).unwrap());
86        }
87
88        let mut out = MatPolynomialRingZq::new(
89            self.get_num_rows() * other.get_num_rows(),
90            self.get_num_columns() * other.get_num_columns(),
91            self.get_mod(),
92        );
93
94        for i in 0..self.get_num_rows() {
95            for j in 0..self.get_num_columns() {
96                let entry: PolyOverZ = unsafe { self.get_entry_unchecked(i, j) };
97
98                if !entry.is_zero() {
99                    unsafe { set_matrix_window_mul(&mut out, i, j, entry, other) }
100                }
101            }
102        }
103
104        Ok(out)
105    }
106}
107
108/// This function sets a specific window of the provided matrix `out`
109/// according to the `scalar` multiple of `matrix`.
110///
111/// Sets the entries
112/// `[i*rows_other, j*columns_other]` up till `[row_left*(row_other +1), column_upper*(columns_other + 1)]`
113///
114/// Parameters:
115/// - `out`: the matrix in which the result is saved
116/// - `row_left`: defines the leftmost row of the set window
117/// - `column_upper`: defines the highest column of the set window
118/// - `scalar`: defines the value with which the part of the tensor product
119///   is calculated
120/// - `matrix`: the matrix with which the scalar is multiplied
121///   before setting the entries in `out`
122///
123/// Implicitly sets the entries of the matrix according to the definition
124/// of the tensor product.
125///
126/// # Security
127/// This function accesses memory directly without checking whether the memory is
128/// actually obtained by the matrix out.
129/// This means that this function should only be called wisely.
130/// If `row_left` or `row_upper` together with the length of the matrix exceeds the
131/// range of the matrix other memory could be overwritten.
132/// We included asserts to check whether this occurs, but we advise careful usage.
133unsafe fn set_matrix_window_mul(
134    out: &mut MatPolynomialRingZq,
135    row_left: i64,
136    column_upper: i64,
137    scalar: PolyOverZ,
138    matrix: &MatPolynomialRingZq,
139) {
140    let columns_other = matrix.get_num_columns();
141    let rows_other = matrix.get_num_rows();
142
143    assert!(row_left >= 0 && row_left + rows_other <= out.get_num_rows());
144    assert!(column_upper >= 0 && column_upper + columns_other <= out.get_num_columns());
145
146    for i_other in 0..rows_other {
147        for j_other in 0..columns_other {
148            unsafe {
149                fmpz_poly_mul(
150                    fmpz_poly_mat_entry(
151                        &out.matrix.matrix,
152                        row_left * rows_other + i_other,
153                        column_upper * columns_other + j_other,
154                    ),
155                    &scalar.poly,
156                    fmpz_poly_mat_entry(&matrix.matrix.matrix, i_other, j_other),
157                );
158                out.reduce_entry(
159                    row_left * rows_other + i_other,
160                    column_upper * columns_other + j_other,
161                );
162            }
163        }
164    }
165}
166
167#[cfg(test)]
168mod test_tensor {
169    use crate::{
170        integer_mod_q::{MatPolynomialRingZq, ModulusPolynomialRingZq},
171        traits::{MatrixDimensions, Tensor},
172    };
173    use std::str::FromStr;
174
175    /// Ensure that the dimensions of the tensor product are taken over correctly.
176    #[test]
177    fn dimensions_fit() {
178        let mod_poly = ModulusPolynomialRingZq::from_str("3  1 2 1 mod 17").unwrap();
179        let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly);
180        let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly);
181
182        let mat_3 = mat_1.tensor_product(&mat_2);
183
184        assert_eq!(51, mat_3.get_num_rows());
185        assert_eq!(52, mat_3.get_num_columns());
186    }
187
188    /// Ensure that the tensor works correctly with identity.
189    #[test]
190    fn identity() {
191        let mod_poly =
192            ModulusPolynomialRingZq::from_str(&format!("3  1 2 1 mod {}", u64::MAX)).unwrap();
193        let identity = MatPolynomialRingZq::identity(2, 2, &mod_poly);
194        let mat_1 = MatPolynomialRingZq::from_str(&format!(
195            "[[1  1, 1  {}, 1  1],[0, 1  {}, 1  -1]] / 3  1 2 1 mod {}",
196            i64::MAX,
197            i64::MIN,
198            u64::MAX
199        ))
200        .unwrap();
201
202        let mat_2 = identity.tensor_product(&mat_1);
203        let mat_3 = mat_1.tensor_product(&identity);
204
205        let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
206            "[[1  1, 1  {}, 1  1, 0, 0, 0],[0, 1  {}, 1  -1, 0, 0, 0],[0, 0, 0, 1  1, 1  {}, 1  1],[0, 0, 0, 0, 1  {}, 1  -1]] / 3  1 2 1 mod {}",
207            i64::MAX,
208            i64::MIN,
209            i64::MAX,
210            i64::MIN,
211            u64::MAX
212        ))
213        .unwrap();
214        let cmp_mat_3 = MatPolynomialRingZq::from_str(&format!(
215            "[[1  1, 0, 1  {}, 0, 1  1, 0],[0, 1  1, 0, 1  {}, 0, 1  1],[0, 0, 1  {}, 0, 1  -1, 0],[0, 0, 0, 1  {}, 0, 1  -1]] / 3  1 2 1 mod {}",
216            i64::MAX,
217            i64::MAX,
218            i64::MIN,
219            i64::MIN,
220            u64::MAX
221        ))
222        .unwrap();
223
224        assert_eq!(cmp_mat_2, mat_2);
225        assert_eq!(cmp_mat_3, mat_3);
226    }
227
228    /// Ensure the tensor product works where one is a vector and the other is a matrix.
229    #[test]
230    fn vector_matrix() {
231        let vector =
232            MatPolynomialRingZq::from_str(&format!("[[1  1],[1  -1]] / 3  1 2 1 mod {}", u64::MAX))
233                .unwrap();
234        let mat_1 = MatPolynomialRingZq::from_str(&format!(
235            "[[1  1, 1  {}, 1  1],[0, 1  {}, 1  -1]] / 3  1 2 1 mod {}",
236            i64::MAX,
237            i64::MAX,
238            u64::MAX
239        ))
240        .unwrap();
241
242        let mat_2 = vector.tensor_product(&mat_1);
243        let mat_3 = mat_1.tensor_product(&vector);
244
245        let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
246            "[[1  1, 1  {}, 1  1],[0, 1  {}, 1  -1],[1  -1, 1  -{}, 1  -1],[0, 1  -{}, 1  1]] / 3  1 2 1 mod {}",
247            i64::MAX,
248            i64::MAX,
249            i64::MAX,
250            i64::MAX,
251            u64::MAX
252        ))
253        .unwrap();
254        let cmp_mat_3 = MatPolynomialRingZq::from_str(&format!(
255            "[[1  1, 1  {}, 1  1],[1  -1, 1  -{}, 1  -1],[0, 1  {}, 1  -1],[0, 1  -{}, 1  1]] / 3  1 2 1 mod {}",
256            i64::MAX,
257            i64::MAX,
258            i64::MAX,
259            i64::MAX,
260            u64::MAX
261        ))
262        .unwrap();
263
264        assert_eq!(cmp_mat_2, mat_2);
265        assert_eq!(cmp_mat_3, mat_3);
266    }
267
268    /// Ensure that the tensor product works correctly with two vectors.
269    #[test]
270    fn vector_vector() {
271        let vec_1 =
272            MatPolynomialRingZq::from_str(&format!("[[1  2],[1  1]] / 3  1 2 1 mod {}", u64::MAX))
273                .unwrap();
274        let vec_2 = MatPolynomialRingZq::from_str(&format!(
275            "[[1  {}],[1  {}]] / 3  1 2 1 mod {}",
276            (u64::MAX - 1) / 2,
277            i64::MIN / 2,
278            u64::MAX
279        ))
280        .unwrap();
281
282        let vec_3 = vec_1.tensor_product(&vec_2);
283        let vec_4 = vec_2.tensor_product(&vec_1);
284
285        let cmp_vec_3 = MatPolynomialRingZq::from_str(&format!(
286            "[[1  {}],[1  {}],[1  {}],[1  {}]] / 3  1 2 1 mod {}",
287            u64::MAX - 1,
288            i64::MIN,
289            (u64::MAX - 1) / 2,
290            i64::MIN / 2,
291            u64::MAX
292        ))
293        .unwrap();
294        let cmp_vec_4 = MatPolynomialRingZq::from_str(&format!(
295            "[[1  {}],[1  {}],[1  {}],[1  {}]] / 3  1 2 1 mod {}",
296            u64::MAX - 1,
297            (u64::MAX - 1) / 2,
298            i64::MIN,
299            i64::MIN / 2,
300            u64::MAX
301        ))
302        .unwrap();
303
304        assert_eq!(cmp_vec_3, vec_3);
305        assert_eq!(cmp_vec_4, vec_4);
306    }
307
308    /// Ensures that the tensor product works for higher degree polynomials.
309    #[test]
310    fn higher_degree() {
311        let higher_degree = MatPolynomialRingZq::from_str(&format!(
312            "[[1  1, 2  0 1, 2  1 1]] / 3  1 2 1 mod {}",
313            u64::MAX
314        ))
315        .unwrap();
316        let mat_1 = MatPolynomialRingZq::from_str(&format!(
317            "[[1  1, 1  {}, 2  1 {}]] / 3  1 2 1 mod {}",
318            i64::MAX,
319            i64::MIN,
320            u64::MAX
321        ))
322        .unwrap();
323
324        let mat_2 = higher_degree.tensor_product(&mat_1);
325
326        let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
327            "[[1  1, 1  {}, 2  1 {}, 2  0 1, 2  0 {}, 3  0 1 {}, 2  1 1, 2  {} {}, 3  1 {} {}]] / 3  1 2 1 mod {}",
328            i64::MAX,
329            i64::MIN,
330            i64::MAX,
331            i64::MIN,
332            i64::MAX,
333            i64::MAX,
334            i64::MIN + 1,
335            i64::MIN,
336            u64::MAX
337        ))
338        .unwrap();
339
340        assert_eq!(cmp_mat_2, mat_2);
341    }
342
343    /// Ensure that the tensor product panics, if the moduli mismatch.
344    #[test]
345    #[should_panic]
346    fn moduli_mismatch_panic() {
347        let mod_poly_1 = ModulusPolynomialRingZq::from_str("3  1 2 1 mod 17").unwrap();
348        let mod_poly_2 = ModulusPolynomialRingZq::from_str("3  1 2 1 mod 16").unwrap();
349        let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly_1);
350        let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly_2);
351
352        let _ = mat_1.tensor_product(&mat_2);
353    }
354
355    /// Ensure that the safe version of the tensor product returns an error, if the moduli mismatch.
356    #[test]
357    fn moduli_mismatch_error() {
358        let mod_poly_1 = ModulusPolynomialRingZq::from_str("3  1 2 1 mod 17").unwrap();
359        let mod_poly_2 = ModulusPolynomialRingZq::from_str("3  1 2 1 mod 16").unwrap();
360        let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly_1);
361        let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly_2);
362
363        assert!(mat_1.tensor_product_safe(&mat_2).is_err());
364    }
365}