mod constant;
mod pwl;
pub use constant::*;
pub use pwl::*;
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema), schemars(untagged))]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize),
serde(try_from = "DemandCurveDto", into = "DemandCurveDto")
)]
#[derive(Clone, Debug)]
pub enum DemandCurve {
Pwl(#[cfg_attr(feature = "schemars", schemars(with = "PwlCurveDto"))] PwlCurve),
Constant(#[cfg_attr(feature = "schemars", schemars(with = "ConstantCurveDto"))] ConstantCurve),
}
#[cfg_attr(feature = "serde", derive(serde::Serialize), serde(untagged))]
#[derive(Debug)]
pub enum DemandCurveDto {
Pwl(PwlCurveDto),
Constant(ConstantCurveDto),
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for DemandCurveDto {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
serde_untagged::UntaggedEnumVisitor::new()
.seq(|seq| seq.deserialize().map(DemandCurveDto::Pwl))
.map(|map| map.deserialize().map(DemandCurveDto::Constant))
.deserialize(deserializer)
}
}
impl TryFrom<DemandCurveDto> for DemandCurve {
type Error = DemandCurveError;
fn try_from(value: DemandCurveDto) -> Result<Self, Self::Error> {
match value {
DemandCurveDto::Pwl(curve) => Ok(curve.try_into()?),
DemandCurveDto::Constant(constant) => Ok(constant.try_into()?),
}
}
}
impl Into<DemandCurveDto> for DemandCurve {
fn into(self) -> DemandCurveDto {
match self {
Self::Pwl(curve) => DemandCurveDto::Pwl(curve.into()),
Self::Constant(constant) => DemandCurveDto::Constant(constant.into()),
}
}
}
impl From<PwlCurve> for DemandCurve {
fn from(value: PwlCurve) -> Self {
Self::Pwl(value)
}
}
impl From<ConstantCurve> for DemandCurve {
fn from(value: ConstantCurve) -> Self {
Self::Constant(value)
}
}
impl TryFrom<PwlCurveDto> for DemandCurve {
type Error = PwlCurveError;
fn try_from(value: PwlCurveDto) -> Result<Self, Self::Error> {
Ok(Self::Pwl(value.try_into()?))
}
}
impl TryFrom<ConstantCurveDto> for DemandCurve {
type Error = ConstantCurveError;
fn try_from(value: ConstantCurveDto) -> Result<Self, Self::Error> {
Ok(Self::Constant(value.try_into()?))
}
}
#[derive(Debug, thiserror::Error)]
pub enum DemandCurveError {
#[error("invalid pwl curve: {0}")]
Pwl(#[from] PwlCurveError),
#[error("invalid constant curve: {0}")]
Constant(#[from] ConstantCurveError),
}
impl DemandCurve {
pub unsafe fn new_unchecked(value: DemandCurveDto) -> Self {
unsafe {
match value {
DemandCurveDto::Pwl(curve) => PwlCurve::new_unchecked(curve.0).into(),
DemandCurveDto::Constant(ConstantCurveDto {
min_rate,
max_rate,
price,
}) => ConstantCurve::new_unchecked(
min_rate.unwrap_or(f64::NEG_INFINITY),
max_rate.unwrap_or(f64::INFINITY),
price,
)
.into(),
}
}
}
pub fn domain(&self) -> (f64, f64) {
match self {
DemandCurve::Pwl(curve) => curve.domain(),
DemandCurve::Constant(curve) => curve.domain(),
}
}
pub fn points(self) -> Vec<Point> {
match self {
DemandCurve::Pwl(curve) => curve.points(),
DemandCurve::Constant(curve) => curve.points(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deserialize_pwl() {
let raw = r#"[
{
"rate": 0.0,
"price": 10.0
},
{
"rate": 1.0,
"price": 5.0
}
]"#;
let test = serde_json::from_str::<DemandCurve>(&raw);
assert!(test.is_ok());
}
#[test]
fn test_deserialize_constant() {
let raw = r#"{
"min_rate": -1.0,
"max_rate": 1.0,
"price": 10.0
}"#;
let test = serde_json::from_str::<DemandCurve>(&raw);
assert!(test.is_ok());
}
}