use crate::data::*;
use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Clone, Default)]
pub struct DataRow {
pub id: String,
pub time: f64,
pub evid: i32,
pub dose: Option<f64>,
pub dur: Option<f64>,
pub addl: Option<i64>,
pub ii: Option<f64>,
pub input: Option<usize>,
pub out: Option<f64>,
pub outeq: Option<usize>,
pub cens: Option<Censor>,
pub c0: Option<f64>,
pub c1: Option<f64>,
pub c2: Option<f64>,
pub c3: Option<f64>,
pub covariates: HashMap<String, f64>,
}
impl DataRow {
pub fn builder(id: impl Into<String>, time: f64) -> DataRowBuilder {
DataRowBuilder::new(id, time)
}
fn get_errorpoly(&self) -> Option<ErrorPoly> {
match (self.c0, self.c1, self.c2, self.c3) {
(Some(c0), Some(c1), Some(c2), Some(c3)) => Some(ErrorPoly::new(c0, c1, c2, c3)),
_ => None,
}
}
pub fn into_events(self) -> Result<Vec<Event>, DataError> {
let mut events: Vec<Event> = Vec::new();
match self.evid {
0 => {
events.push(Event::Observation(Observation::new(
self.time,
self.out,
self.outeq
.ok_or_else(|| DataError::MissingObservationOuteq {
id: self.id.clone(),
time: self.time,
})?, self.get_errorpoly(),
0, self.cens.unwrap_or(Censor::None),
)));
}
1 | 4 => {
let input = self.input.ok_or_else(|| DataError::MissingBolusInput {
id: self.id.clone(),
time: self.time,
})?;
let event = if self.dur.unwrap_or(0.0) > 0.0 {
Event::Infusion(Infusion::new(
self.time,
self.dose.ok_or_else(|| DataError::MissingInfusionDose {
id: self.id.clone(),
time: self.time,
})?,
input,
self.dur.ok_or_else(|| DataError::MissingInfusionDur {
id: self.id.clone(),
time: self.time,
})?,
0,
))
} else {
Event::Bolus(Bolus::new(
self.time,
self.dose.ok_or_else(|| DataError::MissingBolusDose {
id: self.id.clone(),
time: self.time,
})?,
input,
0,
))
};
if let (Some(addl), Some(ii)) = (self.addl, self.ii) {
if addl != 0 && ii > 0.0 {
let mut ev = event.clone();
let interval = ii.abs();
let repetitions = addl.abs();
let direction = addl.signum() as f64;
for _ in 0..repetitions {
ev.inc_time(direction * interval);
events.push(ev.clone());
}
}
}
events.push(event);
}
_ => {
return Err(DataError::UnknownEvid {
evid: self.evid as isize,
id: self.id.clone(),
time: self.time,
});
}
}
Ok(events)
}
pub fn covariates(&self) -> &HashMap<String, f64> {
&self.covariates
}
pub fn is_occasion_reset(&self) -> bool {
self.evid == 4
}
pub fn id(&self) -> &str {
&self.id
}
pub fn time(&self) -> f64 {
self.time
}
}
#[derive(Debug, Clone)]
pub struct DataRowBuilder {
row: DataRow,
}
impl DataRowBuilder {
pub fn new(id: impl Into<String>, time: f64) -> Self {
Self {
row: DataRow {
id: id.into(),
time,
evid: 0, ..Default::default()
},
}
}
pub fn evid(mut self, evid: i32) -> Self {
self.row.evid = evid;
self
}
pub fn dose(mut self, dose: f64) -> Self {
self.row.dose = Some(dose);
self
}
pub fn dur(mut self, dur: f64) -> Self {
self.row.dur = Some(dur);
self
}
pub fn addl(mut self, addl: i64) -> Self {
self.row.addl = Some(addl);
self
}
pub fn ii(mut self, ii: f64) -> Self {
self.row.ii = Some(ii);
self
}
pub fn input(mut self, input: usize) -> Self {
self.row.input = Some(input);
self
}
pub fn out(mut self, out: f64) -> Self {
self.row.out = Some(out);
self
}
pub fn outeq(mut self, outeq: usize) -> Self {
self.row.outeq = Some(outeq);
self
}
pub fn cens(mut self, cens: Censor) -> Self {
self.row.cens = Some(cens);
self
}
pub fn error_poly(mut self, c0: f64, c1: f64, c2: f64, c3: f64) -> Self {
self.row.c0 = Some(c0);
self.row.c1 = Some(c1);
self.row.c2 = Some(c2);
self.row.c3 = Some(c3);
self
}
pub fn covariate(mut self, name: impl Into<String>, value: f64) -> Self {
self.row.covariates.insert(name.into(), value);
self
}
pub fn build(self) -> DataRow {
self.row
}
}
pub fn build_data(rows: impl IntoIterator<Item = DataRow>) -> Result<Data, DataError> {
let mut rows_map: std::collections::HashMap<String, Vec<DataRow>> =
std::collections::HashMap::new();
for row in rows {
rows_map.entry(row.id.clone()).or_default().push(row);
}
let mut subjects: Vec<Subject> = Vec::new();
for (id, rows) in rows_map {
let split_indices: Vec<usize> = rows
.iter()
.enumerate()
.filter_map(|(i, row)| if row.evid == 4 { Some(i) } else { None })
.collect();
let mut block_rows_vec: Vec<&[DataRow]> = Vec::new();
let mut start = 0;
for &split_index in &split_indices {
if start < split_index {
block_rows_vec.push(&rows[start..split_index]);
}
start = split_index;
}
if start < rows.len() {
block_rows_vec.push(&rows[start..]);
}
let mut occasions: Vec<Occasion> = Vec::new();
for (block_index, block) in block_rows_vec.iter().enumerate() {
let mut events: Vec<Event> = Vec::new();
let mut observed_covariates: std::collections::HashMap<
String,
Vec<(f64, Option<f64>)>,
> = std::collections::HashMap::new();
for row in *block {
let row_events = row.clone().into_events()?;
events.extend(row_events);
for (name, value) in &row.covariates {
observed_covariates
.entry(name.clone())
.or_default()
.push((row.time, Some(*value)));
}
}
events.iter_mut().for_each(|e| e.set_occasion(block_index));
let covariates = Covariates::from_pmetrics_observations(&observed_covariates);
let mut occasion = Occasion::new(block_index);
occasion.events = events;
occasion.covariates = covariates;
occasion.sort();
occasions.push(occasion);
}
subjects.push(Subject::new(id, occasions));
}
subjects.sort_by(|a, b| a.id().cmp(b.id()));
Ok(Data::new(subjects))
}
#[allow(private_interfaces)]
#[derive(Error, Debug, Clone)]
pub enum DataError {
#[error("CSV error: {0}")]
CSVError(String),
#[error("Parse error: {0}")]
SerdeError(String),
#[error("Unknown EVID: {evid} for ID {id} at time {time}")]
UnknownEvid { evid: isize, id: String, time: f64 },
#[error("Observation OUT is missing for {id} at time {time}")]
MissingObservationOut { id: String, time: f64 },
#[error("Observation OUTEQ is missing in for {id} at time {time}")]
MissingObservationOuteq { id: String, time: f64 },
#[error("Infusion amount (DOSE) is missing for {id} at time {time}")]
MissingInfusionDose { id: String, time: f64 },
#[error("Infusion compartment (INPUT) is missing for {id} at time {time}")]
MissingInfusionInput { id: String, time: f64 },
#[error("Infusion duration (DUR) is missing for {id} at time {time}")]
MissingInfusionDur { id: String, time: f64 },
#[error("Bolus amount (DOSE) is missing for {id} at time {time}")]
MissingBolusDose { id: String, time: f64 },
#[error("Bolus compartment (INPUT) is missing for {id} at time {time}")]
MissingBolusInput { id: String, time: f64 },
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_observation_row() {
let row = DataRow::builder("pt1", 1.0)
.evid(0)
.out(25.5)
.outeq(1)
.build();
let events = row.into_events().unwrap();
assert_eq!(events.len(), 1);
match &events[0] {
Event::Observation(obs) => {
assert_eq!(obs.time(), 1.0);
assert_eq!(obs.value(), Some(25.5));
assert_eq!(obs.outeq(), 1); }
_ => panic!("Expected observation event"),
}
}
#[test]
fn test_bolus_row() {
let row = DataRow::builder("pt1", 0.0)
.evid(1)
.dose(100.0)
.input(1)
.build();
let events = row.into_events().unwrap();
assert_eq!(events.len(), 1);
match &events[0] {
Event::Bolus(bolus) => {
assert_eq!(bolus.time(), 0.0);
assert_eq!(bolus.amount(), 100.0);
assert_eq!(bolus.input(), 1); }
_ => panic!("Expected bolus event"),
}
}
#[test]
fn test_infusion_row() {
let row = DataRow::builder("pt1", 0.0)
.evid(1)
.dose(100.0)
.dur(2.0)
.input(1)
.build();
let events = row.into_events().unwrap();
assert_eq!(events.len(), 1);
match &events[0] {
Event::Infusion(inf) => {
assert_eq!(inf.time(), 0.0);
assert_eq!(inf.amount(), 100.0);
assert_eq!(inf.duration(), 2.0);
assert_eq!(inf.input(), 1); }
_ => panic!("Expected infusion event"),
}
}
#[test]
fn test_positive_addl() {
let row = DataRow::builder("pt1", 0.0)
.evid(1)
.dose(100.0)
.input(1)
.addl(3)
.ii(12.0)
.build();
let events = row.into_events().unwrap();
assert_eq!(events.len(), 4);
let times: Vec<f64> = events.iter().map(|e| e.time()).collect();
assert_eq!(times, vec![12.0, 24.0, 36.0, 0.0]);
}
#[test]
fn test_negative_addl() {
let row = DataRow::builder("pt1", 0.0)
.evid(1)
.dose(100.0)
.input(1)
.addl(-3)
.ii(12.0)
.build();
let events = row.into_events().unwrap();
assert_eq!(events.len(), 4);
let times: Vec<f64> = events.iter().map(|e| e.time()).collect();
assert_eq!(times, vec![-12.0, -24.0, -36.0, 0.0]);
}
#[test]
fn test_large_negative_addl() {
let row = DataRow::builder("pt1", 0.0)
.evid(1)
.dose(100.0)
.input(1)
.addl(-10)
.ii(12.0)
.build();
let events = row.into_events().unwrap();
assert_eq!(events.len(), 11);
let times: Vec<f64> = events.iter().map(|e| e.time()).collect();
assert_eq!(
times,
vec![-12.0, -24.0, -36.0, -48.0, -60.0, -72.0, -84.0, -96.0, -108.0, -120.0, 0.0]
);
}
#[test]
fn test_infusion_with_addl() {
let row = DataRow::builder("pt1", 0.0)
.evid(1)
.dose(100.0)
.dur(1.0)
.input(1)
.addl(2)
.ii(24.0)
.build();
let events = row.into_events().unwrap();
assert_eq!(events.len(), 3);
for event in &events {
match event {
Event::Infusion(inf) => {
assert_eq!(inf.amount(), 100.0);
assert_eq!(inf.duration(), 1.0);
}
_ => panic!("Expected infusion event"),
}
}
}
#[test]
fn test_covariates() {
let row = DataRow::builder("pt1", 0.0)
.evid(0)
.out(25.0)
.outeq(1)
.covariate("weight", 70.0)
.covariate("age", 45.0)
.build();
assert_eq!(row.covariates().len(), 2);
assert_eq!(row.covariates().get("weight"), Some(&70.0));
assert_eq!(row.covariates().get("age"), Some(&45.0));
}
#[test]
fn test_error_poly() {
let row = DataRow::builder("pt1", 1.0)
.evid(0)
.out(25.0)
.outeq(1)
.error_poly(0.1, 0.2, 0.0, 0.0)
.build();
let events = row.into_events().unwrap();
match &events[0] {
Event::Observation(obs) => {
let ep = obs.errorpoly().unwrap();
assert_eq!(ep.coefficients(), (0.1, 0.2, 0.0, 0.0));
}
_ => panic!("Expected observation"),
}
}
#[test]
fn test_censoring() {
let row = DataRow::builder("pt1", 1.0)
.evid(0)
.out(0.5)
.outeq(1)
.cens(Censor::BLOQ)
.build();
let events = row.into_events().unwrap();
match &events[0] {
Event::Observation(obs) => {
assert!(obs.censored());
assert_eq!(obs.censoring(), Censor::BLOQ);
}
_ => panic!("Expected observation"),
}
}
#[test]
fn test_missing_outeq_error() {
let row = DataRow::builder("pt1", 1.0)
.evid(0)
.out(25.0)
.build();
let result = row.into_events();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
DataError::MissingObservationOuteq { .. }
));
}
#[test]
fn test_missing_dose_error() {
let row = DataRow::builder("pt1", 0.0)
.evid(1)
.input(1)
.build();
let result = row.into_events();
assert!(result.is_err());
}
#[test]
fn test_missing_input_error() {
let row = DataRow::builder("pt1", 0.0)
.evid(1)
.dose(100.0)
.build();
let result = row.into_events();
assert!(result.is_err());
}
#[test]
fn test_unknown_evid_error() {
let row = DataRow::builder("pt1", 0.0).evid(99).build();
let result = row.into_events();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
DataError::UnknownEvid { evid: 99, .. }
));
}
#[test]
fn test_addl_zero_has_no_effect() {
let row = DataRow::builder("pt1", 0.0)
.evid(1)
.dose(100.0)
.input(1)
.addl(0)
.ii(12.0)
.build();
let events = row.into_events().unwrap();
assert_eq!(events.len(), 1); }
#[test]
fn test_addl_without_ii_has_no_effect() {
let row = DataRow::builder("pt1", 0.0)
.evid(1)
.dose(100.0)
.input(1)
.addl(5)
.build();
let events = row.into_events().unwrap();
assert_eq!(events.len(), 1); }
#[test]
fn test_evid_4_reset() {
let row = DataRow::builder("pt1", 24.0)
.evid(4)
.dose(100.0)
.input(1)
.build();
assert!(row.is_occasion_reset());
let events = row.into_events().unwrap();
assert_eq!(events.len(), 1);
}
}