#![forbid(unsafe_code)]
use std::{any::Any, sync::Arc};
use sim_kernel::{
ClassId, ClassRef, Cx, DefaultFactory, Expr, Factory, Object, ObjectEncode, ObjectEncoding,
Result as KernelResult, Symbol, Value,
};
use sim_lib_femm_core::{FemmError, FemmLimits, FemmResult, ParamSet, StableId, value_as_f64};
use sim_lib_femm_field::{Field, Projection, field_as_func};
use sim_lib_femm_mesh::FemmModel;
use sim_lib_femm_post::{FemmSolution, QuantitySpec, quantity};
use sim_lib_femm_solve::{GradientTrust, SolveCertificate, SteadySolve, solve_steady};
use sim_lib_numbers_func::{Func, FuncMetadata};
#[derive(Clone, Debug)]
pub struct FemmCall {
pub params: ParamSet,
pub query: OutputQuery,
pub want_grad: Option<Vec<Symbol>>,
pub limits: FemmLimits,
}
#[derive(Clone, Debug)]
pub enum OutputQuery {
Quantity(QuantitySpec),
Field(Projection),
Solution,
}
#[derive(Clone, Debug)]
pub struct FemmEval {
pub value: Value,
pub gradient: Option<Vec<(Symbol, f64)>>,
pub diagnostics: Vec<sim_kernel::Diagnostic>,
}
#[derive(Clone, Debug)]
pub struct QualityAnswer {
pub value: f64,
pub certificate: SolveCertificate,
pub gradient: Option<(Vec<f64>, GradientTrust)>,
}
#[derive(Clone)]
pub struct FemmFuncPayload {
pub model: FemmModel,
pub vars: Vec<Symbol>,
pub query: OutputQuery,
}
impl Object for FemmFuncPayload {
fn display(&self, _cx: &mut Cx) -> KernelResult<String> {
Ok(format!(
"#<femm-payload model={} query={}>",
self.model.id.0,
describe_query(&self.query)
))
}
fn as_any(&self) -> &dyn Any {
self
}
}
impl sim_kernel::ObjectCompat for FemmFuncPayload {
fn class(&self, cx: &mut Cx) -> KernelResult<ClassRef> {
if let Some(class) = cx
.registry()
.class_by_symbol(&Symbol::qualified("femm", "FuncPayload"))
{
return Ok(class.clone());
}
DefaultFactory.class_stub(ClassId(33), Symbol::qualified("femm", "FuncPayload"))
}
fn as_expr(&self, cx: &mut Cx) -> KernelResult<Expr> {
sim_citizen::constructor_expr(cx, self)
}
fn as_object_encoder(&self) -> Option<&dyn ObjectEncode> {
Some(self)
}
}
impl ObjectEncode for FemmFuncPayload {
fn object_encoding(&self, _cx: &mut Cx) -> KernelResult<ObjectEncoding> {
Ok(ObjectEncoding::Constructor {
class: func_payload_class_symbol(),
args: payload_constructor_args(self),
})
}
}
impl sim_citizen::Citizen for FemmFuncPayload {
fn citizen_symbol() -> Symbol {
func_payload_class_symbol()
}
fn citizen_version() -> u32 {
1
}
fn citizen_arity() -> usize {
3
}
fn citizen_fields() -> &'static [&'static str] {
&["model_id", "query", "vars"]
}
}
fn func_payload_class_symbol() -> Symbol {
Symbol::qualified("femm", "FuncPayload")
}
fn payload_constructor_args(payload: &FemmFuncPayload) -> Vec<Expr> {
vec![
Expr::Symbol(Symbol::new("v1")),
int_expr(payload.model.id.0),
Expr::String(describe_query(&payload.query)),
Expr::List(
payload
.vars
.iter()
.map(|name| Expr::String(name.to_string()))
.collect(),
),
]
}
fn int_expr(value: impl ToString) -> Expr {
Expr::Number(sim_kernel::NumberLiteral {
domain: Symbol::qualified("citizen", "int"),
canonical: value.to_string(),
})
}
pub trait FemmCallable {
fn eval(&self, cx: &mut Cx, call: FemmCall) -> FemmResult<FemmEval>;
}
#[derive(Clone)]
pub struct ModelCallable {
pub model: FemmModel,
}
impl ModelCallable {
fn resolve_params(&self, params: &ParamSet) -> FemmResult<ParamSet> {
let mut entries = params.entries.clone();
for input in &self.model.inputs {
if entries.iter().all(|(name, _)| name != &input.name) {
if let Some(default) = &input.default {
entries.push((input.name.clone(), default.clone()));
} else {
return Err(FemmError::UnknownFemmParameter(input.name.to_string()));
}
}
}
Ok(ParamSet::new(entries))
}
fn solve_solution(
&self,
cx: &mut Cx,
params: &ParamSet,
limits: &FemmLimits,
) -> FemmResult<Arc<FemmSolution>> {
let resolved = self.resolve_params(params)?;
solve_steady(cx, &self.model, &resolved, limits, None).map(|out| out.solution)
}
}
impl FemmCallable for ModelCallable {
fn eval(&self, cx: &mut Cx, call: FemmCall) -> FemmResult<FemmEval> {
let resolved = self.resolve_params(&call.params)?;
match call.query {
OutputQuery::Quantity(QuantitySpec::Custom { expr, .. }) => {
let value = sim_lib_femm_geometry::eval_expr_f64(cx, &expr, &resolved, &[])?;
Ok(FemmEval {
value: cx
.factory()
.number_literal(Symbol::qualified("numbers", "f64"), value.to_string())
.map_err(|err| FemmError::SensitivityUnavailable(err.to_string()))?,
gradient: None,
diagnostics: Vec::new(),
})
}
OutputQuery::Quantity(spec) => {
let solution = self.solve_solution(cx, &resolved, &call.limits)?;
let scalar = quantity(&solution, &spec)?;
Ok(FemmEval {
value: cx
.factory()
.number_literal(Symbol::qualified("numbers", "f64"), scalar.to_string())
.map_err(|err| FemmError::SensitivityUnavailable(err.to_string()))?,
gradient: None,
diagnostics: Vec::new(),
})
}
OutputQuery::Field(projection) => {
let solution = self.solve_solution(cx, &resolved, &call.limits)?;
let field = Field::new(solution, projection);
Ok(FemmEval {
value: cx
.factory()
.opaque(Arc::new(field))
.map_err(|err| FemmError::SensitivityUnavailable(err.to_string()))?,
gradient: None,
diagnostics: Vec::new(),
})
}
OutputQuery::Solution => {
let solution = self.solve_solution(cx, &resolved, &call.limits)?;
Ok(FemmEval {
value: cx
.factory()
.opaque(solution)
.map_err(|err| FemmError::SensitivityUnavailable(err.to_string()))?,
gradient: None,
diagnostics: Vec::new(),
})
}
}
}
}
pub fn quality(
cx: &mut Cx,
solve: &SteadySolve,
quantity_spec: &QuantitySpec,
wrt: Option<&[Symbol]>,
) -> FemmResult<QualityAnswer> {
let value = quantity(&solve.solution, quantity_spec)?;
let mut certificate = solve.certificate.clone();
let gradient = match wrt {
None => None,
Some(params) => {
let (values, trust) =
finite_difference_quality_gradient(cx, solve, quantity_spec, params)?;
certificate.set_gradient_trust(trust.clone());
Some((values, trust))
}
};
Ok(QualityAnswer {
value,
certificate,
gradient,
})
}
fn finite_difference_quality_gradient(
cx: &mut Cx,
solve: &SteadySolve,
quantity_spec: &QuantitySpec,
wrt: &[Symbol],
) -> FemmResult<(Vec<f64>, GradientTrust)> {
let callable = ModelCallable {
model: solve.model.clone(),
};
let base_params = callable.resolve_params(&solve.solution.params)?;
let mut out = Vec::with_capacity(wrt.len());
for symbol in wrt {
let base_value = base_params
.get(symbol)
.ok_or_else(|| FemmError::UnknownFemmParameter(symbol.to_string()))?;
let x = value_as_f64(cx, base_value)?;
if !x.is_finite() {
return Err(FemmError::SensitivityUnavailable(format!(
"non-finite FEMM parameter {symbol}"
)));
}
let h = fd_step(x);
let plus = replace_param_value(cx, &base_params, symbol, x + h)?;
let minus = replace_param_value(cx, &base_params, symbol, x - h)?;
let plus_value = quality_at_params(cx, &solve.model, plus, quantity_spec)?;
let minus_value = quality_at_params(cx, &solve.model, minus, quantity_spec)?;
out.push((plus_value - minus_value) / (2.0 * h));
}
Ok((out, GradientTrust::FiniteDifferenceOnly))
}
fn quality_at_params(
cx: &mut Cx,
model: &FemmModel,
params: ParamSet,
quantity_spec: &QuantitySpec,
) -> FemmResult<f64> {
let solved = solve_steady(cx, model, ¶ms, &FemmLimits::default(), None)?;
quantity(&solved.solution, quantity_spec)
}
fn replace_param_value(
cx: &mut Cx,
params: &ParamSet,
name: &Symbol,
value: f64,
) -> FemmResult<ParamSet> {
let mut found = false;
let mut entries = params.entries.clone();
for (symbol, slot) in &mut entries {
if symbol == name {
*slot = cx
.factory()
.number_literal(Symbol::qualified("numbers", "f64"), value.to_string())
.map_err(|err| FemmError::SensitivityUnavailable(err.to_string()))?;
found = true;
}
}
if found {
Ok(ParamSet::new(entries))
} else {
Err(FemmError::UnknownFemmParameter(name.to_string()))
}
}
fn fd_step(value: f64) -> f64 {
1.0e-6 * value.abs().max(1.0)
}
pub fn femm_as_func(model: FemmModel, vars: Vec<Symbol>, query: OutputQuery) -> Func {
let callable = ModelCallable {
model: model.clone(),
};
let closure_vars = vars.clone();
let payload_vars = closure_vars.clone();
let closure_query = query.clone();
Func {
vars,
body_cas: None,
body_native: Some(Arc::new(move |cx, args| {
let params = ParamSet::new(
closure_vars
.iter()
.cloned()
.zip(args.iter().cloned())
.collect::<Vec<_>>(),
);
callable
.eval(
cx,
FemmCall {
params,
query: closure_query.clone(),
want_grad: None,
limits: FemmLimits::default(),
},
)
.map(|out| out.value)
.map_err(sim_kernel::Error::from)
})),
metadata: FuncMetadata {
source: Some(Symbol::qualified("femm", "model")),
differentiator_hint: Some(Symbol::new("femm-adjoint")),
payload: DefaultFactory
.opaque(Arc::new(FemmFuncPayload {
model: model.clone(),
vars: payload_vars,
query: query.clone(),
}))
.ok(),
},
}
}
pub fn femm_field_func(model: FemmModel) -> Func {
let field = Arc::new(FemmSolution {
id: StableId(model.id.0 + 1),
model_id: model.id,
physics: model.physics.clone(),
formulation: model.formulation.clone(),
params: ParamSet::default(),
mesh: sim_lib_femm_mesh::FemMesh2 {
xy: vec![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]],
tri: vec![[0, 1, 2]],
elem_region: vec![Symbol::new("air")],
edge_boundary: Vec::new(),
},
u: vec![0.0, 1.0, 1.0],
diagnostics: sim_lib_femm_flow::SolveDiagnostics {
method: Symbol::new("femm-ptc"),
converged: true,
iterations: 1,
final_residual: 0.0,
events: Vec::new(),
diagnostics: Vec::new(),
},
});
field_as_func(Field::new(field, Projection::Potential))
}
pub(crate) fn describe_query(query: &OutputQuery) -> String {
match query {
OutputQuery::Quantity(QuantitySpec::Custom { name, .. }) => format!("quantity:{name}"),
OutputQuery::Quantity(_) => "quantity".to_owned(),
OutputQuery::Field(projection) => format!("field:{projection:?}"),
OutputQuery::Solution => "solution".to_owned(),
}
}