use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use pharmsol_dsl::execution::{
ExecutionExpr, ExecutionExprKind, ExecutionLoad, ExecutionModel, ExecutionStmt,
ExecutionStmtKind, KernelImplementation, KernelRole,
};
use pharmsol_dsl::{AnalyticalKernel, CovariateInterpolation, ModelKind, RouteKind};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct NativeModelInfo {
pub name: String,
pub kind: ModelKind,
pub parameters: Vec<String>,
#[serde(default)]
pub derived: Vec<String>,
pub covariates: Vec<NativeCovariateInfo>,
pub states: Vec<NativeStateInfo>,
pub routes: Vec<NativeRouteInfo>,
pub outputs: Vec<NativeOutputInfo>,
pub state_len: usize,
pub derived_len: usize,
pub output_len: usize,
pub route_len: usize,
pub analytical: Option<AnalyticalKernel>,
pub particles: Option<usize>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct NativeCovariateInfo {
pub name: String,
pub index: usize,
pub interpolation: Option<CovariateInterpolation>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct NativeStateInfo {
pub name: String,
pub offset: usize,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct NativeRouteInfo {
pub name: String,
pub declaration_index: usize,
pub index: usize,
pub kind: Option<RouteKind>,
pub destination_offset: usize,
pub destination_name: String,
pub has_lag: bool,
pub has_bioavailability: bool,
pub inject_input_to_destination: bool,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct NativeOutputInfo {
pub name: String,
pub index: usize,
}
impl NativeModelInfo {
pub fn from_execution_model(model: &ExecutionModel) -> Self {
let explicit_route_input_usage = explicit_route_input_usage(model);
Self {
name: model.name.clone(),
kind: model.kind,
parameters: model
.metadata
.parameters
.iter()
.map(|parameter| parameter.name.clone())
.collect(),
derived: model
.metadata
.derived
.iter()
.map(|derived| derived.name.clone())
.collect(),
covariates: model
.metadata
.covariates
.iter()
.map(|covariate| NativeCovariateInfo {
name: covariate.name.clone(),
index: covariate.index,
interpolation: covariate.interpolation,
})
.collect(),
states: model
.metadata
.states
.iter()
.map(|state| NativeStateInfo {
name: state.name.clone(),
offset: state.offset,
})
.collect(),
routes: model
.metadata
.routes
.iter()
.map(|route| NativeRouteInfo {
name: route.name.clone(),
declaration_index: route.declaration_index,
index: route.index,
kind: route.kind,
destination_offset: route.destination.state_offset,
destination_name: route.destination.state_name.clone(),
has_lag: route.has_lag,
has_bioavailability: route.has_bioavailability,
inject_input_to_destination: !explicit_route_input_usage
.get(route.declaration_index)
.copied()
.unwrap_or(false),
})
.collect(),
outputs: model
.metadata
.outputs
.iter()
.map(|output| NativeOutputInfo {
name: output.name.clone(),
index: output.index,
})
.collect(),
state_len: model.abi.state_buffer.len,
derived_len: model.abi.derived_buffer.len,
output_len: model.abi.output_buffer.len,
route_len: model.abi.route_buffer.len,
analytical: model.metadata.analytical,
particles: model.metadata.particles,
}
}
}
fn explicit_route_input_usage(model: &ExecutionModel) -> Vec<bool> {
let declaration_slots = model
.metadata
.routes
.iter()
.map(|route| (route.symbol, route.declaration_index))
.collect::<BTreeMap<_, _>>();
let Some(kernel) = (match model.kind {
ModelKind::Ode => model.kernel(KernelRole::Dynamics),
ModelKind::Sde => model.kernel(KernelRole::Drift),
ModelKind::Analytical => None,
}) else {
return vec![false; model.metadata.routes.len()];
};
let mut usage = vec![false; model.metadata.routes.len()];
if let KernelImplementation::Statements(program) = &kernel.implementation {
mark_route_inputs_in_statements(&program.body.statements, &declaration_slots, &mut usage);
}
usage
}
fn mark_route_inputs_in_statements(
statements: &[ExecutionStmt],
declaration_slots: &BTreeMap<usize, usize>,
usage: &mut [bool],
) {
for statement in statements {
match &statement.kind {
ExecutionStmtKind::Let(let_stmt) => {
mark_route_inputs_in_expr(&let_stmt.value, declaration_slots, usage);
}
ExecutionStmtKind::Assign(assign_stmt) => {
mark_route_inputs_in_expr(&assign_stmt.value, declaration_slots, usage);
}
ExecutionStmtKind::If(if_stmt) => {
mark_route_inputs_in_expr(&if_stmt.condition, declaration_slots, usage);
mark_route_inputs_in_statements(&if_stmt.then_branch, declaration_slots, usage);
if let Some(else_branch) = &if_stmt.else_branch {
mark_route_inputs_in_statements(else_branch, declaration_slots, usage);
}
}
ExecutionStmtKind::For(for_stmt) => {
mark_route_inputs_in_expr(&for_stmt.range.start, declaration_slots, usage);
mark_route_inputs_in_expr(&for_stmt.range.end, declaration_slots, usage);
mark_route_inputs_in_statements(&for_stmt.body, declaration_slots, usage);
}
}
}
}
fn mark_route_inputs_in_expr(
expr: &ExecutionExpr,
declaration_slots: &BTreeMap<usize, usize>,
usage: &mut [bool],
) {
match &expr.kind {
ExecutionExprKind::Literal(_) => {}
ExecutionExprKind::Load(ExecutionLoad::RouteInput { route, .. }) => {
if let Some(slot) = declaration_slots
.get(route)
.and_then(|index| usage.get_mut(*index))
{
*slot = true;
}
}
ExecutionExprKind::Load(_) => {}
ExecutionExprKind::Unary { expr, .. } => {
mark_route_inputs_in_expr(expr, declaration_slots, usage)
}
ExecutionExprKind::Binary { lhs, rhs, .. } => {
mark_route_inputs_in_expr(lhs, declaration_slots, usage);
mark_route_inputs_in_expr(rhs, declaration_slots, usage);
}
ExecutionExprKind::Call { args, .. } => {
for arg in args {
mark_route_inputs_in_expr(arg, declaration_slots, usage);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use pharmsol_dsl::{analyze_model, lower_typed_model, parse_model};
fn load_model_info(src: &str) -> NativeModelInfo {
let model = parse_model(src).expect("model parses");
let typed = analyze_model(&model).expect("model analyzes");
let lowered = lower_typed_model(&typed).expect("model lowers");
NativeModelInfo::from_execution_model(&lowered)
}
#[test]
fn declaration_first_routes_inject_by_default() {
let info = load_model_info(
r#"
model implicit_route_injection {
kind ode
states { central }
routes { iv -> central }
dynamics {
ddt(central) = 0
}
outputs {
cp = central
}
}
"#,
);
assert_eq!(info.routes.len(), 1);
assert!(info.routes[0].inject_input_to_destination);
}
#[test]
fn explicit_rate_usage_disables_automatic_injection() {
let info = load_model_info(
r#"
model explicit_route_usage {
kind ode
states { central }
routes { iv -> central }
dynamics {
ddt(central) = rate(iv)
}
outputs {
cp = central
}
}
"#,
);
assert_eq!(info.routes.len(), 1);
assert!(!info.routes[0].inject_input_to_destination);
}
#[test]
fn authoring_shared_input_routes_keep_declaration_specific_injection() {
let info = load_model_info(
r#"
name = shared_authoring
kind = ode
params = ka, ke, v
states = depot, central
outputs = cp
bolus(oral) -> depot
infusion(iv) -> central
dx(depot) = -ka * depot
dx(central) = ka * depot - ke * central
out(cp) = central / v ~ continuous()
"#,
);
assert_eq!(info.route_len, 1);
assert_eq!(info.routes.len(), 2);
assert_eq!(info.routes[0].kind, Some(RouteKind::Bolus));
assert_eq!(info.routes[1].kind, Some(RouteKind::Infusion));
assert_eq!(info.routes[0].index, 0);
assert_eq!(info.routes[1].index, 0);
assert!(info.routes[0].inject_input_to_destination);
assert!(!info.routes[1].inject_input_to_destination);
}
#[test]
fn native_model_info_preserves_state_covariate_and_route_metadata() {
let info = load_model_info(
r#"
name = metadata_surface
kind = ode
params = ke, v
covariates = wt@linear
states = depot, central
outputs = cp
bolus(oral) -> depot
infusion(iv) -> central
lag(oral) = 1.0
fa(oral) = 0.8
dx(depot) = -ke * depot
dx(central) = ke * depot - rate(iv)
out(cp) = central / v
"#,
);
assert_eq!(info.states.len(), 2);
assert_eq!(info.states[0].name, "depot");
assert_eq!(info.states[1].name, "central");
assert_eq!(
info.covariates[0].interpolation,
Some(CovariateInterpolation::Linear)
);
assert_eq!(info.routes[0].destination_name, "depot");
assert!(info.routes[0].has_lag);
assert!(info.routes[0].has_bioavailability);
assert_eq!(info.routes[1].destination_name, "central");
assert!(!info.routes[1].has_lag);
assert!(!info.routes[1].has_bioavailability);
}
#[test]
fn native_model_info_preserves_canonical_numeric_channel_names() {
let info = load_model_info(
r#"
name = canonical_numeric_channels
kind = ode
params = ke, v
states = depot, central
outputs = cp, outeq_2
bolus(input_10) -> depot
infusion(iv) -> central
dx(depot) = -ke * depot
dx(central) = rate(input_10) - ke * central
out(cp) = central / v
out(outeq_2) = depot / v
"#,
);
assert_eq!(
info.routes
.iter()
.map(|route| route.name.as_str())
.collect::<Vec<_>>(),
vec!["input_10", "iv"]
);
assert_eq!(
info.outputs
.iter()
.map(|output| output.name.as_str())
.collect::<Vec<_>>(),
vec!["cp", "outeq_2"]
);
}
#[test]
fn native_model_info_preserves_declared_derived_order() {
let info = load_model_info(
r#"
model analytical_projection {
kind analytical
parameters { ka, ke0, v }
states { depot, central }
routes { oral -> depot }
derive {
ke = ke0
scale = v / 10
}
analytical {
structure = one_compartment_with_absorption
}
outputs {
cp = central / scale
}
}
"#,
);
assert_eq!(info.derived, vec!["ke".to_string(), "scale".to_string()]);
assert_eq!(info.derived_len, 2);
}
}