use crate::fit::FitResult;
use crate::prelude::*;
use crate::problem::{RhsType, SeparableProblem};
use levenberg_marquardt::LeastSquaresProblem;
use levenberg_marquardt::LevenbergMarquardt;
use nalgebra::{ComplexField, Dyn, RealField, Scalar};
#[cfg(feature = "__lapack")]
use nalgebra_lapack::qr::{QrReal, QrScalar};
#[cfg(feature = "__lapack")]
use num_traits::float::TotalOrder;
#[cfg(feature = "__lapack")]
use num_traits::ConstOne;
#[cfg(feature = "__lapack")]
use num_traits::ConstZero;
use num_traits::{Float, FromPrimitive};
#[cfg(any(test, doctest))]
mod test;
mod levmar_problem;
#[cfg(feature = "__lapack")]
pub use levmar_problem::GeneralQrLinearSolver;
#[cfg(feature = "__lapack")]
#[cfg_attr(
docsrs,
doc(cfg(any(
feature = "lapack-netlib",
feature = "lapack-mkl",
feature = "lapack-mkl-static-seq",
feature = "lapack-mkl-static-par",
feature = "lapack-mkl-dynamic-seq",
feature = "lapack-mkl-dynamic-par",
feature = "lapack-openblas",
feature = "lapack-accelerate",
feature = "lapack-custom"
)))
)]
pub type CpqrLinearSolver<ScalarType> =
GeneralQrLinearSolver<ScalarType, nalgebra_lapack::ColPivQR<ScalarType, Dyn, Dyn>>;
#[cfg(feature = "__lapack")]
#[cfg_attr(
docsrs,
doc(cfg(any(
feature = "lapack-netlib",
feature = "lapack-mkl",
feature = "lapack-mkl-static-seq",
feature = "lapack-mkl-static-par",
feature = "lapack-mkl-dynamic-seq",
feature = "lapack-mkl-dynamic-par",
feature = "lapack-openblas",
feature = "lapack-accelerate",
feature = "lapack-custom"
)))
)]
pub type QrLinearSolver<ScalarType> =
GeneralQrLinearSolver<ScalarType, nalgebra_lapack::QR<ScalarType, Dyn, Dyn>>;
pub use levmar_problem::LevMarProblem;
#[cfg(feature = "__lapack")]
pub use levmar_problem::LevMarProblemCpQr;
#[cfg(feature = "__lapack")]
pub use levmar_problem::LevMarProblemQr;
pub use levmar_problem::LevMarProblemSvd;
pub use levmar_problem::LinearSolver;
pub use levmar_problem::SvdLinearSolver;
#[derive(Debug)]
pub struct LevMarSolver<Model>
where
Model: SeparableNonlinearModel,
{
solver: LevenbergMarquardt<Model::ScalarType>,
}
impl<Model> LevMarSolver<Model>
where
Model: SeparableNonlinearModel,
{
pub fn with_solver(solver: LevenbergMarquardt<Model::ScalarType>) -> Self {
Self { solver }
}
#[allow(clippy::result_large_err)]
pub fn solve_generic<Rhs: RhsType, Solver: LinearSolver<ScalarType = Model::ScalarType>>(
&self,
problem: LevMarProblem<Model, Rhs, Solver>,
) -> Result<FitResult<Model, Rhs>, FitResult<Model, Rhs>>
where
Model: SeparableNonlinearModel,
Model::ScalarType: Scalar + ComplexField + RealField + Float + FromPrimitive,
LevMarProblem<Model, Rhs, Solver>: LeastSquaresProblem<Model::ScalarType, Dyn, Dyn>,
{
let (problem, report) = self.solver.minimize(problem);
let LevMarProblem {
separable_problem,
cached,
} = problem;
let linear_coefficients = cached.map(|cached| cached.linear_coefficients_matrix());
let result = FitResult::new(separable_problem, linear_coefficients, report);
if result.was_successful() {
Ok(result)
} else {
Err(result)
}
}
#[cfg(feature = "__lapack")]
#[cfg_attr(
docsrs,
doc(cfg(any(
feature = "lapack-netlib",
feature = "lapack-mkl",
feature = "lapack-mkl-static-seq",
feature = "lapack-mkl-static-par",
feature = "lapack-mkl-dynamic-seq",
feature = "lapack-mkl-dynamic-par",
feature = "lapack-openblas",
feature = "lapack-accelerate",
feature = "lapack-custom"
)))
)]
#[allow(clippy::result_large_err)]
pub fn solve_with_qr<Rhs: RhsType>(
&self,
problem: SeparableProblem<Model, Rhs>,
) -> Result<FitResult<Model, Rhs>, FitResult<Model, Rhs>>
where
Model: SeparableNonlinearModel,
Model::ScalarType: QrReal
+ QrScalar
+ Scalar
+ ComplexField
+ RealField
+ Float
+ FromPrimitive
+ TotalOrder
+ ConstOne
+ ConstZero,
{
let levmar_problem = LevMarProblemQr::from(problem);
self.solve_generic(levmar_problem)
}
#[cfg(feature = "__lapack")]
#[cfg_attr(
docsrs,
doc(cfg(any(
feature = "lapack-netlib",
feature = "lapack-mkl",
feature = "lapack-mkl-static-seq",
feature = "lapack-mkl-static-par",
feature = "lapack-mkl-dynamic-seq",
feature = "lapack-mkl-dynamic-par",
feature = "lapack-openblas",
feature = "lapack-accelerate",
feature = "lapack-custom"
)))
)]
#[allow(clippy::result_large_err)]
pub fn solve_with_cpqr<Rhs: RhsType>(
&self,
problem: SeparableProblem<Model, Rhs>,
) -> Result<FitResult<Model, Rhs>, FitResult<Model, Rhs>>
where
Model: SeparableNonlinearModel,
Model::ScalarType: QrReal
+ QrScalar
+ Scalar
+ ComplexField
+ RealField
+ Float
+ FromPrimitive
+ TotalOrder
+ ConstOne
+ ConstZero,
{
let levmar_problem = LevMarProblemCpQr::from(problem);
self.solve_generic(levmar_problem)
}
#[allow(clippy::result_large_err)]
pub fn solve_with_svd<Rhs: RhsType>(
&self,
problem: SeparableProblem<Model, Rhs>,
) -> Result<FitResult<Model, Rhs>, FitResult<Model, Rhs>>
where
Model: SeparableNonlinearModel,
Model::ScalarType: Scalar + ComplexField + RealField + Float + FromPrimitive,
{
let levmar_problem = LevMarProblemSvd::from(problem);
self.solve_generic(levmar_problem)
}
#[allow(clippy::result_large_err)]
pub fn solve<Rhs: RhsType>(
&self,
problem: SeparableProblem<Model, Rhs>,
) -> Result<FitResult<Model, Rhs>, FitResult<Model, Rhs>>
where
Model: SeparableNonlinearModel,
Model::ScalarType: Scalar + ComplexField + RealField + Float + FromPrimitive,
{
self.solve_with_svd(problem)
}
#[allow(clippy::result_large_err)]
#[deprecated(since = "0.14.0", note = "use the solve method instead")]
pub fn fit<Rhs: RhsType>(
&self,
problem: SeparableProblem<Model, Rhs>,
) -> Result<FitResult<Model, Rhs>, FitResult<Model, Rhs>>
where
Model: SeparableNonlinearModel,
Model::ScalarType: Scalar + ComplexField + RealField + Float + FromPrimitive,
{
self.solve_with_svd(problem)
}
}
impl<Model> Default for LevMarSolver<Model>
where
Model: SeparableNonlinearModel,
Model::ScalarType: RealField + Float,
{
fn default() -> Self {
Self::with_solver(Default::default())
}
}