use core::ops::{Add, Sub, Mul, Neg};
use crate::fixed_point::imperative::FixedMatrix;
use crate::fixed_point::imperative::FixedPoint;
use crate::fixed_point::imperative::compute_matrix::{ComputeMatrix, compute_lu_decompose};
use crate::fixed_point::imperative::matrix_functions::{
matrix_exp_compute, matrix_log_compute, matrix_sqrt_compute,
};
use crate::fixed_point::imperative::linalg::upscale_to_compute;
use crate::fixed_point::core_types::errors::OverflowDetected;
#[derive(Debug, Clone)]
pub enum LazyMatrixExpr {
Literal(FixedMatrix),
Identity(usize),
Add(Box<LazyMatrixExpr>, Box<LazyMatrixExpr>),
Sub(Box<LazyMatrixExpr>, Box<LazyMatrixExpr>),
Mul(Box<LazyMatrixExpr>, Box<LazyMatrixExpr>),
ScalarMul(FixedPoint, Box<LazyMatrixExpr>),
Transpose(Box<LazyMatrixExpr>),
Negate(Box<LazyMatrixExpr>),
Inverse(Box<LazyMatrixExpr>),
Exp(Box<LazyMatrixExpr>),
Log(Box<LazyMatrixExpr>),
Sqrt(Box<LazyMatrixExpr>),
Pow(Box<LazyMatrixExpr>, FixedPoint),
}
impl LazyMatrixExpr {
pub fn literal(m: FixedMatrix) -> Self {
LazyMatrixExpr::Literal(m)
}
pub fn identity(n: usize) -> Self {
LazyMatrixExpr::Identity(n)
}
pub fn exp(self) -> Self {
LazyMatrixExpr::Exp(Box::new(self))
}
pub fn log(self) -> Self {
LazyMatrixExpr::Log(Box::new(self))
}
pub fn sqrt(self) -> Self {
LazyMatrixExpr::Sqrt(Box::new(self))
}
pub fn pow(self, p: FixedPoint) -> Self {
LazyMatrixExpr::Pow(Box::new(self), p)
}
pub fn transpose(self) -> Self {
LazyMatrixExpr::Transpose(Box::new(self))
}
pub fn inverse(self) -> Self {
LazyMatrixExpr::Inverse(Box::new(self))
}
pub fn scale(self, s: FixedPoint) -> Self {
LazyMatrixExpr::ScalarMul(s, Box::new(self))
}
pub fn depth(&self) -> usize {
match self {
LazyMatrixExpr::Literal(_) | LazyMatrixExpr::Identity(_) => 1,
LazyMatrixExpr::Transpose(inner) | LazyMatrixExpr::Negate(inner)
| LazyMatrixExpr::Inverse(inner) | LazyMatrixExpr::Exp(inner)
| LazyMatrixExpr::Log(inner) | LazyMatrixExpr::Sqrt(inner)
| LazyMatrixExpr::ScalarMul(_, inner) | LazyMatrixExpr::Pow(inner, _) => {
1 + inner.depth()
}
LazyMatrixExpr::Add(l, r) | LazyMatrixExpr::Sub(l, r)
| LazyMatrixExpr::Mul(l, r) => {
1 + l.depth().max(r.depth())
}
}
}
pub fn operation_count(&self) -> usize {
match self {
LazyMatrixExpr::Literal(_) | LazyMatrixExpr::Identity(_) => 0,
LazyMatrixExpr::Transpose(inner) | LazyMatrixExpr::Negate(inner)
| LazyMatrixExpr::Inverse(inner) | LazyMatrixExpr::Exp(inner)
| LazyMatrixExpr::Log(inner) | LazyMatrixExpr::Sqrt(inner)
| LazyMatrixExpr::ScalarMul(_, inner) | LazyMatrixExpr::Pow(inner, _) => {
1 + inner.operation_count()
}
LazyMatrixExpr::Add(l, r) | LazyMatrixExpr::Sub(l, r)
| LazyMatrixExpr::Mul(l, r) => {
1 + l.operation_count() + r.operation_count()
}
}
}
}
impl From<FixedMatrix> for LazyMatrixExpr {
fn from(m: FixedMatrix) -> Self {
LazyMatrixExpr::Literal(m)
}
}
impl Add for LazyMatrixExpr {
type Output = LazyMatrixExpr;
fn add(self, other: Self) -> Self::Output {
LazyMatrixExpr::Add(Box::new(self), Box::new(other))
}
}
impl Sub for LazyMatrixExpr {
type Output = LazyMatrixExpr;
fn sub(self, other: Self) -> Self::Output {
LazyMatrixExpr::Sub(Box::new(self), Box::new(other))
}
}
impl Mul for LazyMatrixExpr {
type Output = LazyMatrixExpr;
fn mul(self, other: Self) -> Self::Output {
LazyMatrixExpr::Mul(Box::new(self), Box::new(other))
}
}
impl Neg for LazyMatrixExpr {
type Output = LazyMatrixExpr;
fn neg(self) -> Self::Output {
LazyMatrixExpr::Negate(Box::new(self))
}
}
impl Mul<LazyMatrixExpr> for FixedPoint {
type Output = LazyMatrixExpr;
fn mul(self, matrix: LazyMatrixExpr) -> Self::Output {
LazyMatrixExpr::ScalarMul(self, Box::new(matrix))
}
}
impl Mul<FixedPoint> for LazyMatrixExpr {
type Output = LazyMatrixExpr;
fn mul(self, scalar: FixedPoint) -> Self::Output {
LazyMatrixExpr::ScalarMul(scalar, Box::new(self))
}
}
impl core::fmt::Display for LazyMatrixExpr {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
LazyMatrixExpr::Literal(_) => write!(f, "Matrix"),
LazyMatrixExpr::Identity(n) => write!(f, "I({})", n),
LazyMatrixExpr::Add(l, r) => write!(f, "({} + {})", l, r),
LazyMatrixExpr::Sub(l, r) => write!(f, "({} - {})", l, r),
LazyMatrixExpr::Mul(l, r) => write!(f, "({} * {})", l, r),
LazyMatrixExpr::ScalarMul(s, inner) => write!(f, "({} * {})", s, inner),
LazyMatrixExpr::Transpose(inner) => write!(f, "({})ᵀ", inner),
LazyMatrixExpr::Negate(inner) => write!(f, "-({})", inner),
LazyMatrixExpr::Inverse(inner) => write!(f, "({})⁻¹", inner),
LazyMatrixExpr::Exp(inner) => write!(f, "exp({})", inner),
LazyMatrixExpr::Log(inner) => write!(f, "log({})", inner),
LazyMatrixExpr::Sqrt(inner) => write!(f, "sqrt({})", inner),
LazyMatrixExpr::Pow(inner, p) => write!(f, "({})^{}", inner, p),
}
}
}
fn eval_compute(expr: &LazyMatrixExpr) -> Result<ComputeMatrix, OverflowDetected> {
match expr {
LazyMatrixExpr::Literal(m) => Ok(ComputeMatrix::from_fixed_matrix(m)),
LazyMatrixExpr::Identity(n) => Ok(ComputeMatrix::identity(*n)),
LazyMatrixExpr::Add(l, r) => {
let lc = eval_compute(l)?;
let rc = eval_compute(r)?;
Ok(lc.add(&rc))
}
LazyMatrixExpr::Sub(l, r) => {
let lc = eval_compute(l)?;
let rc = eval_compute(r)?;
Ok(lc.sub(&rc))
}
LazyMatrixExpr::Mul(l, r) => {
let lc = eval_compute(l)?;
let rc = eval_compute(r)?;
Ok(lc.mat_mul(&rc))
}
LazyMatrixExpr::ScalarMul(s, inner) => {
let mc = eval_compute(inner)?;
let s_compute = upscale_to_compute(s.raw());
Ok(mc.scalar_mul(s_compute))
}
LazyMatrixExpr::Transpose(inner) => {
let mc = eval_compute(inner)?;
Ok(mc.transpose())
}
LazyMatrixExpr::Negate(inner) => {
let mc = eval_compute(inner)?;
Ok(mc.neg())
}
LazyMatrixExpr::Inverse(inner) => {
let mc = eval_compute(inner)?;
let lu = compute_lu_decompose(&mc)?;
lu.inverse()
}
LazyMatrixExpr::Exp(inner) => {
let mc = eval_compute(inner)?;
matrix_exp_compute(&mc)
}
LazyMatrixExpr::Log(inner) => {
let mc = eval_compute(inner)?;
matrix_log_compute(&mc)
}
LazyMatrixExpr::Sqrt(inner) => {
let mc = eval_compute(inner)?;
matrix_sqrt_compute(&mc)
}
LazyMatrixExpr::Pow(inner, p) => {
let mc = eval_compute(inner)?;
let log_a = matrix_log_compute(&mc)?;
let p_compute = upscale_to_compute(p.raw());
let p_log_a = log_a.scalar_mul(p_compute);
matrix_exp_compute(&p_log_a)
}
}
}
pub fn evaluate_matrix(expr: &LazyMatrixExpr) -> Result<FixedMatrix, OverflowDetected> {
let compute_result = eval_compute(expr)?;
Ok(compute_result.to_fixed_matrix())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lazy_matrix_expr_building() {
let a = LazyMatrixExpr::from(FixedMatrix::identity(2));
let b = LazyMatrixExpr::from(FixedMatrix::identity(2));
let expr = a * b;
assert_eq!(expr.depth(), 2);
assert_eq!(expr.operation_count(), 1);
}
#[test]
fn test_lazy_matrix_identity_eval() {
let expr = LazyMatrixExpr::Identity(3);
let result = evaluate_matrix(&expr).unwrap();
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
for i in 0..3 {
for j in 0..3 {
if i == j {
assert_eq!(result.get(i, j), FixedPoint::one());
} else {
assert!(result.get(i, j).is_zero());
}
}
}
}
#[test]
fn test_lazy_matrix_add() {
let a = FixedMatrix::identity(2);
let expr = LazyMatrixExpr::from(a.clone()) + LazyMatrixExpr::from(a);
let result = evaluate_matrix(&expr).unwrap();
let two = FixedPoint::from_int(2);
assert_eq!(result.get(0, 0), two);
assert_eq!(result.get(1, 1), two);
assert!(result.get(0, 1).is_zero());
}
#[test]
fn test_lazy_matrix_chain_depth() {
let a = LazyMatrixExpr::from(FixedMatrix::identity(2));
let expr = a.exp().log(); assert_eq!(expr.depth(), 3);
assert_eq!(expr.operation_count(), 2);
}
#[test]
fn test_lazy_matrix_display() {
let a = LazyMatrixExpr::from(FixedMatrix::identity(2));
let expr = a.exp();
let s = format!("{}", expr);
assert!(s.contains("exp"));
}
}