qfall_math/integer/mat_z/vector/
dot_product.rs

1// Copyright © 2023 Niklas Siemer
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module includes functionality to compute the dot product of two vectors.
10
11use crate::error::MathError;
12use crate::integer::{MatZ, Z};
13use crate::traits::MatrixDimensions;
14use flint_sys::fmpz::fmpz_addmul;
15
16impl MatZ {
17    /// Returns the dot product of two vectors of type [`MatZ`].
18    ///
19    /// Parameters:
20    /// - `other`: specifies the other vector the dot product is calculated over
21    ///
22    /// Returns the resulting `dot_product` as a [`Z`] or an error
23    /// if the given [`MatZ`] instances aren't vectors or have different
24    /// numbers of entries.
25    ///
26    /// # Examples
27    /// ```
28    /// use qfall_math::integer::MatZ;
29    /// use std::str::FromStr;
30    /// # use qfall_math::integer::Z;
31    ///
32    /// let vec_1 = MatZ::from_str("[[1],[2],[3]]").unwrap();
33    /// let vec_2 = MatZ::from_str("[[1, 3, 2]]").unwrap();
34    ///
35    /// let dot_prod = vec_1.dot_product(&vec_2).unwrap();
36    ///
37    /// // 1*1 + 2*3 + 3*2 = 13
38    /// assert_eq!(Z::from(13), dot_prod);
39    /// ```
40    ///
41    /// # Errors and Failures
42    /// - Returns a [`MathError`] of type [`MathError::VectorFunctionCalledOnNonVector`] if
43    ///   the given [`MatZ`] instance is not a (row or column) vector.
44    /// - Returns a [`MathError`] of type [`MathError::MismatchingMatrixDimension`] if
45    ///   the given vectors have different lengths.
46    pub fn dot_product(&self, other: &Self) -> Result<Z, MathError> {
47        if !self.is_vector() {
48            return Err(MathError::VectorFunctionCalledOnNonVector(
49                String::from("dot_product"),
50                self.get_num_rows(),
51                self.get_num_columns(),
52            ));
53        } else if !other.is_vector() {
54            return Err(MathError::VectorFunctionCalledOnNonVector(
55                String::from("dot_product"),
56                other.get_num_rows(),
57                other.get_num_columns(),
58            ));
59        }
60
61        let self_entries = self.collect_entries();
62        let other_entries = other.collect_entries();
63
64        if self_entries.len() != other_entries.len() {
65            return Err(MathError::MismatchingMatrixDimension(format!(
66                "You called the function 'dot_product' for vectors of different lengths: {} and {}",
67                self_entries.len(),
68                other_entries.len()
69            )));
70        }
71
72        // calculate dot product of vectors
73        let mut result = Z::ZERO;
74        for i in 0..self_entries.len() {
75            // sets result = result + self.entry[i] * other.entry[i] without cloned Z element
76            unsafe { fmpz_addmul(&mut result.value, &self_entries[i], &other_entries[i]) }
77        }
78
79        Ok(result)
80    }
81}
82
83#[cfg(test)]
84mod test_dot_product {
85    use super::{MatZ, Z};
86    use std::str::FromStr;
87
88    /// Check whether the dot product is calculated correctly for the combination:
89    /// `self`: row vector, `other`: row vector
90    #[test]
91    fn row_with_row() {
92        let vec_1 = MatZ::from_str("[[1, 2, -3]]").unwrap();
93        let vec_2 = MatZ::from_str("[[1, 3, 2]]").unwrap();
94
95        let dot_prod = vec_1.dot_product(&vec_2).unwrap();
96
97        assert_eq!(dot_prod, Z::ONE);
98    }
99
100    /// Check whether the dot product is calculated correctly for the combination:
101    /// `self`: column vector, `other`: column vector
102    #[test]
103    fn column_with_column() {
104        let vec_1 = MatZ::from_str("[[1],[2],[-3]]").unwrap();
105        let vec_2 = MatZ::from_str("[[1],[3],[2]]").unwrap();
106
107        let dot_prod = vec_1.dot_product(&vec_2).unwrap();
108
109        assert_eq!(dot_prod, Z::ONE);
110    }
111
112    /// Check whether the dot product is calculated correctly for the combination:
113    /// `self`: row vector, `other`: column vector
114    #[test]
115    fn row_with_column() {
116        let vec_1 = MatZ::from_str("[[1, 2, -3]]").unwrap();
117        let vec_2 = MatZ::from_str("[[1],[3],[2]]").unwrap();
118
119        let dot_prod = vec_1.dot_product(&vec_2).unwrap();
120
121        assert_eq!(dot_prod, Z::ONE);
122    }
123
124    /// Check whether the dot product is calculated correctly for the combination:
125    /// `self`: column vector, `other`: row vector
126    #[test]
127    fn column_with_row() {
128        let vec_1 = MatZ::from_str("[[1],[2],[-3]]").unwrap();
129        let vec_2 = MatZ::from_str("[[1, 3, 2]]").unwrap();
130
131        let dot_prod = vec_1.dot_product(&vec_2).unwrap();
132
133        assert_eq!(dot_prod, Z::ONE);
134    }
135
136    /// Check whether the dot product is calculated correctly with large numbers
137    #[test]
138    fn large_numbers() {
139        let vec_1 = MatZ::from_str(&format!("[[1, -1, {}]]", i64::MAX)).unwrap();
140        let vec_2 = MatZ::from_str(&format!("[[1, {}, 1]]", i64::MIN)).unwrap();
141        let cmp = Z::from(-1) * Z::from(i64::MIN) + Z::from(i64::MAX) + Z::ONE;
142
143        let dot_prod = vec_1.dot_product(&vec_2).unwrap();
144
145        assert_eq!(dot_prod, cmp);
146    }
147
148    /// Check whether the dot product calculation on
149    /// non vector instances yield an error
150    #[test]
151    fn non_vector_yield_error() {
152        let vec = MatZ::from_str("[[1, 3, 2]]").unwrap();
153        let mat = MatZ::from_str("[[1, 2],[2, 3],[-3, 4]]").unwrap();
154
155        assert!(vec.dot_product(&mat).is_err());
156        assert!(mat.dot_product(&vec).is_err());
157        assert!(mat.dot_product(&mat).is_err());
158        assert!(vec.dot_product(&vec).is_ok());
159    }
160
161    /// Check whether the dot product calculation on
162    /// vectors of different lengths yield an error
163    #[test]
164    fn different_lengths_yield_error() {
165        let vec_1 = MatZ::from_str("[[1, 2, 3]]").unwrap();
166        let vec_2 = MatZ::from_str("[[1, 2, 3, 4]]").unwrap();
167
168        assert!(vec_1.dot_product(&vec_2).is_err());
169        assert!(vec_2.dot_product(&vec_1).is_err());
170    }
171}