use super::Op;
use crate::{Matrix, Vector};
use num_traits::{One, Zero};
pub trait LinearOp: Op {
fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
let beta = Self::T::zero();
self.gemv_inplace(x, t, beta, y);
}
fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V);
fn matrix(&self, t: Self::T) -> Self::M {
let mut y = Self::M::new_from_sparsity(
self.nstates(),
self.nstates(),
self.sparsity(),
self.context().clone(),
);
self.matrix_inplace(t, &mut y);
y
}
fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) {
self._default_matrix_inplace(t, y);
}
fn _default_matrix_inplace(&self, t: Self::T, y: &mut Self::M) {
let mut v = Self::V::zeros(self.nstates(), self.context().clone());
let mut col = Self::V::zeros(self.nout(), self.context().clone());
for j in 0..self.nstates() {
v.set_index(j, Self::T::one());
self.call_inplace(&v, t, &mut col);
y.set_column(j, &col);
v.set_index(j, Self::T::zero());
}
}
fn sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
None
}
}
pub trait LinearOpTranspose: LinearOp {
fn gemv_transpose_inplace(&self, _x: &Self::V, _t: Self::T, _beta: Self::T, _y: &mut Self::V);
fn call_transpose_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) {
let beta = Self::T::zero();
self.gemv_transpose_inplace(x, t, beta, y);
}
fn transpose_inplace(&self, t: Self::T, y: &mut Self::M) {
self._default_transpose_inplace(t, y);
}
fn _default_transpose_inplace(&self, t: Self::T, y: &mut Self::M) {
let mut v = Self::V::zeros(self.nstates(), self.context().clone());
let mut col = Self::V::zeros(self.nout(), self.context().clone());
for j in 0..self.nstates() {
v.set_index(j, Self::T::one());
self.call_transpose_inplace(&v, t, &mut col);
y.set_column(j, &col);
v.set_index(j, Self::T::zero());
}
}
fn transpose_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
None
}
}
pub trait LinearOpSens: LinearOp {
fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, _y: &mut Self::V);
fn sens_mul(&self, x: &Self::V, t: Self::T, v: &Self::V) -> Self::V {
let mut y = Self::V::zeros(self.nstates(), self.context().clone());
self.sens_mul_inplace(x, t, v, &mut y);
y
}
fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
self._default_sens_inplace(x, t, y);
}
fn _default_sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) {
let mut v = Self::V::zeros(self.nparams(), self.context().clone());
let mut col = Self::V::zeros(self.nout(), self.context().clone());
for j in 0..self.nparams() {
v.set_index(j, Self::T::one());
self.sens_mul_inplace(x, t, &v, &mut col);
y.set_column(j, &col);
v.set_index(j, Self::T::zero());
}
}
fn sens(&self, x: &Self::V, t: Self::T) -> Self::M {
let n = self.nstates();
let m = self.nparams();
let mut y = Self::M::new_from_sparsity(n, m, self.sens_sparsity(), self.context().clone());
self.sens_inplace(x, t, &mut y);
y
}
fn sens_sparsity(&self) -> Option<<Self::M as Matrix>::Sparsity> {
None
}
}
#[cfg(test)]
mod tests {
use crate::{
context::nalgebra::NalgebraContext, matrix::dense_nalgebra_serial::NalgebraMat,
matrix::Matrix, DenseMatrix, LinearOp, LinearOpSens, LinearOpTranspose, Op, Vector,
};
type M = NalgebraMat<f64>;
struct FakeLinearOp {
ctx: NalgebraContext,
}
impl Op for FakeLinearOp {
type T = f64;
type V = crate::NalgebraVec<f64>;
type M = M;
type C = NalgebraContext;
fn context(&self) -> &Self::C {
&self.ctx
}
fn nstates(&self) -> usize {
2
}
fn nout(&self) -> usize {
2
}
fn nparams(&self) -> usize {
2
}
}
impl LinearOp for FakeLinearOp {
fn gemv_inplace(&self, x: &Self::V, _t: Self::T, beta: Self::T, y: &mut Self::V) {
let out = Self::V::from_vec(
vec![
2.0 * x.get_index(0) + 3.0 * x.get_index(1),
-x.get_index(0) + 4.0 * x.get_index(1),
],
NalgebraContext,
);
y.axpy(1.0, &out, beta);
}
}
impl LinearOpTranspose for FakeLinearOp {
fn gemv_transpose_inplace(&self, x: &Self::V, _t: Self::T, beta: Self::T, y: &mut Self::V) {
let out = Self::V::from_vec(
vec![
2.0 * x.get_index(0) - x.get_index(1),
3.0 * x.get_index(0) + 4.0 * x.get_index(1),
],
NalgebraContext,
);
y.axpy(1.0, &out, beta);
}
}
impl LinearOpSens for FakeLinearOp {
fn sens_mul_inplace(&self, _x: &Self::V, _t: Self::T, v: &Self::V, y: &mut Self::V) {
y.copy_from(&Self::V::from_vec(
vec![
v.get_index(0) + 2.0 * v.get_index(1),
3.0 * v.get_index(0) + 4.0 * v.get_index(1),
],
NalgebraContext,
));
}
}
#[test]
fn linear_op_default_helpers_construct_expected_outputs() {
let op = FakeLinearOp {
ctx: NalgebraContext,
};
let x = crate::NalgebraVec::from_vec(vec![1.0, 2.0], NalgebraContext);
let v = crate::NalgebraVec::from_vec(vec![3.0, -1.0], NalgebraContext);
let mut y = crate::NalgebraVec::from_vec(vec![1.0, 1.0], NalgebraContext);
op.call_inplace(&x, 0.0, &mut y);
y.assert_eq_st(
&crate::NalgebraVec::from_vec(vec![8.0, 7.0], NalgebraContext),
1e-12,
);
let matrix = op.matrix(0.0);
assert_eq!(matrix.get_index(0, 0), 2.0);
assert_eq!(matrix.get_index(1, 0), -1.0);
assert_eq!(matrix.get_index(0, 1), 3.0);
assert_eq!(matrix.get_index(1, 1), 4.0);
let mut transpose = M::zeros(2, 2, NalgebraContext);
op.transpose_inplace(0.0, &mut transpose);
assert_eq!(transpose.get_index(0, 0), 2.0);
assert_eq!(transpose.get_index(1, 0), 3.0);
assert_eq!(transpose.get_index(0, 1), -1.0);
assert_eq!(transpose.get_index(1, 1), 4.0);
let mut transpose_call = crate::NalgebraVec::zeros(2, NalgebraContext);
op.call_transpose_inplace(&x, 0.0, &mut transpose_call);
transpose_call.assert_eq_st(
&crate::NalgebraVec::from_vec(vec![0.0, 11.0], NalgebraContext),
1e-12,
);
let sens = op.sens(&x, 0.0);
assert_eq!(sens.get_index(0, 0), 1.0);
assert_eq!(sens.get_index(1, 0), 3.0);
assert_eq!(sens.get_index(0, 1), 2.0);
assert_eq!(sens.get_index(1, 1), 4.0);
op.sens_mul(&x, 0.0, &v).assert_eq_st(
&crate::NalgebraVec::from_vec(vec![1.0, 5.0], NalgebraContext),
1e-12,
);
}
}