use std::collections::HashMap;
use std::fmt;
use std::ops;
use super::{AsVarName, E, constant};
#[derive(Debug, Clone, PartialEq)]
pub struct SymVec(pub Vec<E>);
impl SymVec {
pub fn new<I>(elems: I) -> Self
where
I: IntoIterator,
I::Item: Into<E>,
{
SymVec(elems.into_iter().map(Into::into).collect())
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn get(&self, i: usize) -> &E {
&self.0[i]
}
pub fn dot(&self, other: &SymVec) -> E {
assert_eq!(self.len(), other.len(), "dot product: length mismatch");
let mut terms: Vec<E> = Vec::with_capacity(self.len());
for i in 0..self.len() {
terms.push(self.0[i].clone() * other.0[i].clone());
}
terms.into_iter().reduce(|a, b| a + b).unwrap_or_else(|| constant(0.0))
}
pub fn diff(&self, var: impl AsVarName) -> SymVec {
let v = var.var_name();
SymVec(self.0.iter().map(|e| e.diff(v)).collect())
}
pub fn eval(&self, vars: &HashMap<&str, f64>) -> Result<Vec<f64>, String> {
self.0.iter().map(|e| e.eval(vars)).collect()
}
pub fn simplify(&self) -> SymVec {
SymVec(self.0.iter().map(|e| e.simplify()).collect())
}
pub fn expand(&self) -> SymVec {
SymVec(self.0.iter().map(|e| e.expand()).collect())
}
pub fn subs(&self, var: impl AsVarName, replacement: &E) -> SymVec {
let name = var.var_name();
SymVec(self.0.iter().map(|e| e.subs(name, replacement)).collect())
}
pub fn to_latex(&self) -> String {
let mut buf = String::from("\\begin{pmatrix} ");
for (i, e) in self.0.iter().enumerate() {
if i > 0 { buf.push_str(" \\\\ "); }
buf.push_str(&e.to_latex());
}
buf.push_str(" \\end{pmatrix}");
buf
}
pub fn to_rust(&self, ft: &str) -> String {
let mut buf = String::from("[");
for (i, e) in self.0.iter().enumerate() {
if i > 0 { buf.push_str(", "); }
buf.push_str(&e.to_rust(ft));
}
buf.push(']');
buf
}
}
impl ops::Index<usize> for SymVec {
type Output = E;
fn index(&self, i: usize) -> &E {
&self.0[i]
}
}
impl ops::Add for SymVec {
type Output = SymVec;
fn add(self, rhs: SymVec) -> SymVec {
assert_eq!(self.len(), rhs.len(), "SymVec add: length mismatch");
SymVec(
self.0.into_iter().zip(rhs.0)
.map(|(a, b)| a + b)
.collect()
)
}
}
impl ops::Mul<E> for SymVec {
type Output = SymVec;
fn mul(self, rhs: E) -> SymVec {
SymVec(self.0.into_iter().map(|e| e * rhs.clone()).collect())
}
}
impl ops::Mul<SymVec> for E {
type Output = SymVec;
fn mul(self, rhs: SymVec) -> SymVec {
SymVec(rhs.0.into_iter().map(|e| self.clone() * e).collect())
}
}
impl fmt::Display for SymVec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[")?;
for (i, e) in self.0.iter().enumerate() {
if i > 0 { write!(f, ", ")?; }
fmt::Display::fmt(e, f)?;
}
write!(f, "]")
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct SymMat {
pub rows: usize,
pub cols: usize,
pub data: Vec<E>,
}
impl SymMat {
pub fn new<I>(rows: usize, cols: usize, data: I) -> Self
where
I: IntoIterator,
I::Item: Into<E>,
{
let data: Vec<E> = data.into_iter().map(Into::into).collect();
assert_eq!(data.len(), rows * cols, "SymMat::new: data size mismatch");
SymMat { rows, cols, data }
}
pub fn zeros(rows: usize, cols: usize) -> Self {
SymMat {
rows,
cols,
data: vec![constant(0.0); rows * cols],
}
}
pub fn identity(n: usize) -> Self {
let mut data = vec![constant(0.0); n * n];
for i in 0..n {
data[i * n + i] = constant(1.0);
}
SymMat { rows: n, cols: n, data }
}
pub fn get(&self, i: usize, j: usize) -> &E {
&self.data[i * self.cols + j]
}
pub fn set(&mut self, i: usize, j: usize, val: E) {
self.data[i * self.cols + j] = val;
}
pub fn transpose(&self) -> SymMat {
let mut data = Vec::with_capacity(self.rows * self.cols);
for j in 0..self.cols {
for i in 0..self.rows {
data.push(self.get(i, j).clone());
}
}
SymMat { rows: self.cols, cols: self.rows, data }
}
pub fn diff(&self, var: impl AsVarName) -> SymMat {
let v = var.var_name();
SymMat {
rows: self.rows,
cols: self.cols,
data: self.data.iter().map(|e| e.diff(v)).collect(),
}
}
pub fn eval(&self, vars: &HashMap<&str, f64>) -> Result<Vec<Vec<f64>>, String> {
let mut result = Vec::with_capacity(self.rows);
for i in 0..self.rows {
let mut row = Vec::with_capacity(self.cols);
for j in 0..self.cols {
row.push(self.get(i, j).eval(vars)?);
}
result.push(row);
}
Ok(result)
}
pub fn simplify(&self) -> SymMat {
SymMat {
rows: self.rows,
cols: self.cols,
data: self.data.iter().map(|e| e.simplify()).collect(),
}
}
pub fn expand(&self) -> SymMat {
SymMat {
rows: self.rows,
cols: self.cols,
data: self.data.iter().map(|e| e.expand()).collect(),
}
}
pub fn subs(&self, var: impl AsVarName, replacement: &E) -> SymMat {
let name = var.var_name();
SymMat {
rows: self.rows,
cols: self.cols,
data: self.data.iter().map(|e| e.subs(name, replacement)).collect(),
}
}
pub fn to_latex(&self) -> String {
let mut buf = String::from("\\begin{pmatrix} ");
for i in 0..self.rows {
if i > 0 { buf.push_str(" \\\\ "); }
for j in 0..self.cols {
if j > 0 { buf.push_str(" & "); }
buf.push_str(&self.get(i, j).to_latex());
}
}
buf.push_str(" \\end{pmatrix}");
buf
}
pub fn to_rust(&self, ft: &str) -> String {
let mut buf = String::from("[");
for i in 0..self.rows {
if i > 0 { buf.push_str(", "); }
buf.push('[');
for j in 0..self.cols {
if j > 0 { buf.push_str(", "); }
buf.push_str(&self.get(i, j).to_rust(ft));
}
buf.push(']');
}
buf.push(']');
buf
}
}
impl ops::Add for SymMat {
type Output = SymMat;
fn add(self, rhs: SymMat) -> SymMat {
assert_eq!((self.rows, self.cols), (rhs.rows, rhs.cols), "SymMat add: dimension mismatch");
SymMat {
rows: self.rows,
cols: self.cols,
data: self.data.into_iter().zip(rhs.data)
.map(|(a, b)| a + b)
.collect(),
}
}
}
impl ops::Mul for SymMat {
type Output = SymMat;
fn mul(self, rhs: SymMat) -> SymMat {
assert_eq!(self.cols, rhs.rows, "SymMat mul: dimension mismatch");
let mut data = Vec::with_capacity(self.rows * rhs.cols);
for i in 0..self.rows {
for j in 0..rhs.cols {
let mut sum: Option<E> = None;
for k in 0..self.cols {
let prod = self.get(i, k).clone() * rhs.get(k, j).clone();
sum = Some(match sum {
Some(acc) => acc + prod,
None => prod,
});
}
data.push(sum.unwrap_or_else(|| constant(0.0)));
}
}
SymMat { rows: self.rows, cols: rhs.cols, data }
}
}
impl ops::Mul<SymVec> for SymMat {
type Output = SymVec;
fn mul(self, rhs: SymVec) -> SymVec {
assert_eq!(self.cols, rhs.len(), "SymMat * SymVec: dimension mismatch");
let mut result = Vec::with_capacity(self.rows);
for i in 0..self.rows {
let mut sum: Option<E> = None;
for j in 0..self.cols {
let prod = self.get(i, j).clone() * rhs[j].clone();
sum = Some(match sum {
Some(acc) => acc + prod,
None => prod,
});
}
result.push(sum.unwrap_or_else(|| constant(0.0)));
}
SymVec(result)
}
}
impl ops::Mul<E> for SymMat {
type Output = SymMat;
fn mul(self, rhs: E) -> SymMat {
SymMat {
rows: self.rows,
cols: self.cols,
data: self.data.into_iter().map(|e| e * rhs.clone()).collect(),
}
}
}
impl ops::Mul<SymMat> for E {
type Output = SymMat;
fn mul(self, rhs: SymMat) -> SymMat {
SymMat {
rows: rhs.rows,
cols: rhs.cols,
data: rhs.data.into_iter().map(|e| self.clone() * e).collect(),
}
}
}
impl fmt::Display for SymMat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[")?;
for i in 0..self.rows {
if i > 0 { write!(f, "; ")?; }
for j in 0..self.cols {
if j > 0 { write!(f, ", ")?; }
fmt::Display::fmt(self.get(i, j), f)?;
}
}
write!(f, "]")
}
}
pub fn jacobian(exprs: &[E], vars: &[&str]) -> SymMat {
let rows = exprs.len();
let cols = vars.len();
let mut data = Vec::with_capacity(rows * cols);
for expr in exprs {
for var in vars {
data.push(expr.diff(var));
}
}
SymMat { rows, cols, data }
}