1use std::fmt;
2use std::ops::{Add, AddAssign, Sub, Mul, MulAssign, Div};
3use crate::context::{Set, Ctx};
4use crate::traits::{Specializable, Normalizable};
5use pretty::{DocAllocator, DocBuilder, BoxAllocator, Pretty};
6
7#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
11pub struct Lin<Id>(Ctx<Id, u8>, u8);
12
13#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
15pub struct Bin<Id> { pub exp: Lin<Id> }
16
17impl<T : Ord> Default for Lin<T> {
19 fn default() -> Self {
20 Lin::lit(0)
21 }
22}
23
24impl<T: Ord> Default for Bin<T> {
26 fn default() -> Self {
27 Bin { exp: Lin::default() }
28 }
29}
30
31impl<T: Ord> Lin<T> {
32 pub fn new (terms: Ctx<T, u8>, v: u8) -> Self {
33 Lin(terms, v)
34 }
35 pub fn lit(a: u8) -> Self {
37 Lin(Ctx::new(), a)
38 }
39
40 pub fn var(v: T) -> Self {
42 Lin(Ctx::from([(v, 1)]), 0)
43 }
44
45 pub fn term(v: T, a: u8) -> Self {
47 if a == 0 {
48 Lin::default()
49 } else {
50 Lin(Ctx::from([(v, a)]), 0)
51 }
52 }
53
54 pub fn leq(&self, other: &Self) -> bool {
60 let mut le = true;
61 for (k, v) in self.0.iter() {
62 if let Some(vr) = other.0.get(&k) {
63 if v > vr {
64 le = false;
65 }
66 } else {
67 le = false;
68 }
69 }
70 le && self.1 <= other.1
71 }
72}
73
74impl<T: Ord> AddAssign for Lin<T> {
75 fn add_assign(&mut self, other: Self) {
77 self.0.append_with(other.0.into_iter(), &|a, b| a + b);
78 self.1 += other.1;
79 }
80}
81
82impl<T: Ord + Clone> Add for Lin<T> {
83 type Output = Lin<T>;
84 fn add(self, other: Self) -> Self::Output {
86 let mut c = self.clone();
87 c += other;
88 c
89 }
90}
91
92impl<T: Ord + Clone> Add for &Lin<T> {
93 type Output = Lin<T>;
94 fn add(self, other: Self) -> Self::Output {
96 let mut c = self.clone();
97 c += other.clone();
98 c
99 }
100}
101
102impl<T: Ord + Clone> Sub for Lin<T> {
103 type Output = (Lin<T>, Lin<T>);
104 fn sub(self, other: Self) -> Self::Output {
106 let mut n: u8 = self.1;
107 let mut m: u8 = other.1;
108 if n < m { m -= n;
110 n = 0;
111 } else { n -= m;
113 m = 0;
114 }
115 let mut nvars = self.0.clone();
116 let mut mvars = other.0.clone();
117 for (k, mx) in mvars.iter_mut() {
118 if let Some(nx) = nvars.get_mut(k) {
119 if *nx < *mx {
120 *mx -= *nx;
121 *nx = 0;
122 } else {
123 *nx -= *mx;
124 *mx = 0;
125 }
126 }
127 }
128 nvars.retain(|_, v| *v > 0);
129 mvars.retain(|_, v| *v > 0);
130 (Lin(nvars, n), Lin(mvars, m))
131 }
132}
133
134impl<T: Ord + Clone> Sub for &Lin<T> {
135 type Output = (Lin<T>, Lin<T>);
136 fn sub(self, other: Self) -> Self::Output {
138 self.clone().sub(other.clone())
139 }
140}
141
142impl<T: Ord + Clone> Normalizable for Lin<T> {
144 fn normalize(&mut self) {
145 self.0.retain(|_, v| *v > 0);
146 }
147}
148
149impl<T: Ord + fmt::Display + Clone> Specializable<T, u8> for Lin<T> {
151 fn specialize(&mut self, id: &T, val: u8) {
152 if let Some(v) = self.0.remove(id) {
153 self.1 += v * val;
154 }
155 }
156
157 fn free_vars(&self) -> Set<&T> {
158 self.0.keys()
159 }
160}
161
162impl<T: Ord> Bin<T> {
166 pub fn lit(a: u8) -> Self {
167 Bin { exp: Lin::lit(a) }
168 }
169 pub fn var(v: T) -> Self {
170 Bin{ exp: Lin::var(v) }
171 }
172 pub fn double(self) -> Self where T: Clone {
173 Bin { exp: self.exp + Lin::lit(1) }
174 }
175 pub fn half(self) -> Option<Self> {
177 if self.exp.1 > 0 {
178 Some(Bin { exp: Lin(self.exp.0, self.exp.1 - 1) })
179 } else {
180 None
181 }
182 }
183 pub fn leq(&self, other: &Self) -> bool {
185 self.exp.leq(&other.exp)
186 }
187 pub fn log2(u: i32) -> (Bin<T>, i32) {
191 let mut exp = 0;
192 let mut um = u.abs();
193
194 while um % 2 == 0 && um > 0 {
195 exp += 1;
196 um /= 2;
197 }
198 (Bin { exp: Lin::lit(exp) }, if u > 0 { um } else { -um })
199 }
200
201 pub fn lcm(&self, other: &Self) -> Self where T: Clone {
203 Bin { exp: Lin(
204 self.exp.0.union_with(other.exp.0.clone(), &|a, b| std::cmp::max(a, b)),
205 std::cmp::max(self.exp.1, other.exp.1)
206 )}
207 }
208
209 pub fn gcd(&self, other: &Self) -> Self where T: Clone {
211 Bin { exp : Lin(
212 self.exp.0.intersection_with(other.exp.0.clone(), &|a, b| std::cmp::min(a, b)),
213 std::cmp::min(self.exp.1, other.exp.1)
214 )}
215 }
216}
217
218impl<T: Ord> MulAssign for Bin<T> {
220 fn mul_assign(&mut self, other: Self) {
221 self.exp += other.exp;
222 }
223}
224
225impl<T: Ord + Clone> Mul for Bin<T> {
226 type Output = Bin<T>;
227 fn mul(self, a: Self) -> Self::Output {
228 Bin { exp: self.exp + a.exp }
229 }
230}
231
232impl<T: Ord + Clone> Mul for &Bin<T> {
233 type Output = Bin<T>;
234 fn mul(self, a: Self) -> Self::Output {
235 self.clone() * a.clone()
236 }
237}
238
239impl<T: Ord + Clone> Div for Bin<T> {
241 type Output = (Bin<T>, Bin<T>);
242
243 fn div(self, a: Self) -> Self::Output {
244 let (q, r) = self.exp - a.exp;
245 (Bin { exp: q }, Bin { exp: r })
246 }
247}
248
249impl<T: Ord + Clone> Div for &Bin<T> {
251 type Output = (Bin<T>, Bin<T>);
252
253 fn div(self, a: Self) -> Self::Output {
254 self.clone() / a.clone()
255 }
256}
257
258impl<T: Ord + fmt::Display + Clone> Specializable<T, u8> for Bin<T> {
260 fn specialize(&mut self, id: &T, val: u8) {
261 self.exp.specialize(id, val)
262 }
263 fn free_vars(&self) -> Set<&T> {
264 self.exp.0.keys()
265 }
266}
267
268impl<T: Ord + Clone> Normalizable for Bin<T> {
270 fn normalize(&mut self) {
271 self.exp.normalize();
272 }
273}
274impl<'a, D, A, T> Pretty<'a, D, A> for Lin<T>
278where
279 D: DocAllocator<'a, A>,
280 D::Doc: Clone,
281 A: 'a + Clone,
282 T: Pretty<'a, D, A> + Clone + Ord
283{
284 fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
285 if self.0.is_empty() {
286 allocator.text(format!("{}", self.1))
287 } else {
288 allocator.intersperse(
289 self.0.into_iter()
290 .map(|(k, v)|
291 if v == 0 {
292 allocator.nil()
293 } else if v == 1 {
294 k.pretty(allocator)
295 } else {
296 allocator.text(v.to_string()).append(k.pretty(allocator))
297 }), "+")
298 .append(
299 if self.1 == 0 {
300 allocator.nil()
301 } else {
302 allocator.text(format!("+{}", self.1))
303 })
304 }
305 }
306}
307
308impl<'a, T> fmt::Display for Lin<T>
310where
311 T: Pretty<'a, BoxAllocator, ()> + Clone + Ord
312{
313 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314 <Lin<T> as Pretty<'_, BoxAllocator, ()>>::pretty(self.clone(), &BoxAllocator)
315 .1
316 .render_fmt(100, f)
317 }
318}
319
320#[cfg(test)] use arbitrary::{Arbitrary, Unstructured};
322#[cfg(test)]
323impl<'a, T: Ord + Clone + Arbitrary<'a>> Arbitrary<'a> for Lin<T> {
324 fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
325 let mut l = Lin(Ctx::arbitrary(u)?, u.int_in_range(0..=9)?);
326 l.normalize();
327 Ok(l)
328 }
329}
330
331impl<'a, D, A, T> Pretty<'a, D, A> for Bin<T>
332where
333 D: DocAllocator<'a, A>,
334 D::Doc: Clone,
335 A: 'a + Clone,
336 T: Pretty<'a, D, A> + Clone + Ord
337{
338 fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
339 allocator.text("2^(")
340 .append(self.exp.pretty(allocator))
341 .append(allocator.text(")"))
342 }
343}
344
345impl<'a, T> fmt::Display for Bin<T>
347where
348 T: Pretty<'a, BoxAllocator, ()> + Clone + Ord
349{
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 <Bin<T> as Pretty<'_, BoxAllocator, ()>>::pretty(self.clone(), &BoxAllocator)
352 .1
353 .render_fmt(100, f)
354 }
355}
356
357#[cfg(test)]
359impl<'a, T: Ord + Clone + Arbitrary<'a>> Arbitrary<'a> for Bin<T> {
360 fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
361 Ok(Bin { exp: Lin::arbitrary(u)? })
362 }
363}
364
365#[test]
369fn test_lin_add() {
370 assert_eq!(
371 Lin::lit(1) + Lin::lit(2) + Lin::var("x"),
372 Lin::var("x") + Lin::lit(3)
373 )
374}
375
376#[test]
377fn test_lin_sub() {
378 assert_eq!((Lin::lit(3) + Lin::lit(2) + Lin::var("x"))
379 - (Lin::lit(2) + Lin::var("y") + Lin::var("x")),
380 (Lin::lit(3), Lin::var("y")));
381}
382
383#[test]
384fn test_leq_lin() {
385 assert_eq!(
386 Lin::leq(
387 &(Lin::lit(2) + Lin::var("a")),
388 &(Lin::term("a", 2) + Lin::var("b") + Lin::lit(4))
389 ),
390 true
391 );
392 assert_eq!(
393 Lin::leq(
394 &(Lin::lit(2) + Lin::var("c")),
395 &(Lin::term("a", 2) + Lin::var("b") + Lin::lit(4))
396 ),
397 false
398 );
399 assert_eq!(
400 Lin::leq(
401 &(Lin::term("a", 3) + Lin::var("b")),
402 &(Lin::term("a", 2) + Lin::var("b") + Lin::lit(4))
403 ),
404 false
405 );
406}
407
408#[test]
409fn test_lin_specialize() {
410 let l = Lin::var("x") + Lin::var("y") + Lin::lit(1);
411
412 let mut l1 = l.clone();
413 l1.specialize(&"x", 2);
414 assert_eq!(l1, Lin::var("y") + Lin::lit(3));
415
416 let mut l2 = l.clone();
417 l2.specialize(&"y", 2);
418 assert_eq!(l2, Lin::var("x") + Lin::lit(3));
419
420 let mut l3 = l.clone();
421 l3.specialize(&"z", 2);
422 assert_eq!(l, l3);
423}
424
425#[test]
429fn test_bin_mul() {
430 assert_eq!(
431 Bin::lit(1) * Bin::lit(2) * Bin::var("x"),
433 Bin::var("x") * Bin::lit(3)
434 )
435}
436
437#[test]
438fn test_bin_div() {
439 let a = Bin::lit(3) * Bin::lit(2) * Bin::var("x");
440 let b = Bin::lit(2) * Bin::var("y") * Bin::var("x");
441 assert_eq!(a / b, (Bin::lit(3), Bin::var("y")));
442}
443
444#[test]
445fn test_bin_lcm() {
446 let a = Bin::lit(3) * Bin::lit(2) * Bin::var("x");
447 let b = Bin::lit(2) * Bin::var("y") * Bin::var("x");
448 assert_eq!(a.lcm(&b), Bin::lit(3) * Bin::lit(2) * Bin::var("x") * Bin::var("y"));
449}
450
451#[test]
452fn test_bin_log2() {
453 assert_eq!(Bin::<&str>::log2(12), (Bin::lit(2), 3));
454 assert_eq!(Bin::<&str>::log2(-96), (Bin::lit(5), -3));
455}
456
457#[test]
458fn test_bin_specialize() {
459 let l = Bin::var("x") * Bin::var("y") * Bin::lit(1);
460
461 let mut l1 = l.clone();
462 l1.specialize(&"x", 2);
463 assert_eq!(l1, Bin::var("y") * Bin::lit(3));
464
465 let mut l2 = l.clone();
466 l2.specialize(&"y", 2);
467 assert_eq!(
468 l2,
469 Bin::var("x") * Bin::lit(3)
470 );
471
472 let mut l3 = l.clone();
473 l3.specialize(&"z", 2);
474 assert_eq!(l, l3);
475}
476
477#[cfg(test)] use arbtest::arbtest;
478#[cfg(test)] use crate::id::Id;
479#[cfg(test)] use crate::assert_eqn;
480
481#[test]
482fn test_lin_add_prop() {
483 arbtest(|u| {
485 let a = u.arbitrary::<Lin<Id>>()?;
486 let b = u.arbitrary::<Lin<Id>>()?;
487 let c = u.arbitrary::<Lin<Id>>()?;
488 assert_eq!(&a + &(&b + &c), &(&a + &b) + &c);
489 Ok(())
490 });
491
492 arbtest(|u| {
494 let a = u.arbitrary::<Lin<Id>>()?;
495 let b = u.arbitrary::<Lin<Id>>()?;
496 assert_eq!(&a + &b, &b + &a);
497 Ok(())
498 });
499
500 arbtest(|u| {
502 let a = u.arbitrary::<Lin<Id>>()?;
503 assert_eq!(&a + &Lin::default(), a);
504 assert_eq!(&Lin::default() + &a, a);
505 Ok(())
506 });
507}
508
509#[test]
510fn test_lin_sub_prop() {
511 arbtest(|u| {
513 let a = u.arbitrary::<Lin<Id>>()?;
514 assert_eq!(&a - &a, (Lin::default(), Lin::default()));
515 Ok(())
516 });
517 arbtest(|u| {
519 let a = u.arbitrary::<Lin<Id>>()?;
520 let b = u.arbitrary::<Lin<Id>>()?;
521 assert_eq!(&a + &b - a, (b, Lin::default()));
522 Ok(())
523 });
524 arbtest(|u| {
526 let a = u.arbitrary::<Lin<Id>>()?;
527 assert_eq!(&a - &Lin::default(), (a.clone(), Lin::default()));
528 assert_eq!(&Lin::default() - &a, (Lin::default(), a));
529 Ok(())
530 });
531}
532
533#[test]
534fn test_lin_leq_prop() {
535 arbtest(|u| {
537 let a = u.arbitrary::<Lin<Id>>()?;
538 assert!(a.leq(&a));
539 Ok(())
540 });
541 arbtest(|u| {
543 let a = u.arbitrary::<Lin<Id>>()?;
544 let b = u.arbitrary::<Lin<Id>>()?;
545 assert!(a.leq(&(&a + &b)));
547 assert!(b.leq(&(&a + &b)));
548 Ok(())
549 });
550}
551
552#[test]
553fn test_bin_mul_prop() {
554 arbtest(|u| {
556 let a = u.arbitrary::<Bin<Id>>()?;
557 let b = u.arbitrary::<Bin<Id>>()?;
558 assert_eqn!(&a * &b, &b * &a);
559 Ok(())
560 });
561 arbtest(|u| {
563 let a = u.arbitrary::<Bin<Id>>()?;
564 let b = u.arbitrary::<Bin<Id>>()?;
565 let c = u.arbitrary::<Bin<Id>>()?;
566 assert_eqn!(&a * &(&b * &c), &(&a * &b) * &c);
567 Ok(())
568 });
569 arbtest(|u| {
571 let a = u.arbitrary::<Bin<Id>>()?;
572 assert_eqn!(&a * &Bin::default(), &Bin::default() * &a);
573 Ok(())
574 });
575
576 arbtest(|u| {
578 let a = u.arbitrary::<Bin<Id>>()?;
579 assert_eq!(&a.clone().double() / &a, (Bin::lit(1), Bin::default()));
580 assert_eq!(&a.clone().double().half(), &Some(a));
581 Ok(())
582 });
583}
584
585#[test]
586fn test_bin_div_prop() {
587 arbtest(|u| {
589 let a = u.arbitrary::<Bin<Id>>()?;
590 assert_eq!(&a / &a, (Bin::default(), Bin::default()));
591 Ok(())
592 });
593 arbtest(|u| {
595 let a = u.arbitrary::<Bin<Id>>()?;
596 assert_eq!(&Bin::default() / &a, (Bin::default(), a.clone()));
597 assert_eq!(&a / &Bin::default(), (a, Bin::default()));
598 Ok(())
599 });
600 arbtest(|u| {
602 let a = u.arbitrary::<Bin<Id>>()?;
603 let b = u.arbitrary::<Bin<Id>>()?;
604 assert_eqn!((&(a.lcm(&b)) / &a).1, Bin::<Id>::default());
605 assert_eqn!((&(b.lcm(&a)) / &b).1, Bin::<Id>::default());
606 Ok(())
607 });
608}
609
610#[test]
611fn test_bin_leq_prop() {
612 arbtest(|u| {
614 let a = u.arbitrary::<Bin<Id>>()?;
615 assert!(a.leq(&a));
616 Ok(())
617 });
618 arbtest(|u| {
620 let a = u.arbitrary::<Bin<Id>>()?;
621 let b = u.arbitrary::<Bin<Id>>()?;
622 assert!(a.leq(&(&a * &b)));
623 assert!(b.leq(&(&a * &b)));
624 Ok(())
625 });
626 arbtest(|u| {
628 let a = u.arbitrary::<Bin<Id>>()?;
629 let b = u.arbitrary::<Bin<Id>>()?;
630 let (p, r) = &a / &b;
631 assert!(p.leq(&a));
632 assert!(r.leq(&b));
633 Ok(())
634 });
635}