use crate::vector::Vector;
use pounce_common::tagged::{Tag, TaggedCell, TaggedObject};
use pounce_common::types::{Index, Number};
use std::any::Any;
use std::cell::Cell;
use std::fmt::Debug;
#[derive(Debug)]
pub struct MatrixCache {
tag: TaggedCell,
valid: Cell<Option<(Tag, bool)>>,
}
impl Default for MatrixCache {
fn default() -> Self {
Self::new()
}
}
impl MatrixCache {
pub fn new() -> Self {
Self {
tag: TaggedCell::new(),
valid: Cell::new(None),
}
}
pub fn tag(&self) -> Tag {
self.tag.tag()
}
pub fn bump(&self) {
self.tag.bump();
}
}
pub trait Matrix: TaggedObject + Debug + 'static {
fn n_rows(&self) -> Index;
fn n_cols(&self) -> Index;
fn cache(&self) -> &MatrixCache;
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
fn as_tagged(&self) -> &dyn TaggedObject;
fn as_dyn_matrix(&self) -> &dyn Matrix;
fn mult_vector_impl(&self, alpha: Number, x: &dyn Vector, beta: Number, y: &mut dyn Vector);
fn trans_mult_vector_impl(
&self,
alpha: Number,
x: &dyn Vector,
beta: Number,
y: &mut dyn Vector,
);
fn compute_row_amax_impl(&self, rows_norms: &mut dyn Vector, init: bool);
fn compute_col_amax_impl(&self, cols_norms: &mut dyn Vector, init: bool);
fn has_valid_numbers_impl(&self) -> bool {
true
}
fn add_m_sinv_z_impl(&self, alpha: Number, s: &dyn Vector, z: &dyn Vector, x: &mut dyn Vector) {
let mut tmp = s.make_new_copy();
tmp.set(0.0);
tmp.add_vector_quotient(1.0, z, s, 0.0);
self.mult_vector(alpha, tmp.as_dyn_vector(), 1.0, x);
}
fn sinv_blrm_zmt_dbr_impl(
&self,
alpha: Number,
s: &dyn Vector,
r: &dyn Vector,
z: &dyn Vector,
d: &dyn Vector,
x: &mut dyn Vector,
) {
self.trans_mult_vector(alpha, d, 0.0, x);
x.element_wise_multiply(z);
x.axpy(1.0, r);
x.element_wise_divide(s);
}
fn mult_vector(&self, alpha: Number, x: &dyn Vector, beta: Number, y: &mut dyn Vector) {
self.mult_vector_impl(alpha, x, beta, y);
}
fn trans_mult_vector(&self, alpha: Number, x: &dyn Vector, beta: Number, y: &mut dyn Vector) {
self.trans_mult_vector_impl(alpha, x, beta, y);
}
fn compute_row_amax(&self, rows_norms: &mut dyn Vector, init: bool) {
if init {
rows_norms.set(0.0);
}
self.compute_row_amax_impl(rows_norms, init);
}
fn compute_col_amax(&self, cols_norms: &mut dyn Vector, init: bool) {
if init {
cols_norms.set(0.0);
}
self.compute_col_amax_impl(cols_norms, init);
}
fn add_m_sinv_z(&self, alpha: Number, s: &dyn Vector, z: &dyn Vector, x: &mut dyn Vector) {
self.add_m_sinv_z_impl(alpha, s, z, x);
}
fn sinv_blrm_zmt_dbr(
&self,
alpha: Number,
s: &dyn Vector,
r: &dyn Vector,
z: &dyn Vector,
d: &dyn Vector,
x: &mut dyn Vector,
) {
self.sinv_blrm_zmt_dbr_impl(alpha, s, r, z, d, x);
}
fn has_valid_numbers(&self) -> bool {
let cur = self.cache().tag();
if let Some((t, v)) = self.cache().valid.get() {
if t == cur {
return v;
}
}
let v = self.has_valid_numbers_impl();
self.cache().valid.set(Some((cur, v)));
v
}
}
pub trait SymMatrix: Matrix {
fn dim(&self) -> Index {
debug_assert_eq!(self.n_rows(), self.n_cols());
self.n_rows()
}
}
#[inline]
pub fn sym_default_trans_mult_vector_impl<M: Matrix + ?Sized>(
m: &M,
alpha: Number,
x: &dyn Vector,
beta: Number,
y: &mut dyn Vector,
) {
m.mult_vector_impl(alpha, x, beta, y);
}
#[inline]
pub fn sym_default_compute_col_amax_impl<M: Matrix + ?Sized>(
m: &M,
cols_norms: &mut dyn Vector,
init: bool,
) {
m.compute_row_amax_impl(cols_norms, init);
}