1use super::Segment;
2use fts_core::models::Point;
3use std::iter::Peekable;
4
5pub fn disaggregate<T: Iterator<Item = Point>>(
9 points: T,
10 min: f64,
11 max: f64,
12) -> Option<impl Iterator<Item = Result<Segment, Segment>>> {
13 if !(min <= 0.0 && 0.0 <= max) {
14 return None;
15 }
16
17 let mut points = points.peekable();
18
19 if let Some(point) = points.peek() {
20 let anchor = if point.rate < min {
21 points.next()
22 } else {
23 Some(Point {
24 rate: min,
25 price: point.price,
26 })
27 };
28
29 Some(
30 Disaggregation {
31 points,
32 anchor,
33 domain: (min, max),
34 }
35 .filter(|result| match result {
38 Ok(demand) => demand.q0 != demand.q1,
39 Err(_) => true,
40 }),
41 )
42 } else {
43 None
44 }
45}
46
47#[derive(Debug)]
49struct Disaggregation<T: Iterator<Item = Point>> {
50 points: Peekable<T>,
52 anchor: Option<Point>,
54 domain: (f64, f64),
56}
57
58impl<T: Iterator<Item = Point>> Iterator for Disaggregation<T> {
59 type Item = Result<Segment, Segment>;
61
62 fn next(&mut self) -> Option<Self::Item> {
64 while let Some(prev) = self.anchor.take() {
66 if self.domain.1 <= prev.rate {
68 return None;
70 } else if let Some(mut next) = self.points.next() {
71 loop {
73 if let Some(extra) = self.points.peek() {
75 if is_collinear(&next, &prev, extra) {
76 next = self.points.next().unwrap();
78 continue;
79 } else {
80 break;
81 }
82 } else {
83 if self.domain.1 > next.rate {
84 let extra = Point {
85 rate: self.domain.1,
86 price: next.price,
87 };
88 if is_collinear(&next, &prev, &extra) {
89 next = extra;
90 }
91 }
92 break;
93 }
94 }
95
96 self.anchor = Some(next.clone());
97
98 let segment = Segment::new(prev, next)
99 .map(|(demand, translate)| {
100 demand.clip(self.domain.0 - translate, self.domain.1 - translate)
101 })
102 .map_err(|(demand, _)| demand)
103 .transpose();
104 if segment.is_some() {
105 return segment;
106 } else {
107 continue;
108 }
109 } else {
110 let next = Point {
113 rate: self.domain.1,
114 price: prev.price,
115 };
116
117 return Segment::new(prev, next)
118 .map(|(demand, translate)| {
119 demand.clip(self.domain.0 - translate, self.domain.1 - translate)
120 })
121 .map_err(|(demand, _)| demand)
122 .transpose();
123 }
124 }
125
126 None
127 }
128}
129
130fn is_collinear(pt: &Point, lhs: &Point, rhs: &Point) -> bool {
132 let &Point {
133 rate: x0,
134 price: y0,
135 } = lhs;
136 let &Point {
137 rate: x1,
138 price: y1,
139 } = pt;
140 let &Point {
141 rate: x2,
142 price: y2,
143 } = rhs;
144
145 (x2 - x0) * (y1 - y0) == (x1 - x0) * (y2 - y0)
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 fn data() -> impl Iterator<Item = Point> {
153 vec![
154 Point {
155 rate: -2.0,
156 price: 4.0,
157 },
158 Point {
159 rate: -1.0,
160 price: 3.0,
161 },
162 Point {
163 rate: 1.0,
164 price: 1.0,
165 },
166 Point {
167 rate: 2.0,
168 price: 0.0,
169 },
170 ]
171 .into_iter()
172 }
173
174 #[test]
175 fn collinear_reduction() {
176 let segments = disaggregate(data(), -2.0, 2.0)
177 .unwrap()
178 .map(|res| res.unwrap())
179 .collect::<Vec<_>>();
180
181 assert_eq!(
182 segments,
183 vec![Segment {
184 q0: -2.0,
185 q1: 2.0,
186 p0: 4.0,
187 p1: 0.0,
188 }]
189 )
190 }
191
192 #[test]
193 fn extrapolate_bad() {
194 assert!(disaggregate(data(), -10.0, -5.0).is_none());
195 assert!(disaggregate(data(), 5.0, 10.0).is_none());
196 }
197
198 #[test]
199 fn extrapolate_demand() {
200 let segments = disaggregate(data(), 0.0, 5.0)
201 .unwrap()
202 .map(|res| res.unwrap())
203 .collect::<Vec<_>>();
204
205 let answer = vec![
206 Segment {
207 q0: 0.0,
208 q1: 2.0,
209 p0: 2.0,
210 p1: 0.0,
211 },
212 Segment {
213 q0: 0.0,
214 q1: 3.0,
215 p0: 0.0,
216 p1: 0.0,
217 },
218 ];
219
220 assert_eq!(segments, answer);
221 }
222
223 #[test]
224 fn extrapolate_supply() {
225 let segments = disaggregate(data(), -5.0, 0.0)
226 .unwrap()
227 .map(|res| res.unwrap())
228 .collect::<Vec<_>>();
229
230 let answer = vec![
231 Segment {
232 q0: -3.0,
233 q1: 0.0,
234 p0: 4.0,
235 p1: 4.0,
236 },
237 Segment {
238 q0: -2.0,
239 q1: 0.0,
240 p0: 4.0,
241 p1: 2.0,
242 },
243 ];
244
245 assert_eq!(segments, answer);
246 }
247
248 #[test]
249 fn extrapolate_arbitrage() {
250 let segments = disaggregate(data(), -5.0, 5.0)
251 .unwrap()
252 .map(|res| res.unwrap())
253 .collect::<Vec<_>>();
254
255 let answer = vec![
256 Segment {
257 q0: -3.0,
258 q1: 0.0,
259 p0: 4.0,
260 p1: 4.0,
261 },
262 Segment {
263 q0: -2.0,
264 q1: 2.0,
265 p0: 4.0,
266 p1: 0.0,
267 },
268 Segment {
269 q0: 0.0,
270 q1: 3.0,
271 p0: 0.0,
272 p1: 0.0,
273 },
274 ];
275
276 assert_eq!(segments, answer);
277 }
278
279 #[test]
280 fn extrapolate_simple() {
281 let segments = disaggregate(
282 std::iter::once(Point {
283 rate: 0.0,
284 price: 5.0,
285 }),
286 -5.0,
287 5.0,
288 )
289 .unwrap()
290 .map(|res| res.unwrap())
291 .collect::<Vec<_>>();
292
293 let answer = vec![Segment {
294 q0: -5.0,
295 q1: 5.0,
296 p0: 5.0,
297 p1: 5.0,
298 }];
299
300 assert_eq!(segments, answer);
301 }
302
303 #[test]
304 fn check_slope() {
305 let points = vec![
306 Point {
307 rate: -2.0,
308 price: 4.0,
309 },
310 Point {
311 rate: 2.0,
312 price: 5.0,
313 },
314 ];
315 let segments = disaggregate(points.into_iter(), -1.0, 1.0)
316 .unwrap()
317 .collect::<Result<Vec<_>, _>>();
318 assert!(segments.is_err());
319 }
320}