use crate::{Value, VectorF64, View};
use ffi::FFI;
use sys;
use sys::libc::{c_int, c_void};
ffi_wrapper!(
MultiRootFSolverType,
*const sys::gsl_multiroot_fsolver_type,
"The multiroot algorithms described in this section do not require any derivative information to be
supplied by the user. Any derivatives needed are approximated by finite differences.
Note that if the finite-differencing step size chosen by these routines is inappropriate,
an explicit user-supplied numerical derivative can always be used with
derivative-based algorithms."
);
impl MultiRootFSolverType {
#[doc(alias = "gsl_multiroot_fsolver_hybrids")]
pub fn hybrids() -> MultiRootFSolverType {
ffi_wrap!(gsl_multiroot_fsolver_hybrids)
}
#[doc(alias = "gsl_multiroot_fsolver_hybrid")]
pub fn hybrid() -> MultiRootFSolverType {
ffi_wrap!(gsl_multiroot_fsolver_hybrid)
}
#[doc(alias = "gsl_multiroot_fsolver_dnewton")]
pub fn dnewton() -> MultiRootFSolverType {
ffi_wrap!(gsl_multiroot_fsolver_dnewton)
}
#[doc(alias = "gsl_multiroot_fsolver_broyden")]
pub fn broyden() -> MultiRootFSolverType {
ffi_wrap!(gsl_multiroot_fsolver_broyden)
}
}
ffi_wrapper!(
MultiRootFSolver<'a>,
*mut sys::gsl_multiroot_fsolver,
gsl_multiroot_fsolver_free
;inner_call: sys::gsl_multiroot_function_struct => sys::gsl_multiroot_function_struct{ f: None, n: 0, params: std::ptr::null_mut() };
;inner_closure: Option<Box<dyn Fn(&VectorF64, &mut VectorF64) -> Value + 'a>> => None;,
"This is a workspace for multidimensional root-finding without derivatives."
);
impl<'a> MultiRootFSolver<'a> {
#[doc(alias = "gsl_multiroot_fsolver_alloc")]
pub fn new(t: &MultiRootFSolverType, n: usize) -> Option<MultiRootFSolver<'a>> {
let ptr = unsafe { sys::gsl_multiroot_fsolver_alloc(t.unwrap_shared(), n) };
if ptr.is_null() {
None
} else {
Some(MultiRootFSolver::wrap(ptr))
}
}
#[doc(alias = "gsl_multiroot_fsolver_set")]
pub fn set<F: Fn(&VectorF64, &mut VectorF64) -> Value + 'a>(
&mut self,
f: F,
n: usize,
x: &VectorF64,
) -> Result<(), Value> {
unsafe extern "C" fn inner_f<A: Fn(&VectorF64, &mut VectorF64) -> Value>(
x: *const sys::gsl_vector,
params: *mut c_void,
f: *mut sys::gsl_vector,
) -> c_int {
let g: &A = &*(params as *const A);
let x_new = VectorF64::soft_wrap(x as *const _ as *mut _);
Value::into(g(&x_new, &mut VectorF64::soft_wrap(f)))
}
self.inner_call = sys::gsl_multiroot_function_struct {
f: Some(inner_f::<F>),
n,
params: &f as *const _ as *mut _,
};
self.inner_closure = Some(Box::new(f));
let ret = unsafe {
sys::gsl_multiroot_fsolver_set(
self.unwrap_unique(),
&mut self.inner_call,
x.unwrap_shared(),
)
};
result_handler!(ret, ())
}
#[doc(alias = "gsl_multiroot_fsolver_iterate")]
pub fn iterate(&mut self) -> Result<(), Value> {
let ret = unsafe { sys::gsl_multiroot_fsolver_iterate(self.unwrap_unique()) };
result_handler!(ret, ())
}
#[doc(alias = "gsl_multiroot_fsolver_root")]
pub fn root(&self) -> View<'_, VectorF64> {
unsafe { View::new(sys::gsl_multiroot_fsolver_root(self.unwrap_shared())) }
}
#[doc(alias = "gsl_multiroot_fsolver_dx")]
pub fn dx(&self) -> View<'_, VectorF64> {
unsafe { View::new(sys::gsl_multiroot_fsolver_dx(self.unwrap_shared())) }
}
#[doc(alias = "gsl_multiroot_fsolver_f")]
pub fn f(&self) -> View<'_, VectorF64> {
unsafe { View::new(sys::gsl_multiroot_fsolver_f(self.unwrap_shared())) }
}
}
#[cfg(any(test, doctest))]
mod tests {
use super::*;
use multiroot::test_residual;
use VectorF64;
fn rosenbrock_f(x: &VectorF64, f: &mut VectorF64) -> Value {
f.set(0, 1.0 - x.get(0));
f.set(1, x.get(0) - x.get(1).powf(2.0));
Value::Success
}
fn print_state(solver: &mut MultiRootFSolver, iteration: usize) {
let f = solver.f();
let x = solver.root();
println!(
"iter: {}, f = [{:+.2e}, {:+.2e}], x = [{:+.5}, {:+.5}]",
iteration,
f.get(0),
f.get(1),
x.get(0),
x.get(1)
)
}
#[test]
fn test_multiroot_fsolver() {
let mut multi_root = MultiRootFSolver::new(&MultiRootFSolverType::hybrid(), 2).unwrap();
let array_size: usize = 2;
let guess_value = VectorF64::from_slice(&[-10.0, -5.0]).unwrap();
multi_root
.set(rosenbrock_f, array_size, &guess_value)
.unwrap();
let max_iter: usize = 100;
let mut iter = 0;
let mut status = crate::Value::Continue;
let epsabs = 1e-6;
print_state(&mut multi_root, 0);
while matches!(status, crate::Value::Continue) && iter < max_iter {
multi_root.iterate().unwrap();
print_state(&mut multi_root, iter);
let f_value = multi_root.f();
status = test_residual(&f_value, epsabs);
if matches!(status, crate::Value::Success) {
println!("Converged");
}
iter += 1;
}
assert!(matches!(status, crate::Value::Success))
}
}