qfall_math/integer/mat_z/
tensor.rs

1// Copyright © 2023 Marvin Beckmann
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module contains the implementation of the `tensor` product.
10
11use super::MatZ;
12use crate::traits::{MatrixDimensions, Tensor};
13use flint_sys::fmpz_mat::fmpz_mat_kronecker_product;
14
15impl Tensor for MatZ {
16    /// Computes the tensor product of `self` with `other`.
17    ///
18    /// Parameters:
19    /// - `other`: the value with which the tensor product is computed.
20    ///
21    /// Returns the tensor product of `self` with `other`.
22    ///
23    /// # Examples
24    /// ```
25    /// use qfall_math::integer::MatZ;
26    /// use qfall_math::traits::Tensor;
27    /// use std::str::FromStr;
28    ///
29    /// let mat_1 = MatZ::from_str("[[1, 1],[2, 2]]").unwrap();
30    /// let mat_2 = MatZ::from_str("[[1, 2],[3, 4]]").unwrap();
31    ///
32    /// let mat_ab = mat_1.tensor_product(&mat_2);
33    /// let mat_ba = mat_2.tensor_product(&mat_1);
34    ///
35    /// let res_ab = "[[1, 2, 1, 2],[3, 4, 3, 4],[2, 4, 2, 4],[6, 8, 6, 8]]";
36    /// let res_ba = "[[1, 1, 2, 2],[2, 2, 4, 4],[3, 3, 4, 4],[6, 6, 8, 8]]";
37    /// assert_eq!(mat_ab, MatZ::from_str(res_ab).unwrap());
38    /// assert_eq!(mat_ba, MatZ::from_str(res_ba).unwrap());
39    /// ```
40    fn tensor_product(&self, other: &Self) -> Self {
41        let mut out = MatZ::new(
42            self.get_num_rows() * other.get_num_rows(),
43            self.get_num_columns() * other.get_num_columns(),
44        );
45
46        unsafe { fmpz_mat_kronecker_product(&mut out.matrix, &self.matrix, &other.matrix) };
47
48        out
49    }
50}
51#[cfg(test)]
52mod test_tensor {
53    use crate::{
54        integer::MatZ,
55        traits::{MatrixDimensions, Tensor},
56    };
57    use std::str::FromStr;
58
59    /// Ensure that the dimensions of the tensor product are taken over correctly.
60    #[test]
61    fn dimensions_fit() {
62        let mat_1 = MatZ::new(17, 13);
63        let mat_2 = MatZ::new(3, 4);
64
65        let mat_3 = mat_1.tensor_product(&mat_2);
66
67        assert_eq!(51, mat_3.get_num_rows());
68        assert_eq!(52, mat_3.get_num_columns());
69    }
70
71    /// Ensure that the tensor works correctly with identity.
72    #[test]
73    fn identity() {
74        let identity = MatZ::identity(2, 2);
75        let mat_1 =
76            MatZ::from_str(&format!("[[1, {}, 1],[0, {}, -1]]", u64::MAX, i64::MIN)).unwrap();
77
78        let mat_2 = identity.tensor_product(&mat_1);
79        let mat_3 = mat_1.tensor_product(&identity);
80
81        let cmp_mat_2 = MatZ::from_str(&format!(
82            "[[1, {}, 1, 0, 0, 0],[0, {}, -1, 0, 0, 0],[0, 0, 0, 1, {}, 1],[0, 0, 0, 0, {}, -1]]",
83            u64::MAX,
84            i64::MIN,
85            u64::MAX,
86            i64::MIN
87        ))
88        .unwrap();
89        let cmp_mat_3 = MatZ::from_str(&format!(
90            "[[1, 0, {}, 0, 1, 0],[0, 1, 0, {}, 0, 1],[0, 0, {}, 0, -1, 0],[0, 0, 0, {}, 0, -1]]",
91            u64::MAX,
92            u64::MAX,
93            i64::MIN,
94            i64::MIN
95        ))
96        .unwrap();
97
98        assert_eq!(cmp_mat_2, mat_2);
99        assert_eq!(cmp_mat_3, mat_3);
100    }
101
102    /// Ensure the tensor product works where one is a vector and the other is a matrix.
103    #[test]
104    fn vector_matrix() {
105        let vector = MatZ::from_str("[[1],[-1]]").unwrap();
106        let mat_1 =
107            MatZ::from_str(&format!("[[1, {}, 1],[0, {}, -1]]", u64::MAX, i64::MAX)).unwrap();
108
109        let mat_2 = vector.tensor_product(&mat_1);
110        let mat_3 = mat_1.tensor_product(&vector);
111
112        let cmp_mat_2 = MatZ::from_str(&format!(
113            "[[1, {}, 1],[0, {}, -1],[-1, -{}, -1],[0, -{}, 1]]",
114            u64::MAX,
115            i64::MAX,
116            u64::MAX,
117            i64::MAX
118        ))
119        .unwrap();
120        let cmp_mat_3 = MatZ::from_str(&format!(
121            "[[1, {}, 1],[-1, -{}, -1],[0, {}, -1],[0, -{}, 1]]",
122            u64::MAX,
123            u64::MAX,
124            i64::MAX,
125            i64::MAX
126        ))
127        .unwrap();
128
129        assert_eq!(cmp_mat_2, mat_2);
130        assert_eq!(cmp_mat_3, mat_3);
131    }
132
133    /// Ensure that the tensor product works correctly with two vectors.
134    #[test]
135    fn vector_vector() {
136        let vec_1 = MatZ::from_str("[[2],[1]]").unwrap();
137        let vec_2 =
138            MatZ::from_str(&format!("[[{}],[{}]]", (u64::MAX - 1) / 2, i64::MIN / 2)).unwrap();
139
140        let vec_3 = vec_1.tensor_product(&vec_2);
141        let vec_4 = vec_2.tensor_product(&vec_1);
142
143        let cmp_vec_3 = MatZ::from_str(&format!(
144            "[[{}],[{}],[{}],[{}]]",
145            u64::MAX - 1,
146            i64::MIN,
147            (u64::MAX - 1) / 2,
148            i64::MIN / 2
149        ))
150        .unwrap();
151        let cmp_vec_4 = MatZ::from_str(&format!(
152            "[[{}],[{}],[{}],[{}]]",
153            u64::MAX - 1,
154            (u64::MAX - 1) / 2,
155            i64::MIN,
156            i64::MIN / 2
157        ))
158        .unwrap();
159
160        assert_eq!(cmp_vec_3, vec_3);
161        assert_eq!(cmp_vec_4, vec_4);
162    }
163}