use crate::{DriftDerivResult, HyperOperator};
use ndarray::{Array1, Array2};
use std::sync::Arc;
#[derive(Clone)]
pub struct ExactNewtonJointPsiTerms {
pub objective_psi: f64,
pub score_psi: Array1<f64>,
pub hessian_psi: Array2<f64>,
pub hessian_psi_operator: Option<Arc<dyn HyperOperator>>,
}
impl std::fmt::Debug for ExactNewtonJointPsiTerms {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExactNewtonJointPsiTerms")
.field("objective_psi", &self.objective_psi)
.field("score_psi", &self.score_psi)
.field("hessian_psi", &self.hessian_psi)
.field(
"hessian_psi_operator",
&self.hessian_psi_operator.as_ref().map(|_| "<operator>"),
)
.finish()
}
}
impl ExactNewtonJointPsiTerms {
pub fn zeros(total: usize) -> Self {
Self {
objective_psi: 0.0,
score_psi: Array1::zeros(total),
hessian_psi: Array2::zeros((total, total)),
hessian_psi_operator: None,
}
}
}
pub struct ExactNewtonJointPsiSecondOrderTerms {
pub objective_psi_psi: f64,
pub score_psi_psi: Array1<f64>,
pub hessian_psi_psi: Array2<f64>,
pub hessian_psi_psi_operator: Option<Box<dyn HyperOperator>>,
}
pub struct ExactNewtonJointPsiSecondOrderContracted {
pub objective: Array1<f64>,
pub score: Array2<f64>,
pub hessian: Vec<DriftDerivResult>,
}
pub trait ExactNewtonJointPsiWorkspace: Send + Sync {
fn first_order_terms(&self, _: usize) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
Ok(None)
}
fn first_order_terms_all(&self) -> Result<Option<Vec<ExactNewtonJointPsiTerms>>, String> {
Ok(None)
}
fn second_order_terms(
&self,
psi_i: usize,
psi_j: usize,
) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String>;
fn second_order_terms_contracted(
&self,
_: &[f64],
) -> Result<Option<ExactNewtonJointPsiSecondOrderContracted>, String> {
Ok(None)
}
fn hessian_directional_derivative(
&self,
psi_index: usize,
d_beta_flat: &Array1<f64>,
) -> Result<Option<DriftDerivResult>, String>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zeros_has_zero_objective() {
let t = ExactNewtonJointPsiTerms::zeros(3);
assert_eq!(t.objective_psi, 0.0);
}
#[test]
fn zeros_has_correct_score_dimension() {
let t = ExactNewtonJointPsiTerms::zeros(5);
assert_eq!(t.score_psi.len(), 5);
assert!(t.score_psi.iter().all(|&v| v == 0.0));
}
#[test]
fn zeros_has_square_hessian_of_correct_size() {
let t = ExactNewtonJointPsiTerms::zeros(4);
assert_eq!(t.hessian_psi.nrows(), 4);
assert_eq!(t.hessian_psi.ncols(), 4);
assert!(t.hessian_psi.iter().all(|&v| v == 0.0));
}
#[test]
fn zeros_has_no_operator() {
let t = ExactNewtonJointPsiTerms::zeros(2);
assert!(t.hessian_psi_operator.is_none());
}
#[test]
fn zeros_with_dimension_zero_does_not_panic() {
let t = ExactNewtonJointPsiTerms::zeros(0);
assert_eq!(t.score_psi.len(), 0);
assert_eq!(t.hessian_psi.nrows(), 0);
}
}