qfall_math/rational/mat_q/
tensor.rs

1// Copyright © 2023 Marvin Beckmann
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module contains the implementation of the `tensor` product.
10
11use flint_sys::fmpq_mat::fmpq_mat_kronecker_product;
12
13use super::MatQ;
14use crate::traits::{MatrixDimensions, Tensor};
15
16impl Tensor for MatQ {
17    /// Computes the tensor product of `self` with `other`.
18    ///
19    /// Parameters:
20    /// - `other`: the value with which the tensor product is computed.
21    ///
22    /// Returns the tensor product of `self` with `other`.
23    ///
24    /// # Examples
25    /// ```
26    /// use qfall_math::rational::MatQ;
27    /// use qfall_math::traits::Tensor;
28    /// use std::str::FromStr;
29    ///
30    /// let mat_1 = MatQ::from_str("[[1, 1],[2/3, 2/3]]").unwrap();
31    /// let mat_2 = MatQ::from_str("[[1, 2],[3, 4]]").unwrap();
32    ///
33    /// let mat_12 = mat_1.tensor_product(&mat_2);
34    /// let mat_21 = mat_2.tensor_product(&mat_1);
35    ///
36    /// let res_12 = "[[1, 2, 1, 2],[3, 4, 3, 4],[2/3, 4/3, 2/3, 4/3],[2, 8/3, 2, 8/3]]";
37    /// let res_21 = "[[1, 1, 2, 2],[2/3, 2/3, 4/3, 4/3],[3, 3, 4, 4],[2, 2, 8/3, 8/3]]";
38    /// assert_eq!(mat_12, MatQ::from_str(res_12).unwrap());
39    /// assert_eq!(mat_21, MatQ::from_str(res_21).unwrap());
40    /// ```
41    fn tensor_product(&self, other: &Self) -> Self {
42        let mut out = MatQ::new(
43            self.get_num_rows() * other.get_num_rows(),
44            self.get_num_columns() * other.get_num_columns(),
45        );
46
47        unsafe { fmpq_mat_kronecker_product(&mut out.matrix, &self.matrix, &other.matrix) };
48
49        out
50    }
51}
52
53#[cfg(test)]
54mod test_tensor {
55    use crate::{
56        rational::MatQ,
57        traits::{MatrixDimensions, Tensor},
58    };
59    use std::str::FromStr;
60
61    /// Ensure that the dimensions of the tensor product are taken over correctly.
62    #[test]
63    fn dimensions_fit() {
64        let mat_1 = MatQ::new(17, 13);
65        let mat_2 = MatQ::new(3, 4);
66
67        let mat_3 = mat_1.tensor_product(&mat_2);
68
69        assert_eq!(51, mat_3.get_num_rows());
70        assert_eq!(52, mat_3.get_num_columns());
71    }
72
73    /// Ensure that the tensor works correctly with identity.
74    #[test]
75    fn identity() {
76        let identity = MatQ::identity(2, 2);
77        let mat_1 =
78            MatQ::from_str(&format!("[[1, 2/{}, 1],[0, {}, -1]]", u64::MAX, i64::MIN)).unwrap();
79
80        let mat_2 = identity.tensor_product(&mat_1);
81        let mat_3 = mat_1.tensor_product(&identity);
82
83        let cmp_mat_2 = MatQ::from_str(&format!(
84            "[[1, 2/{}, 1, 0, 0, 0],[0, {}, -1, 0, 0, 0],[0, 0, 0, 1, 2/{}, 1],[0, 0, 0, 0, {}, -1]]",
85            u64::MAX,
86            i64::MIN,
87            u64::MAX,
88            i64::MIN
89        ))
90        .unwrap();
91        let cmp_mat_3 = MatQ::from_str(&format!(
92            "[[1, 0, 2/{}, 0, 1, 0],[0, 1, 0, 2/{}, 0, 1],[0, 0, {}, 0, -1, 0],[0, 0, 0, {}, 0, -1]]",
93            u64::MAX,
94            u64::MAX,
95            i64::MIN,
96            i64::MIN
97        ))
98        .unwrap();
99
100        assert_eq!(cmp_mat_2, mat_2);
101        assert_eq!(cmp_mat_3, mat_3);
102    }
103
104    /// Ensure the tensor product works where one is a vector and the other is a matrix.
105    #[test]
106    fn vector_matrix() {
107        let vector = MatQ::from_str("[[1/3],[-1]]").unwrap();
108        let mat_1 =
109            MatQ::from_str(&format!("[[1, {}, 1],[0, {}, -1]]", u64::MAX, i64::MAX)).unwrap();
110
111        let mat_2 = vector.tensor_product(&mat_1);
112        let mat_3 = mat_1.tensor_product(&vector);
113
114        let cmp_mat_2 = MatQ::from_str(&format!(
115            "[[1/3, {}/3, 1/3],[0, {}/3, -1/3],[-1, -{}, -1],[0, -{}, 1]]",
116            u64::MAX,
117            i64::MAX,
118            u64::MAX,
119            i64::MAX
120        ))
121        .unwrap();
122        let cmp_mat_3 = MatQ::from_str(&format!(
123            "[[1/3, {}/3, 1/3],[-1, -{}, -1],[0, {}/3, -1/3],[0, -{}, 1]]",
124            u64::MAX,
125            u64::MAX,
126            i64::MAX,
127            i64::MAX
128        ))
129        .unwrap();
130
131        assert_eq!(cmp_mat_2, mat_2);
132        assert_eq!(cmp_mat_3, mat_3);
133    }
134
135    /// Ensure that the tensor product works correctly with two vectors.
136    #[test]
137    fn vector_vector() {
138        let vec_1 = MatQ::from_str("[[2],[1]]").unwrap();
139        let vec_2 =
140            MatQ::from_str(&format!("[[{}],[{}]]", (u64::MAX - 1) / 2, i64::MIN / 2)).unwrap();
141
142        let vec_3 = vec_1.tensor_product(&vec_2);
143        let vec_4 = vec_2.tensor_product(&vec_1);
144
145        let cmp_vec_3 = MatQ::from_str(&format!(
146            "[[{}],[{}],[{}],[{}]]",
147            u64::MAX - 1,
148            i64::MIN,
149            (u64::MAX - 1) / 2,
150            i64::MIN / 2
151        ))
152        .unwrap();
153        let cmp_vec_4 = MatQ::from_str(&format!(
154            "[[{}],[{}],[{}],[{}]]",
155            u64::MAX - 1,
156            (u64::MAX - 1) / 2,
157            i64::MIN,
158            i64::MIN / 2
159        ))
160        .unwrap();
161
162        assert_eq!(cmp_vec_3, vec_3);
163        assert_eq!(cmp_vec_4, vec_4);
164    }
165}