use std::borrow::Cow;
#[derive(Debug, Clone)]
pub struct FitParameter {
pub name: Cow<'static, str>,
pub value: f64,
pub lower: f64,
pub upper: f64,
pub fixed: bool,
}
impl FitParameter {
pub fn non_negative(name: impl Into<Cow<'static, str>>, value: f64) -> Self {
Self {
name: name.into(),
value,
lower: 0.0,
upper: f64::INFINITY,
fixed: false,
}
}
pub fn unbounded(name: impl Into<Cow<'static, str>>, value: f64) -> Self {
Self {
name: name.into(),
value,
lower: f64::NEG_INFINITY,
upper: f64::INFINITY,
fixed: false,
}
}
pub fn fixed(name: impl Into<Cow<'static, str>>, value: f64) -> Self {
Self {
name: name.into(),
value,
lower: f64::NEG_INFINITY,
upper: f64::INFINITY,
fixed: true,
}
}
pub fn clamp(&mut self) {
self.value = self.value.clamp(self.lower, self.upper);
}
}
#[derive(Debug, Clone)]
pub struct ParameterSet {
pub params: Vec<FitParameter>,
}
impl ParameterSet {
pub fn new(params: Vec<FitParameter>) -> Self {
Self { params }
}
pub fn n_free(&self) -> usize {
self.params.iter().filter(|p| !p.fixed).count()
}
pub fn free_values(&self) -> Vec<f64> {
self.params
.iter()
.filter(|p| !p.fixed)
.map(|p| p.value)
.collect()
}
pub fn set_free_values(&mut self, values: &[f64]) {
let n_free = self.n_free();
assert_eq!(
values.len(),
n_free,
"set_free_values: values length ({}) must match number of free parameters ({})",
values.len(),
n_free,
);
let mut j = 0;
for p in &mut self.params {
if !p.fixed {
p.value = values[j];
p.clamp();
j += 1;
}
}
}
pub fn all_values(&self) -> Vec<f64> {
self.params.iter().map(|p| p.value).collect()
}
pub fn all_values_into(&self, buf: &mut Vec<f64>) {
buf.clear();
buf.extend(self.params.iter().map(|p| p.value));
}
pub fn free_values_into(&self, buf: &mut Vec<f64>) {
buf.clear();
buf.extend(self.params.iter().filter(|p| !p.fixed).map(|p| p.value));
}
pub fn free_indices(&self) -> Vec<usize> {
self.params
.iter()
.enumerate()
.filter(|(_, p)| !p.fixed)
.map(|(i, _)| i)
.collect()
}
pub fn free_indices_into(&self, buf: &mut Vec<usize>) {
buf.clear();
buf.extend(
self.params
.iter()
.enumerate()
.filter(|(_, p)| !p.fixed)
.map(|(i, _)| i),
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parameter_set_free_values() {
let params = ParameterSet::new(vec![
FitParameter::non_negative("a", 1.0),
FitParameter::fixed("b", 2.0),
FitParameter::non_negative("c", 3.0),
]);
assert_eq!(params.n_free(), 2);
assert_eq!(params.free_values(), vec![1.0, 3.0]);
assert_eq!(params.all_values(), vec![1.0, 2.0, 3.0]);
}
#[test]
#[should_panic(expected = "set_free_values: values length")]
fn test_set_free_values_length_mismatch() {
let mut params = ParameterSet::new(vec![
FitParameter::non_negative("a", 1.0),
FitParameter::non_negative("b", 2.0),
]);
params.set_free_values(&[1.0]);
}
#[test]
fn test_set_free_values_with_clamping() {
let mut params = ParameterSet::new(vec![
FitParameter::non_negative("a", 1.0),
FitParameter::fixed("b", 2.0),
FitParameter::non_negative("c", 3.0),
]);
params.set_free_values(&[-0.5, 5.0]);
assert_eq!(params.params[0].value, 0.0); assert_eq!(params.params[1].value, 2.0); assert_eq!(params.params[2].value, 5.0);
}
}