qfall_math/integer/mat_z/vector/
dot_product.rs1use crate::error::MathError;
12use crate::integer::{MatZ, Z};
13use crate::traits::MatrixDimensions;
14use flint_sys::fmpz::fmpz_addmul;
15
16impl MatZ {
17 pub fn dot_product(&self, other: &Self) -> Result<Z, MathError> {
47 if !self.is_vector() {
48 return Err(MathError::VectorFunctionCalledOnNonVector(
49 String::from("dot_product"),
50 self.get_num_rows(),
51 self.get_num_columns(),
52 ));
53 } else if !other.is_vector() {
54 return Err(MathError::VectorFunctionCalledOnNonVector(
55 String::from("dot_product"),
56 other.get_num_rows(),
57 other.get_num_columns(),
58 ));
59 }
60
61 let self_entries = self.collect_entries();
62 let other_entries = other.collect_entries();
63
64 if self_entries.len() != other_entries.len() {
65 return Err(MathError::MismatchingMatrixDimension(format!(
66 "You called the function 'dot_product' for vectors of different lengths: {} and {}",
67 self_entries.len(),
68 other_entries.len()
69 )));
70 }
71
72 let mut result = Z::ZERO;
74 for i in 0..self_entries.len() {
75 unsafe { fmpz_addmul(&mut result.value, &self_entries[i], &other_entries[i]) }
77 }
78
79 Ok(result)
80 }
81}
82
83#[cfg(test)]
84mod test_dot_product {
85 use super::{MatZ, Z};
86 use std::str::FromStr;
87
88 #[test]
91 fn row_with_row() {
92 let vec_1 = MatZ::from_str("[[1, 2, -3]]").unwrap();
93 let vec_2 = MatZ::from_str("[[1, 3, 2]]").unwrap();
94
95 let dot_prod = vec_1.dot_product(&vec_2).unwrap();
96
97 assert_eq!(dot_prod, Z::ONE);
98 }
99
100 #[test]
103 fn column_with_column() {
104 let vec_1 = MatZ::from_str("[[1],[2],[-3]]").unwrap();
105 let vec_2 = MatZ::from_str("[[1],[3],[2]]").unwrap();
106
107 let dot_prod = vec_1.dot_product(&vec_2).unwrap();
108
109 assert_eq!(dot_prod, Z::ONE);
110 }
111
112 #[test]
115 fn row_with_column() {
116 let vec_1 = MatZ::from_str("[[1, 2, -3]]").unwrap();
117 let vec_2 = MatZ::from_str("[[1],[3],[2]]").unwrap();
118
119 let dot_prod = vec_1.dot_product(&vec_2).unwrap();
120
121 assert_eq!(dot_prod, Z::ONE);
122 }
123
124 #[test]
127 fn column_with_row() {
128 let vec_1 = MatZ::from_str("[[1],[2],[-3]]").unwrap();
129 let vec_2 = MatZ::from_str("[[1, 3, 2]]").unwrap();
130
131 let dot_prod = vec_1.dot_product(&vec_2).unwrap();
132
133 assert_eq!(dot_prod, Z::ONE);
134 }
135
136 #[test]
138 fn large_numbers() {
139 let vec_1 = MatZ::from_str(&format!("[[1, -1, {}]]", i64::MAX)).unwrap();
140 let vec_2 = MatZ::from_str(&format!("[[1, {}, 1]]", i64::MIN)).unwrap();
141 let cmp = Z::from(-1) * Z::from(i64::MIN) + Z::from(i64::MAX) + Z::ONE;
142
143 let dot_prod = vec_1.dot_product(&vec_2).unwrap();
144
145 assert_eq!(dot_prod, cmp);
146 }
147
148 #[test]
151 fn non_vector_yield_error() {
152 let vec = MatZ::from_str("[[1, 3, 2]]").unwrap();
153 let mat = MatZ::from_str("[[1, 2],[2, 3],[-3, 4]]").unwrap();
154
155 assert!(vec.dot_product(&mat).is_err());
156 assert!(mat.dot_product(&vec).is_err());
157 assert!(mat.dot_product(&mat).is_err());
158 assert!(vec.dot_product(&vec).is_ok());
159 }
160
161 #[test]
164 fn different_lengths_yield_error() {
165 let vec_1 = MatZ::from_str("[[1, 2, 3]]").unwrap();
166 let vec_2 = MatZ::from_str("[[1, 2, 3, 4]]").unwrap();
167
168 assert!(vec_1.dot_product(&vec_2).is_err());
169 assert!(vec_2.dot_product(&vec_1).is_err());
170 }
171}