fts_solver/types/demand/
disaggregate.rs

1use super::Segment;
2use fts_core::models::Point;
3use std::iter::Peekable;
4
5/// If a demand curve is an aggregation of individual demand segments, then we
6/// can disaggregate a demand curve into these segments. This is useful for
7/// constructing the optimization program.
8pub fn disaggregate<T: Iterator<Item = Point>>(
9    points: T,
10    min: f64,
11    max: f64,
12) -> Option<impl Iterator<Item = Result<Segment, Segment>>> {
13    if !(min <= 0.0 && 0.0 <= max) {
14        return None;
15    }
16
17    let mut points = points.peekable();
18
19    if let Some(point) = points.peek() {
20        let anchor = if point.rate < min {
21            points.next()
22        } else {
23            Some(Point {
24                rate: min,
25                price: point.price,
26            })
27        };
28
29        Some(
30            Disaggregation {
31                points,
32                anchor,
33                domain: (min, max),
34            }
35            // We remove any demand segments which do not contribute, but we preserve
36            // any invalid segments in order to surface the error to the caller.
37            .filter(|result| match result {
38                Ok(demand) => demand.q0 != demand.q1,
39                Err(_) => true,
40            }),
41        )
42    } else {
43        None
44    }
45}
46
47// An iterator that disaggregates a demand curve into its simple segments
48#[derive(Debug)]
49struct Disaggregation<T: Iterator<Item = Point>> {
50    /// The raw, underlying iterator of points
51    points: Peekable<T>,
52    /// An anchoring point, representing the "left" point of a sliding window of points
53    anchor: Option<Point>,
54    // A clipping domain. Since we validate domain.0 <= domain.1 in the caller, and the constructor is private, we can rely on this invariant
55    domain: (f64, f64),
56}
57
58impl<T: Iterator<Item = Point>> Iterator for Disaggregation<T> {
59    // If an Err() is returned, the original demand curve was invalid
60    type Item = Result<Segment, Segment>;
61
62    // Iterate over the translated segments of a demand curve
63    fn next(&mut self) -> Option<Self::Item> {
64        // Are we anchored?
65        while let Some(prev) = self.anchor.take() {
66            // If so, contemplate the next point.
67            if self.domain.1 <= prev.rate {
68                // early exit condition
69                return None;
70            } else if let Some(mut next) = self.points.next() {
71                // If there is a point, try to generate a segment.
72                loop {
73                    // We remove any interior, collinear points to simplify the curve
74                    if let Some(extra) = self.points.peek() {
75                        if is_collinear(&next, &prev, extra) {
76                            // Safe, since self.points.peek().is_some()
77                            next = self.points.next().unwrap();
78                            continue;
79                        } else {
80                            break;
81                        }
82                    } else {
83                        if self.domain.1 > next.rate {
84                            let extra = Point {
85                                rate: self.domain.1,
86                                price: next.price,
87                            };
88                            if is_collinear(&next, &prev, &extra) {
89                                next = extra;
90                            }
91                        }
92                        break;
93                    }
94                }
95
96                self.anchor = Some(next.clone());
97
98                let segment = Segment::new(prev, next)
99                    .map(|(demand, translate)| {
100                        demand.clip(self.domain.0 - translate, self.domain.1 - translate)
101                    })
102                    .map_err(|(demand, _)| demand)
103                    .transpose();
104                if segment.is_some() {
105                    return segment;
106                } else {
107                    continue;
108                }
109            } else {
110                // If there are no more points, we are done iterating.
111                // However, we might need to extrapolate one additional point.
112                let next = Point {
113                    rate: self.domain.1,
114                    price: prev.price,
115                };
116
117                return Segment::new(prev, next)
118                    .map(|(demand, translate)| {
119                        demand.clip(self.domain.0 - translate, self.domain.1 - translate)
120                    })
121                    .map_err(|(demand, _)| demand)
122                    .transpose();
123            }
124        }
125
126        None
127    }
128}
129
130/// Is this point collinear with the other two?
131fn is_collinear(pt: &Point, lhs: &Point, rhs: &Point) -> bool {
132    let &Point {
133        rate: x0,
134        price: y0,
135    } = lhs;
136    let &Point {
137        rate: x1,
138        price: y1,
139    } = pt;
140    let &Point {
141        rate: x2,
142        price: y2,
143    } = rhs;
144
145    (x2 - x0) * (y1 - y0) == (x1 - x0) * (y2 - y0)
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    fn data() -> impl Iterator<Item = Point> {
153        vec![
154            Point {
155                rate: -2.0,
156                price: 4.0,
157            },
158            Point {
159                rate: -1.0,
160                price: 3.0,
161            },
162            Point {
163                rate: 1.0,
164                price: 1.0,
165            },
166            Point {
167                rate: 2.0,
168                price: 0.0,
169            },
170        ]
171        .into_iter()
172    }
173
174    #[test]
175    fn collinear_reduction() {
176        let segments = disaggregate(data(), -2.0, 2.0)
177            .unwrap()
178            .map(|res| res.unwrap())
179            .collect::<Vec<_>>();
180
181        assert_eq!(
182            segments,
183            vec![Segment {
184                q0: -2.0,
185                q1: 2.0,
186                p0: 4.0,
187                p1: 0.0,
188            }]
189        )
190    }
191
192    #[test]
193    fn extrapolate_bad() {
194        assert!(disaggregate(data(), -10.0, -5.0).is_none());
195        assert!(disaggregate(data(), 5.0, 10.0).is_none());
196    }
197
198    #[test]
199    fn extrapolate_demand() {
200        let segments = disaggregate(data(), 0.0, 5.0)
201            .unwrap()
202            .map(|res| res.unwrap())
203            .collect::<Vec<_>>();
204
205        let answer = vec![
206            Segment {
207                q0: 0.0,
208                q1: 2.0,
209                p0: 2.0,
210                p1: 0.0,
211            },
212            Segment {
213                q0: 0.0,
214                q1: 3.0,
215                p0: 0.0,
216                p1: 0.0,
217            },
218        ];
219
220        assert_eq!(segments, answer);
221    }
222
223    #[test]
224    fn extrapolate_supply() {
225        let segments = disaggregate(data(), -5.0, 0.0)
226            .unwrap()
227            .map(|res| res.unwrap())
228            .collect::<Vec<_>>();
229
230        let answer = vec![
231            Segment {
232                q0: -3.0,
233                q1: 0.0,
234                p0: 4.0,
235                p1: 4.0,
236            },
237            Segment {
238                q0: -2.0,
239                q1: 0.0,
240                p0: 4.0,
241                p1: 2.0,
242            },
243        ];
244
245        assert_eq!(segments, answer);
246    }
247
248    #[test]
249    fn extrapolate_arbitrage() {
250        let segments = disaggregate(data(), -5.0, 5.0)
251            .unwrap()
252            .map(|res| res.unwrap())
253            .collect::<Vec<_>>();
254
255        let answer = vec![
256            Segment {
257                q0: -3.0,
258                q1: 0.0,
259                p0: 4.0,
260                p1: 4.0,
261            },
262            Segment {
263                q0: -2.0,
264                q1: 2.0,
265                p0: 4.0,
266                p1: 0.0,
267            },
268            Segment {
269                q0: 0.0,
270                q1: 3.0,
271                p0: 0.0,
272                p1: 0.0,
273            },
274        ];
275
276        assert_eq!(segments, answer);
277    }
278
279    #[test]
280    fn extrapolate_simple() {
281        let segments = disaggregate(
282            std::iter::once(Point {
283                rate: 0.0,
284                price: 5.0,
285            }),
286            -5.0,
287            5.0,
288        )
289        .unwrap()
290        .map(|res| res.unwrap())
291        .collect::<Vec<_>>();
292
293        let answer = vec![Segment {
294            q0: -5.0,
295            q1: 5.0,
296            p0: 5.0,
297            p1: 5.0,
298        }];
299
300        assert_eq!(segments, answer);
301    }
302
303    #[test]
304    fn check_slope() {
305        let points = vec![
306            Point {
307                rate: -2.0,
308                price: 4.0,
309            },
310            Point {
311                rate: 2.0,
312                price: 5.0,
313            },
314        ];
315        let segments = disaggregate(points.into_iter(), -1.0, 1.0)
316            .unwrap()
317            .collect::<Result<Vec<_>, _>>();
318        assert!(segments.is_err());
319    }
320}