use crate::error::MathError;
use crate::integer::{MatZ, Z};
use crate::traits::MatrixDimensions;
use flint_sys::fmpz::fmpz_addmul;
impl MatZ {
pub fn dot_product(&self, other: &Self) -> Result<Z, MathError> {
if !self.is_vector() {
return Err(MathError::VectorFunctionCalledOnNonVector(
String::from("dot_product"),
self.get_num_rows(),
self.get_num_columns(),
));
} else if !other.is_vector() {
return Err(MathError::VectorFunctionCalledOnNonVector(
String::from("dot_product"),
other.get_num_rows(),
other.get_num_columns(),
));
}
let self_entries = self.collect_entries();
let other_entries = other.collect_entries();
if self_entries.len() != other_entries.len() {
return Err(MathError::MismatchingMatrixDimension(format!(
"You called the function 'dot_product' for vectors of different lengths: {} and {}",
self_entries.len(),
other_entries.len()
)));
}
let mut result = Z::ZERO;
for i in 0..self_entries.len() {
unsafe { fmpz_addmul(&mut result.value, &self_entries[i], &other_entries[i]) }
}
Ok(result)
}
}
#[cfg(test)]
mod test_dot_product {
use super::{MatZ, Z};
use std::str::FromStr;
#[test]
fn row_with_row() {
let vec_1 = MatZ::from_str("[[1, 2, -3]]").unwrap();
let vec_2 = MatZ::from_str("[[1, 3, 2]]").unwrap();
let dot_prod = vec_1.dot_product(&vec_2).unwrap();
assert_eq!(dot_prod, Z::ONE);
}
#[test]
fn column_with_column() {
let vec_1 = MatZ::from_str("[[1],[2],[-3]]").unwrap();
let vec_2 = MatZ::from_str("[[1],[3],[2]]").unwrap();
let dot_prod = vec_1.dot_product(&vec_2).unwrap();
assert_eq!(dot_prod, Z::ONE);
}
#[test]
fn row_with_column() {
let vec_1 = MatZ::from_str("[[1, 2, -3]]").unwrap();
let vec_2 = MatZ::from_str("[[1],[3],[2]]").unwrap();
let dot_prod = vec_1.dot_product(&vec_2).unwrap();
assert_eq!(dot_prod, Z::ONE);
}
#[test]
fn column_with_row() {
let vec_1 = MatZ::from_str("[[1],[2],[-3]]").unwrap();
let vec_2 = MatZ::from_str("[[1, 3, 2]]").unwrap();
let dot_prod = vec_1.dot_product(&vec_2).unwrap();
assert_eq!(dot_prod, Z::ONE);
}
#[test]
fn large_numbers() {
let vec_1 = MatZ::from_str(&format!("[[1, -1, {}]]", i64::MAX)).unwrap();
let vec_2 = MatZ::from_str(&format!("[[1, {}, 1]]", i64::MIN)).unwrap();
let cmp = Z::from(-1) * Z::from(i64::MIN) + Z::from(i64::MAX) + Z::ONE;
let dot_prod = vec_1.dot_product(&vec_2).unwrap();
assert_eq!(dot_prod, cmp);
}
#[test]
fn non_vector_yield_error() {
let vec = MatZ::from_str("[[1, 3, 2]]").unwrap();
let mat = MatZ::from_str("[[1, 2],[2, 3],[-3, 4]]").unwrap();
assert!(vec.dot_product(&mat).is_err());
assert!(mat.dot_product(&vec).is_err());
assert!(mat.dot_product(&mat).is_err());
assert!(vec.dot_product(&vec).is_ok());
}
#[test]
fn different_lengths_yield_error() {
let vec_1 = MatZ::from_str("[[1, 2, 3]]").unwrap();
let vec_2 = MatZ::from_str("[[1, 2, 3, 4]]").unwrap();
assert!(vec_1.dot_product(&vec_2).is_err());
assert!(vec_2.dot_product(&vec_1).is_err());
}
}