use std::cell::OnceCell;
use arael_sym::{E, FuncKind, FunctionBag, constant, parse_with_functions, symbol};
pub struct UserFnEntry {
pub sym_name: &'static str,
pub param_names: &'static [&'static str],
pub kind: UserFnKind,
}
pub enum UserFnKind {
Symbolic {
body_src: &'static str,
deriv_srcs: Option<&'static [&'static str]>,
},
Extern {
deriv_srcs: &'static [&'static str],
eval_fn: fn(&[f64]) -> f64,
call_path: &'static str,
},
}
inventory::collect!(UserFnEntry);
thread_local! {
static REGISTRY_BAG: OnceCell<FunctionBag> = const { OnceCell::new() };
}
pub fn with_registry_bag<R>(f: impl FnOnce(&FunctionBag) -> R) -> R {
REGISTRY_BAG.with(|slot| f(slot.get_or_init(build_registry_bag)))
}
pub fn registry_bag() -> FunctionBag {
with_registry_bag(|b| b.clone())
}
fn build_registry_bag() -> FunctionBag {
fn zero_derivs(arity: usize) -> Vec<E> {
(0..arity).map(|_| constant(0.0)).collect()
}
let entries: Vec<&'static UserFnEntry> =
inventory::iter::<UserFnEntry>.into_iter().collect();
let mut bag = FunctionBag::new();
for e in &entries {
let params: Vec<String> = e.param_names.iter().map(|s| s.to_string()).collect();
match &e.kind {
UserFnKind::Symbolic { .. } => {
bag.add_symbolic(e.sym_name.to_string(), params, symbol(e.sym_name));
}
UserFnKind::Extern { eval_fn, call_path, deriv_srcs } => {
bag.add_with_kind(
e.sym_name.to_string(),
params,
FuncKind::Extern {
derivs: zero_derivs(deriv_srcs.len()),
eval_fn: *eval_fn,
call_path: call_path.to_string(),
},
);
}
}
}
for e in &entries {
match &e.kind {
UserFnKind::Symbolic { body_src, deriv_srcs } => {
let body = parse_with_functions(body_src, &bag).unwrap_or_else(|err| {
panic!(
"#[arael::function `{}`] body parse failed at runtime: {}",
e.sym_name, err
)
});
let params: Vec<String> =
e.param_names.iter().map(|s| s.to_string()).collect();
let kind = match deriv_srcs {
None => FuncKind::Symbolic { body },
Some(ds) => {
let derivs: Vec<E> = ds
.iter()
.map(|s| {
parse_with_functions(s, &bag).unwrap_or_else(|err| {
panic!(
"#[arael::function `{}`] deriv parse failed at runtime: {}",
e.sym_name, err
)
})
})
.collect();
FuncKind::SymbolicDerivs { body, derivs }
}
};
bag.add_with_kind(e.sym_name.to_string(), params, kind);
}
UserFnKind::Extern { deriv_srcs, eval_fn, call_path } => {
let user_to_placeholder: Vec<(E, E)> = e
.param_names
.iter()
.enumerate()
.map(|(i, p)| (symbol(p), symbol(&format!("__p{i}"))))
.collect();
let derivs: Vec<E> = deriv_srcs
.iter()
.map(|s| {
parse_with_functions(s, &bag)
.unwrap_or_else(|err| {
panic!(
"#[arael::function `{}`] deriv parse failed at runtime: {}",
e.sym_name, err
)
})
.substitute(&user_to_placeholder)
})
.collect();
let placeholder_params: Vec<String> =
(0..e.param_names.len()).map(|i| format!("__p{i}")).collect();
bag.add_with_kind(
e.sym_name.to_string(),
placeholder_params,
FuncKind::Extern {
derivs,
eval_fn: *eval_fn,
call_path: call_path.to_string(),
},
);
}
}
}
bag
}