qfall_math/integer/poly_over_z/
dot_product.rs1use crate::integer::PolyOverZ;
12use crate::{error::MathError, integer::Z};
13use flint_sys::fmpz::{fmpz_add, fmpz_mul};
14use flint_sys::fmpz_poly::fmpz_poly_get_coeff_fmpz;
15
16impl PolyOverZ {
17 pub fn dot_product(&self, other: &Self) -> Result<Z, MathError> {
37 let self_degree = self.get_degree();
38 let other_degree = other.get_degree();
39
40 let mut smaller_degree = self_degree;
41 if smaller_degree > other_degree {
42 smaller_degree = other_degree;
43 }
44
45 let mut result = Z::default();
47 let mut temp = Z::default();
48 for i in 0..=smaller_degree {
49 unsafe {
51 let mut coefficient_1 = Z::default();
52 let mut coefficient_2 = Z::default();
53 fmpz_poly_get_coeff_fmpz(&mut coefficient_1.value, &self.poly, i);
54 fmpz_poly_get_coeff_fmpz(&mut coefficient_2.value, &other.poly, i);
55
56 fmpz_mul(&mut temp.value, &coefficient_1.value, &coefficient_2.value);
57
58 fmpz_add(&mut result.value, &result.value, &temp.value)
59 }
60 }
61
62 Ok(result)
63 }
64}
65
66#[cfg(test)]
67mod test_dot_product {
68 use crate::integer::{PolyOverZ, Z};
69 use std::str::FromStr;
70
71 #[test]
73 fn dot_product_correct() {
74 let poly_1 = PolyOverZ::from_str("2 1 1").unwrap();
75 let poly_2 = PolyOverZ::from_str("2 3 4").unwrap();
76
77 let cmp = Z::from(7);
78 let dot_prod = poly_1.dot_product(&poly_2).unwrap();
79
80 assert_eq!(dot_prod, cmp);
81 }
82
83 #[test]
85 fn large_numbers() {
86 let poly_1 = PolyOverZ::from_str("3 6 2 4").unwrap();
87 let poly_2 = PolyOverZ::from_str(&format!("3 1 2 {}", i64::MAX / 8)).unwrap();
88
89 let cmp = Z::from(10 + 4 * (i64::MAX / 8));
90 let dot_prod = poly_1.dot_product(&poly_2).unwrap();
91
92 assert_eq!(dot_prod, cmp);
93 }
94
95 #[test]
98 fn different_lengths_work() {
99 let poly_1 = PolyOverZ::from_str("3 1 2 3").unwrap();
100 let poly_2 = PolyOverZ::from_str("2 3 4").unwrap();
101
102 let cmp = Z::from(11);
103 let dot_prod = poly_1.dot_product(&poly_2).unwrap();
104
105 assert_eq!(dot_prod, cmp);
106 }
107
108 #[test]
111 fn zero_length_works() {
112 let poly_1 = PolyOverZ::from_str("3 1 2 3").unwrap();
113 let poly_2 = PolyOverZ::from(0);
114
115 let cmp = Z::from(0);
116 let dot_prod = poly_1.dot_product(&poly_2).unwrap();
117
118 assert_eq!(dot_prod, cmp);
119 }
120}