calyx_ir/
guard.rs

1use crate::Printer;
2
3use super::{NumAttr, Port, RRC};
4use calyx_utils::Error;
5use std::fmt::Debug;
6use std::mem;
7use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, Not};
8use std::{cmp::Ordering, hash::Hash, rc::Rc};
9
10#[derive(Debug, Clone, Default, Eq, PartialEq)]
11#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
12pub struct Nothing;
13
14impl ToString for Nothing {
15    fn to_string(&self) -> String {
16        "".to_string()
17    }
18}
19
20/// Comparison operations that can be performed between ports by [Guard::CompOp].
21#[derive(Debug, Clone, PartialEq, Eq)]
22#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
23pub enum PortComp {
24    /// p1 == p2
25    Eq,
26    /// p1 != p2
27    Neq,
28    /// p1 > p2
29    Gt,
30    /// p1 < p2
31    Lt,
32    /// p1 >= p2
33    Geq,
34    /// p1 <= p2
35    Leq,
36}
37
38/// An assignment guard which has pointers to the various ports from which it reads.
39#[derive(Debug, Clone, Default)]
40#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
41pub enum Guard<T> {
42    /// Represents `c1 || c2`.
43    Or(Box<Guard<T>>, Box<Guard<T>>),
44    /// Represents `c1 && c2`.
45    And(Box<Guard<T>>, Box<Guard<T>>),
46    /// Represents `!c1`
47    Not(Box<Guard<T>>),
48    #[default]
49    /// The constant true
50    True,
51    /// Comparison operator.
52    CompOp(PortComp, RRC<Port>, RRC<Port>),
53    /// Uses the value on a port as the condition. Same as `p1 == true`
54    Port(RRC<Port>),
55    /// Other types of information.
56    Info(T),
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
61pub struct StaticTiming {
62    interval: (u64, u64),
63}
64
65impl ToString for StaticTiming {
66    fn to_string(&self) -> String {
67        if self.interval.0 + 1 == self.interval.1 {
68            format!("%{}", self.interval.0)
69        } else {
70            format!("%[{}:{}]", self.interval.0, self.interval.1)
71        }
72    }
73}
74
75impl StaticTiming {
76    /// creates a new `StaticTiming` struct
77    pub fn new(interval: (u64, u64)) -> Self {
78        StaticTiming { interval }
79    }
80
81    /// returns the (u64, u64) interval for `struct`
82    pub fn get_interval(&self) -> (u64, u64) {
83        self.interval
84    }
85
86    /// overwrites the current `interval` to be `new_interval`
87    pub fn set_interval(&mut self, new_interval: (u64, u64)) {
88        self.interval = new_interval;
89    }
90}
91
92impl<T> Hash for Guard<T>
93where
94    T: ToString,
95{
96    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
97        match self {
98            Guard::Or(l, r) | Guard::And(l, r) => {
99                l.hash(state);
100                r.hash(state)
101            }
102            Guard::CompOp(_, l, r) => {
103                l.borrow().name.hash(state);
104                l.borrow().get_parent_name().hash(state);
105                r.borrow().name.hash(state);
106                r.borrow().get_parent_name().hash(state);
107            }
108            Guard::Not(inner) => inner.hash(state),
109            Guard::Port(p) => {
110                p.borrow().name.hash(state);
111                p.borrow().get_parent_name().hash(state);
112            }
113            Guard::True => {}
114            Guard::Info(i) => i.to_string().hash(state),
115        }
116    }
117}
118
119impl From<Guard<Nothing>> for Guard<StaticTiming> {
120    /// Turns a normal guard into a static guard
121    fn from(g: Guard<Nothing>) -> Self {
122        match g {
123            Guard::Or(left, right) => {
124                let l = Self::from(*left);
125                let r = Self::from(*right);
126                Guard::Or(Box::new(l), Box::new(r))
127            }
128            Guard::And(left, right) => {
129                let l = Self::from(*left);
130                let r = Self::from(*right);
131                Guard::And(Box::new(l), Box::new(r))
132            }
133            Guard::Not(c) => {
134                let inside = Self::from(*c);
135                Guard::Not(Box::new(inside))
136            }
137            Guard::True => Guard::True,
138            Guard::CompOp(pc, left, right) => Guard::CompOp(pc, left, right),
139            Guard::Port(p) => Guard::Port(p),
140            Guard::Info(_) => {
141                unreachable!(
142                    "{:?}: Guard<Nothing> should not be of the
143                info variant type",
144                    g
145                )
146            }
147        }
148    }
149}
150
151impl<T> Guard<T> {
152    /// Returns true definitely `Guard::True`.
153    /// Returning false does not mean that the guard is not true.
154    pub fn is_true(&self) -> bool {
155        match self {
156            Guard::True => true,
157            Guard::Port(p) => p.borrow().is_constant(1, 1),
158            _ => false,
159        }
160    }
161
162    /// Checks if the guard is always false.
163    /// Returning false does not mean that the guard is not false.
164    pub fn is_false(&self) -> bool {
165        match self {
166            Guard::Not(g) => g.is_true(),
167            _ => false,
168        }
169    }
170
171    /// returns true if the self is !cell_name, false otherwise.
172    pub fn is_not_done(&self, cell_name: &crate::Id) -> bool {
173        if let Guard::Not(g) = self {
174            if let Guard::Port(port) = &(**g) {
175                return port.borrow().attributes.has(NumAttr::Done)
176                    && port.borrow().get_parent_name() == cell_name;
177            }
178        }
179        false
180    }
181
182    /// Update the guard in place. Replaces this guard with `upd(self)`.
183    /// Uses `std::mem::take` for the in-place update.
184    #[inline(always)]
185    pub fn update<F>(&mut self, upd: F)
186    where
187        F: FnOnce(Guard<T>) -> Guard<T>,
188    {
189        let old = mem::take(self);
190        let new = upd(old);
191        *self = new;
192    }
193
194    /// Return the string corresponding to the guard operation.
195    pub fn op_str(&self) -> String {
196        match self {
197            Guard::And(..) => "&".to_string(),
198            Guard::Or(..) => "|".to_string(),
199            Guard::CompOp(op, _, _) => match op {
200                PortComp::Eq => "==".to_string(),
201                PortComp::Neq => "!=".to_string(),
202                PortComp::Gt => ">".to_string(),
203                PortComp::Lt => "<".to_string(),
204                PortComp::Geq => ">=".to_string(),
205                PortComp::Leq => "<=".to_string(),
206            },
207            Guard::Not(..) => "!".to_string(),
208            Guard::Port(..) | Guard::True | Guard::Info(_) => {
209                panic!("No operator string for Guard::Port/True/Info")
210            }
211        }
212    }
213
214    pub fn port(p: RRC<Port>) -> Self {
215        if p.borrow().is_constant(1, 1) {
216            Guard::True
217        } else {
218            Guard::Port(p)
219        }
220    }
221
222    pub fn and(self, rhs: Guard<T>) -> Self
223    where
224        T: Eq,
225    {
226        if rhs == Guard::True {
227            self
228        } else if self == Guard::True {
229            rhs
230        } else if self == rhs {
231            self
232        } else {
233            Guard::And(Box::new(self), Box::new(rhs))
234        }
235    }
236
237    pub fn or(self, rhs: Guard<T>) -> Self
238    where
239        T: Eq,
240    {
241        match (self, rhs) {
242            (Guard::True, _) | (_, Guard::True) => Guard::True,
243            (Guard::Not(n), g) | (g, Guard::Not(n)) => {
244                if *n == Guard::True {
245                    g
246                } else {
247                    Guard::Or(Box::new(Guard::Not(n)), Box::new(g))
248                }
249            }
250            (l, r) => {
251                if l == r {
252                    l
253                } else {
254                    Guard::Or(Box::new(l), Box::new(r))
255                }
256            }
257        }
258    }
259
260    pub fn eq(self, other: Guard<T>) -> Self
261    where
262        T: Debug + Eq + ToString,
263    {
264        match (self, other) {
265            (Guard::Port(l), Guard::Port(r)) => {
266                Guard::CompOp(PortComp::Eq, l, r)
267            }
268            (l, r) => {
269                unreachable!(
270                    "Cannot build Guard::Eq using `{}' and `{}'",
271                    Printer::guard_str(&l),
272                    Printer::guard_str(&r),
273                )
274            }
275        }
276    }
277
278    pub fn neq(self, other: Guard<T>) -> Self
279    where
280        T: Debug + Eq + ToString,
281    {
282        match (self, other) {
283            (Guard::Port(l), Guard::Port(r)) => {
284                Guard::CompOp(PortComp::Neq, l, r)
285            }
286            (l, r) => {
287                unreachable!(
288                    "Cannot build Guard::Eq using `{}' and `{}'",
289                    Printer::guard_str(&l),
290                    Printer::guard_str(&r),
291                )
292            }
293        }
294    }
295
296    pub fn le(self, other: Guard<T>) -> Self
297    where
298        T: Debug + Eq + ToString,
299    {
300        match (self, other) {
301            (Guard::Port(l), Guard::Port(r)) => {
302                Guard::CompOp(PortComp::Leq, l, r)
303            }
304            (l, r) => {
305                unreachable!(
306                    "Cannot build Guard::Eq using `{}' and `{}'",
307                    Printer::guard_str(&l),
308                    Printer::guard_str(&r),
309                )
310            }
311        }
312    }
313
314    pub fn lt(self, other: Guard<T>) -> Self
315    where
316        T: Debug + Eq + ToString,
317    {
318        match (self, other) {
319            (Guard::Port(l), Guard::Port(r)) => {
320                Guard::CompOp(PortComp::Lt, l, r)
321            }
322            (l, r) => {
323                unreachable!(
324                    "Cannot build Guard::Eq using `{}' and `{}'",
325                    Printer::guard_str(&l),
326                    Printer::guard_str(&r),
327                )
328            }
329        }
330    }
331
332    pub fn ge(self, other: Guard<T>) -> Self
333    where
334        T: Debug + Eq + ToString,
335    {
336        match (self, other) {
337            (Guard::Port(l), Guard::Port(r)) => {
338                Guard::CompOp(PortComp::Geq, l, r)
339            }
340            (l, r) => {
341                unreachable!(
342                    "Cannot build Guard::Eq using `{}' and `{}'",
343                    Printer::guard_str(&l),
344                    Printer::guard_str(&r),
345                )
346            }
347        }
348    }
349
350    pub fn gt(self, other: Guard<T>) -> Self
351    where
352        T: Debug + Eq + ToString,
353    {
354        match (self, other) {
355            (Guard::Port(l), Guard::Port(r)) => {
356                Guard::CompOp(PortComp::Gt, l, r)
357            }
358            (l, r) => {
359                unreachable!(
360                    "Cannot build Guard::Eq using `{}' and `{}'",
361                    Printer::guard_str(&l),
362                    Printer::guard_str(&r),
363                )
364            }
365        }
366    }
367
368    /// Returns all the ports used by this guard.
369    pub fn all_ports(&self) -> Vec<RRC<Port>> {
370        match self {
371            Guard::Port(a) => vec![Rc::clone(a)],
372            Guard::And(l, r) | Guard::Or(l, r) => {
373                let mut atoms = l.all_ports();
374                atoms.append(&mut r.all_ports());
375                atoms
376            }
377            Guard::CompOp(_, l, r) => {
378                vec![Rc::clone(l), Rc::clone(r)]
379            }
380            Guard::Not(g) => g.all_ports(),
381            Guard::True => vec![],
382            Guard::Info(_) => vec![],
383        }
384    }
385}
386
387/// Helper functions for the guard.
388impl<T> Guard<T> {
389    /// Mutates a guard by calling `f` on every leaf in the
390    /// guard tree and replacing the leaf with the guard that `f`
391    /// returns.
392    pub fn for_each<F>(&mut self, f: &mut F)
393    where
394        F: FnMut(RRC<Port>) -> Option<Guard<T>>,
395    {
396        match self {
397            Guard::And(l, r) | Guard::Or(l, r) => {
398                l.for_each(f);
399                r.for_each(f);
400            }
401            Guard::Not(inner) => {
402                inner.for_each(f);
403            }
404            Guard::CompOp(_, l, r) => {
405                match f(Rc::clone(l)) {
406                    Some(Guard::Port(p)) => *l = p,
407                    Some(_) => unreachable!(
408                        "Cannot replace port inside comparison operator"
409                    ),
410                    None => {}
411                }
412                match f(Rc::clone(r)) {
413                    Some(Guard::Port(p)) => *r = p,
414                    Some(_) => unreachable!(
415                        "Cannot replace port inside comparison operator"
416                    ),
417                    None => {}
418                }
419            }
420            Guard::Port(port) => {
421                let guard = f(Rc::clone(port))
422                    .unwrap_or_else(|| Guard::port(Rc::clone(port)));
423                *self = guard;
424            }
425            Guard::True => {}
426            Guard::Info(_) =>
427                // Info shouldn't count as port
428                {}
429        }
430    }
431
432    /// runs f(info) on each Guard::Info in `guard`.
433    /// if `f(info)` = Some(result)` replaces interval with result.
434    /// if `f(info)` = None` does nothing.
435    pub fn for_each_info<F>(&mut self, f: &mut F)
436    where
437        F: FnMut(&mut T) -> Option<Guard<T>>,
438    {
439        match self {
440            Guard::And(l, r) | Guard::Or(l, r) => {
441                l.for_each_info(f);
442                r.for_each_info(f);
443            }
444            Guard::Not(inner) => {
445                inner.for_each_info(f);
446            }
447            Guard::True | Guard::Port(_) | Guard::CompOp(_, _, _) => {}
448            Guard::Info(timing_interval) => {
449                if let Some(new_interval) = f(timing_interval) {
450                    *self = new_interval
451                }
452            }
453        }
454    }
455
456    /// runs f(info) on each info in `guard`.
457    /// f should return Result<(), Error>, meaning that it essentially does
458    /// nothing if the `f` returns OK(()), but returns an appropraite error otherwise
459    pub fn check_for_each_info<F>(&self, f: &mut F) -> Result<(), Error>
460    where
461        F: Fn(&T) -> Result<(), Error>,
462    {
463        match self {
464            Guard::And(l, r) | Guard::Or(l, r) => {
465                let l_result = l.check_for_each_info(f);
466                if l_result.is_err() {
467                    l_result
468                } else {
469                    r.check_for_each_info(f)
470                }
471            }
472            Guard::Not(inner) => inner.check_for_each_info(f),
473            Guard::True | Guard::Port(_) | Guard::CompOp(_, _, _) => Ok(()),
474            Guard::Info(timing_interval) => f(timing_interval),
475        }
476    }
477}
478
479impl Guard<StaticTiming> {
480    /// updates self -> self & interval
481    pub fn add_interval(&mut self, timing_interval: StaticTiming) {
482        self.update(|g| g.and(Guard::Info(timing_interval)));
483    }
484}
485
486/// Construct guards from ports
487impl<T> From<RRC<Port>> for Guard<T> {
488    fn from(port: RRC<Port>) -> Self {
489        Guard::Port(Rc::clone(&port))
490    }
491}
492
493impl<T> PartialEq for Guard<T>
494where
495    T: Eq,
496{
497    fn eq(&self, other: &Self) -> bool {
498        match (self, other) {
499            (Guard::Or(la, ra), Guard::Or(lb, rb))
500            | (Guard::And(la, ra), Guard::And(lb, rb)) => la == lb && ra == rb,
501            (Guard::CompOp(opa, la, ra), Guard::CompOp(opb, lb, rb)) => {
502                (opa == opb)
503                    && (la.borrow().get_parent_name(), &la.borrow().name)
504                        == (lb.borrow().get_parent_name(), &lb.borrow().name)
505                    && (ra.borrow().get_parent_name(), &ra.borrow().name)
506                        == (rb.borrow().get_parent_name(), &rb.borrow().name)
507            }
508            (Guard::Not(a), Guard::Not(b)) => a == b,
509            (Guard::Port(a), Guard::Port(b)) => {
510                (a.borrow().get_parent_name(), &a.borrow().name)
511                    == (b.borrow().get_parent_name(), &b.borrow().name)
512            }
513            (Guard::True, Guard::True) => true,
514            (Guard::Info(i1), Guard::Info(i2)) => i1 == i2,
515            _ => false,
516        }
517    }
518}
519
520impl<T> Eq for Guard<T> where T: Eq {}
521
522/// Define order on guards
523impl<T> PartialOrd for Guard<T>
524where
525    T: Eq,
526{
527    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
528        Some(self.cmp(other))
529    }
530}
531
532/// Define an ordering on the precedence of guards. Guards are
533/// considered equal when they have the same precedence.
534impl<T> Ord for Guard<T>
535where
536    T: Eq,
537{
538    fn cmp(&self, other: &Self) -> Ordering {
539        match (self, other) {
540            (Guard::Or(..), Guard::Or(..))
541            | (Guard::And(..), Guard::And(..))
542            | (Guard::CompOp(..), Guard::CompOp(..))
543            | (Guard::Not(..), Guard::Not(..))
544            | (Guard::Port(..), Guard::Port(..))
545            | (Guard::Info(_), Guard::Info(_))
546            | (Guard::True, Guard::True) => Ordering::Equal,
547            (Guard::Or(..), _) => Ordering::Greater,
548            (_, Guard::Or(..)) => Ordering::Less,
549            (Guard::And(..), _) => Ordering::Greater,
550            (_, Guard::And(..)) => Ordering::Less,
551            (Guard::CompOp(PortComp::Leq, ..), _) => Ordering::Greater,
552            (_, Guard::CompOp(PortComp::Leq, ..)) => Ordering::Less,
553            (Guard::CompOp(PortComp::Geq, ..), _) => Ordering::Greater,
554            (_, Guard::CompOp(PortComp::Geq, ..)) => Ordering::Less,
555            (Guard::CompOp(PortComp::Lt, ..), _) => Ordering::Greater,
556            (_, Guard::CompOp(PortComp::Lt, ..)) => Ordering::Less,
557            (Guard::CompOp(PortComp::Gt, ..), _) => Ordering::Greater,
558            (_, Guard::CompOp(PortComp::Gt, ..)) => Ordering::Less,
559            (Guard::CompOp(PortComp::Eq, ..), _) => Ordering::Greater,
560            (_, Guard::CompOp(PortComp::Eq, ..)) => Ordering::Less,
561            (Guard::CompOp(PortComp::Neq, ..), _) => Ordering::Greater,
562            (_, Guard::CompOp(PortComp::Neq, ..)) => Ordering::Less,
563            (Guard::Not(..), _) => Ordering::Greater,
564            (_, Guard::Not(..)) => Ordering::Less,
565            (Guard::Port(..), _) => Ordering::Greater,
566            (_, Guard::Port(..)) => Ordering::Less,
567            // maybe we should change this?
568            (Guard::Info(..), _) => Ordering::Greater,
569            (_, Guard::Info(..)) => Ordering::Less,
570        }
571    }
572}
573
574/////////////// Sugar for convience constructors /////////////
575
576/// Construct a Guard::And:
577/// ```
578/// let and_guard = g1 & g2;
579/// ```
580impl<T> BitAnd for Guard<T>
581where
582    T: Eq,
583{
584    type Output = Self;
585
586    fn bitand(self, other: Self) -> Self::Output {
587        self.and(other)
588    }
589}
590
591/// Construct a Guard::Or:
592/// ```
593/// let or_guard = g1 | g2;
594/// ```
595impl<T> BitOr for Guard<T>
596where
597    T: Eq,
598{
599    type Output = Self;
600
601    fn bitor(self, other: Self) -> Self::Output {
602        self.or(other)
603    }
604}
605
606/// Construct a Guard::Or:
607/// ```
608/// let not_guard = !g1;
609/// ```
610impl<T> Not for Guard<T> {
611    type Output = Self;
612
613    fn not(self) -> Self {
614        match self {
615            Guard::CompOp(PortComp::Eq, lhs, rhs) => {
616                Guard::CompOp(PortComp::Neq, lhs, rhs)
617            }
618            Guard::CompOp(PortComp::Neq, lhs, rhs) => {
619                Guard::CompOp(PortComp::Eq, lhs, rhs)
620            }
621            Guard::CompOp(PortComp::Gt, lhs, rhs) => {
622                Guard::CompOp(PortComp::Leq, lhs, rhs)
623            }
624            Guard::CompOp(PortComp::Lt, lhs, rhs) => {
625                Guard::CompOp(PortComp::Geq, lhs, rhs)
626            }
627            Guard::CompOp(PortComp::Geq, lhs, rhs) => {
628                Guard::CompOp(PortComp::Lt, lhs, rhs)
629            }
630            Guard::CompOp(PortComp::Leq, lhs, rhs) => {
631                Guard::CompOp(PortComp::Gt, lhs, rhs)
632            }
633            Guard::Not(expr) => *expr,
634            _ => Guard::Not(Box::new(self)),
635        }
636    }
637}
638
639/// Update a Guard with Or.
640/// ```
641/// g1 |= g2;
642/// ```
643impl<T> BitOrAssign for Guard<T>
644where
645    T: Eq,
646{
647    fn bitor_assign(&mut self, other: Self) {
648        self.update(|old| old | other)
649    }
650}
651
652/// Update a Guard with Or.
653/// ```
654/// g1 &= g2;
655/// ```
656impl<T> BitAndAssign for Guard<T>
657where
658    T: Eq,
659{
660    fn bitand_assign(&mut self, other: Self) {
661        self.update(|old| old & other)
662    }
663}