1use super::MatZq;
12use crate::{
13 error::MathError,
14 traits::{CompareBase, MatrixDimensions, Tensor},
15};
16use flint_sys::{fmpz_mat::fmpz_mat_kronecker_product, fmpz_mod_mat::_fmpz_mod_mat_reduce};
17
18impl Tensor for MatZq {
19 fn tensor_product(&self, other: &Self) -> Self {
49 self.tensor_product_safe(other).unwrap()
50 }
51}
52
53impl MatZq {
54 pub fn tensor_product_safe(&self, other: &Self) -> Result<Self, MathError> {
84 if !self.compare_base(other) {
85 return Err(self.call_compare_base_error(other).unwrap());
86 }
87
88 let mut out = MatZq::new(
89 self.get_num_rows() * other.get_num_rows(),
90 self.get_num_columns() * other.get_num_columns(),
91 self.get_mod(),
92 );
93
94 unsafe {
95 fmpz_mat_kronecker_product(
96 &mut out.matrix.mat[0],
97 &self.matrix.mat[0],
98 &other.matrix.mat[0],
99 )
100 };
101
102 unsafe { _fmpz_mod_mat_reduce(&mut out.matrix) }
103
104 Ok(out)
105 }
106}
107
108#[cfg(test)]
109mod test_tensor {
110 use crate::{
111 integer_mod_q::MatZq,
112 traits::{MatrixDimensions, Tensor},
113 };
114 use std::str::FromStr;
115
116 #[test]
118 fn dimensions_fit() {
119 let mat_1 = MatZq::new(17, 13, 13);
120 let mat_2 = MatZq::new(3, 4, 13);
121
122 let mat_3 = mat_1.tensor_product(&mat_2);
123 let mat_3_safe = mat_1.tensor_product_safe(&mat_2).unwrap();
124
125 assert_eq!(51, mat_3.get_num_rows());
126 assert_eq!(52, mat_3.get_num_columns());
127 assert_eq!(&mat_3, &mat_3_safe);
128 }
129
130 #[test]
132 fn identity() {
133 let identity = MatZq::from_str(&format!("[[1, 0],[0, 1]] mod {}", u128::MAX)).unwrap();
134 let mat_1 = MatZq::from_str(&format!(
135 "[[1, {}, 1],[0, {}, -1]] mod {}",
136 u64::MAX,
137 i64::MIN,
138 u128::MAX
139 ))
140 .unwrap();
141
142 let mat_2 = identity.tensor_product(&mat_1);
143 let mat_3 = mat_1.tensor_product(&identity);
144 let mat_2_safe = identity.tensor_product_safe(&mat_1).unwrap();
145 let mat_3_safe = mat_1.tensor_product_safe(&identity).unwrap();
146
147 let cmp_mat_2 = MatZq::from_str(&format!(
148 "[[1, {}, 1, 0, 0, 0], \
149 [0, {}, -1, 0, 0, 0], \
150 [0, 0, 0, 1, {}, 1], \
151 [0, 0, 0, 0, {}, -1]] mod {}",
152 u64::MAX,
153 i64::MIN,
154 u64::MAX,
155 i64::MIN,
156 u128::MAX
157 ))
158 .unwrap();
159 let cmp_mat_3 = MatZq::from_str(&format!(
160 "[[1, 0, {}, 0, 1, 0], \
161 [0, 1, 0, {}, 0, 1], \
162 [0, 0, {}, 0, -1, 0], \
163 [0, 0, 0, {}, 0, -1]] mod {}",
164 u64::MAX,
165 u64::MAX,
166 i64::MIN,
167 i64::MIN,
168 u128::MAX
169 ))
170 .unwrap();
171
172 assert_eq!(cmp_mat_2, mat_2);
173 assert_eq!(cmp_mat_3, mat_3);
174 assert_eq!(cmp_mat_2, mat_2_safe);
175 assert_eq!(cmp_mat_3, mat_3_safe);
176 }
177
178 #[test]
180 fn vector_matrix() {
181 let vector = MatZq::from_str(&format!("[[1],[-1]] mod {}", u128::MAX)).unwrap();
182 let mat_1 = MatZq::from_str(&format!(
183 "[[1, {}, 1],[0, {}, -1]] mod {}",
184 u64::MAX,
185 i64::MAX,
186 u128::MAX
187 ))
188 .unwrap();
189
190 let mat_2 = vector.tensor_product(&mat_1);
191 let mat_3 = mat_1.tensor_product(&vector);
192 let mat_2_safe = vector.tensor_product_safe(&mat_1).unwrap();
193 let mat_3_safe = mat_1.tensor_product_safe(&vector).unwrap();
194
195 let cmp_mat_2 = MatZq::from_str(&format!(
196 "[[1, {}, 1],[0, {}, -1],[-1, -{}, -1],[0, -{}, 1]] mod {}",
197 u64::MAX,
198 i64::MAX,
199 u64::MAX,
200 i64::MAX,
201 u128::MAX
202 ))
203 .unwrap();
204 let cmp_mat_3 = MatZq::from_str(&format!(
205 "[[1, {}, 1],[-1, -{}, -1],[0, {}, -1],[0, -{}, 1]] mod {}",
206 u64::MAX,
207 u64::MAX,
208 i64::MAX,
209 i64::MAX,
210 u128::MAX
211 ))
212 .unwrap();
213
214 assert_eq!(cmp_mat_2, mat_2);
215 assert_eq!(cmp_mat_3, mat_3);
216 assert_eq!(cmp_mat_2, mat_2_safe);
217 assert_eq!(cmp_mat_3, mat_3_safe);
218 }
219
220 #[test]
222 fn vector_vector() {
223 let vec_1 = MatZq::from_str(&format!("[[2],[1]] mod {}", u128::MAX)).unwrap();
224 let vec_2 = MatZq::from_str(&format!(
225 "[[{}],[{}]] mod {}",
226 (u64::MAX - 1) / 2,
227 i64::MIN / 2,
228 u128::MAX
229 ))
230 .unwrap();
231
232 let vec_3 = vec_1.tensor_product(&vec_2);
233 let vec_4 = vec_2.tensor_product(&vec_1);
234 let vec_3_safe = vec_1.tensor_product_safe(&vec_2).unwrap();
235 let vec_4_safe = vec_2.tensor_product_safe(&vec_1).unwrap();
236
237 let cmp_vec_3 = MatZq::from_str(&format!(
238 "[[{}],[{}],[{}],[{}]] mod {}",
239 u64::MAX - 1,
240 i64::MIN,
241 (u64::MAX - 1) / 2,
242 i64::MIN / 2,
243 u128::MAX
244 ))
245 .unwrap();
246 let cmp_vec_4 = MatZq::from_str(&format!(
247 "[[{}],[{}],[{}],[{}]] mod {}",
248 u64::MAX - 1,
249 (u64::MAX - 1) / 2,
250 i64::MIN,
251 i64::MIN / 2,
252 u128::MAX
253 ))
254 .unwrap();
255
256 assert_eq!(cmp_vec_3, vec_3);
257 assert_eq!(cmp_vec_4, vec_4);
258 assert_eq!(cmp_vec_3, vec_3_safe);
259 assert_eq!(cmp_vec_4, vec_4_safe);
260 }
261
262 #[test]
264 fn entries_reduced() {
265 let mat_1 = MatZq::from_str(&format!("[[1, 2],[3, 4]] mod {}", u64::MAX - 58)).unwrap();
266 let mat_2 = MatZq::from_str(&format!("[[1, 58],[0, -1]] mod {}", u64::MAX - 58)).unwrap();
267
268 let mat_3 = mat_1.tensor_product(&mat_2);
269 let mat_3_safe = mat_1.tensor_product_safe(&mat_2).unwrap();
270
271 let mat_3_cmp = MatZq::from_str(&format!(
272 "[[1, 58, 2, 116],[0, -1, 0, -2],[3, 174, 4, 232],[0, -3, 0, -4]] mod {}",
273 u64::MAX - 58
274 ))
275 .unwrap();
276 assert_eq!(mat_3_cmp, mat_3);
277 assert_eq!(mat_3_cmp, mat_3_safe);
278 }
279
280 #[test]
282 #[should_panic]
283 fn mismatching_moduli_tensor_product() {
284 let mat_1 = MatZq::new(1, 2, u64::MAX);
285 let mat_2 = MatZq::new(1, 2, u64::MAX - 58);
286
287 let _ = mat_1.tensor_product(&mat_2);
288 }
289
290 #[test]
292 fn mismatching_moduli_tensor_product_safe() {
293 let mat_1 = MatZq::new(1, 2, u64::MAX);
294 let mat_2 = MatZq::new(1, 2, u64::MAX - 58);
295
296 assert!(mat_1.tensor_product_safe(&mat_2).is_err());
297 assert!(mat_2.tensor_product_safe(&mat_1).is_err());
298 }
299}