qfall_math/integer_mod_q/mat_zq/
tensor.rs

1// Copyright © 2023 Marvin Beckmann
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module contains the implementation of the `tensor` product.
10
11use super::MatZq;
12use crate::{
13    error::MathError,
14    traits::{CompareBase, MatrixDimensions, Tensor},
15};
16use flint_sys::{fmpz_mat::fmpz_mat_kronecker_product, fmpz_mod_mat::_fmpz_mod_mat_reduce};
17
18impl Tensor for MatZq {
19    /// Computes the tensor product of `self` with `other`.
20    ///
21    /// Parameters:
22    /// - `other`: the value with which the tensor product is computed.
23    ///
24    /// Returns the tensor product of `self` with `other` and panics if the
25    /// moduli of the provided matrices mismatch.
26    ///
27    /// # Examples
28    /// ```
29    /// use qfall_math::integer_mod_q::MatZq;
30    /// use qfall_math::traits::Tensor;
31    /// use std::str::FromStr;
32    ///
33    /// let mat_1 = MatZq::from_str("[[1, 1],[2, 2]] mod 7").unwrap();
34    /// let mat_2 = MatZq::from_str("[[1, 2],[3, 4]] mod 7").unwrap();
35    ///
36    /// let mat_ab = mat_1.tensor_product(&mat_2);
37    /// let mat_ba = mat_2.tensor_product(&mat_1);
38    ///
39    /// let res_ab = "[[1, 2, 1, 2],[3, 4, 3, 4],[2, 4, 2, 4],[6, 1, 6, 1]] mod 7";
40    /// let res_ba = "[[1, 1, 2, 2],[2, 2, 4, 4],[3, 3, 4, 4],[6, 6, 1, 1]] mod 7";
41    /// assert_eq!(mat_ab, MatZq::from_str(res_ab).unwrap());
42    /// assert_eq!(mat_ba, MatZq::from_str(res_ba).unwrap());
43    /// ```
44    ///
45    /// # Panics ...
46    /// - if the moduli of both matrices mismatch.
47    ///   Use [`tensor_product_safe`](crate::integer_mod_q::MatZq::tensor_product_safe) to get an error instead.
48    fn tensor_product(&self, other: &Self) -> Self {
49        self.tensor_product_safe(other).unwrap()
50    }
51}
52
53impl MatZq {
54    /// Computes the tensor product of `self` with `other`.
55    ///
56    /// Parameters:
57    /// - `other`: the value with which the tensor product is computed.
58    ///
59    /// Returns the tensor product of `self` with `other` or an error if the
60    /// moduli of the provided matrices mismatch.
61    ///
62    /// # Examples
63    /// ```
64    /// use qfall_math::integer_mod_q::MatZq;
65    /// use std::str::FromStr;
66    ///
67    /// let mat_1 = MatZq::from_str("[[1, 1],[2, 2]] mod 7").unwrap();
68    /// let mat_2 = MatZq::from_str("[[1, 2],[3, 4]] mod 7").unwrap();
69    ///
70    /// let mat_ab = mat_1.tensor_product_safe(&mat_2).unwrap();
71    /// let mat_ba = mat_2.tensor_product_safe(&mat_1).unwrap();
72    ///
73    /// let res_ab = "[[1, 2, 1, 2],[3, 4, 3, 4],[2, 4, 2, 4],[6, 1, 6, 1]] mod 7";
74    /// let res_ba = "[[1, 1, 2, 2],[2, 2, 4, 4],[3, 3, 4, 4],[6, 6, 1, 1]] mod 7";
75    /// assert_eq!(mat_ab, MatZq::from_str(res_ab).unwrap());
76    /// assert_eq!(mat_ba, MatZq::from_str(res_ba).unwrap());
77    /// ```
78    ///
79    /// # Errors and Failures
80    /// - Returns a [`MathError`] of type
81    ///   [`MismatchingModulus`](MathError::MismatchingModulus) if the
82    ///   moduli of the provided matrices mismatch.
83    pub fn tensor_product_safe(&self, other: &Self) -> Result<Self, MathError> {
84        if !self.compare_base(other) {
85            return Err(self.call_compare_base_error(other).unwrap());
86        }
87
88        let mut out = MatZq::new(
89            self.get_num_rows() * other.get_num_rows(),
90            self.get_num_columns() * other.get_num_columns(),
91            self.get_mod(),
92        );
93
94        unsafe {
95            fmpz_mat_kronecker_product(
96                &mut out.matrix.mat[0],
97                &self.matrix.mat[0],
98                &other.matrix.mat[0],
99            )
100        };
101
102        unsafe { _fmpz_mod_mat_reduce(&mut out.matrix) }
103
104        Ok(out)
105    }
106}
107
108#[cfg(test)]
109mod test_tensor {
110    use crate::{
111        integer_mod_q::MatZq,
112        traits::{MatrixDimensions, Tensor},
113    };
114    use std::str::FromStr;
115
116    /// Ensure that the dimensions of the tensor product are taken over correctly.
117    #[test]
118    fn dimensions_fit() {
119        let mat_1 = MatZq::new(17, 13, 13);
120        let mat_2 = MatZq::new(3, 4, 13);
121
122        let mat_3 = mat_1.tensor_product(&mat_2);
123        let mat_3_safe = mat_1.tensor_product_safe(&mat_2).unwrap();
124
125        assert_eq!(51, mat_3.get_num_rows());
126        assert_eq!(52, mat_3.get_num_columns());
127        assert_eq!(&mat_3, &mat_3_safe);
128    }
129
130    /// Ensure that the tensor works correctly with identity.
131    #[test]
132    fn identity() {
133        let identity = MatZq::from_str(&format!("[[1, 0],[0, 1]] mod {}", u128::MAX)).unwrap();
134        let mat_1 = MatZq::from_str(&format!(
135            "[[1, {}, 1],[0, {}, -1]] mod {}",
136            u64::MAX,
137            i64::MIN,
138            u128::MAX
139        ))
140        .unwrap();
141
142        let mat_2 = identity.tensor_product(&mat_1);
143        let mat_3 = mat_1.tensor_product(&identity);
144        let mat_2_safe = identity.tensor_product_safe(&mat_1).unwrap();
145        let mat_3_safe = mat_1.tensor_product_safe(&identity).unwrap();
146
147        let cmp_mat_2 = MatZq::from_str(&format!(
148            "[[1, {}, 1, 0, 0, 0], \
149              [0, {}, -1, 0, 0, 0], \
150              [0, 0, 0, 1, {}, 1], \
151              [0, 0, 0, 0, {}, -1]] mod {}",
152            u64::MAX,
153            i64::MIN,
154            u64::MAX,
155            i64::MIN,
156            u128::MAX
157        ))
158        .unwrap();
159        let cmp_mat_3 = MatZq::from_str(&format!(
160            "[[1, 0, {}, 0, 1, 0], \
161              [0, 1, 0, {}, 0, 1], \
162              [0, 0, {}, 0, -1, 0], \
163              [0, 0, 0, {}, 0, -1]] mod {}",
164            u64::MAX,
165            u64::MAX,
166            i64::MIN,
167            i64::MIN,
168            u128::MAX
169        ))
170        .unwrap();
171
172        assert_eq!(cmp_mat_2, mat_2);
173        assert_eq!(cmp_mat_3, mat_3);
174        assert_eq!(cmp_mat_2, mat_2_safe);
175        assert_eq!(cmp_mat_3, mat_3_safe);
176    }
177
178    /// Ensure the tensor product works where one is a vector and the other is a matrix.
179    #[test]
180    fn vector_matrix() {
181        let vector = MatZq::from_str(&format!("[[1],[-1]] mod {}", u128::MAX)).unwrap();
182        let mat_1 = MatZq::from_str(&format!(
183            "[[1, {}, 1],[0, {}, -1]] mod {}",
184            u64::MAX,
185            i64::MAX,
186            u128::MAX
187        ))
188        .unwrap();
189
190        let mat_2 = vector.tensor_product(&mat_1);
191        let mat_3 = mat_1.tensor_product(&vector);
192        let mat_2_safe = vector.tensor_product_safe(&mat_1).unwrap();
193        let mat_3_safe = mat_1.tensor_product_safe(&vector).unwrap();
194
195        let cmp_mat_2 = MatZq::from_str(&format!(
196            "[[1, {}, 1],[0, {}, -1],[-1, -{}, -1],[0, -{}, 1]] mod {}",
197            u64::MAX,
198            i64::MAX,
199            u64::MAX,
200            i64::MAX,
201            u128::MAX
202        ))
203        .unwrap();
204        let cmp_mat_3 = MatZq::from_str(&format!(
205            "[[1, {}, 1],[-1, -{}, -1],[0, {}, -1],[0, -{}, 1]] mod {}",
206            u64::MAX,
207            u64::MAX,
208            i64::MAX,
209            i64::MAX,
210            u128::MAX
211        ))
212        .unwrap();
213
214        assert_eq!(cmp_mat_2, mat_2);
215        assert_eq!(cmp_mat_3, mat_3);
216        assert_eq!(cmp_mat_2, mat_2_safe);
217        assert_eq!(cmp_mat_3, mat_3_safe);
218    }
219
220    /// Ensure that the tensor product works correctly with two vectors.
221    #[test]
222    fn vector_vector() {
223        let vec_1 = MatZq::from_str(&format!("[[2],[1]] mod {}", u128::MAX)).unwrap();
224        let vec_2 = MatZq::from_str(&format!(
225            "[[{}],[{}]] mod {}",
226            (u64::MAX - 1) / 2,
227            i64::MIN / 2,
228            u128::MAX
229        ))
230        .unwrap();
231
232        let vec_3 = vec_1.tensor_product(&vec_2);
233        let vec_4 = vec_2.tensor_product(&vec_1);
234        let vec_3_safe = vec_1.tensor_product_safe(&vec_2).unwrap();
235        let vec_4_safe = vec_2.tensor_product_safe(&vec_1).unwrap();
236
237        let cmp_vec_3 = MatZq::from_str(&format!(
238            "[[{}],[{}],[{}],[{}]] mod {}",
239            u64::MAX - 1,
240            i64::MIN,
241            (u64::MAX - 1) / 2,
242            i64::MIN / 2,
243            u128::MAX
244        ))
245        .unwrap();
246        let cmp_vec_4 = MatZq::from_str(&format!(
247            "[[{}],[{}],[{}],[{}]] mod {}",
248            u64::MAX - 1,
249            (u64::MAX - 1) / 2,
250            i64::MIN,
251            i64::MIN / 2,
252            u128::MAX
253        ))
254        .unwrap();
255
256        assert_eq!(cmp_vec_3, vec_3);
257        assert_eq!(cmp_vec_4, vec_4);
258        assert_eq!(cmp_vec_3, vec_3_safe);
259        assert_eq!(cmp_vec_4, vec_4_safe);
260    }
261
262    /// Ensure that entries are reduced by the modulus.
263    #[test]
264    fn entries_reduced() {
265        let mat_1 = MatZq::from_str(&format!("[[1, 2],[3, 4]] mod {}", u64::MAX - 58)).unwrap();
266        let mat_2 = MatZq::from_str(&format!("[[1, 58],[0, -1]] mod {}", u64::MAX - 58)).unwrap();
267
268        let mat_3 = mat_1.tensor_product(&mat_2);
269        let mat_3_safe = mat_1.tensor_product_safe(&mat_2).unwrap();
270
271        let mat_3_cmp = MatZq::from_str(&format!(
272            "[[1, 58, 2, 116],[0, -1, 0, -2],[3, 174, 4, 232],[0, -3, 0, -4]] mod {}",
273            u64::MAX - 58
274        ))
275        .unwrap();
276        assert_eq!(mat_3_cmp, mat_3);
277        assert_eq!(mat_3_cmp, mat_3_safe);
278    }
279
280    /// Ensure that tensor panics if the moduli mismatch.
281    #[test]
282    #[should_panic]
283    fn mismatching_moduli_tensor_product() {
284        let mat_1 = MatZq::new(1, 2, u64::MAX);
285        let mat_2 = MatZq::new(1, 2, u64::MAX - 58);
286
287        let _ = mat_1.tensor_product(&mat_2);
288    }
289
290    /// Ensure that tensor_product_safe returns an error if the moduli mismatch.
291    #[test]
292    fn mismatching_moduli_tensor_product_safe() {
293        let mat_1 = MatZq::new(1, 2, u64::MAX);
294        let mat_2 = MatZq::new(1, 2, u64::MAX - 58);
295
296        assert!(mat_1.tensor_product_safe(&mat_2).is_err());
297        assert!(mat_2.tensor_product_safe(&mat_1).is_err());
298    }
299}