use std::collections::HashMap;
use crate::params::expr::{Expr, ExprError, parse};
#[derive(Debug, Clone)]
pub struct Param {
pub name: String,
pub value: f64,
pub vary: bool,
pub min: f64,
pub max: f64,
pub expr: Option<String>,
}
const TINY: f64 = 1.0e-15;
impl Param {
fn var(name: &str, value: f64) -> Param {
Param {
name: name.to_string(),
value,
vary: true,
min: f64::NEG_INFINITY,
max: f64::INFINITY,
expr: None,
}
}
fn to_internal(&self) -> f64 {
let (min, max) = (self.min, self.max);
let v = self.value.clamp(min, max);
let internal = if min == f64::NEG_INFINITY && max == f64::INFINITY {
v
} else if max == f64::INFINITY {
((v - min + 1.0).powi(2) - 1.0).sqrt()
} else if min == f64::NEG_INFINITY {
((max - v + 1.0).powi(2) - 1.0).sqrt()
} else {
(2.0 * (v - min) / (max - min) - 1.0).asin()
};
if internal.abs() < TINY { 0.0 } else { internal }
}
fn to_external(&self, val: f64) -> f64 {
let (min, max) = (self.min, self.max);
if min == f64::NEG_INFINITY && max == f64::INFINITY {
val
} else if max == f64::INFINITY {
min - 1.0 + (val * val + 1.0).sqrt()
} else if min == f64::NEG_INFINITY {
max + 1.0 - (val * val + 1.0).sqrt()
} else {
min + (val.sin() + 1.0) * (max - min) / 2.0
}
}
fn scale_gradient(&self, val: f64) -> f64 {
let (min, max) = (self.min, self.max);
if min == f64::NEG_INFINITY && max == f64::INFINITY {
1.0
} else if max == f64::INFINITY {
val / (val * val + 1.0).sqrt()
} else if min == f64::NEG_INFINITY {
-val / (val * val + 1.0).sqrt()
} else {
val.cos() * (max - min) / 2.0
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ParamError {
Expr(String, ExprError),
Cycle(Vec<String>),
}
impl std::fmt::Display for ParamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ParamError::Expr(n, e) => write!(f, "parameter '{n}': {e}"),
ParamError::Cycle(ns) => write!(f, "constraint cycle among: {}", ns.join(", ")),
}
}
}
impl std::error::Error for ParamError {}
#[derive(Debug, Clone, Default)]
pub struct Parameters {
order: Vec<String>,
map: HashMap<String, Param>,
consts: HashMap<String, f64>,
}
impl Parameters {
pub fn new() -> Self {
Parameters::default()
}
pub fn add_var(&mut self, name: &str, value: f64) {
self.insert(Param::var(name, value));
}
pub fn add_var_bounded(&mut self, name: &str, value: f64, min: f64, max: f64) {
let mut p = Param::var(name, value);
p.min = min;
p.max = max;
self.insert(p);
}
pub fn add_fixed(&mut self, name: &str, value: f64) {
let mut p = Param::var(name, value);
p.vary = false;
self.insert(p);
}
pub fn add_expr(&mut self, name: &str, expr: &str) {
let mut p = Param::var(name, f64::NAN);
p.vary = false;
p.expr = Some(expr.to_string());
self.insert(p);
}
pub fn set_const(&mut self, name: &str, value: f64) {
self.consts.insert(name.to_string(), value);
}
fn insert(&mut self, p: Param) {
if !self.map.contains_key(&p.name) {
self.order.push(p.name.clone());
}
self.map.insert(p.name.clone(), p);
}
pub fn n_vary(&self) -> usize {
self.order
.iter()
.filter(|n| {
let p = &self.map[*n];
p.vary && p.expr.is_none()
})
.count()
}
pub fn var_names(&self) -> Vec<String> {
self.order
.iter()
.filter(|n| {
let p = &self.map[*n];
p.vary && p.expr.is_none()
})
.cloned()
.collect()
}
pub fn expr_names(&self) -> Vec<String> {
self.order
.iter()
.filter(|n| self.map[*n].expr.is_some())
.cloned()
.collect()
}
pub fn value(&self, name: &str) -> Option<f64> {
self.map.get(name).map(|p| p.value)
}
pub fn get(&self, name: &str) -> Option<&Param> {
self.map.get(name)
}
pub fn symbols(&self) -> HashMap<String, f64> {
let mut m = self.consts.clone();
for n in &self.order {
m.insert(n.clone(), self.map[n].value);
}
m
}
pub fn set_var_values(&mut self, vals: &[f64]) {
let names = self.var_names();
for (n, &v) in names.iter().zip(vals) {
if let Some(p) = self.map.get_mut(n) {
p.value = v.clamp(p.min, p.max);
}
}
}
pub fn internal_x0(&self) -> Vec<f64> {
self.var_names()
.iter()
.map(|n| self.map[n].to_internal())
.collect()
}
pub fn set_var_internal(&mut self, internal: &[f64]) {
let names = self.var_names();
for (n, &v) in names.iter().zip(internal) {
if let Some(p) = self.map.get_mut(n) {
p.value = p.to_external(v);
}
}
}
pub fn var_scale_gradients(&self, internal: &[f64]) -> Vec<f64> {
self.var_names()
.iter()
.zip(internal)
.map(|(n, &v)| self.map[n].scale_gradient(v))
.collect()
}
pub fn update_constraints(&mut self) -> Result<(), ParamError> {
for n in self.order.clone() {
let p = self.map.get_mut(&n).unwrap();
if p.expr.is_none() && p.vary {
p.value = p.value.clamp(p.min, p.max);
}
}
let mut asts: HashMap<String, Expr> = HashMap::new();
for n in &self.order {
if let Some(src) = self.map[n].expr.clone() {
let ast = parse(&src).map_err(|e| ParamError::Expr(n.clone(), e))?;
asts.insert(n.clone(), ast);
}
}
let order = self.topo_order(&asts)?;
let mut sym: HashMap<String, f64> = self.consts.clone();
for n in &self.order {
let p = &self.map[n];
if p.expr.is_none() {
sym.insert(n.clone(), p.value);
}
}
for n in &order {
let val = asts[n]
.eval(&sym)
.map_err(|e| ParamError::Expr(n.clone(), e))?;
sym.insert(n.clone(), val);
self.map.get_mut(n).unwrap().value = val;
}
Ok(())
}
pub fn value_grads(&self) -> Result<HashMap<String, (f64, Vec<f64>)>, ParamError> {
let var_names = self.var_names();
let nvar = var_names.len();
let index: HashMap<&str, usize> = var_names
.iter()
.enumerate()
.map(|(i, n)| (n.as_str(), i))
.collect();
let sym = self.symbols();
let mut grads: HashMap<String, Vec<f64>> = HashMap::new();
for n in &self.order {
let p = &self.map[n];
if p.expr.is_none() {
let mut g = vec![0.0; nvar];
if let Some(&i) = index.get(n.as_str()) {
g[i] = 1.0;
}
grads.insert(n.clone(), g);
}
}
let mut asts: HashMap<String, Expr> = HashMap::new();
for n in &self.order {
if let Some(src) = self.map[n].expr.clone() {
let ast = parse(&src).map_err(|e| ParamError::Expr(n.clone(), e))?;
asts.insert(n.clone(), ast);
}
}
for n in &self.topo_order(&asts)? {
let (_v, g) = asts[n]
.eval_dual(&sym, &grads, nvar)
.map_err(|e| ParamError::Expr(n.clone(), e))?;
grads.insert(n.clone(), g);
}
let mut out = HashMap::with_capacity(self.order.len());
for n in &self.order {
let g = grads.remove(n).unwrap_or_else(|| vec![0.0; nvar]);
out.insert(n.clone(), (self.map[n].value, g));
}
Ok(out)
}
fn topo_order(&self, asts: &HashMap<String, Expr>) -> Result<Vec<String>, ParamError> {
let mut deps: HashMap<String, Vec<String>> = HashMap::new();
for (n, ast) in asts {
let mut vars = Vec::new();
ast.vars(&mut vars);
let edges: Vec<String> = vars
.into_iter()
.filter(|v| asts.contains_key(v)) .collect();
deps.insert(n.clone(), edges);
}
let expr_order: Vec<String> = self
.order
.iter()
.filter(|n| asts.contains_key(*n))
.cloned()
.collect();
let mut resolved: Vec<String> = Vec::new();
let mut done: std::collections::HashSet<String> = std::collections::HashSet::new();
while resolved.len() < expr_order.len() {
let mut progressed = false;
for n in &expr_order {
if done.contains(n) {
continue;
}
if deps[n].iter().all(|d| done.contains(d)) {
resolved.push(n.clone());
done.insert(n.clone());
progressed = true;
}
}
if !progressed {
let unresolved: Vec<String> = expr_order
.iter()
.filter(|n| !done.contains(*n))
.cloned()
.collect();
return Err(ParamError::Cycle(unresolved));
}
}
Ok(resolved)
}
}