pcw_fn/
partial.rs

1//! Variant for Partially ordered domains that panics on incomparibility.
2
3// #![feature(generic_const_exprs)]
4use crate::{
5    functional_hackery::{Functor, FunctorRef, Kind1To1},
6    PcwFnError,
7};
8use is_sorted::IsSorted;
9use itertools::{EitherOrBoth, Itertools};
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12use std::cmp::Ordering;
13use std::iter;
14/// A piecewise function given by
15///        ╭ f₁(x)   if      x < x₀
16///        │ f₂(x)   if x₀ ≤ x < x₁
17/// f(x) = ┤ f₃(x)   if x₁ ≤ x < x₂
18///        │  ⋮               ⋮
19///        ╰ fₙ(x)   if xₙ ≤ x
20/// for all x ∈ X where
21///     f₁,...,fₙ : X -> Y, and
22///     x₀ < x₁ < ... < xₙ
23/// from some strictly partially ordered set X (so X is `PartialOrd`). Note that the fᵢ are not
24/// necessarily distinct. Panics if two x aren't comparable.
25///
26/// We'll call the collection of all xᵢ the jump positions, or simply jumps of the piecewise
27/// function.
28pub trait PcwFn<X: PartialOrd, F>: Functor<F> + Sized {
29    type JmpIter: Iterator<Item = X>;
30    type FncIter: Iterator<Item = F>;
31
32    /// Get a reference to the jumps.
33    fn jumps(&self) -> &[X];
34
35    /// Get a reference to the funcs in order.
36    fn funcs(&self) -> &[F];
37
38    /// Get a mutable reference to the jumps.
39    fn funcs_mut(&mut self) -> &mut [F];
40
41    /// Try constructing a new piecewise function from iterators over jumps and functions.
42    fn try_from_iters<Jmp: IntoIterator<Item = X>, Fnc: IntoIterator<Item = F>>(
43        jumps: Jmp,
44        funcs: Fnc,
45    ) -> Result<Self, PcwFnError>;
46
47    /// A function that's globally given by a single function `f`.
48    fn global(f: F) -> Self {
49        Self::try_from_iters(iter::empty(), iter::once(f)).unwrap()
50    }
51
52    /// Add another segment to the piecewise function at the back.
53    fn add_segment(&mut self, jump: X, func: F);
54
55    /// Deconstruct a piecewise function into sequences of functions and jumps.
56    fn into_jumps_and_funcs(self) -> (Self::JmpIter, Self::FncIter);
57
58    /// Turn the function into an owned iterator over the jumps.
59    fn into_jumps(self) -> Self::JmpIter {
60        self.into_jumps_and_funcs().0
61    }
62
63    /// Turn the function into an owned iterator over the functions.
64    fn into_funcs(self) -> Self::FncIter {
65        self.into_jumps_and_funcs().1
66    }
67
68    /// How many segments the function consists of.
69    fn segment_count(&self) -> usize {
70        self.funcs().len()
71    }
72
73    /// Combine two piecewise functions using a pointwise action to obtain another piecewise function.
74    fn combine<Rhs, G, Out, H>(self, rhs: Rhs, mut action: impl FnMut(F, G) -> H) -> Out
75    where
76        X: PartialOrd,
77        F: Clone,
78        G: Clone,
79        Rhs: PcwFn<X, G>,
80        Out: PcwFn<X, H>,
81    {
82        match (self.segment_count(), rhs.segment_count()) {
83            (0, _) => panic!("Empty function is invalid"),
84            (_, 0) => panic!("Empty function is invalid"),
85            (1, _) => {
86                let l = self.into_funcs().next().unwrap();
87                let (jr, fr) = rhs.into_jumps_and_funcs();
88                Out::try_from_iters(jr, fr.map(|r| action(l.clone(), r))).unwrap()
89            }
90            (_, 1) => {
91                let r = rhs.into_funcs().next().unwrap();
92                let (jl, fl) = self.into_jumps_and_funcs();
93                Out::try_from_iters(jl, fl.map(|l| action(l, r.clone()))).unwrap()
94            }
95            (n, m) => {
96                let (jl, mut fl) = self.into_jumps_and_funcs();
97                let (jr, mut fr) = rhs.into_jumps_and_funcs();
98                // there'll be no more than n+m segments in the combined function
99                let mut funcs = Vec::with_capacity(n + m);
100                // -1 is always valid since we know fl and fr have at least one element
101                let mut jumps = Vec::with_capacity(funcs.capacity() - 1);
102                let mut l = fl.next().unwrap();
103                let mut r = fr.next().unwrap();
104                funcs.push(action(l.clone(), r.clone())); // value of result sufficiently far left in the domain
105                for c in jl.merge_join_by(jr.into_iter(), |x, y| x.partial_cmp(y).unwrap()) {
106                    match c {
107                        EitherOrBoth::Left(jump) => {
108                            jumps.push(jump);
109                            l = fl.next().unwrap_or(l);
110                        }
111                        EitherOrBoth::Right(jump) => {
112                            jumps.push(jump);
113                            r = fr.next().unwrap_or(r);
114                        }
115                        EitherOrBoth::Both(jump, _) => {
116                            jumps.push(jump);
117                            l = fl.next().unwrap_or(l);
118                            r = fr.next().unwrap_or(r);
119                        }
120                    }
121                    funcs.push(action(l.clone(), r.clone()));
122                }
123                jumps.shrink_to_fit();
124                funcs.shrink_to_fit();
125                Out::try_from_iters(jumps.into_iter(), funcs.into_iter()).unwrap()
126            }
127        }
128    }
129
130    /// Resample `self` to the segments of `other`: replace the jumps of `self` with those of
131    /// other and if that leaves multiple functions on a single segment combine them using the
132    /// provided `combine` function.
133    fn resample_to<PcwOut, G>(
134        self,
135        other: impl PcwFn<X, G>,
136        mut combine: impl FnMut(F, F) -> F,
137    ) -> PcwOut
138    where
139        F: Clone,
140        PcwOut: PcwFn<X, F>,
141    {
142        match (self.segment_count(), other.segment_count()) {
143            (0, _) => panic!("Empty function is invalid"),
144            (_, 0) => panic!("Empty function is invalid"),
145            (1, n) => PcwOut::try_from_iters(
146                other.into_jumps(),
147                iter::repeat(self.into_funcs().next().unwrap()).take(n),
148            )
149            .unwrap(),
150            (_, 1) => PcwOut::try_from_iters(
151                other.into_jumps(),
152                iter::once(self.into_funcs().reduce(combine).unwrap()),
153            )
154            .unwrap(),
155            (_, n) => {
156                let (jl, mut fl) = self.into_jumps_and_funcs();
157                let mut funcs = Vec::with_capacity(n);
158                let mut active_f = fl.next().unwrap();
159                funcs.push(active_f.clone()); // value of result sufficiently far left in the domain
160                for c in jl.merge_join_by(other.jumps(), |x, y| x.partial_cmp(y).unwrap()) {
161                    match c {
162                        EitherOrBoth::Left(_) => {
163                            if let Some(new_f) = fl.next() {
164                                active_f = new_f.clone();
165                                let f = combine(funcs.pop().unwrap(), new_f);
166                                funcs.push(f);
167                            }
168                        }
169                        EitherOrBoth::Right(_) => funcs.push(active_f.clone()),
170                        EitherOrBoth::Both(_, _) => {
171                            if let Some(new_f) = fl.next() {
172                                active_f = new_f.clone();
173                                funcs.push(new_f);
174                            } else {
175                                funcs.push(active_f.clone())
176                            }
177                        }
178                    }
179                }
180                PcwOut::try_from_iters(other.into_jumps(), funcs).unwrap()
181            }
182        }
183    }
184
185    /// Find the function that locally defines the piecewise function at some point `x` of
186    /// the domain.
187    fn func_at(&self, x: &X) -> &F {
188        match self.segment_count() {
189            0 => panic!("Empty function is invalid"),
190            1 => &self.funcs()[0],
191            _ => match self.jumps().binary_search_by(|y| y.partial_cmp(x).unwrap()) {
192                Ok(jump_idx) => &self.funcs()[jump_idx + 1],
193                Err(insertion_idx) => &self.funcs()[insertion_idx],
194            },
195        }
196    }
197
198    /// Find the function that locally defines the piecewise function at some point `x` of
199    /// the domain.
200    fn func_at_mut(&mut self, x: &X) -> &mut F {
201        match self.segment_count() {
202            0 => panic!("Empty function is invalid"),
203            1 => &mut self.funcs_mut()[0],
204            _ => match self.jumps().binary_search_by(|y| y.partial_cmp(x).unwrap()) {
205                Ok(jump_idx) => &mut self.funcs_mut()[jump_idx + 1],
206                Err(insertion_idx) => &mut self.funcs_mut()[insertion_idx],
207            },
208        }
209    }
210
211    /// Evaluate the function at some point `x` of the domain.
212    fn eval<Y>(&self, x: X) -> Y
213    where
214        F: Fn(X) -> Y,
215    {
216        self.func_at(&x)(x)
217    }
218
219    /// Mutably evaluate the function at some point `x` of the domain.
220    fn eval_mut<Y>(&mut self, x: X) -> Y
221    where
222        F: FnMut(X) -> Y,
223    {
224        self.func_at_mut(&x)(x)
225    }
226}
227
228/// A piecewise function internally backed by `Vec`s
229#[derive(Debug, PartialEq, Eq, Hash, Clone)]
230#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
231pub struct VecPcwFn<X, F> {
232    jumps: Vec<X>,
233    funcs: Vec<F>,
234}
235
236impl<X, F> Kind1To1 for VecPcwFn<X, F> {
237    type Constructor<S> = VecPcwFn<X, S>;
238}
239
240impl<X, F> Functor<F> for VecPcwFn<X, F> {
241    fn fmap<S>(self, f: impl FnMut(F) -> S) -> Self::Constructor<S> {
242        VecPcwFn {
243            jumps: self.jumps,
244            funcs: self.funcs.into_iter().map(f).collect(),
245        }
246    }
247}
248
249impl<X, F> FunctorRef<F> for VecPcwFn<X, F>
250where
251    X: Clone,
252{
253    fn fmap_ref<S>(&self, f: impl FnMut(&F) -> S) -> Self::Constructor<S> {
254        VecPcwFn {
255            jumps: self.jumps.clone(),
256            funcs: self.funcs.iter().map(f).collect(),
257        }
258    }
259}
260
261fn strictly_less<T: PartialOrd>(x: &T, y: &T) -> Option<Ordering> {
262    use Ordering::*;
263    match x.partial_cmp(y) {
264        Some(Less) => Some(Less),
265        _ => None,
266    }
267}
268
269impl<X: PartialOrd, F> PcwFn<X, F> for VecPcwFn<X, F> {
270    type JmpIter = <Vec<X> as IntoIterator>::IntoIter;
271    type FncIter = <Vec<F> as IntoIterator>::IntoIter;
272
273    fn jumps(&self) -> &[X] {
274        &self.jumps
275    }
276
277    fn funcs(&self) -> &[F] {
278        &self.funcs
279    }
280
281    fn funcs_mut(&mut self) -> &mut [F] {
282        &mut self.funcs
283    }
284
285    fn try_from_iters<Jmp: IntoIterator<Item = X>, Fnc: IntoIterator<Item = F>>(
286        jumps: Jmp,
287        funcs: Fnc,
288    ) -> Result<Self, PcwFnError> {
289        use std::cmp::Ordering::*;
290        let jumps = jumps.into_iter().collect_vec();
291        let funcs = funcs.into_iter().collect_vec();
292        //if !jumps.iter().is_strictly_sorted() {
293        if !IsSorted::is_sorted_by(&mut jumps.iter(), strictly_less) {
294            Err(PcwFnError::JumpsNotStrictlySorted)
295        } else {
296            match (jumps.iter().len() + 1).cmp(&funcs.iter().len()) {
297                Greater => Err(PcwFnError::TooManyJumpsForFuncs),
298                Less => Err(PcwFnError::TooManyJumpsForFuncs),
299                Equal => Ok(VecPcwFn { jumps, funcs }),
300            }
301        }
302    }
303
304    fn add_segment(&mut self, jump: X, func: F) {
305        self.jumps.push(jump);
306        self.funcs.push(func);
307    }
308
309    fn into_jumps_and_funcs(self) -> (Self::JmpIter, Self::FncIter) {
310        (self.jumps.into_iter(), self.funcs.into_iter())
311    }
312}
313
314pub use num_impls::*;
315mod num_impls {
316    use super::*;
317    use num_traits::{One, Pow, Zero};
318    use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Not, Rem, Shl, Shr, Sub};
319
320    /// Lifts a basic binary operation from functions to piecewise functions.
321    macro_rules! pointwise_owned_binop_impl {
322        ( $trait_to_impl:ident, $method_name:ident, $for_type:ident ) => {
323            impl<Rhs, X, F> $trait_to_impl<Rhs> for $for_type<X, F>
324            where
325                X: PartialOrd,
326                Rhs: PcwFn<X, F>,
327                F: $trait_to_impl<F> + Clone,
328            {
329                type Output = VecPcwFn<X, F::Output>;
330                fn $method_name(self, rhs: Rhs) -> Self::Output {
331                    self.combine(rhs, $trait_to_impl::$method_name)
332                }
333            }
334        };
335    }
336
337    /* the above macro produces impls like
338
339    impl<R, X, F> Add<R> for VecPcwFn<X, F>
340    where
341        X: Ord,
342        R: PcwFn<X, F>,
343        F: Add<F> + Clone,
344    {
345        type Output = VecPcwFn<X, F::Output>;
346
347        fn add(self, rhs: R) -> Self::Output {
348            self.combine(rhs, Add::add)
349        }
350    }
351
352    */
353
354    pointwise_owned_binop_impl!(Add, add, VecPcwFn);
355    pointwise_owned_binop_impl!(Sub, sub, VecPcwFn);
356    pointwise_owned_binop_impl!(Mul, mul, VecPcwFn);
357    pointwise_owned_binop_impl!(Div, div, VecPcwFn);
358    pointwise_owned_binop_impl!(Pow, pow, VecPcwFn);
359    pointwise_owned_binop_impl!(Rem, rem, VecPcwFn);
360    pointwise_owned_binop_impl!(BitAnd, bitand, VecPcwFn);
361    pointwise_owned_binop_impl!(BitOr, bitor, VecPcwFn);
362    pointwise_owned_binop_impl!(BitXor, bitxor, VecPcwFn);
363    pointwise_owned_binop_impl!(Shl, shl, VecPcwFn);
364    pointwise_owned_binop_impl!(Shr, shr, VecPcwFn);
365
366    impl<X, F> Zero for VecPcwFn<X, F>
367    where
368        X: PartialOrd,
369        F: Zero + Clone,
370    {
371        fn zero() -> Self {
372            Self::global(F::zero())
373        }
374
375        fn is_zero(&self) -> bool {
376            self.segment_count() == 1 && self.funcs[0].is_zero()
377        }
378    }
379
380    impl<X, F> One for VecPcwFn<X, F>
381    where
382        X: PartialOrd,
383        F: One + Clone + PartialEq,
384    {
385        fn one() -> Self {
386            Self::global(F::one())
387        }
388        fn is_one(&self) -> bool
389        where
390            Self: PartialEq,
391        {
392            self.segment_count() == 1 && self.funcs[0].is_one()
393        }
394    }
395
396    /// Lifts a basic unary operation from functions to piecewise functions
397    macro_rules! pointwise_owned_unop_impl {
398        ( $trait_to_impl:ident, $method_name:ident, $for_type:ident ) => {
399            impl<X, F> $trait_to_impl for $for_type<X, F>
400            where
401                X: PartialOrd,
402                F: $trait_to_impl,
403            {
404                type Output = VecPcwFn<X, F::Output>;
405                fn $method_name(self) -> Self::Output {
406                    self.fmap($trait_to_impl::$method_name)
407                }
408            }
409        };
410    }
411
412    pointwise_owned_unop_impl!(Not, not, VecPcwFn);
413    pointwise_owned_unop_impl!(Neg, neg, VecPcwFn);
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    mod add {
421        use super::*;
422
423        #[test]
424        fn same_domains() {
425            let f = VecPcwFn::try_from_iters(vec![5, 10, 15], vec![1, 2, 3, 4]).unwrap();
426            let g = VecPcwFn::try_from_iters(vec![5, 10, 15], vec![2, 4, 6, 8]).unwrap();
427            assert_eq!(
428                f + g,
429                VecPcwFn::try_from_iters(vec!(5, 10, 15), vec![3, 6, 9, 12]).unwrap()
430            )
431        }
432
433        #[test]
434        fn left_domain_larger() {
435            let f = VecPcwFn::try_from_iters(vec![5, 10, 15], vec![1, 2, 3, 4]).unwrap();
436            let g = VecPcwFn::try_from_iters(vec![10], vec![4, 6]).unwrap();
437            assert_eq!(
438                VecPcwFn::try_from_iters(vec!(5, 10, 15), vec![5, 6, 9, 10]).unwrap(),
439                f + g
440            )
441        }
442
443        #[test]
444        fn right_domain_larger() {
445            let g = VecPcwFn::try_from_iters(vec![5, 10, 15], vec![1, 2, 3, 4]).unwrap();
446            let f = VecPcwFn::try_from_iters(vec![10], vec![4, 6]).unwrap();
447            assert_eq!(
448                VecPcwFn::try_from_iters(vec!(5, 10, 15), vec![5, 6, 9, 10]).unwrap(),
449                f + g
450            )
451        }
452
453        #[test]
454        fn unaligned_domains() {
455            let f = VecPcwFn::try_from_iters(vec![0], vec![1, 3]).unwrap();
456            let g = VecPcwFn::try_from_iters(vec![1], vec![2, 4]).unwrap();
457            assert_eq!(
458                VecPcwFn::try_from_iters(vec!(0, 1), vec![3, 5, 7]).unwrap(),
459                f + g
460            );
461            // same thing, just a bigger example
462            let f = VecPcwFn::try_from_iters(vec![5, 10, 15], vec![1, 2, 3, 4]).unwrap();
463            let g = VecPcwFn::try_from_iters(vec![7, 12], vec![4, 6, -5]).unwrap();
464            assert_eq!(
465                VecPcwFn::try_from_iters(vec!(5, 7, 10, 12, 15), vec![5, 6, 8, 9, -2, -1]).unwrap(),
466                f + g,
467            )
468        }
469    }
470
471    mod resample {
472        use super::*;
473        #[test]
474        fn same_domains() {
475            let f = VecPcwFn::try_from_iters(vec![5, 10, 15], vec![1, 2, 3, 4]).unwrap();
476            let g = VecPcwFn::try_from_iters(vec![5, 10, 15], vec![2, 4, 6, 8]).unwrap();
477            let h: VecPcwFn<_, _> = f.resample_to(g, std::cmp::min);
478            assert_eq!(
479                h,
480                VecPcwFn::try_from_iters(vec!(5, 10, 15), vec![1, 2, 3, 4]).unwrap()
481            )
482        }
483
484        #[test]
485        fn left_domain_larger() {
486            let f = VecPcwFn::try_from_iters(vec![5, 10, 15], vec![1, 2, 3, 4]).unwrap();
487            let g = VecPcwFn::try_from_iters(vec![10], vec![4, 6]).unwrap();
488            let h: VecPcwFn<_, _> = f.resample_to(g, std::cmp::min);
489            assert_eq!(h, VecPcwFn::try_from_iters(vec!(10), vec![1, 3]).unwrap())
490        }
491
492        #[test]
493        fn right_domain_larger() {
494            let f = VecPcwFn::try_from_iters(vec![10], vec![4, 6]).unwrap();
495            let g = VecPcwFn::try_from_iters(vec![5, 10, 15], vec![1, 2, 3, 4]).unwrap();
496            let h: VecPcwFn<_, _> = f.resample_to(g, std::cmp::min);
497            assert_eq!(
498                h,
499                VecPcwFn::try_from_iters(vec!(5, 10, 15), vec![4, 4, 6, 6]).unwrap()
500            )
501        }
502
503        #[test]
504        fn unaligned_domains() {
505            let f = VecPcwFn::try_from_iters(vec![0], vec![1, 3]).unwrap();
506            let g = VecPcwFn::try_from_iters(vec![1], vec![2, 4]).unwrap();
507            let h: VecPcwFn<_, _> = f.resample_to(g, std::cmp::min);
508            assert_eq!(h, VecPcwFn::try_from_iters(vec!(1), vec![1, 3]).unwrap());
509            // same thing, just a bigger example
510            let f = VecPcwFn::try_from_iters(vec![5, 10, 15], vec![1, 2, 3, 4]).unwrap();
511            let g = VecPcwFn::try_from_iters(vec![7, 12], vec![4, 6, -5]).unwrap();
512            let h: VecPcwFn<_, _> = f.resample_to(g, std::cmp::min);
513            assert_eq!(
514                h,
515                VecPcwFn::try_from_iters(vec!(7, 12), vec![1, 2, 3]).unwrap()
516            )
517        }
518
519        #[test]
520        fn unaligned_domains_big() {
521            let f = VecPcwFn::try_from_iters(
522                vec![5, 10, 15, 16, 17, 20],
523                vec![2, 1, 6, 4, 5, -10, 100],
524            )
525            .unwrap();
526            let g = VecPcwFn::try_from_iters(vec![7, 12, 30], vec![4, 6, -5, -20]).unwrap();
527            let h: VecPcwFn<_, _> = f.resample_to(g, std::cmp::min);
528            assert_eq!(
529                h,
530                VecPcwFn::try_from_iters(vec!(7, 12, 30), vec![1, 1, -10, 100]).unwrap()
531            )
532        }
533    }
534
535    #[test]
536    fn eval() {
537        let f: VecPcwFn<_, &dyn Fn(i32) -> i32> = VecPcwFn::try_from_iters(
538            vec![5, 10, 15],
539            vec![
540                &(|_| -10) as &dyn Fn(i32) -> i32,
541                &(|_| 10) as &dyn Fn(i32) -> i32,
542                &(|_| 5) as &dyn Fn(i32) -> i32,
543                &(|_| -5) as &dyn Fn(i32) -> i32,
544            ],
545        )
546        .unwrap();
547        assert_eq!(f.eval(0), -10);
548        assert_eq!(f.eval(4), -10);
549        assert_eq!(f.eval(5), 10);
550        assert_eq!(f.eval(6), 10);
551        assert_eq!(f.eval(10), 5);
552        assert_eq!(f.eval(200), -5);
553    }
554}