1mod constant;
8mod pwl;
9
10pub use constant::*;
11pub use pwl::*;
12
13#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema), schemars(untagged))]
16#[cfg_attr(
17 feature = "serde",
18 derive(serde::Serialize, serde::Deserialize),
19 serde(try_from = "DemandCurveDto", into = "DemandCurveDto")
20)]
21#[derive(Clone, Debug)]
22pub enum DemandCurve {
27 Pwl(#[cfg_attr(feature = "schemars", schemars(with = "PwlCurveDto"))] PwlCurve),
29 Constant(#[cfg_attr(feature = "schemars", schemars(with = "ConstantCurveDto"))] ConstantCurve),
31}
32
33#[cfg_attr(feature = "serde", derive(serde::Serialize), serde(untagged))]
35#[derive(Debug)]
36pub enum DemandCurveDto {
37 Pwl(PwlCurveDto),
39 Constant(ConstantCurveDto),
41}
42
43#[cfg(feature = "serde")]
44impl<'de> serde::Deserialize<'de> for DemandCurveDto {
45 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
46 where
47 D: serde::Deserializer<'de>,
48 {
49 serde_untagged::UntaggedEnumVisitor::new()
50 .seq(|seq| seq.deserialize().map(DemandCurveDto::Pwl))
51 .map(|map| map.deserialize().map(DemandCurveDto::Constant))
52 .deserialize(deserializer)
53 }
54}
55
56impl TryFrom<DemandCurveDto> for DemandCurve {
57 type Error = DemandCurveError;
58
59 fn try_from(value: DemandCurveDto) -> Result<Self, Self::Error> {
61 match value {
62 DemandCurveDto::Pwl(curve) => Ok(curve.try_into()?),
63 DemandCurveDto::Constant(constant) => Ok(constant.try_into()?),
64 }
65 }
66}
67
68impl Into<DemandCurveDto> for DemandCurve {
69 fn into(self) -> DemandCurveDto {
70 match self {
71 Self::Pwl(curve) => DemandCurveDto::Pwl(curve.into()),
72 Self::Constant(constant) => DemandCurveDto::Constant(constant.into()),
73 }
74 }
75}
76
77impl From<PwlCurve> for DemandCurve {
78 fn from(value: PwlCurve) -> Self {
79 Self::Pwl(value)
80 }
81}
82
83impl From<ConstantCurve> for DemandCurve {
84 fn from(value: ConstantCurve) -> Self {
85 Self::Constant(value)
86 }
87}
88
89impl TryFrom<PwlCurveDto> for DemandCurve {
90 type Error = PwlCurveError;
91 fn try_from(value: PwlCurveDto) -> Result<Self, Self::Error> {
92 Ok(Self::Pwl(value.try_into()?))
93 }
94}
95
96impl TryFrom<ConstantCurveDto> for DemandCurve {
97 type Error = ConstantCurveError;
98 fn try_from(value: ConstantCurveDto) -> Result<Self, Self::Error> {
99 Ok(Self::Constant(value.try_into()?))
100 }
101}
102
103#[derive(Debug, thiserror::Error)]
105pub enum DemandCurveError {
106 #[error("invalid pwl curve: {0}")]
108 Pwl(#[from] PwlCurveError),
109 #[error("invalid constant curve: {0}")]
111 Constant(#[from] ConstantCurveError),
112}
113
114impl DemandCurve {
115 pub unsafe fn new_unchecked(value: DemandCurveDto) -> Self {
121 unsafe {
122 match value {
123 DemandCurveDto::Pwl(curve) => PwlCurve::new_unchecked(curve.0).into(),
124 DemandCurveDto::Constant(ConstantCurveDto {
125 min_rate,
126 max_rate,
127 price,
128 }) => ConstantCurve::new_unchecked(
129 min_rate.unwrap_or(f64::NEG_INFINITY),
130 max_rate.unwrap_or(f64::INFINITY),
131 price,
132 )
133 .into(),
134 }
135 }
136 }
137
138 pub fn domain(&self) -> (f64, f64) {
143 match self {
144 DemandCurve::Pwl(curve) => curve.domain(),
145 DemandCurve::Constant(curve) => curve.domain(),
146 }
147 }
148
149 pub fn points(self) -> Vec<Point> {
154 match self {
155 DemandCurve::Pwl(curve) => curve.points(),
156 DemandCurve::Constant(curve) => curve.points(),
157 }
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
166 fn test_deserialize_pwl() {
167 let raw = r#"[
168 {
169 "rate": 0.0,
170 "price": 10.0
171 },
172 {
173 "rate": 1.0,
174 "price": 5.0
175 }
176 ]"#;
177
178 let test = serde_json::from_str::<DemandCurve>(&raw);
179 assert!(test.is_ok());
180 }
181
182 #[test]
183 fn test_deserialize_constant() {
184 let raw = r#"{
185 "min_rate": -1.0,
186 "max_rate": 1.0,
187 "price": 10.0
188 }"#;
189
190 let test = serde_json::from_str::<DemandCurve>(&raw);
191 assert!(test.is_ok());
192 }
193}