use crate::error::MathError;
use crate::rational::{MatQ, Q};
use crate::traits::MatrixDimensions;
use flint_sys::fmpq::fmpq_addmul;
impl MatQ {
pub fn dot_product(&self, other: &Self) -> Result<Q, MathError> {
if !self.is_vector() {
return Err(MathError::VectorFunctionCalledOnNonVector(
String::from("dot_product"),
self.get_num_rows(),
self.get_num_columns(),
));
} else if !other.is_vector() {
return Err(MathError::VectorFunctionCalledOnNonVector(
String::from("dot_product"),
other.get_num_rows(),
other.get_num_columns(),
));
}
let self_entries = self.collect_entries();
let other_entries = other.collect_entries();
if self_entries.len() != other_entries.len() {
return Err(MathError::MismatchingMatrixDimension(format!(
"You called the function 'dot_product' for vectors of different lengths: {} and {}",
self_entries.len(),
other_entries.len()
)));
}
let mut result = Q::default();
for i in 0..self_entries.len() {
unsafe { fmpq_addmul(&mut result.value, &self_entries[i], &other_entries[i]) }
}
Ok(result)
}
}
#[cfg(test)]
mod test_dot_product {
use super::{MatQ, Q};
use std::str::FromStr;
#[test]
fn row_with_row() {
let vec_1 = MatQ::from_str("[[1/2, 2/7, -3]]").unwrap();
let vec_2 = MatQ::from_str("[[1, 3, 2/7]]").unwrap();
let dot_prod = vec_1.dot_product(&vec_2).unwrap();
assert_eq!(dot_prod, Q::from((1, 2)));
}
#[test]
fn column_with_column() {
let vec_1 = MatQ::from_str("[[1/2],[2/7],[-3]]").unwrap();
let vec_2 = MatQ::from_str("[[1],[3],[2/7]]").unwrap();
let dot_prod = vec_1.dot_product(&vec_2).unwrap();
assert_eq!(dot_prod, Q::from((1, 2)));
}
#[test]
fn row_with_column() {
let vec_1 = MatQ::from_str("[[1/2, 2/7, -3]]").unwrap();
let vec_2 = MatQ::from_str("[[1],[3],[2/7]]").unwrap();
let dot_prod = vec_1.dot_product(&vec_2).unwrap();
assert_eq!(dot_prod, Q::from((1, 2)));
}
#[test]
fn column_with_row() {
let vec_1 = MatQ::from_str("[[1/2],[2/7],[-3]]").unwrap();
let vec_2 = MatQ::from_str("[[1, 3, 2/7]]").unwrap();
let dot_prod = vec_1.dot_product(&vec_2).unwrap();
assert_eq!(dot_prod, Q::from((1, 2)));
}
#[test]
fn large_numbers() {
let vec_1 = MatQ::from_str(&format!("[[1, -1, {}]]", i64::MAX)).unwrap();
let vec_2 = MatQ::from_str(&format!("[[1, {}, 1]]", i64::MIN)).unwrap();
let cmp = -1 * Q::from(i64::MIN) + Q::from(i64::MAX) + 1;
let dot_prod = vec_1.dot_product(&vec_2).unwrap();
assert_eq!(dot_prod, cmp);
}
#[test]
fn non_vector_yield_error() {
let vec = MatQ::from_str("[[1/2, 3, 2/7]]").unwrap();
let mat = MatQ::from_str("[[1, 2],[2/7, 3],[-3, 4]]").unwrap();
assert!(vec.dot_product(&mat).is_err());
assert!(mat.dot_product(&vec).is_err());
assert!(mat.dot_product(&mat).is_err());
assert!(vec.dot_product(&vec).is_ok());
}
#[test]
fn different_lengths_yield_error() {
let vec_1 = MatQ::from_str("[[1, 2, 3]]").unwrap();
let vec_2 = MatQ::from_str("[[1, 2, 3, 4]]").unwrap();
assert!(vec_1.dot_product(&vec_2).is_err());
assert!(vec_2.dot_product(&vec_1).is_err());
}
}