1extern crate geo;
39extern crate num_traits;
40
41#[cfg(feature = "serde")]
42#[macro_use]
43extern crate serde;
44
45use std::cmp::Ordering;
46use std::collections::BinaryHeap;
47use std::convert::{TryFrom, TryInto};
48
49pub use geo::{Coord, CoordFloat, Line, LineString, Point};
50use num_traits::Signed;
51
52#[derive(PartialEq, Clone, Debug)]
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81pub struct PiecewiseLinearFunction<T: CoordFloat> {
82 pub coordinates: Vec<Coord<T>>,
84}
85
86impl<T: CoordFloat> PiecewiseLinearFunction<T> {
87 pub fn new(coordinates: Vec<Coord<T>>) -> Option<Self> {
91 if coordinates.len() >= 2 && coordinates.windows(2).all(|w| w[0].x < w[1].x) {
92 Some(PiecewiseLinearFunction { coordinates })
93 } else {
94 None
95 }
96 }
97
98 pub fn constant(domain: (T, T), value: T) -> Option<Self> {
102 if domain.0 < domain.1 {
103 let coordinates = vec![(domain.0, value).into(), (domain.1, value).into()];
104 Some(PiecewiseLinearFunction { coordinates })
105 } else {
106 None
107 }
108 }
109
110 pub fn domain(&self) -> (T, T) {
112 (self.coordinates[0].x, self.coordinates.last().unwrap().x)
113 }
114
115 pub fn has_same_domain_as(&self, other: &PiecewiseLinearFunction<T>) -> bool {
117 self.domain() == other.domain()
118 }
119
120 pub fn segments_iter(&self) -> SegmentsIterator<T> {
124 SegmentsIterator(self.coordinates.iter().peekable())
125 }
126
127 pub fn points_of_inflection_iter<'a>(
131 &'a self,
132 other: &'a PiecewiseLinearFunction<T>,
133 ) -> Option<PointsOfInflectionIterator<T>> {
134 if !self.has_same_domain_as(other) {
135 None
136 } else {
137 Some(PointsOfInflectionIterator {
138 segment_iterators: vec![
139 self.segments_iter().peekable(),
140 other.segments_iter().peekable(),
141 ],
142 heap: BinaryHeap::new(),
143 initial: true,
144 })
145 }
146 }
147
148 pub fn segment_at_x(&self, x: T) -> Option<Line<T>> {
152 let idx = match self
153 .coordinates
154 .binary_search_by(|val| bogus_compare(&val.x, &x))
155 {
156 Ok(idx) => idx,
157 Err(idx) => {
158 if idx == 0 || idx == self.coordinates.len() {
159 return None;
161 } else {
162 idx
163 }
164 }
165 };
166
167 if idx == 0 {
168 Some(Line::new(self.coordinates[idx], self.coordinates[idx + 1]))
169 } else {
170 Some(Line::new(self.coordinates[idx - 1], self.coordinates[idx]))
171 }
172 }
173
174 pub fn y_at_x(&self, x: T) -> Option<T> {
178 self.segment_at_x(x).map(|line| y_at_x(&line, x))
179 }
180
181 pub fn shrink_domain(&self, to_domain: (T, T)) -> Option<PiecewiseLinearFunction<T>> {
186 let order = compare_domains(self.domain(), to_domain);
187 match order {
188 Some(Ordering::Equal) => Some(self.clone()),
189 Some(Ordering::Greater) => {
190 let mut new_points = Vec::new();
191 for segment in self.segments_iter() {
192 if let Some(restricted) = line_in_domain(&segment, to_domain) {
193 if segment.start.x <= to_domain.0 {
197 new_points.push(restricted.start);
198 }
199 new_points.push(restricted.end);
200 }
201 }
202 Some(new_points.try_into().unwrap())
203 }
204 _ => None,
205 }
206 }
207
208 pub fn expand_domain(
214 &self,
215 to_domain: (T, T),
216 strategy: ExpandDomainStrategy,
217 ) -> PiecewiseLinearFunction<T> {
218 if compare_domains(self.domain(), to_domain) == Some(Ordering::Equal) {
219 return self.clone();
220 }
221 let mut new_points = Vec::new();
222 if self.coordinates[0].x > to_domain.0 {
223 match &strategy {
224 ExpandDomainStrategy::ExtendSegment => new_points.push(Coord {
225 x: to_domain.0,
226 y: y_at_x(
227 &Line::new(self.coordinates[0], self.coordinates[1]),
228 to_domain.0,
229 ),
230 }),
231 ExpandDomainStrategy::ExtendValue => {
232 new_points.push((to_domain.0, self.coordinates[0].y).into());
233 new_points.push(self.coordinates[0]);
234 }
235 }
236 } else {
237 new_points.push(self.coordinates[0]);
238 }
239
240 let last_index = self.coordinates.len() - 1;
241 new_points.extend_from_slice(&self.coordinates[1..last_index]);
242
243 if self.coordinates[last_index].x < to_domain.1 {
244 match &strategy {
245 ExpandDomainStrategy::ExtendSegment => new_points.push(Coord {
246 x: to_domain.1,
247 y: y_at_x(
248 &Line::new(
249 self.coordinates[last_index - 1],
250 self.coordinates[last_index],
251 ),
252 to_domain.1,
253 ),
254 }),
255 ExpandDomainStrategy::ExtendValue => {
256 new_points.push(self.coordinates[last_index]);
257 new_points.push((to_domain.1, self.coordinates[last_index].y).into());
258 }
259 }
260 } else {
261 new_points.push(self.coordinates[last_index])
262 }
263
264 new_points.try_into().unwrap()
265 }
266
267 pub fn add(&self, other: &PiecewiseLinearFunction<T>) -> Option<PiecewiseLinearFunction<T>> {
271 self.points_of_inflection_iter(other).map(|poi| {
272 PiecewiseLinearFunction::new(
273 poi.map(|(x, coords)| Coord {
274 x,
275 y: coords[0] + coords[1],
276 })
277 .collect(),
278 )
279 .unwrap()
282 })
283 }
284
285 pub fn max(&self, other: &PiecewiseLinearFunction<T>) -> Option<PiecewiseLinearFunction<T>> {
305 let mut poi_iter = self.points_of_inflection_iter(other)?;
306 let mut new_values = Vec::new();
307
308 let (x, values) = poi_iter.next().unwrap();
309 let (i_largest, largest) = argmax(&values).unwrap();
310 new_values.push(Coord { x, y: *largest });
311
312 let mut prev_largest = i_largest;
313 let mut prev_x = x;
314 let mut prev_values = values;
315
316 for (x, values) in poi_iter {
317 let (i_largest, largest) = argmax(&values).unwrap();
318 if i_largest != prev_largest {
319 let (inter_x, inter_y) = line_intersect(
320 &Line::new((prev_x, prev_values[0]), (x, values[0])),
321 &Line::new((prev_x, prev_values[1]), (x, values[1])),
322 );
323 if inter_x > prev_x && inter_x < x {
326 new_values.push(Coord {
327 x: inter_x,
328 y: inter_y,
329 });
330 }
331 }
332 new_values.push(Coord { x, y: *largest });
333 prev_largest = i_largest;
334 prev_x = x;
335 prev_values = values;
336 }
337
338 Some(PiecewiseLinearFunction::new(new_values).unwrap())
339 }
340}
341
342#[derive(Copy, Clone, Debug, PartialEq, Eq)]
345pub enum ExpandDomainStrategy {
346 ExtendSegment,
348 ExtendValue,
350}
351
352impl<T: CoordFloat + Signed> PiecewiseLinearFunction<T> {
353 pub fn negate(&self) -> PiecewiseLinearFunction<T> {
355 PiecewiseLinearFunction::new(
356 self.coordinates
357 .iter()
358 .map(|Coord { x, y }| Coord { x: *x, y: -(*y) })
359 .collect(),
360 )
361 .unwrap()
363 }
364
365 pub fn min(&self, other: &PiecewiseLinearFunction<T>) -> Option<PiecewiseLinearFunction<T>> {
369 Some(self.negate().max(&other.negate())?.negate())
370 }
371
372 pub fn abs(&self) -> PiecewiseLinearFunction<T> {
374 self.max(&self.negate()).unwrap()
375 }
376}
377
378impl<T: CoordFloat + Signed> ::std::ops::Neg for PiecewiseLinearFunction<T> {
379 type Output = Self;
380
381 fn neg(self) -> Self::Output {
382 self.negate()
383 }
384}
385
386impl<T: CoordFloat + ::std::iter::Sum> PiecewiseLinearFunction<T> {
387 pub fn integrate(&self) -> T {
389 self.segments_iter()
390 .map(|segment| {
391 let (min_y, max_y) = if segment.start.y < segment.end.y {
392 (segment.start.y, segment.end.y)
393 } else {
394 (segment.end.y, segment.start.y)
395 };
396 let x_span = segment.end.x - segment.start.x;
397 x_span * (min_y + max_y / T::from(2).unwrap())
398 })
399 .sum()
400 }
401}
402
403impl<T: CoordFloat> TryFrom<LineString<T>> for PiecewiseLinearFunction<T> {
406 type Error = ();
407
408 fn try_from(value: LineString<T>) -> Result<Self, Self::Error> {
409 PiecewiseLinearFunction::new(value.0).ok_or(())
410 }
411}
412
413impl<T: CoordFloat> TryFrom<Vec<Coord<T>>> for PiecewiseLinearFunction<T> {
414 type Error = ();
415
416 fn try_from(value: Vec<Coord<T>>) -> Result<Self, Self::Error> {
417 PiecewiseLinearFunction::new(value).ok_or(())
418 }
419}
420
421impl<T: CoordFloat> TryFrom<Vec<Point<T>>> for PiecewiseLinearFunction<T> {
422 type Error = ();
423
424 fn try_from(value: Vec<Point<T>>) -> Result<Self, Self::Error> {
425 PiecewiseLinearFunction::new(value.into_iter().map(|p| p.0).collect()).ok_or(())
426 }
427}
428
429impl<T: CoordFloat> TryFrom<Vec<(T, T)>> for PiecewiseLinearFunction<T> {
430 type Error = ();
431
432 fn try_from(value: Vec<(T, T)>) -> Result<Self, Self::Error> {
433 PiecewiseLinearFunction::new(value.into_iter().map(Coord::from).collect()).ok_or(())
434 }
435}
436
437impl<T: CoordFloat> From<PiecewiseLinearFunction<T>> for Vec<(T, T)> {
438 fn from(val: PiecewiseLinearFunction<T>) -> Self {
439 val.coordinates
440 .into_iter()
441 .map(|coord| coord.x_y())
442 .collect()
443 }
444}
445
446#[derive(Debug, Clone, Copy, PartialEq)]
449struct NextSegment<T: CoordFloat> {
450 x: T,
451 index: usize,
452}
453
454impl<T: CoordFloat> ::std::cmp::Eq for NextSegment<T> {}
455
456impl<T: CoordFloat> ::std::cmp::PartialOrd for NextSegment<T> {
457 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
458 self.x.partial_cmp(&other.x).map(|r| r.reverse())
459 }
460}
461
462impl<T: CoordFloat> ::std::cmp::Ord for NextSegment<T> {
463 fn cmp(&self, other: &Self) -> Ordering {
464 bogus_compare(self, other)
465 }
466}
467
468pub struct PointsOfInflectionIterator<'a, T: CoordFloat + 'a> {
472 segment_iterators: Vec<::std::iter::Peekable<SegmentsIterator<'a, T>>>,
473 heap: BinaryHeap<NextSegment<T>>,
474 initial: bool,
475}
476
477impl<'a, T: CoordFloat + 'a> PointsOfInflectionIterator<'a, T> {
478 fn initialize(
481 segment_iterators: &mut [::std::iter::Peekable<SegmentsIterator<'a, T>>],
482 heap: &mut BinaryHeap<NextSegment<T>>,
483 ) -> (T, Vec<T>) {
484 let values = segment_iterators
485 .iter_mut()
486 .enumerate()
487 .map(|(index, it)| {
488 let seg = it.peek().unwrap();
489 heap.push(NextSegment {
490 x: seg.end.x,
491 index,
492 });
493 seg.start.y
494 })
495 .collect();
496 let x = segment_iterators[0].peek().unwrap().start.x;
497 (x, values)
498 }
499}
500
501impl<'a, T: CoordFloat + 'a> Iterator for PointsOfInflectionIterator<'a, T> {
502 type Item = (T, Vec<T>);
503
504 fn next(&mut self) -> Option<Self::Item> {
505 if self.initial {
506 self.initial = false;
507 Some(Self::initialize(
508 &mut self.segment_iterators,
509 &mut self.heap,
510 ))
511 } else {
512 self.heap.peek().cloned().map(|next_segment| {
513 let x = next_segment.x;
514 let values = self
515 .segment_iterators
516 .iter_mut()
517 .map(|segment_iterator| y_at_x(segment_iterator.peek().unwrap(), x))
518 .collect();
519
520 while let Some(segt) = self.heap.peek().cloned() {
521 if segt.x != x {
522 break;
523 }
524 self.heap.pop();
525 self.segment_iterators[segt.index].next();
526 if let Some(segment) = self.segment_iterators[segt.index].peek().cloned() {
527 self.heap.push(NextSegment {
528 x: segment.end.x,
529 index: segt.index,
530 })
531 }
532 }
533
534 (x, values)
535 })
536 }
537 }
538}
539
540pub struct SegmentsIterator<'a, T: CoordFloat + 'a>(
542 ::std::iter::Peekable<::std::slice::Iter<'a, Coord<T>>>,
543);
544
545impl<'a, T: CoordFloat + 'a> Iterator for SegmentsIterator<'a, T> {
546 type Item = Line<T>;
547
548 fn next(&mut self) -> Option<Self::Item> {
549 self.0
550 .next()
551 .and_then(|first| self.0.peek().map(|second| Line::new(*first, **second)))
552 }
553}
554
555pub fn points_of_inflection_iter<'a, T: CoordFloat + 'a>(
579 funcs: &'a [PiecewiseLinearFunction<T>],
580) -> Option<PointsOfInflectionIterator<'a, T>> {
581 if funcs.is_empty() || !funcs.windows(2).all(|w| w[0].has_same_domain_as(&w[1])) {
582 return None;
583 }
584 Some(PointsOfInflectionIterator {
585 segment_iterators: funcs.iter().map(|f| f.segments_iter().peekable()).collect(),
586 heap: BinaryHeap::new(),
587 initial: true,
588 })
589}
590
591pub fn sum<'a, T: CoordFloat + ::std::iter::Sum + 'a>(
595 funcs: &[PiecewiseLinearFunction<T>],
596) -> Option<PiecewiseLinearFunction<T>> {
597 points_of_inflection_iter(funcs).map(|poi| {
598 PiecewiseLinearFunction::new(
599 poi.map(|(x, values)| Coord {
600 x,
601 y: values.iter().cloned().sum(),
602 })
603 .collect(),
604 )
605 .unwrap()
607 })
608}
609
610fn line_in_domain<T: CoordFloat>(l: &Line<T>, domain: (T, T)) -> Option<Line<T>> {
615 if l.end.x <= domain.0 || l.start.x >= domain.1 {
616 None
617 } else {
618 let left_point = if l.start.x >= domain.0 {
619 l.start
620 } else {
621 (domain.0, y_at_x(l, domain.0)).into()
622 };
623 let right_point = if l.end.x <= domain.1 {
624 l.end
625 } else {
626 (domain.1, y_at_x(l, domain.1)).into()
627 };
628 Some(Line::new(left_point, right_point))
629 }
630}
631
632fn y_at_x<T: CoordFloat>(line: &Line<T>, x: T) -> T {
633 line.start.y + (x - line.start.x) * line.slope()
634}
635
636fn line_intersect<T: CoordFloat>(l1: &Line<T>, l2: &Line<T>) -> (T, T) {
637 let y_intercept_1 = l1.start.y - l1.start.x * l1.slope();
638 let y_intercept_2 = l2.start.y - l2.start.x * l2.slope();
639
640 let x_intersect = (y_intercept_2 - y_intercept_1) / (l1.slope() - l2.slope());
641 let y_intersect = y_at_x(l1, x_intersect);
642 (x_intersect, y_intersect)
643}
644
645fn compare_domains<T: CoordFloat>(d1: (T, T), d2: (T, T)) -> Option<Ordering> {
646 if d1 == d2 {
647 Some(Ordering::Equal)
648 } else if d1.0 <= d2.0 && d1.1 >= d2.1 {
649 Some(Ordering::Greater)
650 } else if d2.0 <= d1.0 && d2.1 >= d1.1 {
651 Some(Ordering::Less)
652 } else {
653 None
654 }
655}
656
657fn bogus_compare<T: PartialOrd>(a: &T, b: &T) -> Ordering {
658 a.partial_cmp(b).unwrap_or(Ordering::Equal)
659}
660
661fn argmax<T: CoordFloat>(values: &[T]) -> Option<(usize, &T)> {
662 values
663 .iter()
664 .enumerate()
665 .max_by(|(_, a), (_, b)| bogus_compare(a, b))
666}
667
668#[cfg(test)]
669mod tests {
670 use std::convert::TryInto;
671
672 use super::*;
673
674 fn get_test_function() -> PiecewiseLinearFunction<f64> {
675 PiecewiseLinearFunction::try_from(vec![
676 (-5.25, std::f64::MIN),
677 (-std::f64::consts::FRAC_PI_2, 0.1),
678 (-std::f64::consts::FRAC_PI_3, 0.1 + std::f64::EPSILON),
679 (0.1, 1.),
680 (1., 2.),
681 (2., 3.),
682 (3., 4.),
683 (std::f64::INFINITY, std::f64::NEG_INFINITY),
684 ])
685 .unwrap()
686 }
687
688 #[test]
689 fn test_y_at_x() {
690 assert_eq!(y_at_x(&Line::new((0., 0.), (1., 1.)), 0.25), 0.25);
691 assert_eq!(y_at_x(&Line::new((1., 0.), (2., 1.)), 1.25), 0.25);
692 }
693
694 #[test]
695 fn test_constant() {
696 assert_eq!(PiecewiseLinearFunction::constant((0.5, 0.5), 1.), None);
697 assert_eq!(PiecewiseLinearFunction::constant((0.5, -0.5), 1.), None);
698 assert_eq!(
699 PiecewiseLinearFunction::constant((-25., -13.), 1.).unwrap(),
700 vec![(-25., 1.), (-13., 1.)].try_into().unwrap()
701 );
702 }
703
704 #[test]
705 fn test_domain() {
706 assert_eq!(
707 PiecewiseLinearFunction::constant((-4., 5.25), 8.2)
708 .unwrap()
709 .domain(),
710 (-4., 5.25)
711 );
712 assert_eq!(
713 PiecewiseLinearFunction::try_from(vec![
714 (std::f64::NEG_INFINITY, -1.),
715 (0., 0.),
716 (std::f64::INFINITY, 0.)
717 ])
718 .unwrap()
719 .domain(),
720 (std::f64::NEG_INFINITY, std::f64::INFINITY)
721 );
722 }
723
724 #[test]
725 fn test_segment_at_x() {
726 assert_eq!(
727 get_test_function().segment_at_x(1.5).unwrap(),
728 Line::new((1., 2.), (2., 3.))
729 );
730 assert_eq!(
731 get_test_function().segment_at_x(1.).unwrap(),
732 Line::new((0.1, 1.), (1., 2.))
733 );
734 }
735
736 #[test]
737 fn test_segments_iter() {
738 let f = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1., 1.), (2., 1.5)]).unwrap();
739 assert_eq!(
740 f.segments_iter().collect::<Vec<_>>(),
741 vec![
742 Line::new((0., 0.), (1., 1.)),
743 Line::new((1., 1.), (2., 1.5))
744 ]
745 );
746 }
747
748 #[test]
749 fn test_points_of_inflection_iter() {
750 let f = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1., 1.), (2., 1.5)]).unwrap();
751 let g = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1.5, 3.), (2., 10.)]).unwrap();
752 assert_eq!(
753 f.points_of_inflection_iter(&g).unwrap().collect::<Vec<_>>(),
754 vec![
755 (0., vec![0., 0.]),
756 (1., vec![1., 2.]),
757 (1.5, vec![1.25, 3.]),
758 (2., vec![1.5, 10.])
759 ]
760 );
761 }
762
763 #[test]
764 fn test_line_in_domain() {
765 assert_eq!(
767 line_in_domain(&Line::new((-1., 1.), (0., 2.)), (1., 2.)),
768 None
769 );
770 assert_eq!(
771 line_in_domain(&Line::new((-1., 1.), (0., 2.)), (-3., -2.)),
772 None
773 );
774 assert_eq!(
775 line_in_domain(&Line::new((-1., 1.), (0., 2.)), (0., 1.)),
776 None
777 );
778
779 assert_eq!(
781 line_in_domain(&Line::new((-1., 1.), (0., 2.)), (-2., 1.)),
782 Some(Line::new((-1., 1.), (0., 2.)))
783 );
784
785 assert_eq!(
787 line_in_domain(&Line::new((-1., 1.), (0., 2.)), (-0.5, 0.5)),
788 Some(Line::new((-0.5, 1.5), (0., 2.)))
789 );
790
791 assert_eq!(
793 line_in_domain(&Line::new((-1., 1.), (0., 2.)), (-1., -0.25)),
794 Some(Line::new((-1., 1.), (-0.25, 1.75)))
795 );
796
797 assert_eq!(
799 line_in_domain(&Line::new((-1., 1.), (0., 2.)), (-0.75, -0.25)),
800 Some(Line::new((-0.75, 1.25), (-0.25, 1.75)))
801 );
802 }
803
804 #[test]
805 fn test_shrink_domain() {
806 let first_val = y_at_x(
807 &Line::new(
808 (-std::f64::consts::FRAC_PI_3, 0.1 + std::f64::EPSILON),
809 (0.1, 1.),
810 ),
811 0.,
812 );
813 assert_eq!(
814 get_test_function()
815 .shrink_domain((0.0, std::f64::INFINITY))
816 .unwrap(),
817 PiecewiseLinearFunction::try_from(vec![
818 (0., first_val),
819 (0.1, 1.),
820 (1., 2.),
821 (2., 3.),
822 (3., 4.),
823 (std::f64::INFINITY, std::f64::NEG_INFINITY),
824 ])
825 .unwrap()
826 );
827 }
828
829 #[test]
830 fn test_expand_domain() {
831 let f = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1., 1.), (2., 1.5)]).unwrap();
832
833 assert_eq!(
835 f.expand_domain((0., 2.), ExpandDomainStrategy::ExtendSegment),
836 f
837 );
838
839 assert_eq!(
841 f.expand_domain((-1., 2.), ExpandDomainStrategy::ExtendSegment),
842 vec![(-1., -1.), (1., 1.), (2., 1.5)].try_into().unwrap()
843 );
844 assert_eq!(
845 f.expand_domain((-1., 2.), ExpandDomainStrategy::ExtendValue),
846 vec![(-1., 0.), (0., 0.), (1., 1.), (2., 1.5)]
847 .try_into()
848 .unwrap()
849 );
850
851 assert_eq!(
853 f.expand_domain((0., 4.), ExpandDomainStrategy::ExtendSegment),
854 vec![(0., 0.), (1., 1.), (4., 2.5)].try_into().unwrap()
855 );
856 assert_eq!(
857 f.expand_domain((0., 4.), ExpandDomainStrategy::ExtendValue),
858 vec![(0., 0.), (1., 1.), (2., 1.5), (4., 1.5)]
859 .try_into()
860 .unwrap()
861 );
862 }
863
864 #[test]
865 fn test_negative() {
866 let f = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1., 1.), (2., 1.5)]).unwrap();
867 assert_eq!(
868 f.negate(),
869 vec![(0., -0.), (1., -1.), (2., -1.5)].try_into().unwrap()
870 )
871 }
872
873 #[test]
874 fn test_line_intersect() {
875 assert_eq!(
876 line_intersect(
877 &Line::new((-2., -1.), (5., 3.)),
878 &Line::new((-1., 4.), (6., 2.))
879 ),
880 (4. + 1. / 6., 2. + 11. / 21.)
881 );
882 }
883}