lotus_extra/math/
piecewise_linear_function.rs

1//! # Piecewise Linear Function
2//!
3//! A Rust library for creating and evaluating piecewise linear functions.
4//!
5//! A piecewise linear function is defined by a series of connected line segments,
6//! where each segment connects two consecutive points. This library provides
7//! efficient storage, modification, and evaluation of such functions.
8//!
9//! ## Features
10//!
11//! - **Sorted storage**: Points are automatically kept sorted by x-coordinate
12//! - **Linear interpolation**: Values between points are calculated using linear interpolation
13//! - **Edge case handling**: Values outside the defined range return the nearest endpoint value
14//! - **Duplicate handling**: Points with the same x-coordinate overwrite previous values
15//! - **Error handling**: Comprehensive error handling for invalid inputs
16//!
17//! ## Quick Start
18//!
19//! ```rust
20//! use piecewise_linear_function::PiecewiseLinearFunction;
21//!
22//! // Create a function from a vector of points
23//! let points = vec![(0.0, 1.0), (2.0, 3.0), (4.0, 2.0)];
24//! let function = PiecewiseLinearFunction::new(points);
25//!
26//! // Evaluate the function at different points
27//! assert_eq!(function.get_value(0.0).unwrap(), 1.0);  // Exact match
28//! assert_eq!(function.get_value(1.0).unwrap(), 2.0);  // Interpolated
29//! assert_eq!(function.get_value(5.0).unwrap(), 2.0);  // Beyond range
30//! ```
31//!
32//! ## Building Functions
33//!
34//! You can create piecewise linear functions in several ways:
35//!
36//! ```rust
37//! use piecewise_linear_function::PiecewiseLinearFunction;
38//!
39//! // From a vector of points
40//! let function1 = PiecewiseLinearFunction::new(vec![(0.0, 1.0), (1.0, 2.0)]);
41//!
42//! // Start empty and add points
43//! let mut function2 = PiecewiseLinearFunction::empty();
44//! function2.add_point(0.0, 1.0).unwrap();
45//! function2.add_point(1.0, 2.0).unwrap();
46//!
47//! // From an iterator
48//! let points = vec![(0.0, 1.0), (1.0, 2.0)];
49//! let function3: PiecewiseLinearFunction = points.into_iter().collect();
50//! ```
51
52use std::fmt;
53
54/// Errors that can occur when working with piecewise linear functions.
55#[derive(Debug, Clone, PartialEq)]
56pub enum PiecewiseError {
57    /// The function has no points defined and cannot be evaluated.
58    EmptyFunction,
59    /// A point contains invalid coordinates (NaN or infinite values).
60    InvalidPoint,
61}
62
63impl fmt::Display for PiecewiseError {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        match self {
66            PiecewiseError::EmptyFunction => write!(f, "Function has no points defined"),
67            PiecewiseError::InvalidPoint => write!(f, "Invalid point coordinates"),
68        }
69    }
70}
71
72impl std::error::Error for PiecewiseError {}
73
74/// A piecewise linear function defined by a series of connected line segments.
75///
76/// The function is represented by a collection of points (x, y), where consecutive
77/// points are connected by straight lines. The points are automatically sorted by
78/// their x-coordinates for efficient evaluation.
79///
80/// # Behavior
81///
82/// - **Interpolation**: For x-values between defined points, the function uses linear interpolation
83/// - **Extrapolation**: For x-values outside the defined range, the function returns the y-value of the nearest endpoint
84/// - **Duplicate x-values**: Adding a point with an existing x-coordinate overwrites the previous y-value
85///
86/// # Examples
87///
88/// ```rust
89/// use piecewise_linear_function::PiecewiseLinearFunction;
90///
91/// let mut function = PiecewiseLinearFunction::empty();
92/// function.add_point(0.0, 0.0).unwrap();
93/// function.add_point(1.0, 2.0).unwrap();
94/// function.add_point(2.0, 1.0).unwrap();
95///
96/// // Evaluate at defined points
97/// assert_eq!(function.get_value(0.0).unwrap(), 0.0);
98/// assert_eq!(function.get_value(1.0).unwrap(), 2.0);
99///
100/// // Interpolation between points
101/// assert_eq!(function.get_value(0.5).unwrap(), 1.0);
102/// assert_eq!(function.get_value(1.5).unwrap(), 1.5);
103///
104/// // Extrapolation beyond range
105/// assert_eq!(function.get_value(-1.0).unwrap(), 0.0);  // Returns first point's y
106/// assert_eq!(function.get_value(3.0).unwrap(), 1.0);   // Returns last point's y
107/// ```
108#[derive(Debug)]
109pub struct PiecewiseLinearFunction {
110    /// Internal storage of points, kept sorted by x-coordinate.
111    points: Vec<(f32, f32)>,
112}
113
114impl PiecewiseLinearFunction {
115    /// Creates a new piecewise linear function from a vector of points.
116    ///
117    /// The points will be automatically sorted by their x-coordinates, and any
118    /// duplicate x-coordinates will result in the last y-value being kept.
119    ///
120    /// # Arguments
121    ///
122    /// * `points` - A vector of (x, y) coordinate pairs
123    ///
124    /// # Examples
125    ///
126    /// ```rust
127    /// use piecewise_linear_function::PiecewiseLinearFunction;
128    ///
129    /// // Points can be provided in any order
130    /// let points = vec![(2.0, 4.0), (0.0, 0.0), (1.0, 2.0)];
131    /// let function = PiecewiseLinearFunction::new(points);
132    ///
133    /// assert_eq!(function.len(), 3);
134    /// assert_eq!(function.get_value(0.5).unwrap(), 1.0);
135    /// ```
136    #[must_use]
137    pub fn new(points: Vec<(f32, f32)>) -> Self {
138        let mut fun = Self { points: Vec::new() };
139
140        for (x, y) in points {
141            fun.add_point_unchecked(x, y);
142        }
143
144        fun
145    }
146
147    /// Creates an empty piecewise linear function with no points.
148    ///
149    /// Points can be added later using [`add_point`](Self::add_point).
150    ///
151    /// # Examples
152    ///
153    /// ```rust
154    /// use piecewise_linear_function::PiecewiseLinearFunction;
155    ///
156    /// let mut function = PiecewiseLinearFunction::empty();
157    /// assert!(function.is_empty());
158    ///
159    /// function.add_point(1.0, 1.0).unwrap();
160    /// assert!(!function.is_empty());
161    /// ```
162    #[must_use]
163    pub fn empty() -> Self {
164        Self { points: Vec::new() }
165    }
166
167    /// Adds a point to the function.
168    ///
169    /// The point will be inserted in the correct position to maintain sorted order
170    /// by x-coordinate. If a point with the same x-coordinate already exists,
171    /// its y-value will be updated.
172    ///
173    /// # Arguments
174    ///
175    /// * `x` - The x-coordinate of the point (must be finite)
176    /// * `y` - The y-coordinate of the point (must be finite)
177    ///
178    /// # Errors
179    ///
180    /// Returns [`PiecewiseError::InvalidPoint`] if either coordinate is NaN or infinite.
181    ///
182    /// # Examples
183    ///
184    /// ```rust
185    /// use piecewise_linear_function::PiecewiseLinearFunction;
186    ///
187    /// let mut function = PiecewiseLinearFunction::empty();
188    ///
189    /// // Add points in any order
190    /// function.add_point(2.0, 4.0).unwrap();
191    /// function.add_point(0.0, 0.0).unwrap();
192    /// function.add_point(1.0, 2.0).unwrap();
193    ///
194    /// // Update existing point
195    /// function.add_point(1.0, 3.0).unwrap();
196    /// assert_eq!(function.get_value(1.0).unwrap(), 3.0);
197    ///
198    /// // Invalid coordinates are rejected
199    /// assert!(function.add_point(f32::NAN, 1.0).is_err());
200    /// assert!(function.add_point(1.0, f32::INFINITY).is_err());
201    /// ```
202    pub fn add_point(&mut self, x: f32, y: f32) -> Result<(), PiecewiseError> {
203        if !x.is_finite() || !y.is_finite() {
204            return Err(PiecewiseError::InvalidPoint);
205        }
206
207        self.add_point_unchecked(x, y);
208        Ok(())
209    }
210
211    /// Adds a point without validation (internal use only).
212    ///
213    /// This method assumes the coordinates are valid and finite.
214    fn add_point_unchecked(&mut self, x: f32, y: f32) {
215        match self
216            .points
217            .binary_search_by(|&(px, _)| px.partial_cmp(&x).unwrap())
218        {
219            Ok(pos) => {
220                self.points[pos].1 = y;
221            }
222            Err(pos) => {
223                self.points.insert(pos, (x, y));
224            }
225        }
226    }
227
228    /// Returns the number of points in the function.
229    ///
230    /// # Examples
231    ///
232    /// ```rust
233    /// use piecewise_linear_function::PiecewiseLinearFunction;
234    ///
235    /// let function = PiecewiseLinearFunction::new(vec![(0.0, 1.0), (1.0, 2.0)]);
236    /// assert_eq!(function.len(), 2);
237    /// ```
238    #[must_use]
239    pub fn len(&self) -> usize {
240        self.points.len()
241    }
242
243    /// Returns `true` if the function contains no points.
244    ///
245    /// # Examples
246    ///
247    /// ```rust
248    /// use piecewise_linear_function::PiecewiseLinearFunction;
249    ///
250    /// let empty_function = PiecewiseLinearFunction::empty();
251    /// assert!(empty_function.is_empty());
252    ///
253    /// let function = PiecewiseLinearFunction::new(vec![(0.0, 1.0)]);
254    /// assert!(!function.is_empty());
255    /// ```
256    #[must_use]
257    pub fn is_empty(&self) -> bool {
258        self.points.is_empty()
259    }
260
261    /// Returns the range of y-values in the function.
262    ///
263    /// This returns the minimum and maximum y-values among all defined points.
264    /// Note that interpolated values between points might fall outside this range.
265    ///
266    /// # Returns
267    ///
268    /// * `Some((min_y, max_y))` if the function has points
269    /// * `None` if the function is empty
270    ///
271    /// # Examples
272    ///
273    /// ```rust
274    /// use piecewise_linear_function::PiecewiseLinearFunction;
275    ///
276    /// let function = PiecewiseLinearFunction::new(vec![
277    ///     (0.0, 1.0),
278    ///     (1.0, 5.0),
279    ///     (2.0, 2.0)
280    /// ]);
281    ///
282    /// assert_eq!(function.range(), Some((1.0, 5.0)));
283    ///
284    /// let empty_function = PiecewiseLinearFunction::empty();
285    /// assert_eq!(empty_function.range(), None);
286    /// ```
287    #[must_use]
288    pub fn range(&self) -> Option<(f32, f32)> {
289        if self.points.is_empty() {
290            return None;
291        }
292
293        let mut min_y = self.points[0].1;
294        let mut max_y = self.points[0].1;
295
296        for &(_, y) in &self.points {
297            min_y = min_y.min(y);
298            max_y = max_y.max(y);
299        }
300
301        Some((min_y, max_y))
302    }
303
304    /// Evaluates the function at the given x-coordinate, returning 0.0 for empty functions.
305    ///
306    /// This is a convenience method that returns a default value instead of an error
307    /// when the function is empty.
308    ///
309    /// # Arguments
310    ///
311    /// * `x` - The x-coordinate to evaluate
312    ///
313    /// # Returns
314    ///
315    /// The function value at x, or 0.0 if the function is empty or x is invalid.
316    ///
317    /// # Examples
318    ///
319    /// ```rust
320    /// use piecewise_linear_function::PiecewiseLinearFunction;
321    ///
322    /// let function = PiecewiseLinearFunction::new(vec![(0.0, 1.0), (1.0, 2.0)]);
323    /// assert_eq!(function.get_value_or_default(0.5), 1.5);
324    ///
325    /// let empty_function = PiecewiseLinearFunction::empty();
326    /// assert_eq!(empty_function.get_value_or_default(0.5), 0.0);
327    /// ```
328    pub fn get_value_or_default(&self, x: f32) -> f32 {
329        self.get_value(x).unwrap_or(0.0)
330    }
331
332    /// Evaluates the function at the given x-coordinate.
333    ///
334    /// The function uses linear interpolation between defined points. For x-values
335    /// outside the defined range, it returns the y-value of the nearest endpoint.
336    ///
337    /// # Arguments
338    ///
339    /// * `x` - The x-coordinate to evaluate (must be finite)
340    ///
341    /// # Returns
342    ///
343    /// * `Ok(y)` - The function value at x
344    /// * `Err(PiecewiseError::EmptyFunction)` - If the function has no points
345    /// * `Err(PiecewiseError::InvalidPoint)` - If x is NaN or infinite
346    ///
347    /// # Examples
348    ///
349    /// ```rust
350    /// use piecewise_linear_function::PiecewiseLinearFunction;
351    ///
352    /// let function = PiecewiseLinearFunction::new(vec![
353    ///     (0.0, 0.0),
354    ///     (2.0, 4.0),
355    ///     (4.0, 2.0)
356    /// ]);
357    ///
358    /// // Exact matches
359    /// assert_eq!(function.get_value(0.0).unwrap(), 0.0);
360    /// assert_eq!(function.get_value(2.0).unwrap(), 4.0);
361    ///
362    /// // Linear interpolation
363    /// assert_eq!(function.get_value(1.0).unwrap(), 2.0);
364    /// assert_eq!(function.get_value(3.0).unwrap(), 3.0);
365    ///
366    /// // Extrapolation (returns nearest endpoint)
367    /// assert_eq!(function.get_value(-1.0).unwrap(), 0.0);
368    /// assert_eq!(function.get_value(5.0).unwrap(), 2.0);
369    ///
370    /// // Error cases
371    /// let empty = PiecewiseLinearFunction::empty();
372    /// assert!(empty.get_value(1.0).is_err());
373    /// assert!(function.get_value(f32::NAN).is_err());
374    /// ```
375    pub fn get_value(&self, x: f32) -> Result<f32, PiecewiseError> {
376        if self.points.is_empty() {
377            return Err(PiecewiseError::EmptyFunction);
378        }
379
380        if !x.is_finite() {
381            return Err(PiecewiseError::InvalidPoint);
382        }
383
384        // edge case, smaller than the existing values
385        if x <= self.points[0].0 {
386            return Ok(self.points[0].1);
387        }
388
389        // edge case, greater than the existing values
390        if x >= self.points[self.points.len() - 1].0 {
391            return Ok(self.points[self.points.len() - 1].1);
392        }
393
394        let pos = self
395            .points
396            .binary_search_by(|&(px, _)| px.partial_cmp(&x).unwrap());
397
398        match pos {
399            Ok(index) => {
400                // Exact match found
401                Ok(self.points[index].1)
402            }
403            Err(index) => {
404                let (x0, y0) = self.points[index - 1];
405                let (x1, y1) = self.points[index];
406
407                // Linear interpolation
408                let interpolated_y = y0 + (x - x0) * (y1 - y0) / (x1 - x0);
409                Ok(interpolated_y)
410            }
411        }
412    }
413}
414
415impl Default for PiecewiseLinearFunction {
416    /// Creates an empty piecewise linear function.
417    ///
418    /// This is equivalent to calling [`PiecewiseLinearFunction::empty()`].
419    fn default() -> Self {
420        Self::empty()
421    }
422}
423
424impl FromIterator<(f32, f32)> for PiecewiseLinearFunction {
425    /// Creates a piecewise linear function from an iterator of points.
426    ///
427    /// # Examples
428    ///
429    /// ```rust
430    /// use piecewise_linear_function::PiecewiseLinearFunction;
431    ///
432    /// let points = vec![(0.0, 1.0), (1.0, 2.0), (2.0, 1.5)];
433    /// let function: PiecewiseLinearFunction = points.into_iter().collect();
434    ///
435    /// assert_eq!(function.len(), 3);
436    /// assert_eq!(function.get_value(0.5).unwrap(), 1.5);
437    /// ```
438    fn from_iter<T: IntoIterator<Item = (f32, f32)>>(iter: T) -> Self {
439        let points: Vec<_> = iter.into_iter().collect();
440        Self::new(points)
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_basic_functionality() {
450        let mut piecewise = PiecewiseLinearFunction::empty();
451
452        piecewise.add_point(0.0, 1.0).unwrap();
453        piecewise.add_point(2.0, 3.0).unwrap();
454        piecewise.add_point(5.0, 2.0).unwrap();
455
456        assert_eq!(piecewise.get_value(0.0).unwrap(), 1.0);
457        assert_eq!(piecewise.get_value(1.0).unwrap(), 2.0);
458        assert_eq!(piecewise.get_value(3.5).unwrap(), 2.5);
459        assert_eq!(piecewise.get_value(6.0).unwrap(), 2.0);
460    }
461
462    #[test]
463    fn test_from_vec() {
464        let points = vec![(0.0, 1.0), (5.0, 2.0), (2.0, 3.0)]; // Unsorted
465        let piecewise = PiecewiseLinearFunction::new(points);
466
467        assert_eq!(piecewise.len(), 3);
468        assert_eq!(piecewise.get_value(1.0).unwrap(), 2.0);
469    }
470
471    #[test]
472    fn test_empty_function() {
473        let piecewise = PiecewiseLinearFunction::empty();
474        assert!(piecewise.get_value(1.0).is_err());
475        assert!(piecewise.is_empty());
476    }
477
478    #[test]
479    fn test_duplicate_x_values() {
480        let mut piecewise = PiecewiseLinearFunction::empty();
481        piecewise.add_point(1.0, 2.0).unwrap();
482        piecewise.add_point(1.0, 3.0).unwrap(); // Overwrites the previous value
483
484        assert_eq!(piecewise.len(), 1);
485        assert_eq!(piecewise.get_value(1.0).unwrap(), 3.0);
486    }
487
488    #[test]
489    fn test_range() {
490        let piecewise = PiecewiseLinearFunction::new(vec![(1.0, 5.0), (3.0, 2.0), (5.0, 8.0)]);
491
492        assert_eq!(piecewise.range(), Some((2.0, 8.0)));
493    }
494
495    #[test]
496    fn test_invalid_values() {
497        let mut piecewise = PiecewiseLinearFunction::empty();
498        assert!(piecewise.add_point(f32::NAN, 1.0).is_err());
499        assert!(piecewise.add_point(1.0, f32::INFINITY).is_err());
500
501        piecewise.add_point(1.0, 1.0).unwrap();
502        assert!(piecewise.get_value(f32::NAN).is_err());
503    }
504
505    #[test]
506    fn test_from_iterator() {
507        let points = vec![(0.0, 1.0), (2.0, 3.0), (5.0, 2.0)];
508        let piecewise: PiecewiseLinearFunction = points.into_iter().collect();
509
510        assert_eq!(piecewise.len(), 3);
511        assert_eq!(piecewise.get_value(1.0).unwrap(), 2.0);
512    }
513
514    #[test]
515    fn test_get_value_or_default() {
516        let function = PiecewiseLinearFunction::new(vec![(0.0, 1.0), (1.0, 2.0)]);
517        assert_eq!(function.get_value_or_default(0.5), 1.5);
518
519        let empty_function = PiecewiseLinearFunction::empty();
520        assert_eq!(empty_function.get_value_or_default(0.5), 0.0);
521    }
522}