dyadic_rationals/
bin.rs

1use std::fmt;
2use std::ops::{Add, AddAssign, Sub, Mul, MulAssign, Div};
3use crate::context::{Set, Ctx};
4use crate::traits::{Specializable, Normalizable};
5use pretty::{DocAllocator, DocBuilder, BoxAllocator, Pretty};
6
7/// Represents a type-level positive, linear expression
8/// ex: 2a + 3b + 4c + 5
9/// The invariant is that multipliers are non-zero
10#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
11pub struct Lin<Id>(Ctx<Id, u8>, u8);
12
13/// Represents a type-level power of two (2^Lin)
14#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
15pub struct Bin<Id> { pub exp: Lin<Id> }
16
17/// Identity element of addition is 0
18impl<T : Ord> Default for Lin<T> {
19    fn default() -> Self {
20        Lin::lit(0)
21    }
22}
23
24/// Identity element of multiplication is 2^0
25impl<T: Ord> Default for Bin<T> {
26    fn default() -> Self {
27        Bin { exp: Lin::default() }
28    }
29}
30
31impl<T: Ord> Lin<T> {
32    pub fn new (terms: Ctx<T, u8>, v: u8) -> Self {
33        Lin(terms, v)
34    }
35    /// Create a linear constant
36    pub fn lit(a: u8) -> Self {
37        Lin(Ctx::new(), a)
38    }
39
40    /// Create a linear variable
41    pub fn var(v: T) -> Self {
42        Lin(Ctx::from([(v, 1)]), 0)
43    }
44
45    /// Create a linear variable with multiplier
46    pub fn term(v: T, a: u8) -> Self {
47        if a == 0 {
48            Lin::default()
49        } else {
50            Lin(Ctx::from([(v, a)]), 0)
51        }
52    }
53
54    /// Define a partial order for linear, positive expressions
55    /// true:  2*a + b <= 3*a + b + c
56    ///        {} <= a
57    /// false: 4*a <= 3*a + b + c
58    ///        b <= c
59    pub fn leq(&self, other: &Self) -> bool {
60        let mut le = true;
61        for (k, v) in self.0.iter() {
62            if let Some(vr) = other.0.get(&k) {
63                if v > vr {
64                    le = false;
65                }
66            } else {
67                le = false;
68            }
69        }
70        le && self.1 <= other.1
71    }
72}
73
74impl<T: Ord> AddAssign for Lin<T> {
75    /// Add two linear terms
76    fn add_assign(&mut self, other: Self) {
77        self.0.append_with(other.0.into_iter(), &|a, b| a + b);
78        self.1 += other.1;
79    }
80}
81
82impl<T: Ord + Clone> Add for Lin<T> {
83    type Output = Lin<T>;
84    /// Add two linear terms
85    fn add(self, other: Self) -> Self::Output {
86        let mut c = self.clone();
87        c += other;
88        c
89    }
90}
91
92impl<T: Ord + Clone> Add for &Lin<T> {
93    type Output = Lin<T>;
94    /// Add two linear terms
95    fn add(self, other: Self) -> Self::Output {
96        let mut c = self.clone();
97        c += other.clone();
98        c
99    }
100}
101
102impl<T: Ord + Clone> Sub for Lin<T> {
103    type Output = (Lin<T>, Lin<T>);
104    /// Subtract two linear terms (with remainder)
105    fn sub(self, other: Self) -> Self::Output {
106        let mut n: u8 = self.1;
107        let mut m: u8 = other.1;
108        if n < m { // 3 - 4 = (0 ,1)
109            m -= n;
110            n = 0;
111        } else {   // 4 - 1 = (3, 0)
112            n -= m;
113            m = 0;
114        }
115        let mut nvars = self.0.clone();
116        let mut mvars = other.0.clone();
117        for (k, mx) in mvars.iter_mut() {
118            if let Some(nx) = nvars.get_mut(k) {
119                if *nx < *mx {
120                    *mx -= *nx;
121                    *nx = 0;
122                } else {
123                    *nx -= *mx;
124                    *mx = 0;
125                }
126            }
127        }
128        nvars.retain(|_, v| *v > 0);
129        mvars.retain(|_, v| *v > 0);
130       (Lin(nvars, n), Lin(mvars, m))
131    }
132}
133
134impl<T: Ord + Clone> Sub for &Lin<T> {
135    type Output = (Lin<T>, Lin<T>);
136    /// Subtract two linear terms (with remainder)
137    fn sub(self, other: Self) -> Self::Output {
138        self.clone().sub(other.clone())
139    }
140}
141
142/// Remove all zero elements (0*a + c = c)
143impl<T: Ord + Clone> Normalizable for Lin<T> {
144    fn normalize(&mut self) {
145        self.0.retain(|_, v| *v > 0);
146    }
147}
148
149/// Specialize a linear term by substituting a variable
150impl<T: Ord + fmt::Display + Clone> Specializable<T, u8> for Lin<T> {
151    fn specialize(&mut self, id: &T, val: u8) {
152        if let Some(v) = self.0.remove(id) {
153            self.1 += v * val;
154        }
155    }
156
157    fn free_vars(&self) -> Set<&T> {
158        self.0.keys()
159    }
160}
161
162////////////////////////////////////////////////////////////////////////////////////////
163// Implementing Bin (2^Lin) operations
164////////////////////////////////////////////////////////////////////////////////////////
165impl<T: Ord> Bin<T> {
166    pub fn lit(a: u8) -> Self {
167        Bin { exp: Lin::lit(a) }
168    }
169    pub fn var(v: T) -> Self {
170        Bin{ exp: Lin::var(v) }
171    }
172    pub fn double(self) -> Self where T: Clone {
173        Bin { exp: self.exp + Lin::lit(1) }
174    }
175    /// Halving could fail, ex: 2^a / 2 = None
176    pub fn half(self) -> Option<Self> {
177        if self.exp.1 > 0 {
178            Some(Bin { exp: Lin(self.exp.0, self.exp.1 - 1) })
179        } else {
180            None
181        }
182    }
183    /// Partial order extends to [Bin] as [2^-] is monotone
184    pub fn leq(&self, other: &Self) -> bool {
185        self.exp.leq(&other.exp)
186    }
187    /// Logarithm with remainder
188    /// ex: log2(9) = (3, 1)
189    ///     log(-129) = (7, -1)
190    pub fn log2(u: i32) -> (Bin<T>, i32) {
191        let mut exp = 0;
192        let mut um = u.abs();
193
194        while um % 2 == 0 && um > 0 {
195            exp += 1;
196            um /= 2;
197        }
198        (Bin { exp: Lin::lit(exp) }, if u > 0 { um } else { -um })
199    }
200
201    /// Least common multiple of two exponents of two
202    pub fn lcm(&self, other: &Self) -> Self where T: Clone {
203        Bin { exp: Lin(
204            self.exp.0.union_with(other.exp.0.clone(), &|a, b| std::cmp::max(a, b)),
205            std::cmp::max(self.exp.1, other.exp.1)
206        )}
207    }
208
209    /// Greatest common divisor of two exponents of two
210    pub fn gcd(&self, other: &Self) -> Self where T: Clone {
211        Bin { exp : Lin(
212            self.exp.0.intersection_with(other.exp.0.clone(), &|a, b| std::cmp::min(a, b)),
213            std::cmp::min(self.exp.1, other.exp.1)
214        )}
215    }
216}
217
218/// Multiplication of powers of two is equivalent to adding the exponents
219impl<T: Ord> MulAssign for Bin<T> {
220    fn mul_assign(&mut self, other: Self) {
221        self.exp += other.exp;
222    }
223}
224
225impl<T: Ord + Clone> Mul for Bin<T> {
226    type Output = Bin<T>;
227    fn mul(self, a: Self) -> Self::Output {
228        Bin { exp: self.exp + a.exp }
229    }
230}
231
232impl<T: Ord + Clone> Mul for &Bin<T> {
233    type Output = Bin<T>;
234    fn mul(self, a: Self) -> Self::Output {
235        self.clone() * a.clone()
236    }
237}
238
239/// Division of powers of two is equivalent to subtracting the exponents
240impl<T: Ord + Clone> Div for Bin<T> {
241    type Output = (Bin<T>, Bin<T>);
242
243    fn div(self, a: Self) -> Self::Output {
244        let (q, r) = self.exp - a.exp;
245        (Bin { exp: q }, Bin { exp: r })
246    }
247}
248
249/// Division of powers of two is equivalent to subtracting the exponents
250impl<T: Ord + Clone> Div for &Bin<T> {
251    type Output = (Bin<T>, Bin<T>);
252
253    fn div(self, a: Self) -> Self::Output {
254        self.clone() / a.clone()
255    }
256}
257
258/// Specialize a bin power by substituting a variable with a literal
259impl<T: Ord + fmt::Display + Clone> Specializable<T, u8> for Bin<T> {
260    fn specialize(&mut self, id: &T, val: u8) {
261        self.exp.specialize(id, val)
262    }
263    fn free_vars(&self) -> Set<&T> {
264        self.exp.0.keys()
265    }
266}
267
268/// Remove all zero elements (0*a + c = c)
269impl<T: Ord + Clone> Normalizable for Bin<T> {
270    fn normalize(&mut self) {
271        self.exp.normalize();
272    }
273}
274////////////////////////////////////////////////////////////////////////////////////////
275// Pretty Formatting, Display & Arbitrary for Lin and  Bin
276////////////////////////////////////////////////////////////////////////////////////////
277impl<'a, D, A, T> Pretty<'a, D, A> for Lin<T>
278where
279    D: DocAllocator<'a, A>,
280    D::Doc: Clone,
281    A: 'a + Clone,
282    T: Pretty<'a, D, A> + Clone + Ord
283{
284    fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
285        if self.0.is_empty() {
286            allocator.text(format!("{}", self.1))
287        } else {
288            allocator.intersperse(
289                self.0.into_iter()
290                    .map(|(k, v)|
291                        if v == 0 {
292                            allocator.nil()
293                        } else if v == 1 {
294                            k.pretty(allocator)
295                        } else {
296                            allocator.text(v.to_string()).append(k.pretty(allocator))
297                        }), "+")
298                .append(
299            if self.1 == 0 {
300                    allocator.nil()
301                  } else {
302                    allocator.text(format!("+{}", self.1))
303                  })
304        }
305    }
306}
307
308/// Display instance calls the pretty printer
309impl<'a, T> fmt::Display for Lin<T>
310where
311    T: Pretty<'a, BoxAllocator, ()> + Clone + Ord
312{
313    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314        <Lin<T> as Pretty<'_, BoxAllocator, ()>>::pretty(self.clone(), &BoxAllocator)
315            .1
316            .render_fmt(100, f)
317    }
318}
319
320/// Arbitrary instance for Lin
321#[cfg(test)] use arbitrary::{Arbitrary, Unstructured};
322#[cfg(test)]
323impl<'a, T: Ord + Clone + Arbitrary<'a>> Arbitrary<'a> for Lin<T> {
324    fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
325        let mut l = Lin(Ctx::arbitrary(u)?, u.int_in_range(0..=9)?);
326        l.normalize();
327        Ok(l)
328    }
329}
330
331impl<'a, D, A, T> Pretty<'a, D, A> for Bin<T>
332where
333    D: DocAllocator<'a, A>,
334    D::Doc: Clone,
335    A: 'a + Clone,
336    T: Pretty<'a, D, A> + Clone + Ord
337{
338    fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
339        allocator.text("2^(")
340            .append(self.exp.pretty(allocator))
341            .append(allocator.text(")"))
342    }
343}
344
345/// Display instance calls the pretty printer
346impl<'a, T> fmt::Display for Bin<T>
347where
348    T: Pretty<'a, BoxAllocator, ()> + Clone + Ord
349{
350    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351        <Bin<T> as Pretty<'_, BoxAllocator, ()>>::pretty(self.clone(), &BoxAllocator)
352            .1
353            .render_fmt(100, f)
354    }
355}
356
357/// Arbitrary instance for Bin
358#[cfg(test)]
359impl<'a, T: Ord + Clone + Arbitrary<'a>> Arbitrary<'a> for Bin<T> {
360    fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
361        Ok(Bin { exp: Lin::arbitrary(u)? })
362    }
363}
364
365////////////////////////////////////////////////////////////////////////////////////////
366// Unit Tests for Lin
367////////////////////////////////////////////////////////////////////////////////////////
368#[test]
369fn test_lin_add() {
370    assert_eq!(
371        Lin::lit(1) + Lin::lit(2) + Lin::var("x"),
372        Lin::var("x") + Lin::lit(3)
373    )
374}
375
376#[test]
377fn test_lin_sub() {
378    assert_eq!((Lin::lit(3) + Lin::lit(2) + Lin::var("x"))
379        - (Lin::lit(2) + Lin::var("y") + Lin::var("x")),
380        (Lin::lit(3), Lin::var("y")));
381}
382
383#[test]
384fn test_leq_lin() {
385    assert_eq!(
386        Lin::leq(
387            &(Lin::lit(2) + Lin::var("a")),
388            &(Lin::term("a", 2) + Lin::var("b") + Lin::lit(4))
389        ),
390        true
391    );
392    assert_eq!(
393        Lin::leq(
394            &(Lin::lit(2) + Lin::var("c")),
395            &(Lin::term("a", 2) + Lin::var("b") + Lin::lit(4))
396        ),
397        false
398    );
399    assert_eq!(
400        Lin::leq(
401            &(Lin::term("a", 3) + Lin::var("b")),
402            &(Lin::term("a", 2) + Lin::var("b") + Lin::lit(4))
403        ),
404        false
405    );
406}
407
408#[test]
409fn test_lin_specialize() {
410    let l = Lin::var("x") + Lin::var("y") + Lin::lit(1);
411
412    let mut l1 = l.clone();
413    l1.specialize(&"x", 2);
414    assert_eq!(l1, Lin::var("y") + Lin::lit(3));
415
416    let mut l2 = l.clone();
417    l2.specialize(&"y", 2);
418    assert_eq!(l2, Lin::var("x") + Lin::lit(3));
419
420    let mut l3 = l.clone();
421    l3.specialize(&"z", 2);
422    assert_eq!(l, l3);
423}
424
425////////////////////////////////////////////////////////////////////////////////////////
426// Unit Tests for Bin
427////////////////////////////////////////////////////////////////////////////////////////
428#[test]
429fn test_bin_mul() {
430    assert_eq!(
431        // 2^1 * 2^2 * 2^x = 2^(3 + x)
432        Bin::lit(1) * Bin::lit(2) * Bin::var("x"),
433        Bin::var("x") * Bin::lit(3)
434    )
435}
436
437#[test]
438fn test_bin_div() {
439    let a = Bin::lit(3) * Bin::lit(2) * Bin::var("x");
440    let b = Bin::lit(2) * Bin::var("y") * Bin::var("x");
441    assert_eq!(a / b, (Bin::lit(3), Bin::var("y")));
442}
443
444#[test]
445fn test_bin_lcm() {
446    let a = Bin::lit(3) * Bin::lit(2) * Bin::var("x");
447    let b = Bin::lit(2) * Bin::var("y") * Bin::var("x");
448    assert_eq!(a.lcm(&b), Bin::lit(3) * Bin::lit(2) * Bin::var("x") * Bin::var("y"));
449}
450
451#[test]
452fn test_bin_log2() {
453    assert_eq!(Bin::<&str>::log2(12), (Bin::lit(2), 3));
454    assert_eq!(Bin::<&str>::log2(-96), (Bin::lit(5), -3));
455}
456
457#[test]
458fn test_bin_specialize() {
459    let l = Bin::var("x") * Bin::var("y") * Bin::lit(1);
460
461    let mut l1 = l.clone();
462    l1.specialize(&"x", 2);
463    assert_eq!(l1, Bin::var("y") * Bin::lit(3));
464
465    let mut l2 = l.clone();
466    l2.specialize(&"y", 2);
467    assert_eq!(
468        l2,
469        Bin::var("x") * Bin::lit(3)
470    );
471
472    let mut l3 = l.clone();
473    l3.specialize(&"z", 2);
474    assert_eq!(l, l3);
475}
476
477#[cfg(test)] use arbtest::arbtest;
478#[cfg(test)] use crate::id::Id;
479#[cfg(test)] use crate::assert_eqn;
480
481#[test]
482fn test_lin_add_prop() {
483    // Associativity
484    arbtest(|u| {
485        let a = u.arbitrary::<Lin<Id>>()?;
486        let b = u.arbitrary::<Lin<Id>>()?;
487        let c = u.arbitrary::<Lin<Id>>()?;
488        assert_eq!(&a + &(&b + &c), &(&a + &b) + &c);
489        Ok(())
490    });
491
492    // Commutativity
493    arbtest(|u| {
494        let a = u.arbitrary::<Lin<Id>>()?;
495        let b = u.arbitrary::<Lin<Id>>()?;
496        assert_eq!(&a + &b, &b + &a);
497        Ok(())
498    });
499
500    // Unit
501    arbtest(|u| {
502        let a = u.arbitrary::<Lin<Id>>()?;
503        assert_eq!(&a + &Lin::default(), a);
504        assert_eq!(&Lin::default() + &a, a);
505        Ok(())
506    });
507}
508
509#[test]
510fn test_lin_sub_prop() {
511    // Cancelativity
512    arbtest(|u| {
513        let a = u.arbitrary::<Lin<Id>>()?;
514        assert_eq!(&a - &a, (Lin::default(), Lin::default()));
515        Ok(())
516    });
517    // Subtraction is the inverse of addition
518    arbtest(|u| {
519        let a = u.arbitrary::<Lin<Id>>()?;
520        let b = u.arbitrary::<Lin<Id>>()?;
521        assert_eq!(&a + &b - a, (b, Lin::default()));
522        Ok(())
523    });
524    // Unit with subtraction
525    arbtest(|u| {
526        let a = u.arbitrary::<Lin<Id>>()?;
527        assert_eq!(&a - &Lin::default(), (a.clone(), Lin::default()));
528        assert_eq!(&Lin::default() - &a, (Lin::default(), a));
529        Ok(())
530    });
531}
532
533#[test]
534fn test_lin_leq_prop() {
535    // Reflexivity
536    arbtest(|u| {
537        let a = u.arbitrary::<Lin<Id>>()?;
538        assert!(a.leq(&a));
539        Ok(())
540    });
541    // Leq and addition
542    arbtest(|u| {
543        let a = u.arbitrary::<Lin<Id>>()?;
544        let b = u.arbitrary::<Lin<Id>>()?;
545        // a <= a + b
546        assert!(a.leq(&(&a + &b)));
547        assert!(b.leq(&(&a + &b)));
548        Ok(())
549    });
550}
551
552#[test]
553fn test_bin_mul_prop() {
554    // Commutativity
555    arbtest(|u| {
556        let a = u.arbitrary::<Bin<Id>>()?;
557        let b = u.arbitrary::<Bin<Id>>()?;
558        assert_eqn!(&a * &b, &b * &a);
559        Ok(())
560    });
561    // Associativity
562    arbtest(|u| {
563        let a = u.arbitrary::<Bin<Id>>()?;
564        let b = u.arbitrary::<Bin<Id>>()?;
565        let c = u.arbitrary::<Bin<Id>>()?;
566        assert_eqn!(&a * &(&b * &c), &(&a * &b) * &c);
567        Ok(())
568    });
569    // Units
570    arbtest(|u| {
571        let a = u.arbitrary::<Bin<Id>>()?;
572        assert_eqn!(&a * &Bin::default(), &Bin::default() * &a);
573        Ok(())
574    });
575
576    // Double and half
577    arbtest(|u| {
578        let a = u.arbitrary::<Bin<Id>>()?;
579        assert_eq!(&a.clone().double() / &a, (Bin::lit(1), Bin::default()));
580        assert_eq!(&a.clone().double().half(), &Some(a));
581        Ok(())
582    });
583}
584
585#[test]
586fn test_bin_div_prop() {
587    // Cancellativity
588    arbtest(|u| {
589        let a = u.arbitrary::<Bin<Id>>()?;
590        assert_eq!(&a / &a, (Bin::default(), Bin::default()));
591        Ok(())
592    });
593    // Unit and Division
594    arbtest(|u| {
595        let a = u.arbitrary::<Bin<Id>>()?;
596        assert_eq!(&Bin::default() / &a, (Bin::default(), a.clone()));
597        assert_eq!(&a / &Bin::default(), (a, Bin::default()));
598        Ok(())
599    });
600    // Least-common multiple divides evenly
601    arbtest(|u| {
602        let a = u.arbitrary::<Bin<Id>>()?;
603        let b = u.arbitrary::<Bin<Id>>()?;
604        assert_eqn!((&(a.lcm(&b)) / &a).1, Bin::<Id>::default());
605        assert_eqn!((&(b.lcm(&a)) / &b).1, Bin::<Id>::default());
606        Ok(())
607    });
608}
609
610#[test]
611fn test_bin_leq_prop() {
612    // Reflexivity
613    arbtest(|u| {
614        let a = u.arbitrary::<Bin<Id>>()?;
615        assert!(a.leq(&a));
616        Ok(())
617    });
618    // Terms less than their product
619    arbtest(|u| {
620        let a = u.arbitrary::<Bin<Id>>()?;
621        let b = u.arbitrary::<Bin<Id>>()?;
622        assert!(a.leq(&(&a * &b)));
623        assert!(b.leq(&(&a * &b)));
624        Ok(())
625    });
626    // Div less than terms
627    arbtest(|u| {
628        let a = u.arbitrary::<Bin<Id>>()?;
629        let b = u.arbitrary::<Bin<Id>>()?;
630        let (p, r) = &a / &b;
631        assert!(p.leq(&a));
632        assert!(r.leq(&b));
633        Ok(())
634    });
635}