use crate::prelude::*;
use crate::problem::SeparableProblem;
use crate::util::Weights;
use nalgebra::{ComplexField, DMatrix, Dyn, OMatrix, OVector, RealField, Scalar};
use num_traits::{float::TotalOrder, Float, Zero};
use std::ops::Mul;
use thiserror::Error as ThisError;
use super::{MultiRhs, RhsType, SingleRhs};
#[derive(Debug, Clone, ThisError, PartialEq, Eq)]
#[allow(missing_docs)]
pub enum SeparableProblemBuilderError {
#[error("Right hand side(s) not provided")]
YDataMissing,
#[error(
"Vectors x and y must have same lengths. Given x length = {} and y length = {}",
x_length,
y_length
)]
InvalidLengthOfData { x_length: usize, y_length: usize },
#[error("x or y must have nonzero number of elements.")]
ZeroLengthVector,
#[error(
"Initial guess vector must have same length as parameters. Model has {} parameters and {} initial guesses were provided.",
model_count,
provided_count
)]
InvalidParameterCount {
model_count: usize,
provided_count: usize,
},
#[error("The weights must have the same length as the data y.")]
InvalidLengthOfWeights,
}
#[derive(Clone)]
#[allow(non_snake_case)]
pub struct SeparableProblemBuilder<Model, Rhs: RhsType>
where
Model::ScalarType: Scalar + ComplexField + Copy,
<Model::ScalarType as ComplexField>::RealField: Float,
Model: SeparableNonlinearModel,
{
Y: Option<DMatrix<Model::ScalarType>>,
separable_model: Model,
weights: Weights<Model::ScalarType, Dyn>,
phantom: std::marker::PhantomData<Rhs>,
}
impl<Model> SeparableProblemBuilder<Model, SingleRhs>
where
Model::ScalarType: Scalar + ComplexField + Zero + Copy,
<Model::ScalarType as ComplexField>::RealField: Float,
<<Model as SeparableNonlinearModel>::ScalarType as ComplexField>::RealField:
Mul<Model::ScalarType, Output = Model::ScalarType> + Float,
Model: SeparableNonlinearModel,
{
pub fn new(model: Model) -> Self {
Self {
Y: None,
separable_model: model,
weights: Weights::default(),
phantom: Default::default(),
}
}
}
impl<Model> SeparableProblemBuilder<Model, SingleRhs>
where
Model::ScalarType: Scalar + ComplexField + Zero + Copy,
<Model::ScalarType as ComplexField>::RealField: Float,
<<Model as SeparableNonlinearModel>::ScalarType as ComplexField>::RealField:
Mul<Model::ScalarType, Output = Model::ScalarType> + Float,
Model: SeparableNonlinearModel,
{
pub fn observations(self, observed: OVector<Model::ScalarType, Dyn>) -> Self {
let nrows = observed.nrows();
Self {
Y: Some(observed.reshape_generic(Dyn(nrows), Dyn(1))),
..self
}
}
}
impl<Model> SeparableProblemBuilder<Model, MultiRhs>
where
Model::ScalarType: Scalar + ComplexField + Zero + Copy,
<Model::ScalarType as ComplexField>::RealField: Float,
<<Model as SeparableNonlinearModel>::ScalarType as ComplexField>::RealField:
Mul<Model::ScalarType, Output = Model::ScalarType> + Float,
Model: SeparableNonlinearModel,
{
pub fn mrhs(model: Model) -> Self {
Self {
Y: None,
separable_model: model,
weights: Weights::default(),
phantom: Default::default(),
}
}
}
impl<Model> SeparableProblemBuilder<Model, MultiRhs>
where
Model::ScalarType: Scalar + ComplexField + Zero + Copy,
<Model::ScalarType as ComplexField>::RealField: Float,
<<Model as SeparableNonlinearModel>::ScalarType as ComplexField>::RealField:
Mul<Model::ScalarType, Output = Model::ScalarType> + Float,
Model: SeparableNonlinearModel,
{
pub fn observations(self, observed: OMatrix<Model::ScalarType, Dyn, Dyn>) -> Self {
Self {
Y: Some(observed),
..self
}
}
}
impl<Model, Rhs: RhsType> SeparableProblemBuilder<Model, Rhs>
where
Model::ScalarType: Scalar + ComplexField + Zero + Copy,
<Model::ScalarType as ComplexField>::RealField: Float,
<<Model as SeparableNonlinearModel>::ScalarType as ComplexField>::RealField:
Mul<Model::ScalarType, Output = Model::ScalarType> + Float,
Model: SeparableNonlinearModel,
{
pub fn weights(self, weights: OVector<Model::ScalarType, Dyn>) -> Self {
Self {
weights: Weights::diagonal(weights),
..self
}
}
#[allow(non_snake_case)]
pub fn build(self) -> Result<SeparableProblem<Model, Rhs>, SeparableProblemBuilderError>
where
Model::ScalarType: Float + RealField + TotalOrder,
{
let Y = self.Y.ok_or(SeparableProblemBuilderError::YDataMissing)?;
let model = self.separable_model;
let weights = self.weights;
let x_len: usize = model.output_len();
if x_len == 0 || Y.is_empty() {
return Err(SeparableProblemBuilderError::ZeroLengthVector);
}
if x_len != Y.nrows() {
return Err(SeparableProblemBuilderError::InvalidLengthOfData {
x_length: x_len,
y_length: Y.nrows(),
});
}
if !weights.is_size_correct_for_data_length(Y.nrows()) {
return Err(SeparableProblemBuilderError::InvalidLengthOfWeights);
}
let Y_w = &weights * Y;
Ok(SeparableProblem {
Y_w,
model,
weights,
phantom: Default::default(),
})
}
}
#[cfg(any(test, doctest))]
mod test;