use crate::matrix::{Matrix, MatrixCache};
use crate::vector::Vector;
use pounce_common::tagged::{Tag, TaggedObject};
use pounce_common::types::{Index, Number};
use std::any::Any;
use std::rc::Rc;
#[derive(Debug)]
pub struct TransposeMatrix {
orig: Rc<dyn Matrix>,
cache: MatrixCache,
}
impl TransposeMatrix {
pub fn new(orig: Rc<dyn Matrix>) -> Self {
Self {
orig,
cache: MatrixCache::new(),
}
}
pub fn orig(&self) -> &Rc<dyn Matrix> {
&self.orig
}
}
impl TaggedObject for TransposeMatrix {
fn get_tag(&self) -> Tag {
self.cache.tag()
}
}
impl Matrix for TransposeMatrix {
fn n_rows(&self) -> Index {
self.orig.n_cols()
}
fn n_cols(&self) -> Index {
self.orig.n_rows()
}
fn cache(&self) -> &MatrixCache {
&self.cache
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn as_tagged(&self) -> &dyn TaggedObject {
self
}
fn as_dyn_matrix(&self) -> &dyn Matrix {
self
}
fn mult_vector_impl(&self, alpha: Number, x: &dyn Vector, beta: Number, y: &mut dyn Vector) {
self.orig.trans_mult_vector(alpha, x, beta, y);
}
fn trans_mult_vector_impl(
&self,
alpha: Number,
x: &dyn Vector,
beta: Number,
y: &mut dyn Vector,
) {
self.orig.mult_vector(alpha, x, beta, y);
}
fn has_valid_numbers_impl(&self) -> bool {
self.orig.has_valid_numbers()
}
fn compute_row_amax_impl(&self, rows_norms: &mut dyn Vector, init: bool) {
self.orig.compute_col_amax_impl(rows_norms, init);
}
fn compute_col_amax_impl(&self, cols_norms: &mut dyn Vector, init: bool) {
self.orig.compute_row_amax_impl(cols_norms, init);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dense_vector::DenseVectorSpace;
use crate::expansion_matrix::{ExpansionMatrix, ExpansionMatrixSpace};
use crate::DenseVector;
fn dvec_box(values: &[Number]) -> Box<dyn Vector> {
let space = DenseVectorSpace::new(values.len() as Index);
let mut v = space.make_new_dense();
v.set_values(values);
Box::new(v)
}
#[test]
fn transpose_swaps_mult_and_trans_mult() {
let exp_space = ExpansionMatrixSpace::new(5, 2, &[1, 3], 0);
let p: Rc<dyn Matrix> = Rc::new(ExpansionMatrix::new(exp_space));
let pt = TransposeMatrix::new(Rc::clone(&p));
assert_eq!(pt.n_rows(), 2);
assert_eq!(pt.n_cols(), 5);
let large = dvec_box(&[10.0, 20.0, 30.0, 40.0, 50.0]);
let mut small = dvec_box(&[0.0, 0.0]);
pt.mult_vector(1.0, large.as_dyn_vector(), 0.0, small.as_mut());
let dv = small.as_any().downcast_ref::<DenseVector>().unwrap();
assert_eq!(dv.expanded_values().to_vec(), vec![20.0, 40.0]);
let small2 = dvec_box(&[7.0, -2.0]);
let mut large2 = dvec_box(&[0.0; 5]);
pt.trans_mult_vector(1.0, small2.as_dyn_vector(), 0.0, large2.as_mut());
let dv = large2.as_any().downcast_ref::<DenseVector>().unwrap();
assert_eq!(
dv.expanded_values().to_vec(),
vec![0.0, 7.0, 0.0, -2.0, 0.0]
);
}
}