use std::{fmt::Debug, sync::Arc};
pub mod analytical;
pub mod metadata;
pub mod ode;
pub mod sde;
pub use analytical::*;
pub use metadata::*;
pub use ode::*;
pub use pharmsol_dsl::{AnalyticalKernel, ModelKind};
use pharmsol_dsl::{NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX};
pub use sde::*;
use crate::{
error_model::{AssayErrorModels, BoundAssayErrorModels},
simulator::{cache::BoundErrorModelCache, Fa, Lag},
Covariates, Event, Infusion, InputLabel, Observation, Occasion, OutputLabel, Parameters,
PharmsolError, Subject,
};
use super::likelihood::Prediction;
pub trait State {
fn add_bolus(&mut self, input: usize, amount: f64);
}
pub trait Predictions: Default {
fn new(_nparticles: usize) -> Self {
Default::default()
}
fn squared_error(&self) -> f64;
fn get_predictions(&self) -> Vec<Prediction>;
fn log_likelihood(&self, error_models: &AssayErrorModels) -> Result<f64, PharmsolError>;
}
pub trait Cache: Sized {
fn with_cache_capacity(self, size: u64) -> Self;
fn enable_cache(self) -> Self;
fn clear_cache(&self);
fn disable_cache(self) -> Self;
}
pub trait EquationTypes {
type S: State + Debug;
type P: Predictions;
}
pub(crate) trait EquationPriv: EquationTypes {
fn lag(&self) -> &Lag;
fn fa(&self) -> &Fa;
fn get_nstates(&self) -> usize;
fn get_ndrugs(&self) -> usize;
fn get_nouteqs(&self) -> usize;
fn metadata(&self) -> Option<&ValidatedModelMetadata>;
fn solve(
&self,
state: &mut Self::S,
parameters: &[f64],
covariates: &Covariates,
infusions: &[Infusion],
start_time: f64,
end_time: f64,
) -> Result<(), PharmsolError>;
fn nparticles(&self) -> usize {
1
}
fn resolve_input_label(
&self,
label: &InputLabel,
expected_kind: RouteKind,
) -> Result<usize, PharmsolError> {
if let Some(metadata) = self.metadata() {
let route = metadata
.route(label.as_str())
.or_else(|| {
canonical_numeric_alias(label.as_str(), NUMERIC_ROUTE_PREFIX)
.and_then(|alias| metadata.route(alias.as_str()))
})
.ok_or_else(|| PharmsolError::UnknownInputLabel {
label: label.to_string(),
})?;
if route.kind() != expected_kind {
return Err(PharmsolError::UnsupportedInputRouteKind {
input: route.input_index(),
kind: match expected_kind {
RouteKind::Bolus => pharmsol_dsl::RouteKind::Bolus,
RouteKind::Infusion => pharmsol_dsl::RouteKind::Infusion,
},
});
}
return Ok(route.input_index());
}
label
.index()
.ok_or_else(|| PharmsolError::UnknownInputLabel {
label: label.to_string(),
})
}
fn resolve_output_label(&self, label: &OutputLabel) -> Result<usize, PharmsolError> {
if let Some(metadata) = self.metadata() {
return metadata
.output_index(label.as_str())
.or_else(|| {
canonical_numeric_alias(label.as_str(), NUMERIC_OUTPUT_PREFIX)
.and_then(|alias| metadata.output_index(alias.as_str()))
})
.ok_or_else(|| PharmsolError::UnknownOutputLabel {
label: label.to_string(),
});
}
label
.index()
.ok_or_else(|| PharmsolError::UnknownOutputLabel {
label: label.to_string(),
})
}
fn resolve_occasion_events(
&self,
occasion: &Occasion,
parameters: &[f64],
covariates: &Covariates,
) -> Result<Vec<Event>, PharmsolError> {
let mut resolved = occasion.clone();
for event in resolved.events_iter_mut() {
match event {
Event::Bolus(bolus) => {
let input = self.resolve_input_label(bolus.input(), RouteKind::Bolus)?;
bolus.set_input(input);
}
Event::Infusion(infusion) => {
let input = self.resolve_input_label(infusion.input(), RouteKind::Infusion)?;
infusion.set_input(input);
}
Event::Observation(observation) => {
let outeq = self.resolve_output_label(observation.outeq())?;
observation.set_outeq(outeq);
}
}
}
Ok(resolved.process_events(Some((self.fa(), self.lag(), parameters, covariates)), true))
}
#[allow(dead_code)]
fn is_sde(&self) -> bool {
false
}
#[allow(clippy::too_many_arguments)]
fn process_observation(
&self,
parameters: &[f64],
observation: &Observation,
error_models: Option<&AssayErrorModels>,
time: f64,
covariates: &Covariates,
x: &mut Self::S,
likelihood: &mut Vec<f64>,
output: &mut Self::P,
) -> Result<(), PharmsolError>;
fn initial_state(
&self,
parameters: &[f64],
covariates: &Covariates,
occasion_index: usize,
) -> Self::S;
#[allow(clippy::too_many_arguments)]
fn simulate_event(
&self,
parameters: &[f64],
event: &Event,
next_event: Option<&Event>,
error_models: Option<&AssayErrorModels>,
covariates: &Covariates,
x: &mut Self::S,
infusions: &mut Vec<Infusion>,
likelihood: &mut Vec<f64>,
output: &mut Self::P,
) -> Result<(), PharmsolError> {
match event {
Event::Bolus(bolus) => {
let input =
bolus
.input_index()
.ok_or_else(|| PharmsolError::UnknownInputLabel {
label: bolus.input().to_string(),
})?;
if input >= self.get_ndrugs() {
return Err(PharmsolError::InputOutOfRange {
input,
ndrugs: self.get_ndrugs(),
});
}
x.add_bolus(input, bolus.amount());
}
Event::Infusion(infusion) => {
infusions.push(infusion.clone());
}
Event::Observation(observation) => {
self.process_observation(
parameters,
observation,
error_models,
event.time(),
covariates,
x,
likelihood,
output,
)?;
}
}
if let Some(next_event) = next_event {
self.solve(
x,
parameters,
covariates,
infusions,
event.time(),
next_event.time(),
)?;
}
Ok(())
}
}
fn canonical_numeric_alias(label: &str, prefix: &str) -> Option<String> {
if label.is_empty() || !label.chars().all(|ch| ch.is_ascii_digit()) {
return None;
}
Some(format!("{prefix}{label}"))
}
#[allow(private_bounds)]
pub trait Equation: EquationPriv + 'static + Clone + Sync {
#[doc(hidden)]
fn bound_error_model_cache(&self) -> Option<&BoundErrorModelCache> {
None
}
#[doc(hidden)]
fn bind_error_models<'a>(
&'a self,
error_models: &'a AssayErrorModels,
) -> Result<BoundAssayErrorModels<'a>, PharmsolError> {
if let Some(cache) = self.bound_error_model_cache() {
let key = error_models.hash();
if let Some(bound_error_models) = cache.get(&key) {
return Ok(BoundAssayErrorModels::Shared(bound_error_models));
}
return match error_models.bind_to(self)? {
BoundAssayErrorModels::Owned(bound_error_models) => {
let bound_error_models = Arc::new(bound_error_models);
cache.insert(key, Arc::clone(&bound_error_models));
Ok(BoundAssayErrorModels::Shared(bound_error_models))
}
bound_error_models => Ok(bound_error_models),
};
}
Ok(error_models.bind_to(self)?)
}
#[deprecated(
since = "0.23.0",
note = "Use estimate_log_likelihood() instead for better numerical stability"
)]
fn estimate_likelihood(
&self,
subject: &Subject,
parameters: &Parameters,
error_models: &AssayErrorModels,
) -> Result<f64, PharmsolError>;
fn estimate_log_likelihood(
&self,
subject: &Subject,
parameters: &Parameters,
error_models: &AssayErrorModels,
) -> Result<f64, PharmsolError>;
fn kind() -> EqnKind;
#[doc(hidden)]
fn estimate_predictions_dense(
&self,
subject: &Subject,
parameters: &[f64],
) -> Result<Self::P, PharmsolError> {
Ok(self.simulate_subject_dense(subject, parameters, None)?.0)
}
#[doc(hidden)]
fn estimate_log_likelihood_dense(
&self,
subject: &Subject,
parameters: &[f64],
error_models: &AssayErrorModels,
) -> Result<f64, PharmsolError> {
let bound_error_models = self.bind_error_models(error_models)?;
let predictions = self.estimate_predictions_dense(subject, parameters)?;
predictions.log_likelihood(&bound_error_models)
}
#[doc(hidden)]
fn simulate_subject_dense(
&self,
subject: &Subject,
parameters: &[f64],
error_models: Option<&AssayErrorModels>,
) -> Result<(Self::P, Option<f64>), PharmsolError> {
let bound_error_models = match error_models {
Some(error_models) => Some(self.bind_error_models(error_models)?),
None => None,
};
let bound_error_models = bound_error_models.as_ref().map(|models| &**models);
let mut output = Self::P::new(self.nparticles());
let mut likelihood = Vec::new();
for occasion in subject.occasions() {
let covariates = occasion.covariates();
let mut x = self.initial_state(parameters, covariates, occasion.index());
let mut infusions = Vec::new();
let events = self.resolve_occasion_events(occasion, parameters, covariates)?;
for (index, event) in events.iter().enumerate() {
self.simulate_event(
parameters,
event,
events.get(index + 1),
bound_error_models,
covariates,
&mut x,
&mut infusions,
&mut likelihood,
&mut output,
)?;
}
}
let ll = bound_error_models.map(|_| likelihood.iter().product::<f64>());
Ok((output, ll))
}
fn estimate_predictions(
&self,
subject: &Subject,
parameters: &Parameters,
) -> Result<Self::P, PharmsolError> {
self.estimate_predictions_dense(subject, parameters.as_slice())
}
fn nouteqs(&self) -> usize {
self.get_nouteqs()
}
fn nstates(&self) -> usize {
self.get_nstates()
}
#[doc(hidden)]
fn assay_error_models(&self) -> AssayErrorModels {
self.metadata()
.map(|metadata| {
AssayErrorModels::with_output_names(
metadata.outputs().iter().map(|output| output.name()),
)
})
.unwrap_or_else(AssayErrorModels::empty)
}
fn simulate_subject(
&self,
subject: &Subject,
parameters: &Parameters,
error_models: Option<&AssayErrorModels>,
) -> Result<(Self::P, Option<f64>), PharmsolError> {
self.simulate_subject_dense(subject, parameters.as_slice(), error_models)
}
}
#[repr(C)]
#[derive(Clone, Debug)]
pub enum EqnKind {
ODE = 0,
Analytical = 1,
SDE = 2,
}
impl EqnKind {
pub fn to_str(&self) -> &'static str {
match self {
Self::ODE => "EqnKind::ODE",
Self::Analytical => "EqnKind::Analytical",
Self::SDE => "EqnKind::SDE",
}
}
}
#[inline(always)]
pub(crate) fn parameters_hash(parameters: &[f64]) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = ahash::AHasher::default();
for &value in parameters {
let bits = if value == 0.0 { 0u64 } else { value.to_bits() };
bits.hash(&mut hasher);
}
hasher.finish()
}