fts_core/models/
curve.rs

1//! Demand curve implementations for flow trading.
2//!
3//! This module provides different curve types to express bidders' pricing preferences:
4//! - [`PwlCurve`]: Piecewise linear curves for complex pricing strategies
5//! - [`ConstantCurve`]: Fixed price curves for simple trading strategies
6
7mod constant;
8mod pwl;
9
10pub use constant::*;
11pub use pwl::*;
12
13// `schemars` does not support serde's try_from/into (https://github.com/GREsau/schemars/issues/210).
14// Thus, the "parse" path necessarily diverges a bit between serde and schemars, which is unfortunate.
15#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema), schemars(untagged))]
16#[cfg_attr(
17    feature = "serde",
18    derive(serde::Serialize, serde::Deserialize),
19    serde(try_from = "DemandCurveDto", into = "DemandCurveDto")
20)]
21#[derive(Clone, Debug)]
22/// A demand curve expressing a bidder's willingness to pay at different rates.
23///
24/// The solver uses these curves to find optimal allocations that maximize total welfare.
25/// All curves must include rate=0 in their domain to allow for zero trade scenarios.
26pub enum DemandCurve {
27    /// Piecewise linear curve defined by a series of points
28    Pwl(#[cfg_attr(feature = "schemars", schemars(with = "PwlCurveDto"))] PwlCurve),
29    /// Constant price curve over a rate interval
30    Constant(#[cfg_attr(feature = "schemars", schemars(with = "ConstantCurveDto"))] ConstantCurve),
31}
32
33/// DTO for demand curves to enable validation during deserialization
34#[cfg_attr(feature = "serde", derive(serde::Serialize), serde(untagged))]
35#[derive(Debug)]
36pub enum DemandCurveDto {
37    /// Piecewise linear curve DTO
38    Pwl(PwlCurveDto),
39    /// Constant price curve DTO
40    Constant(ConstantCurveDto),
41}
42
43#[cfg(feature = "serde")]
44impl<'de> serde::Deserialize<'de> for DemandCurveDto {
45    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
46    where
47        D: serde::Deserializer<'de>,
48    {
49        serde_untagged::UntaggedEnumVisitor::new()
50            .seq(|seq| seq.deserialize().map(DemandCurveDto::Pwl))
51            .map(|map| map.deserialize().map(DemandCurveDto::Constant))
52            .deserialize(deserializer)
53    }
54}
55
56impl TryFrom<DemandCurveDto> for DemandCurve {
57    type Error = DemandCurveError;
58
59    /// Creates a demand curve from a DTO, validating all constraints
60    fn try_from(value: DemandCurveDto) -> Result<Self, Self::Error> {
61        match value {
62            DemandCurveDto::Pwl(curve) => Ok(curve.try_into()?),
63            DemandCurveDto::Constant(constant) => Ok(constant.try_into()?),
64        }
65    }
66}
67
68impl Into<DemandCurveDto> for DemandCurve {
69    fn into(self) -> DemandCurveDto {
70        match self {
71            Self::Pwl(curve) => DemandCurveDto::Pwl(curve.into()),
72            Self::Constant(constant) => DemandCurveDto::Constant(constant.into()),
73        }
74    }
75}
76
77impl From<PwlCurve> for DemandCurve {
78    fn from(value: PwlCurve) -> Self {
79        Self::Pwl(value)
80    }
81}
82
83impl From<ConstantCurve> for DemandCurve {
84    fn from(value: ConstantCurve) -> Self {
85        Self::Constant(value)
86    }
87}
88
89impl TryFrom<PwlCurveDto> for DemandCurve {
90    type Error = PwlCurveError;
91    fn try_from(value: PwlCurveDto) -> Result<Self, Self::Error> {
92        Ok(Self::Pwl(value.try_into()?))
93    }
94}
95
96impl TryFrom<ConstantCurveDto> for DemandCurve {
97    type Error = ConstantCurveError;
98    fn try_from(value: ConstantCurveDto) -> Result<Self, Self::Error> {
99        Ok(Self::Constant(value.try_into()?))
100    }
101}
102
103/// Errors that can occur when constructing demand curves
104#[derive(Debug, thiserror::Error)]
105pub enum DemandCurveError {
106    /// Error from constructing a piecewise linear curve
107    #[error("invalid pwl curve: {0}")]
108    Pwl(#[from] PwlCurveError),
109    /// Error from constructing a constant curve
110    #[error("invalid constant curve: {0}")]
111    Constant(#[from] ConstantCurveError),
112}
113
114impl DemandCurve {
115    /// Creates a demand curve without validation
116    ///
117    /// # Safety
118    /// The caller must ensure the data represents a valid curve.
119    /// Invalid curves may cause undefined behavior in the solver.
120    pub unsafe fn new_unchecked(value: DemandCurveDto) -> Self {
121        unsafe {
122            match value {
123                DemandCurveDto::Pwl(curve) => PwlCurve::new_unchecked(curve.0).into(),
124                DemandCurveDto::Constant(ConstantCurveDto {
125                    min_rate,
126                    max_rate,
127                    price,
128                }) => ConstantCurve::new_unchecked(
129                    min_rate.unwrap_or(f64::NEG_INFINITY),
130                    max_rate.unwrap_or(f64::INFINITY),
131                    price,
132                )
133                .into(),
134            }
135        }
136    }
137
138    /// Returns the rate interval over which this curve is defined
139    ///
140    /// # Returns
141    /// A tuple `(min_rate, max_rate)` defining the valid rate range for this curve.
142    pub fn domain(&self) -> (f64, f64) {
143        match self {
144            DemandCurve::Pwl(curve) => curve.domain(),
145            DemandCurve::Constant(curve) => curve.domain(),
146        }
147    }
148
149    /// Converts the curve into a vector of points
150    ///
151    /// For PWL curves, returns all defining points. For constant curves,
152    /// returns two points representing the endpoints of the constant price segment.
153    pub fn points(self) -> Vec<Point> {
154        match self {
155            DemandCurve::Pwl(curve) => curve.points(),
156            DemandCurve::Constant(curve) => curve.points(),
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_deserialize_pwl() {
167        let raw = r#"[
168            {
169                "rate": 0.0,
170                "price": 10.0
171            },
172            {
173                "rate": 1.0,
174                "price": 5.0
175            }
176        ]"#;
177
178        let test = serde_json::from_str::<DemandCurve>(&raw);
179        assert!(test.is_ok());
180    }
181
182    #[test]
183    fn test_deserialize_constant() {
184        let raw = r#"{
185            "min_rate": -1.0,
186            "max_rate": 1.0,
187            "price": 10.0
188        }"#;
189
190        let test = serde_json::from_str::<DemandCurve>(&raw);
191        assert!(test.is_ok());
192    }
193}