#[derive(Debug, Clone)]
pub struct Matrix {
pub rows: usize,
pub cols: usize,
pub data: Vec<f64>,
}
impl Matrix {
pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> Self {
assert_eq!(data.len(), rows * cols);
Self { rows, cols, data }
}
pub fn multiply(&self, other: &Matrix) -> Result<Matrix, String> {
if self.cols != other.rows {
return Err("Incompatible dimensions for multiplication".into());
}
let mut result = vec![0.0; self.rows * other.cols];
for i in 0..self.rows {
for j in 0..other.cols {
let mut sum = 0.0;
for k in 0..self.cols {
sum += self.data[i * self.cols + k] * other.data[k * other.cols + j];
}
result[i * other.cols + j] = sum;
}
}
Ok(Matrix {
rows: self.rows,
cols: other.cols,
data: result,
})
}
}
fn matrix_chain_order(dims: &[usize]) -> (Vec<Vec<usize>>, Vec<Vec<usize>>) {
let n = dims.len() - 1; let mut m = vec![vec![0; n]; n];
let mut s = vec![vec![0; n]; n];
for chain_length in 2..=n {
for i in 0..=n - chain_length {
let j = i + chain_length - 1;
m[i][j] = usize::MAX;
for k in i..j {
let q = m[i][k] + m[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];
if q < m[i][j] {
m[i][j] = q;
s[i][j] = k;
}
}
}
}
(m, s)
}
fn construct_optimal_parens(s: &Vec<Vec<usize>>, i: usize, j: usize) -> Vec<(usize, usize)> {
if i == j {
return vec![(i, i)];
}
let k = s[i][j];
let left = construct_optimal_parens(s, i, k);
let right = construct_optimal_parens(s, k + 1, j);
[left, right].concat()
}
pub fn optimal_matrix_chain_multiplication(matrices: &[Matrix]) -> Result<Matrix, String> {
if matrices.is_empty() {
return Err("No matrices provided".to_string());
}
if matrices.len() == 1 {
return Ok(matrices[0].clone());
}
let mut dims = Vec::with_capacity(matrices.len() + 1);
dims.push(matrices[0].rows);
for mat in matrices.iter() {
if dims.last().unwrap() != &mat.rows {
return Err("Dimension mismatch in consecutive matrices".into());
}
dims.push(mat.cols);
}
let n = matrices.len();
let (_, s) = matrix_chain_order(&dims);
let _order = construct_optimal_parens(&s, 0, n - 1);
multiply_chain_rec(matrices, &s, 0, n - 1)
}
fn multiply_chain_rec(
matrices: &[Matrix],
s: &Vec<Vec<usize>>,
i: usize,
j: usize,
) -> Result<Matrix, String> {
if i == j {
return Ok(matrices[i].clone());
}
let k = s[i][j];
let left = multiply_chain_rec(matrices, s, i, k)?;
let right = multiply_chain_rec(matrices, s, k + 1, j)?;
left.multiply(&right)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimal_matrix_chain_multiplication() {
let a = Matrix::new(10, 30, vec![1.0; 10 * 30]);
let b = Matrix::new(30, 5, vec![1.0; 30 * 5]);
let c = Matrix::new(5, 60, vec![1.0; 5 * 60]);
let matrices = vec![a, b, c];
let result = optimal_matrix_chain_multiplication(&matrices);
assert!(result.is_ok());
let res_matrix = result.unwrap();
assert_eq!(res_matrix.rows, 10);
assert_eq!(res_matrix.cols, 60);
}
}