use std::{
error::Error,
fmt::Display,
ops::{Add, Deref, DerefMut, Div, Mul, SubAssign},
};
use nalgebra::DMatrix;
use serde::{Deserialize, Serialize};
use slopes::Slopes;
use crate::{
builders::SourceBuilder,
wavefrontsensor::{
segment_wise::data_processing::{
slopes, slopes_array::SlopesArrayError, TruncatedPseudoInverse,
},
SlopesArray,
},
};
#[derive(Debug)]
#[non_exhaustive]
pub enum CalibrationError {
SlopesArray(SlopesArrayError),
Collect,
}
impl Display for CalibrationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CalibrationError::SlopesArray(_) => f.write_str("error in SlopesArray"),
CalibrationError::Collect => {
f.write_str("failed to flatten Calibration because of DataRef mismatch")
}
}
}
}
impl Error for CalibrationError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
CalibrationError::SlopesArray(e) => Some(e),
_ => None,
}
}
}
impl From<SlopesArrayError> for CalibrationError {
fn from(value: SlopesArrayError) -> Self {
Self::SlopesArray(value)
}
}
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
pub struct Calibration {
pub(crate) data: Vec<SlopesArray>,
pub(crate) src: SourceBuilder,
}
impl Deref for Calibration {
type Target = Vec<SlopesArray>;
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl DerefMut for Calibration {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
impl Display for Calibration {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
println!("Calibration:");
for s in self.iter() {
s.fmt(f)?;
}
Ok(())
}
}
impl From<(DMatrix<f32>, Calibration)> for Calibration {
fn from((value, mut cal): (DMatrix<f32>, Calibration)) -> Self {
assert_eq!(cal.data.len(), 1);
let sa = cal.data.pop().unwrap();
Self {
data: vec![(value, sa).into()],
..Default::default()
}
}
}
impl Calibration {
pub fn shape(&self) -> (usize, usize) {
(self.nrows(), self.ncols())
}
pub fn nrows(&self) -> usize {
self.iter().map(|x| x.nrows()).sum()
}
pub fn ncols(&self) -> usize {
self.iter().map(|x| x.ncols()).sum()
}
pub fn size(&self) -> usize {
self.data.len()
}
pub fn masks<'a>(&'a self) -> impl Iterator<Item = Option<&'a nalgebra::DMatrix<bool>>> {
self.data.iter().map(|s| s.data_ref.mask())
}
pub fn reference_slopes(&self) -> Vec<Option<&Vec<f32>>> {
self.data.iter().map(|sa| sa.reference_slopes()).collect()
}
pub fn condition_number(&self, lasts: Option<Vec<Option<usize>>>) -> Vec<f32> {
match lasts {
Some(lasts) => self
.iter()
.zip(lasts.into_iter())
.map(|(x, last)| x.condition_number(last))
.collect(),
None => self.iter().map(|x| x.condition_number(None)).collect(),
}
}
pub fn pseudo_inverse(
&mut self,
truncation: Option<Vec<Option<TruncatedPseudoInverse>>>,
) -> Result<&mut Self, CalibrationError> {
let n = self.size();
self.iter_mut()
.zip(truncation.unwrap_or(vec![None; n]).into_iter())
.map(|(x, t)| x.pseudo_inverse(t))
.collect::<Result<Vec<_>, SlopesArrayError>>()?;
Ok(self)
}
pub fn concat_pinv(&self) -> Vec<f64> {
self.iter()
.filter_map(|x| x.inverse.as_ref().map(|x| x.as_slice().to_vec()))
.flatten()
.map(|x| x as f64)
.collect()
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn interaction_matrices(&self) -> Vec<DMatrix<f32>> {
self.iter().map(|s| s.interaction_matrix()).collect()
}
pub fn trim(&mut self, indices: Vec<(usize, Option<Vec<usize>>)>) -> &mut Self {
let mut count = 0;
for (idx, maybe_idx) in indices.into_iter() {
let i = idx - count;
if let Some(idxs) = maybe_idx {
self.data.get_mut(i).map(|sa| sa.trim(idxs));
} else {
self.data.remove(i);
count += 1;
}
}
self
}
pub fn flatten(self) -> Result<Self, CalibrationError> {
let mut sa_iter = self.data.into_iter();
let SlopesArray {
mut slopes,
data_ref,
..
} = sa_iter.next().unwrap();
for mut sa in sa_iter {
if sa.data_ref == data_ref {
slopes.append(&mut sa.slopes);
} else {
return Err(CalibrationError::Collect);
}
}
Ok(Self {
data: vec![SlopesArray {
slopes,
data_ref,
inverse: None,
}],
..Default::default()
})
}
pub fn insert_rows(&mut self, indices: Vec<(usize, Vec<usize>)>) {
for (sa_idx, rows) in indices.into_iter() {
self.data.get_mut(sa_idx).map(|sa| sa.insert_rows(rows));
}
}
}
impl Add for Calibration {
type Output = Calibration;
fn add(self, rhs: Self) -> Self::Output {
Calibration {
data: self.data.into_iter().chain(rhs.data.into_iter()).collect(),
..Default::default()
}
}
}
impl Mul for Calibration {
type Output = Calibration;
fn mul(self, rhs: Self) -> Self::Output {
Self {
data: self
.interaction_matrices()
.into_iter()
.zip(rhs.interaction_matrices())
.map(|(a, b)| a * b)
.map(|c| SlopesArray::from(c))
.collect(),
..Default::default()
}
}
}
impl Div for Calibration {
type Output = Result<Calibration, Box<dyn Error>>;
fn div(self, rhs: Self) -> Self::Output {
let mut slopes_array: Vec<SlopesArray> = vec![];
for (a, b) in self
.interaction_matrices()
.into_iter()
.zip(rhs.interaction_matrices())
{
let c = a * b.pseudo_inverse(0.)?;
slopes_array.push(SlopesArray::from(c));
}
Ok(Self {
data: slopes_array,
..Default::default()
})
}
}
impl SubAssign for Calibration {
fn sub_assign(&mut self, rhs: Self) {
self.data
.iter_mut()
.zip(rhs.interaction_matrices())
.for_each(|(sa, b)| {
let a = sa.interaction_matrix();
let c = a - b;
let slopes: Vec<_> = c
.column_iter()
.map(|row| Slopes::from(row.as_slice().to_vec()))
.collect();
sa.slopes = slopes;
});
}
}