piecewise_linear/
lib.rs

1// Copyright 2019 Matthieu Felix
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! This crate provides utilities to manipulate continuous
16//! [piecewise linear functions](https://en.wikipedia.org/wiki/Piecewise_linear_function).
17//!
18//! They are internally represented as a list of `(x, y)` pairs, each representing a point of
19//! inflection (or equivalently a limit between two linear pieces). The represented function is
20//! assumed to be linear between each of these points.
21//!
22//! ## Domains
23//!
24//! The domain of a function is the range over which it is defined, that is, the range between
25//! the smallest _x_ coordinate and the greatest one in the function's definition points.
26//!
27//! Most methods will refuse to operate on two (or more) functions that do not have the same
28//! domain. You can use `expand_domain()` and `shrink_domain()` to adapt domains.
29//!
30//! Domains over all real numbers should be possible by using ±inf _x_ values, but this has not
31//! been extensively tested.
32//!
33//! ## Numeric types
34//!
35//! This crate should support functions using any `CoordFloat` (more or less a rust-num `Num`),
36//! however it has not been tested with types other than `f32` and `f64`.
37
38extern crate geo;
39extern crate num_traits;
40
41#[cfg(feature = "serde")]
42#[macro_use]
43extern crate serde;
44
45use std::cmp::Ordering;
46use std::collections::BinaryHeap;
47use std::convert::{TryFrom, TryInto};
48
49pub use geo::{Coord, CoordFloat, Line, LineString, Point};
50use num_traits::Signed;
51
52/// A continuous piecewise linear function.
53///
54/// The function is represented as a list of `(x, y)` pairs, each representing a point of
55/// inflection (or equivalently a limit between two linear pieces). The represented function is
56/// assumed to be linear between each of these points.
57///
58/// ## Invariants
59///
60/// All methods defined on `PiecewiseLinearFunction` preserve the following invariants:
61///
62///   * There are at least two coordinates in the `coordinates` array
63///   * The coordinates are in strictly increasing order of `x` value.
64///
65/// However, two consecutive segments do not necessarily have different slopes. These methods
66/// will panic if invariants are broken by manually editing the `coordinates` vector.
67///
68/// This representation means that functions defined on an empty or singleton set, as well as
69/// discontinuous functions, are not supported.
70///
71/// ## Example
72///
73/// ```
74/// use piecewise_linear::PiecewiseLinearFunction;
75/// use std::convert::TryFrom;
76/// let f = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1., 1.), (2., 1.5)]).unwrap();
77/// assert_eq!(f.y_at_x(1.25), Some(1.125));
78/// ```
79#[derive(PartialEq, Clone, Debug)]
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81pub struct PiecewiseLinearFunction<T: CoordFloat> {
82    /// Vector of points that make up the function.
83    pub coordinates: Vec<Coord<T>>,
84}
85
86impl<T: CoordFloat> PiecewiseLinearFunction<T> {
87    /// Creates a new `PiecewiseLinearFunction` from a vector of `Coordinates`.
88    ///
89    /// Returns a new PicewiseLinearFunction, or `None` if the invariants were not respected.
90    pub fn new(coordinates: Vec<Coord<T>>) -> Option<Self> {
91        if coordinates.len() >= 2 && coordinates.windows(2).all(|w| w[0].x < w[1].x) {
92            Some(PiecewiseLinearFunction { coordinates })
93        } else {
94            None
95        }
96    }
97
98    /// Returns a new constant `PiecewiseLinearFunction` with the specified domain and value.
99    ///
100    /// Returns `None` if the domain is not valid (i.e. `domain.1 <= domain.0`).
101    pub fn constant(domain: (T, T), value: T) -> Option<Self> {
102        if domain.0 < domain.1 {
103            let coordinates = vec![(domain.0, value).into(), (domain.1, value).into()];
104            Some(PiecewiseLinearFunction { coordinates })
105        } else {
106            None
107        }
108    }
109
110    /// Returns a function's domain, represented as its min and max.
111    pub fn domain(&self) -> (T, T) {
112        (self.coordinates[0].x, self.coordinates.last().unwrap().x)
113    }
114
115    /// Checks whether this function has the same domain as another one.
116    pub fn has_same_domain_as(&self, other: &PiecewiseLinearFunction<T>) -> bool {
117        self.domain() == other.domain()
118    }
119
120    /// Returns an iterator over the segments of f.
121    ///
122    /// This iterator is guaranteed to have at least one element.
123    pub fn segments_iter(&self) -> SegmentsIterator<T> {
124        SegmentsIterator(self.coordinates.iter().peekable())
125    }
126
127    /// Returns an iterator over the joint points of inflection of `self` and `other`.
128    ///
129    /// See `points_of_inflection_iter()` in this module for details.
130    pub fn points_of_inflection_iter<'a>(
131        &'a self,
132        other: &'a PiecewiseLinearFunction<T>,
133    ) -> Option<PointsOfInflectionIterator<T>> {
134        if !self.has_same_domain_as(other) {
135            None
136        } else {
137            Some(PointsOfInflectionIterator {
138                segment_iterators: vec![
139                    self.segments_iter().peekable(),
140                    other.segments_iter().peekable(),
141                ],
142                heap: BinaryHeap::new(),
143                initial: true,
144            })
145        }
146    }
147
148    /// Returns a segment `((x1, y1), (x2, y2))` of this function such that `x1 <= x <= x2`.
149    ///
150    /// Returns `None` if `x` is outside the domain of f.
151    pub fn segment_at_x(&self, x: T) -> Option<Line<T>> {
152        let idx = match self
153            .coordinates
154            .binary_search_by(|val| bogus_compare(&val.x, &x))
155        {
156            Ok(idx) => idx,
157            Err(idx) => {
158                if idx == 0 || idx == self.coordinates.len() {
159                    // Outside the function's domain
160                    return None;
161                } else {
162                    idx
163                }
164            }
165        };
166
167        if idx == 0 {
168            Some(Line::new(self.coordinates[idx], self.coordinates[idx + 1]))
169        } else {
170            Some(Line::new(self.coordinates[idx - 1], self.coordinates[idx]))
171        }
172    }
173
174    /// Computes the value f(x) for this piecewise linear function.
175    ///
176    /// Returns `None` if `x` is outside the domain of f.
177    pub fn y_at_x(&self, x: T) -> Option<T> {
178        self.segment_at_x(x).map(|line| y_at_x(&line, x))
179    }
180
181    /// Returns a new piecewise linear function that is the restriction of this function to the
182    /// specified domain.
183    ///
184    /// Returns `None` if `to_domain` is not a subset of the domain of `self`.
185    pub fn shrink_domain(&self, to_domain: (T, T)) -> Option<PiecewiseLinearFunction<T>> {
186        let order = compare_domains(self.domain(), to_domain);
187        match order {
188            Some(Ordering::Equal) => Some(self.clone()),
189            Some(Ordering::Greater) => {
190                let mut new_points = Vec::new();
191                for segment in self.segments_iter() {
192                    if let Some(restricted) = line_in_domain(&segment, to_domain) {
193                        // segment.start.x was segment.end.x at the last iteration; it it's less
194                        // than or equal to the domain's start, the previous segment was totally
195                        // discarded, but this point should still be added.
196                        if segment.start.x <= to_domain.0 {
197                            new_points.push(restricted.start);
198                        }
199                        new_points.push(restricted.end);
200                    }
201                }
202                Some(new_points.try_into().unwrap())
203            }
204            _ => None,
205        }
206    }
207
208    /// Returns a new piecewise linear function that is the expansion of this function to the
209    /// specified domain.
210    ///
211    /// At most one value is added on either side. See `ExpandDomainStrategy` for options
212    /// determining how these added values are picked.
213    pub fn expand_domain(
214        &self,
215        to_domain: (T, T),
216        strategy: ExpandDomainStrategy,
217    ) -> PiecewiseLinearFunction<T> {
218        if compare_domains(self.domain(), to_domain) == Some(Ordering::Equal) {
219            return self.clone();
220        }
221        let mut new_points = Vec::new();
222        if self.coordinates[0].x > to_domain.0 {
223            match &strategy {
224                ExpandDomainStrategy::ExtendSegment => new_points.push(Coord {
225                    x: to_domain.0,
226                    y: y_at_x(
227                        &Line::new(self.coordinates[0], self.coordinates[1]),
228                        to_domain.0,
229                    ),
230                }),
231                ExpandDomainStrategy::ExtendValue => {
232                    new_points.push((to_domain.0, self.coordinates[0].y).into());
233                    new_points.push(self.coordinates[0]);
234                }
235            }
236        } else {
237            new_points.push(self.coordinates[0]);
238        }
239
240        let last_index = self.coordinates.len() - 1;
241        new_points.extend_from_slice(&self.coordinates[1..last_index]);
242
243        if self.coordinates[last_index].x < to_domain.1 {
244            match &strategy {
245                ExpandDomainStrategy::ExtendSegment => new_points.push(Coord {
246                    x: to_domain.1,
247                    y: y_at_x(
248                        &Line::new(
249                            self.coordinates[last_index - 1],
250                            self.coordinates[last_index],
251                        ),
252                        to_domain.1,
253                    ),
254                }),
255                ExpandDomainStrategy::ExtendValue => {
256                    new_points.push(self.coordinates[last_index]);
257                    new_points.push((to_domain.1, self.coordinates[last_index].y).into());
258                }
259            }
260        } else {
261            new_points.push(self.coordinates[last_index])
262        }
263
264        new_points.try_into().unwrap()
265    }
266
267    /// Sums this method with another piecewise linear function.
268    ///
269    /// Both functions must have the same domain; returns `None` otherwise.
270    pub fn add(&self, other: &PiecewiseLinearFunction<T>) -> Option<PiecewiseLinearFunction<T>> {
271        self.points_of_inflection_iter(other).map(|poi| {
272            PiecewiseLinearFunction::new(
273                poi.map(|(x, coords)| Coord {
274                    x,
275                    y: coords[0] + coords[1],
276                })
277                .collect(),
278            )
279            // This unwrap is guaranteed to succeed as the starting POI has generates ordered x,
280            // which do not get modified.
281            .unwrap()
282        })
283    }
284
285    /// Returns a new piecewise linear function that is the maximum of `self` and `other`.
286    ///
287    /// Note that the resulting function may have more points of inflection than either function.
288    /// For instance,
289    ///
290    /// ## Example
291    ///
292    /// ```
293    /// use piecewise_linear::PiecewiseLinearFunction;
294    /// use std::convert::TryFrom;
295    /// let f = PiecewiseLinearFunction::try_from(vec![(0., 1.), (1., 0.)]).unwrap();
296    /// let g = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1., 1.)]).unwrap();
297    /// assert_eq!(
298    ///     f.max(&g).unwrap(),
299    ///     PiecewiseLinearFunction::try_from(vec![(0., 1.), (0.5, 0.5), (1., 1.)]).unwrap()
300    /// );
301    /// ```
302    ///
303    /// Returns `None` if the domains of `self` and `other` are not equal.
304    pub fn max(&self, other: &PiecewiseLinearFunction<T>) -> Option<PiecewiseLinearFunction<T>> {
305        let mut poi_iter = self.points_of_inflection_iter(other)?;
306        let mut new_values = Vec::new();
307
308        let (x, values) = poi_iter.next().unwrap();
309        let (i_largest, largest) = argmax(&values).unwrap();
310        new_values.push(Coord { x, y: *largest });
311
312        let mut prev_largest = i_largest;
313        let mut prev_x = x;
314        let mut prev_values = values;
315
316        for (x, values) in poi_iter {
317            let (i_largest, largest) = argmax(&values).unwrap();
318            if i_largest != prev_largest {
319                let (inter_x, inter_y) = line_intersect(
320                    &Line::new((prev_x, prev_values[0]), (x, values[0])),
321                    &Line::new((prev_x, prev_values[1]), (x, values[1])),
322                );
323                // This condition seems necessary as argmax() is likely unstable, so i_largest
324                // can change even if two lines remain equal.
325                if inter_x > prev_x && inter_x < x {
326                    new_values.push(Coord {
327                        x: inter_x,
328                        y: inter_y,
329                    });
330                }
331            }
332            new_values.push(Coord { x, y: *largest });
333            prev_largest = i_largest;
334            prev_x = x;
335            prev_values = values;
336        }
337
338        Some(PiecewiseLinearFunction::new(new_values).unwrap())
339    }
340}
341
342/// Controls how the domain of a function is expanded using `expand_domain()` on
343/// `PiecewiseLinearFunction`.
344#[derive(Copy, Clone, Debug, PartialEq, Eq)]
345pub enum ExpandDomainStrategy {
346    /// Extend the segment at the edge of the function.
347    ExtendSegment,
348    /// Add a constant segment with the value of the edge point of the function.
349    ExtendValue,
350}
351
352impl<T: CoordFloat + Signed> PiecewiseLinearFunction<T> {
353    /// Returns -f.
354    pub fn negate(&self) -> PiecewiseLinearFunction<T> {
355        PiecewiseLinearFunction::new(
356            self.coordinates
357                .iter()
358                .map(|Coord { x, y }| Coord { x: *x, y: -(*y) })
359                .collect(),
360        )
361        // This unwrap is guaranteed to succeed because the coordinate's x values haven't changed.
362        .unwrap()
363    }
364
365    /// Computes the minimum of this function and `other`.
366    ///
367    /// Returns `None` in case of a domain error.
368    pub fn min(&self, other: &PiecewiseLinearFunction<T>) -> Option<PiecewiseLinearFunction<T>> {
369        Some(self.negate().max(&other.negate())?.negate())
370    }
371
372    /// Computes the absolute value of this function.
373    pub fn abs(&self) -> PiecewiseLinearFunction<T> {
374        self.max(&self.negate()).unwrap()
375    }
376}
377
378impl<T: CoordFloat + Signed> ::std::ops::Neg for PiecewiseLinearFunction<T> {
379    type Output = Self;
380
381    fn neg(self) -> Self::Output {
382        self.negate()
383    }
384}
385
386impl<T: CoordFloat + ::std::iter::Sum> PiecewiseLinearFunction<T> {
387    /// Returns the integral of the considered function over its entire domain.
388    pub fn integrate(&self) -> T {
389        self.segments_iter()
390            .map(|segment| {
391                let (min_y, max_y) = if segment.start.y < segment.end.y {
392                    (segment.start.y, segment.end.y)
393                } else {
394                    (segment.end.y, segment.start.y)
395                };
396                let x_span = segment.end.x - segment.start.x;
397                x_span * (min_y + max_y / T::from(2).unwrap())
398            })
399            .sum()
400    }
401}
402
403/**** Conversions ****/
404
405impl<T: CoordFloat> TryFrom<LineString<T>> for PiecewiseLinearFunction<T> {
406    type Error = ();
407
408    fn try_from(value: LineString<T>) -> Result<Self, Self::Error> {
409        PiecewiseLinearFunction::new(value.0).ok_or(())
410    }
411}
412
413impl<T: CoordFloat> TryFrom<Vec<Coord<T>>> for PiecewiseLinearFunction<T> {
414    type Error = ();
415
416    fn try_from(value: Vec<Coord<T>>) -> Result<Self, Self::Error> {
417        PiecewiseLinearFunction::new(value).ok_or(())
418    }
419}
420
421impl<T: CoordFloat> TryFrom<Vec<Point<T>>> for PiecewiseLinearFunction<T> {
422    type Error = ();
423
424    fn try_from(value: Vec<Point<T>>) -> Result<Self, Self::Error> {
425        PiecewiseLinearFunction::new(value.into_iter().map(|p| p.0).collect()).ok_or(())
426    }
427}
428
429impl<T: CoordFloat> TryFrom<Vec<(T, T)>> for PiecewiseLinearFunction<T> {
430    type Error = ();
431
432    fn try_from(value: Vec<(T, T)>) -> Result<Self, Self::Error> {
433        PiecewiseLinearFunction::new(value.into_iter().map(Coord::from).collect()).ok_or(())
434    }
435}
436
437impl<T: CoordFloat> From<PiecewiseLinearFunction<T>> for Vec<(T, T)> {
438    fn from(val: PiecewiseLinearFunction<T>) -> Self {
439        val.coordinates
440            .into_iter()
441            .map(|coord| coord.x_y())
442            .collect()
443    }
444}
445
446/**** Iterators ****/
447
448#[derive(Debug, Clone, Copy, PartialEq)]
449struct NextSegment<T: CoordFloat> {
450    x: T,
451    index: usize,
452}
453
454impl<T: CoordFloat> ::std::cmp::Eq for NextSegment<T> {}
455
456impl<T: CoordFloat> ::std::cmp::PartialOrd for NextSegment<T> {
457    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
458        self.x.partial_cmp(&other.x).map(|r| r.reverse())
459    }
460}
461
462impl<T: CoordFloat> ::std::cmp::Ord for NextSegment<T> {
463    fn cmp(&self, other: &Self) -> Ordering {
464        bogus_compare(self, other)
465    }
466}
467
468/// Structure returned by `points_of_inflection_iter()`
469///
470/// See that function's documentation for details.
471pub struct PointsOfInflectionIterator<'a, T: CoordFloat + 'a> {
472    segment_iterators: Vec<::std::iter::Peekable<SegmentsIterator<'a, T>>>,
473    heap: BinaryHeap<NextSegment<T>>,
474    initial: bool,
475}
476
477impl<'a, T: CoordFloat + 'a> PointsOfInflectionIterator<'a, T> {
478    /// Helper method to avoid having rust complain about mutably accessing the segment iterators
479    /// and heap at the same time.
480    fn initialize(
481        segment_iterators: &mut [::std::iter::Peekable<SegmentsIterator<'a, T>>],
482        heap: &mut BinaryHeap<NextSegment<T>>,
483    ) -> (T, Vec<T>) {
484        let values = segment_iterators
485            .iter_mut()
486            .enumerate()
487            .map(|(index, it)| {
488                let seg = it.peek().unwrap();
489                heap.push(NextSegment {
490                    x: seg.end.x,
491                    index,
492                });
493                seg.start.y
494            })
495            .collect();
496        let x = segment_iterators[0].peek().unwrap().start.x;
497        (x, values)
498    }
499}
500
501impl<'a, T: CoordFloat + 'a> Iterator for PointsOfInflectionIterator<'a, T> {
502    type Item = (T, Vec<T>);
503
504    fn next(&mut self) -> Option<Self::Item> {
505        if self.initial {
506            self.initial = false;
507            Some(Self::initialize(
508                &mut self.segment_iterators,
509                &mut self.heap,
510            ))
511        } else {
512            self.heap.peek().cloned().map(|next_segment| {
513                let x = next_segment.x;
514                let values = self
515                    .segment_iterators
516                    .iter_mut()
517                    .map(|segment_iterator| y_at_x(segment_iterator.peek().unwrap(), x))
518                    .collect();
519
520                while let Some(segt) = self.heap.peek().cloned() {
521                    if segt.x != x {
522                        break;
523                    }
524                    self.heap.pop();
525                    self.segment_iterators[segt.index].next();
526                    if let Some(segment) = self.segment_iterators[segt.index].peek().cloned() {
527                        self.heap.push(NextSegment {
528                            x: segment.end.x,
529                            index: segt.index,
530                        })
531                    }
532                }
533
534                (x, values)
535            })
536        }
537    }
538}
539
540/// Structure returned by `segments_iter()` on a `PiecewiseLinearFunction`.
541pub struct SegmentsIterator<'a, T: CoordFloat + 'a>(
542    ::std::iter::Peekable<::std::slice::Iter<'a, Coord<T>>>,
543);
544
545impl<'a, T: CoordFloat + 'a> Iterator for SegmentsIterator<'a, T> {
546    type Item = Line<T>;
547
548    fn next(&mut self) -> Option<Self::Item> {
549        self.0
550            .next()
551            .and_then(|first| self.0.peek().map(|second| Line::new(*first, **second)))
552    }
553}
554
555/**** General functions ****/
556
557/// Returns an iterator over pairs `(x, values)`, where `x` is the union of all points of
558/// inflection of `self` and `other`, and `values` is a vector of the values of all passed
559/// functions, in the same order, at the corresponding `x`.
560///
561/// ## Example
562///
563/// ```
564/// use std::convert::TryFrom;
565/// use piecewise_linear::{PiecewiseLinearFunction, points_of_inflection_iter};
566/// let f = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1., 1.), (2., 1.5)]).unwrap();
567/// let g = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1.5, 3.), (2., 10.)]).unwrap();
568/// assert_eq!(
569///     points_of_inflection_iter(vec![f, g].as_slice()).unwrap().collect::<Vec<_>>(),
570///     vec![(0., vec![0., 0.]), (1., vec![1., 2.]), (1.5, vec![1.25, 3.]), (2., vec![1.5, 10.])]
571/// );
572/// ```
573///
574/// ## Complexity
575///
576/// The complexity of this method is _O(k log(k) n)_, where _k_ is the number of functions passed,
577/// and _n_ is the number of points in each function.
578pub fn points_of_inflection_iter<'a, T: CoordFloat + 'a>(
579    funcs: &'a [PiecewiseLinearFunction<T>],
580) -> Option<PointsOfInflectionIterator<'a, T>> {
581    if funcs.is_empty() || !funcs.windows(2).all(|w| w[0].has_same_domain_as(&w[1])) {
582        return None;
583    }
584    Some(PointsOfInflectionIterator {
585        segment_iterators: funcs.iter().map(|f| f.segments_iter().peekable()).collect(),
586        heap: BinaryHeap::new(),
587        initial: true,
588    })
589}
590
591/// Sums the functions together. Returns `None` in case of domain error.
592///
593/// This is faster than calling .add() repeatedly by a factor of _k / log(k)_.
594pub fn sum<'a, T: CoordFloat + ::std::iter::Sum + 'a>(
595    funcs: &[PiecewiseLinearFunction<T>],
596) -> Option<PiecewiseLinearFunction<T>> {
597    points_of_inflection_iter(funcs).map(|poi| {
598        PiecewiseLinearFunction::new(
599            poi.map(|(x, values)| Coord {
600                x,
601                y: values.iter().cloned().sum(),
602            })
603            .collect(),
604        )
605        // This unwrap is guaranteed to succeed because the coordinate's x values haven't changed.
606        .unwrap()
607    })
608}
609
610/**** Helpers ****/
611
612/// Returns the restriction of segment `l` to the given domain, or `None` if the line's
613/// intersection with the domain is either a singleton or empty.
614fn line_in_domain<T: CoordFloat>(l: &Line<T>, domain: (T, T)) -> Option<Line<T>> {
615    if l.end.x <= domain.0 || l.start.x >= domain.1 {
616        None
617    } else {
618        let left_point = if l.start.x >= domain.0 {
619            l.start
620        } else {
621            (domain.0, y_at_x(l, domain.0)).into()
622        };
623        let right_point = if l.end.x <= domain.1 {
624            l.end
625        } else {
626            (domain.1, y_at_x(l, domain.1)).into()
627        };
628        Some(Line::new(left_point, right_point))
629    }
630}
631
632fn y_at_x<T: CoordFloat>(line: &Line<T>, x: T) -> T {
633    line.start.y + (x - line.start.x) * line.slope()
634}
635
636fn line_intersect<T: CoordFloat>(l1: &Line<T>, l2: &Line<T>) -> (T, T) {
637    let y_intercept_1 = l1.start.y - l1.start.x * l1.slope();
638    let y_intercept_2 = l2.start.y - l2.start.x * l2.slope();
639
640    let x_intersect = (y_intercept_2 - y_intercept_1) / (l1.slope() - l2.slope());
641    let y_intersect = y_at_x(l1, x_intersect);
642    (x_intersect, y_intersect)
643}
644
645fn compare_domains<T: CoordFloat>(d1: (T, T), d2: (T, T)) -> Option<Ordering> {
646    if d1 == d2 {
647        Some(Ordering::Equal)
648    } else if d1.0 <= d2.0 && d1.1 >= d2.1 {
649        Some(Ordering::Greater)
650    } else if d2.0 <= d1.0 && d2.1 >= d1.1 {
651        Some(Ordering::Less)
652    } else {
653        None
654    }
655}
656
657fn bogus_compare<T: PartialOrd>(a: &T, b: &T) -> Ordering {
658    a.partial_cmp(b).unwrap_or(Ordering::Equal)
659}
660
661fn argmax<T: CoordFloat>(values: &[T]) -> Option<(usize, &T)> {
662    values
663        .iter()
664        .enumerate()
665        .max_by(|(_, a), (_, b)| bogus_compare(a, b))
666}
667
668#[cfg(test)]
669mod tests {
670    use std::convert::TryInto;
671
672    use super::*;
673
674    fn get_test_function() -> PiecewiseLinearFunction<f64> {
675        PiecewiseLinearFunction::try_from(vec![
676            (-5.25, std::f64::MIN),
677            (-std::f64::consts::FRAC_PI_2, 0.1),
678            (-std::f64::consts::FRAC_PI_3, 0.1 + std::f64::EPSILON),
679            (0.1, 1.),
680            (1., 2.),
681            (2., 3.),
682            (3., 4.),
683            (std::f64::INFINITY, std::f64::NEG_INFINITY),
684        ])
685        .unwrap()
686    }
687
688    #[test]
689    fn test_y_at_x() {
690        assert_eq!(y_at_x(&Line::new((0., 0.), (1., 1.)), 0.25), 0.25);
691        assert_eq!(y_at_x(&Line::new((1., 0.), (2., 1.)), 1.25), 0.25);
692    }
693
694    #[test]
695    fn test_constant() {
696        assert_eq!(PiecewiseLinearFunction::constant((0.5, 0.5), 1.), None);
697        assert_eq!(PiecewiseLinearFunction::constant((0.5, -0.5), 1.), None);
698        assert_eq!(
699            PiecewiseLinearFunction::constant((-25., -13.), 1.).unwrap(),
700            vec![(-25., 1.), (-13., 1.)].try_into().unwrap()
701        );
702    }
703
704    #[test]
705    fn test_domain() {
706        assert_eq!(
707            PiecewiseLinearFunction::constant((-4., 5.25), 8.2)
708                .unwrap()
709                .domain(),
710            (-4., 5.25)
711        );
712        assert_eq!(
713            PiecewiseLinearFunction::try_from(vec![
714                (std::f64::NEG_INFINITY, -1.),
715                (0., 0.),
716                (std::f64::INFINITY, 0.)
717            ])
718            .unwrap()
719            .domain(),
720            (std::f64::NEG_INFINITY, std::f64::INFINITY)
721        );
722    }
723
724    #[test]
725    fn test_segment_at_x() {
726        assert_eq!(
727            get_test_function().segment_at_x(1.5).unwrap(),
728            Line::new((1., 2.), (2., 3.))
729        );
730        assert_eq!(
731            get_test_function().segment_at_x(1.).unwrap(),
732            Line::new((0.1, 1.), (1., 2.))
733        );
734    }
735
736    #[test]
737    fn test_segments_iter() {
738        let f = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1., 1.), (2., 1.5)]).unwrap();
739        assert_eq!(
740            f.segments_iter().collect::<Vec<_>>(),
741            vec![
742                Line::new((0., 0.), (1., 1.)),
743                Line::new((1., 1.), (2., 1.5))
744            ]
745        );
746    }
747
748    #[test]
749    fn test_points_of_inflection_iter() {
750        let f = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1., 1.), (2., 1.5)]).unwrap();
751        let g = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1.5, 3.), (2., 10.)]).unwrap();
752        assert_eq!(
753            f.points_of_inflection_iter(&g).unwrap().collect::<Vec<_>>(),
754            vec![
755                (0., vec![0., 0.]),
756                (1., vec![1., 2.]),
757                (1.5, vec![1.25, 3.]),
758                (2., vec![1.5, 10.])
759            ]
760        );
761    }
762
763    #[test]
764    fn test_line_in_domain() {
765        // Case 1 - fully outside
766        assert_eq!(
767            line_in_domain(&Line::new((-1., 1.), (0., 2.)), (1., 2.)),
768            None
769        );
770        assert_eq!(
771            line_in_domain(&Line::new((-1., 1.), (0., 2.)), (-3., -2.)),
772            None
773        );
774        assert_eq!(
775            line_in_domain(&Line::new((-1., 1.), (0., 2.)), (0., 1.)),
776            None
777        );
778
779        // Case 2 - fully inside
780        assert_eq!(
781            line_in_domain(&Line::new((-1., 1.), (0., 2.)), (-2., 1.)),
782            Some(Line::new((-1., 1.), (0., 2.)))
783        );
784
785        // Case 3 - overlap to the right
786        assert_eq!(
787            line_in_domain(&Line::new((-1., 1.), (0., 2.)), (-0.5, 0.5)),
788            Some(Line::new((-0.5, 1.5), (0., 2.)))
789        );
790
791        // Case 4 - overlap to the left
792        assert_eq!(
793            line_in_domain(&Line::new((-1., 1.), (0., 2.)), (-1., -0.25)),
794            Some(Line::new((-1., 1.), (-0.25, 1.75)))
795        );
796
797        // Case 5 - overlap on both sides
798        assert_eq!(
799            line_in_domain(&Line::new((-1., 1.), (0., 2.)), (-0.75, -0.25)),
800            Some(Line::new((-0.75, 1.25), (-0.25, 1.75)))
801        );
802    }
803
804    #[test]
805    fn test_shrink_domain() {
806        let first_val = y_at_x(
807            &Line::new(
808                (-std::f64::consts::FRAC_PI_3, 0.1 + std::f64::EPSILON),
809                (0.1, 1.),
810            ),
811            0.,
812        );
813        assert_eq!(
814            get_test_function()
815                .shrink_domain((0.0, std::f64::INFINITY))
816                .unwrap(),
817            PiecewiseLinearFunction::try_from(vec![
818                (0., first_val),
819                (0.1, 1.),
820                (1., 2.),
821                (2., 3.),
822                (3., 4.),
823                (std::f64::INFINITY, std::f64::NEG_INFINITY),
824            ])
825            .unwrap()
826        );
827    }
828
829    #[test]
830    fn test_expand_domain() {
831        let f = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1., 1.), (2., 1.5)]).unwrap();
832
833        // Case 1: no expansion
834        assert_eq!(
835            f.expand_domain((0., 2.), ExpandDomainStrategy::ExtendSegment),
836            f
837        );
838
839        // Case 2: left expansion
840        assert_eq!(
841            f.expand_domain((-1., 2.), ExpandDomainStrategy::ExtendSegment),
842            vec![(-1., -1.), (1., 1.), (2., 1.5)].try_into().unwrap()
843        );
844        assert_eq!(
845            f.expand_domain((-1., 2.), ExpandDomainStrategy::ExtendValue),
846            vec![(-1., 0.), (0., 0.), (1., 1.), (2., 1.5)]
847                .try_into()
848                .unwrap()
849        );
850
851        // Case 3: right expansion
852        assert_eq!(
853            f.expand_domain((0., 4.), ExpandDomainStrategy::ExtendSegment),
854            vec![(0., 0.), (1., 1.), (4., 2.5)].try_into().unwrap()
855        );
856        assert_eq!(
857            f.expand_domain((0., 4.), ExpandDomainStrategy::ExtendValue),
858            vec![(0., 0.), (1., 1.), (2., 1.5), (4., 1.5)]
859                .try_into()
860                .unwrap()
861        );
862    }
863
864    #[test]
865    fn test_negative() {
866        let f = PiecewiseLinearFunction::try_from(vec![(0., 0.), (1., 1.), (2., 1.5)]).unwrap();
867        assert_eq!(
868            f.negate(),
869            vec![(0., -0.), (1., -1.), (2., -1.5)].try_into().unwrap()
870        )
871    }
872
873    #[test]
874    fn test_line_intersect() {
875        assert_eq!(
876            line_intersect(
877                &Line::new((-2., -1.), (5., 3.)),
878                &Line::new((-1., 4.), (6., 2.))
879            ),
880            (4. + 1. / 6., 2. + 11. / 21.)
881        );
882    }
883}