use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::sync::Arc;
#[derive(Serialize, Deserialize)]
#[serde(tag = "type")]
enum TemporalInjectionSnapshot {
Dirac {
time: f64,
amount: f64,
},
Gaussian {
center: f64,
width: f64,
peak_concentration: f64,
},
Rectangle {
start: f64,
end: f64,
concentration: f64,
},
None,
}
pub enum TemporalInjection {
Dirac { time: f64, amount: f64 },
Gaussian {
center: f64,
width: f64,
peak_concentration: f64,
},
Rectangle {
start: f64,
end: f64,
concentration: f64,
},
Custom(Arc<dyn Fn(f64) -> f64 + Send + Sync>),
None,
}
impl Clone for TemporalInjection {
fn clone(&self) -> Self {
match self {
Self::Dirac { time, amount } => Self::Dirac {
time: *time,
amount: *amount,
},
Self::Gaussian {
center,
width,
peak_concentration,
} => Self::Gaussian {
center: *center,
width: *width,
peak_concentration: *peak_concentration,
},
Self::Rectangle {
start,
end,
concentration,
} => Self::Rectangle {
start: *start,
end: *end,
concentration: *concentration,
},
Self::Custom(f) => Self::Custom(Arc::clone(f)), Self::None => Self::None,
}
}
}
impl std::fmt::Debug for TemporalInjection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Dirac { time, amount } => f
.debug_struct("Dirac")
.field("time", time)
.field("amount", amount)
.finish(),
Self::Gaussian {
center,
width,
peak_concentration,
} => f
.debug_struct("Gaussian")
.field("center", center)
.field("width", width)
.field("peak_concentration", peak_concentration)
.finish(),
Self::Rectangle {
start,
end,
concentration,
} => f
.debug_struct("Rectangle")
.field("start", start)
.field("end", end)
.field("concentration", concentration)
.finish(),
Self::Custom(_) => f
.debug_struct("Custom")
.field("function", &"<user-defined>")
.finish(),
Self::None => f.debug_struct("None").finish(),
}
}
}
impl Serialize for TemporalInjection {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let snapshot = match self {
Self::Dirac { time, amount } => TemporalInjectionSnapshot::Dirac {
time: *time,
amount: *amount,
},
Self::Gaussian {
center,
width,
peak_concentration,
} => TemporalInjectionSnapshot::Gaussian {
center: *center,
width: *width,
peak_concentration: *peak_concentration,
},
Self::Rectangle {
start,
end,
concentration,
} => TemporalInjectionSnapshot::Rectangle {
start: *start,
end: *end,
concentration: *concentration,
},
Self::None => TemporalInjectionSnapshot::None,
Self::Custom(_) => {
return Err(serde::ser::Error::custom(
"Temporal Injection::Custom cannot be serialized",
));
}
};
snapshot.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for TemporalInjection {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let snapshot = match TemporalInjectionSnapshot::deserialize(deserializer)? {
TemporalInjectionSnapshot::Dirac { time, amount } => Self::Dirac { time, amount },
TemporalInjectionSnapshot::Gaussian {
center,
width,
peak_concentration,
} => Self::Gaussian {
center,
width,
peak_concentration,
},
TemporalInjectionSnapshot::Rectangle {
start,
end,
concentration,
} => Self::Rectangle {
start,
end,
concentration,
},
TemporalInjectionSnapshot::None => Self::None,
};
Ok(snapshot)
}
}
impl TemporalInjection {
pub fn dirac(time: f64, amount: f64) -> Self {
Self::Dirac { time, amount }
}
pub fn gaussian(center: f64, width: f64, peak_concentration: f64) -> Self {
Self::Gaussian {
center,
width,
peak_concentration,
}
}
pub fn rectangle(start: f64, end: f64, concentration: f64) -> Self {
assert!(end > start, "Rectangle end must be > start");
Self::Rectangle {
start,
end,
concentration,
}
}
pub fn custom<F>(f: F) -> Self
where
F: Fn(f64) -> f64 + Send + Sync + 'static,
{
Self::Custom(Arc::new(f))
}
pub fn none() -> Self {
Self::None
}
pub fn evaluate(&self, t: f64) -> f64 {
match self {
Self::Dirac { time, amount } => {
if t == *time { *amount } else { 0.0 }
}
Self::Gaussian {
center,
width,
peak_concentration,
} => {
let distance = (t - center) / width;
peak_concentration * (-distance * distance / 2.0).exp()
}
Self::Rectangle {
start,
end,
concentration,
} => {
if t >= *start && t < *end {
*concentration
} else {
0.0
}
}
Self::Custom(f) => f(t),
Self::None => 0.0,
}
}
pub fn evaluate_series(&self, times: &[f64]) -> Vec<f64> {
times.iter().map(|&t| self.evaluate(t)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dirac_injection() {
let injection = TemporalInjection::dirac(10.0, 1.0);
assert_eq!(injection.evaluate(10.0), 1.0);
assert_eq!(injection.evaluate(0.0), 0.0);
assert_eq!(injection.evaluate(20.0), 0.0);
assert_eq!(injection.evaluate(10.0 + 1e-9), 0.0);
}
#[test]
fn test_gaussian_injection() {
let injection = TemporalInjection::gaussian(10.0, 2.0, 0.1);
assert!((injection.evaluate(10.0) - 0.1).abs() < 1e-10);
let expected = 0.1 * 0.606;
assert!((injection.evaluate(8.0) - expected).abs() < 0.01);
assert!((injection.evaluate(12.0) - expected).abs() < 0.01);
assert!(injection.evaluate(0.0) < 0.001);
assert!(injection.evaluate(20.0) < 0.001);
}
#[test]
fn test_rectangle_injection() {
let injection = TemporalInjection::rectangle(5.0, 15.0, 0.05);
assert_eq!(injection.evaluate(4.0), 0.0);
assert_eq!(injection.evaluate(5.0), 0.05);
assert_eq!(injection.evaluate(10.0), 0.05);
assert_eq!(injection.evaluate(14.9), 0.05);
assert_eq!(injection.evaluate(15.0), 0.0);
}
#[test]
fn test_custom_injection() {
let injection = TemporalInjection::custom(|t| if t < 10.0 { 0.01 * t } else { 0.0 });
assert_eq!(injection.evaluate(0.0), 0.0);
assert_eq!(injection.evaluate(5.0), 0.05);
assert_eq!(injection.evaluate(10.0), 0.0);
}
#[test]
fn test_none_injection() {
let injection = TemporalInjection::none();
assert_eq!(injection.evaluate(0.0), 0.0);
assert_eq!(injection.evaluate(100.0), 0.0);
}
#[test]
fn test_evaluate_series() {
let injection = TemporalInjection::gaussian(10.0, 2.0, 0.1);
let times = vec![0.0, 5.0, 10.0, 15.0, 20.0];
let values = injection.evaluate_series(×);
assert_eq!(values.len(), 5);
assert!((values[2] - 0.1).abs() < 1e-10); }
#[test]
#[should_panic(expected = "Rectangle end must be > start")]
fn test_rectangle_invalid() {
TemporalInjection::rectangle(10.0, 10.0, 0.05);
}
#[test]
fn test_debug_temporal_injection_dirac() {
let injection = TemporalInjection::dirac(0.0, 10.0);
let debug = format!("{:?}", injection);
assert_eq!(debug, "Dirac { time: 0.0, amount: 10.0 }");
}
#[test]
fn test_debug_temporal_injection_gaussian() {
let injection = TemporalInjection::gaussian(10.0, 2.0, 0.1);
let debug = format!("{:?}", injection);
assert_eq!(
debug,
"Gaussian { center: 10.0, width: 2.0, peak_concentration: 0.1 }"
);
}
#[test]
fn test_debug_temporal_injection_rectangle() {
let injection = TemporalInjection::rectangle(5.0, 9.0, 0.25);
let debug = format!("{:?}", injection);
assert_eq!(
debug,
"Rectangle { start: 5.0, end: 9.0, concentration: 0.25 }"
);
}
#[test]
fn test_debug_temporal_injection_custom() {
let injection = TemporalInjection::custom(|t| t.exp());
let debug = format!("{:?}", injection);
assert_eq!(debug, "Custom { function: \"<user-defined>\" }");
}
#[test]
fn test_debug_temporal_injection_none() {
let injection = TemporalInjection::none();
let debug = format!("{:?}", injection);
assert_eq!(debug, "None");
}
#[test]
fn test_clone_temporal_injection_dirac() {
let injection = TemporalInjection::dirac(10.0, 1.0);
let clone = injection.clone();
assert_eq!(injection.evaluate(0.0), clone.evaluate(0.0));
assert_eq!(injection.evaluate(10.0), clone.evaluate(10.0));
assert_eq!(injection.evaluate(20.0), clone.evaluate(20.0));
assert_eq!(injection.evaluate(10.0), 1.0);
assert_eq!(injection.evaluate(0.0), 0.0);
assert_eq!(injection.evaluate(20.0), 0.0);
}
#[test]
fn test_clone_temporal_injection_gaussian() {
let injection = TemporalInjection::gaussian(10.0, 2.0, 0.1);
let clone = injection.clone();
assert!((injection.evaluate(10.0) - 0.1).abs() < 1e-10);
assert!((clone.evaluate(10.0) - 0.1).abs() < 1e-10);
let expected = 0.1 * 0.606;
assert!((injection.evaluate(8.0) - expected).abs() < 0.01);
assert!((injection.evaluate(12.0) - expected).abs() < 0.01);
assert!((clone.evaluate(8.0) - expected).abs() < 0.01);
assert!((clone.evaluate(12.0) - expected).abs() < 0.01);
assert!(injection.evaluate(0.0) < 0.001);
assert!(injection.evaluate(20.0) < 0.001);
assert!(clone.evaluate(0.0) < 0.001);
assert!(clone.evaluate(20.0) < 0.001);
}
#[test]
fn test_clone_temporal_injection_rectangle() {
let injection = TemporalInjection::rectangle(5.0, 15.0, 0.05);
let clone = injection.clone();
assert_eq!(injection.evaluate(4.0), 0.0);
assert_eq!(clone.evaluate(4.0), 0.0);
assert_eq!(injection.evaluate(5.0), 0.05);
assert_eq!(injection.evaluate(10.0), 0.05);
assert_eq!(injection.evaluate(14.9), 0.05);
assert_eq!(clone.evaluate(5.0), 0.05);
assert_eq!(clone.evaluate(10.0), 0.05);
assert_eq!(clone.evaluate(14.9), 0.05);
assert_eq!(injection.evaluate(15.0), 0.0);
assert_eq!(clone.evaluate(15.0), 0.0);
}
#[test]
fn test_clone_temporal_injection_custom() {
let injection = TemporalInjection::custom(|t| if t < 10.0 { 0.01 * t } else { 0.0 });
let clone = injection.clone();
assert_eq!(injection.evaluate(0.0), 0.0);
assert_eq!(clone.evaluate(0.0), 0.0);
assert_eq!(injection.evaluate(5.0), 0.05);
assert_eq!(clone.evaluate(5.0), 0.05);
assert_eq!(injection.evaluate(10.0), 0.0);
assert_eq!(clone.evaluate(10.0), 0.0);
}
#[test]
fn test_clone_temporal_injection_none() {
let injection = TemporalInjection::none();
let clone = injection.clone();
assert_eq!(clone.evaluate(0.0), 0.0);
assert_eq!(clone.evaluate(100.0), 0.0);
}
#[test]
fn test_serialize_dirac() {
let injection = TemporalInjection::dirac(5.0, 0.1);
let json = serde_json::to_string(&injection).unwrap();
assert!(json.contains("\"type\":\"Dirac\""));
assert!(json.contains("\"time\":5.0"));
assert!(json.contains("\"amount\":0.1"));
}
#[test]
fn test_serialize_gaussian() {
let injection = TemporalInjection::gaussian(10.0, 3.0, 0.1);
let json = serde_json::to_string(&injection).unwrap();
assert!(json.contains("\"type\":\"Gaussian\""));
assert!(json.contains("\"center\":10.0"));
}
#[test]
fn test_serialize_rectangle() {
let injection = TemporalInjection::rectangle(5.0, 15.0, 0.05);
let json = serde_json::to_string(&injection).unwrap();
assert!(json.contains("\"type\":\"Rectangle\""));
assert!(json.contains("\"start\":5.0"));
assert!(json.contains("\"end\":15.0"));
}
#[test]
fn test_serialize_none() {
let injection = TemporalInjection::none();
let json = serde_json::to_string(&injection).unwrap();
assert!(json.contains("\"type\":\"None\""));
}
#[test]
fn test_serialize_custom_fails() {
let injection = TemporalInjection::custom(|t| t);
assert!(serde_json::to_string(&injection).is_err());
}
#[test]
fn test_round_trip_dirac() {
let original = TemporalInjection::dirac(5.0, 0.1);
let json = serde_json::to_string(&original).unwrap();
let restored: TemporalInjection = serde_json::from_str(&json).unwrap();
assert!((restored.evaluate(5.0) - original.evaluate(5.0)).abs() < 1e-10);
}
#[test]
fn test_round_trip_gaussian() {
let original = TemporalInjection::gaussian(10.0, 3.0, 0.1);
let json = serde_json::to_string(&original).unwrap();
let restored: TemporalInjection = serde_json::from_str(&json).unwrap();
assert!((restored.evaluate(10.0) - original.evaluate(10.0)).abs() < 1e-10);
}
#[test]
fn test_round_trip_rectangle() {
let original = TemporalInjection::rectangle(5.0, 15.0, 0.05);
let json = serde_json::to_string(&original).unwrap();
let restored: TemporalInjection = serde_json::from_str(&json).unwrap();
assert_eq!(restored.evaluate(10.0), original.evaluate(10.0));
}
#[test]
fn test_round_trip_none() {
let original = TemporalInjection::none();
let json = serde_json::to_string(&original).unwrap();
let restored: TemporalInjection = serde_json::from_str(&json).unwrap();
assert_eq!(restored.evaluate(0.0), 0.0);
}
#[test]
fn test_round_trip_yaml() {
let original = TemporalInjection::dirac(5.0, 0.1);
let yaml = serde_yaml::to_string(&original).unwrap();
let restored: TemporalInjection = serde_yaml::from_str(&yaml).unwrap();
assert!((restored.evaluate(5.0) - original.evaluate(5.0)).abs() < 1e-10);
}
}