1use super::{Point, Segment};
2use std::iter::Peekable;
3
4pub fn disaggregate<T: Iterator<Item = Point>>(
8 points: T,
9 min: f64,
10 max: f64,
11) -> Option<impl Iterator<Item = Result<Segment, Segment>>> {
12 if !(min <= 0.0 && 0.0 <= max) {
13 return None;
14 }
15
16 let mut points = points.peekable();
17
18 if let Some(point) = points.peek() {
19 let anchor = if point.quantity < min {
20 points.next()
21 } else {
22 Some(Point {
23 quantity: min,
24 price: point.price,
25 })
26 };
27
28 Some(
29 Disaggregation {
30 points,
31 anchor,
32 domain: (min, max),
33 }
34 .filter(|result| match result {
37 Ok(demand) => demand.q0 != demand.q1,
38 Err(_) => true,
39 }),
40 )
41 } else {
42 None
43 }
44}
45
46#[derive(Debug)]
48struct Disaggregation<T: Iterator<Item = Point>> {
49 points: Peekable<T>,
51 anchor: Option<Point>,
53 domain: (f64, f64),
55}
56
57impl<T: Iterator<Item = Point>> Iterator for Disaggregation<T> {
58 type Item = Result<Segment, Segment>;
60
61 fn next(&mut self) -> Option<Self::Item> {
63 while let Some(prev) = self.anchor.take() {
65 if self.domain.1 <= prev.quantity {
67 return None;
69 } else if let Some(mut next) = self.points.next() {
70 loop {
72 if let Some(extra) = self.points.peek() {
74 if next.is_collinear(&prev, extra) {
75 next = self.points.next().unwrap();
77 continue;
78 } else {
79 break;
80 }
81 } else {
82 if self.domain.1 > next.quantity {
83 let extra = Point {
84 quantity: self.domain.1,
85 price: next.price,
86 };
87 if next.is_collinear(&prev, &extra) {
88 next = extra;
89 }
90 }
91 break;
92 }
93 }
94
95 self.anchor = Some(next.clone());
96
97 let segment = Segment::new(prev, next)
98 .map(|(demand, translate)| {
99 demand.clip(self.domain.0 - translate, self.domain.1 - translate)
100 })
101 .map_err(|(demand, _)| demand)
102 .transpose();
103 if segment.is_some() {
104 return segment;
105 } else {
106 continue;
107 }
108 } else {
109 let next = Point {
112 quantity: self.domain.1,
113 price: prev.price,
114 };
115
116 return Segment::new(prev, next)
117 .map(|(demand, translate)| {
118 demand.clip(self.domain.0 - translate, self.domain.1 - translate)
119 })
120 .map_err(|(demand, _)| demand)
121 .transpose();
122 }
123 }
124
125 None
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 fn data() -> impl Iterator<Item = Point> {
134 vec![
135 Point {
136 quantity: -2.0,
137 price: 4.0,
138 },
139 Point {
140 quantity: -1.0,
141 price: 3.0,
142 },
143 Point {
144 quantity: 1.0,
145 price: 1.0,
146 },
147 Point {
148 quantity: 2.0,
149 price: 0.0,
150 },
151 ]
152 .into_iter()
153 }
154
155 #[test]
156 fn collinear_reduction() {
157 let segments = disaggregate(data(), -2.0, 2.0)
158 .unwrap()
159 .map(|res| res.unwrap())
160 .collect::<Vec<_>>();
161
162 assert_eq!(
163 segments,
164 vec![Segment {
165 q0: -2.0,
166 q1: 2.0,
167 p0: 4.0,
168 p1: 0.0,
169 }]
170 )
171 }
172
173 #[test]
174 fn extrapolate_bad() {
175 assert!(disaggregate(data(), -10.0, -5.0).is_none());
176 assert!(disaggregate(data(), 5.0, 10.0).is_none());
177 }
178
179 #[test]
180 fn extrapolate_demand() {
181 let segments = disaggregate(data(), 0.0, 5.0)
182 .unwrap()
183 .map(|res| res.unwrap())
184 .collect::<Vec<_>>();
185
186 let answer = vec![
187 Segment {
188 q0: 0.0,
189 q1: 2.0,
190 p0: 2.0,
191 p1: 0.0,
192 },
193 Segment {
194 q0: 0.0,
195 q1: 3.0,
196 p0: 0.0,
197 p1: 0.0,
198 },
199 ];
200
201 assert_eq!(segments, answer);
202 }
203
204 #[test]
205 fn extrapolate_supply() {
206 let segments = disaggregate(data(), -5.0, 0.0)
207 .unwrap()
208 .map(|res| res.unwrap())
209 .collect::<Vec<_>>();
210
211 let answer = vec![
212 Segment {
213 q0: -3.0,
214 q1: 0.0,
215 p0: 4.0,
216 p1: 4.0,
217 },
218 Segment {
219 q0: -2.0,
220 q1: 0.0,
221 p0: 4.0,
222 p1: 2.0,
223 },
224 ];
225
226 assert_eq!(segments, answer);
227 }
228
229 #[test]
230 fn extrapolate_arbitrage() {
231 let segments = disaggregate(data(), -5.0, 5.0)
232 .unwrap()
233 .map(|res| res.unwrap())
234 .collect::<Vec<_>>();
235
236 let answer = vec![
237 Segment {
238 q0: -3.0,
239 q1: 0.0,
240 p0: 4.0,
241 p1: 4.0,
242 },
243 Segment {
244 q0: -2.0,
245 q1: 2.0,
246 p0: 4.0,
247 p1: 0.0,
248 },
249 Segment {
250 q0: 0.0,
251 q1: 3.0,
252 p0: 0.0,
253 p1: 0.0,
254 },
255 ];
256
257 assert_eq!(segments, answer);
258 }
259
260 #[test]
261 fn extrapolate_simple() {
262 let segments = disaggregate(
263 std::iter::once(Point {
264 quantity: 0.0,
265 price: 5.0,
266 }),
267 -5.0,
268 5.0,
269 )
270 .unwrap()
271 .map(|res| res.unwrap())
272 .collect::<Vec<_>>();
273
274 let answer = vec![Segment {
275 q0: -5.0,
276 q1: 5.0,
277 p0: 5.0,
278 p1: 5.0,
279 }];
280
281 assert_eq!(segments, answer);
282 }
283
284 #[test]
285 fn check_slope() {
286 let points = vec![
287 Point {
288 quantity: -2.0,
289 price: 4.0,
290 },
291 Point {
292 quantity: 2.0,
293 price: 5.0,
294 },
295 ];
296 let segments = disaggregate(points.into_iter(), -1.0, 1.0)
297 .unwrap()
298 .collect::<Result<Vec<_>, _>>();
299 assert!(segments.is_err());
300 }
301}