use crate::{BasisError, PenaltyMatrix};
use ndarray::{Array1, Array2, ArrayView1, ArrayViewMut1};
use std::any::Any;
use std::ops::Range;
use std::sync::Arc;
#[derive(Clone)]
pub struct CustomFamilyBlockPsiDerivative {
pub penalty_index: Option<usize>,
pub x_psi: Array2<f64>,
pub s_psi: Array2<f64>,
pub s_psi_components: Option<Vec<(usize, Array2<f64>)>>,
pub s_psi_penalty_components: Option<Vec<(usize, PenaltyMatrix)>>,
pub x_psi_psi: Option<Vec<Array2<f64>>>,
pub s_psi_psi: Option<Vec<Array2<f64>>>,
pub s_psi_psi_components: Option<Vec<Vec<(usize, Array2<f64>)>>>,
pub s_psi_psi_penalty_components: Option<Vec<Vec<(usize, PenaltyMatrix)>>>,
pub implicit_operator: Option<Arc<dyn CustomFamilyPsiDerivativeOperator>>,
pub implicit_axis: usize,
pub implicit_group_id: Option<usize>,
}
pub type SharedDerivativeBlocks = Arc<Vec<Vec<CustomFamilyBlockPsiDerivative>>>;
impl CustomFamilyBlockPsiDerivative {
pub fn new(
penalty_index: Option<usize>,
x_psi: Array2<f64>,
s_psi: Array2<f64>,
s_psi_components: Option<Vec<(usize, Array2<f64>)>>,
x_psi_psi: Option<Vec<Array2<f64>>>,
s_psi_psi: Option<Vec<Array2<f64>>>,
s_psi_psi_components: Option<Vec<Vec<(usize, Array2<f64>)>>>,
) -> Self {
Self {
penalty_index,
x_psi,
s_psi,
s_psi_components,
s_psi_penalty_components: None,
x_psi_psi,
s_psi_psi,
s_psi_psi_components,
s_psi_psi_penalty_components: None,
implicit_operator: None,
implicit_axis: 0,
implicit_group_id: None,
}
}
}
pub trait CustomFamilyPsiDerivativeOperator: Send + Sync + Any {
fn as_any(&self) -> &dyn Any;
fn n_data(&self) -> usize;
fn p_out(&self) -> usize;
fn transpose_mul(
&self,
axis: usize,
v: &ArrayView1<'_, f64>,
) -> Result<Array1<f64>, BasisError>;
fn forward_mul(&self, axis: usize, u: &ArrayView1<'_, f64>) -> Result<Array1<f64>, BasisError>;
fn transpose_mul_second_diag(
&self,
axis: usize,
v: &ArrayView1<'_, f64>,
) -> Result<Array1<f64>, BasisError>;
fn transpose_mul_second_cross(
&self,
axis_d: usize,
axis_e: usize,
v: &ArrayView1<'_, f64>,
) -> Result<Array1<f64>, BasisError>;
fn forward_mul_second_diag(
&self,
axis: usize,
u: &ArrayView1<'_, f64>,
) -> Result<Array1<f64>, BasisError>;
fn forward_mul_second_cross(
&self,
axis_d: usize,
axis_e: usize,
u: &ArrayView1<'_, f64>,
) -> Result<Array1<f64>, BasisError>;
fn row_chunk_first(&self, axis: usize, rows: Range<usize>) -> Result<Array2<f64>, BasisError>;
fn row_vector_first_into(
&self,
axis: usize,
row: usize,
mut out: ArrayViewMut1<'_, f64>,
) -> Result<(), BasisError> {
let chunk = self.row_chunk_first(axis, row..row + 1)?;
out.assign(&chunk.row(0));
Ok(())
}
fn row_chunk_second_diag(
&self,
axis: usize,
rows: Range<usize>,
) -> Result<Array2<f64>, BasisError>;
fn row_chunk_second_cross(
&self,
axis_d: usize,
axis_e: usize,
rows: Range<usize>,
) -> Result<Array2<f64>, BasisError>;
fn as_materializable(&self) -> Option<&dyn MaterializablePsiDerivativeOperator> {
None
}
}
pub trait MaterializablePsiDerivativeOperator: CustomFamilyPsiDerivativeOperator {
fn materialize_first(&self, axis: usize) -> Result<Array2<f64>, BasisError>;
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum JointHessianSourcePreference {
Dense,
Operator,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum MaterializationIntent {
InnerSolve,
LogdetFactorization,
OuterEvaluation,
OuterGradient,
}