Skip to main content

morok_ir/pattern/
helpers.rs

1//! Helper functions for pattern matching.
2//!
3//! These functions are used by the generated pattern matching code to check
4//! common conditions like zero/one constants.
5
6use std::sync::Arc;
7
8use crate::ConstValue;
9
10use crate::{Op, UOp};
11
12/// Check if a UOp is a zero constant.
13#[inline]
14pub fn is_zero(uop: &Arc<UOp>) -> bool {
15    matches!(uop.op(), Op::Const(cv) if cv.0.is_zero())
16}
17
18/// Check if a UOp is a one constant.
19#[inline]
20pub fn is_one(uop: &Arc<UOp>) -> bool {
21    matches!(uop.op(), Op::Const(cv) if cv.0.is_one())
22}
23
24/// Check if a UOp is a negative one constant.
25#[inline]
26pub fn is_neg_one(uop: &Arc<UOp>) -> bool {
27    matches!(uop.op(), Op::Const(cv) if cv.0.is_neg_one())
28}
29
30/// Check if a UOp is a non-zero constant.
31#[inline]
32pub fn is_nonzero(uop: &Arc<UOp>) -> bool {
33    matches!(uop.op(), Op::Const(cv) if !cv.0.is_zero())
34}
35
36/// Extract const value if present.
37#[inline]
38pub fn try_const(uop: &Arc<UOp>) -> Option<&ConstValue> {
39    match uop.op() {
40        Op::Const(cv) => Some(&cv.0),
41        _ => None,
42    }
43}
44
45/// Check if a UOp is a VConst (vector constant).
46#[inline]
47pub fn is_vconst(uop: &Arc<UOp>) -> bool {
48    matches!(uop.op(), Op::VConst { .. })
49}
50
51/// Check if a UOp is a pure constant tree (no buffer references).
52///
53/// Returns true for bare CONST/VCONST, and also for unary transformations
54/// of constants (e.g., CAST(CONST), BITCAST(CONST), RESHAPE(CONST),
55/// EXPAND(CONST)). These trees have no buffer backing and need
56/// `.contiguous()` wrapping before realization.
57#[inline]
58pub fn is_any_const(uop: &Arc<UOp>) -> bool {
59    match uop.op() {
60        Op::Const(_) | Op::VConst { .. } => true,
61        Op::Cast { src, .. }
62        | Op::BitCast { src, .. }
63        | Op::Reshape { src, .. }
64        | Op::Expand { src, .. }
65        | Op::Shrink { src, .. }
66        | Op::Pad { src, .. }
67        | Op::Permute { src, .. }
68        | Op::Flip { src, .. } => is_any_const(src),
69        _ => false,
70    }
71}
72
73/// Extract VConst values if present.
74#[inline]
75pub fn try_vconst(uop: &Arc<UOp>) -> Option<&Vec<ConstValue>> {
76    match uop.op() {
77        Op::VConst { values } => Some(values),
78        _ => None,
79    }
80}
81
82/// Extract values from any constant (Const returns single-element slice, VConst returns full slice).
83#[inline]
84pub fn try_any_const_values(uop: &Arc<UOp>) -> Option<Vec<ConstValue>> {
85    match uop.op() {
86        Op::Const(cv) => Some(vec![cv.0]),
87        Op::VConst { values } => Some(values.clone()),
88        _ => None,
89    }
90}
91
92/// Check if a UOp matches a constant predicate.
93#[inline]
94pub fn const_matches<F>(uop: &Arc<UOp>, predicate: F) -> bool
95where
96    F: FnOnce(&ConstValue) -> bool,
97{
98    match uop.op() {
99        Op::Const(cv) => predicate(&cv.0),
100        _ => false,
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use crate::types::BinaryOp;
108    use morok_dtype::DType;
109
110    fn const_int(v: i64) -> Arc<UOp> {
111        UOp::const_(DType::Int32, ConstValue::Int(v))
112    }
113
114    #[test]
115    fn test_is_zero() {
116        let zero = const_int(0);
117        let one = const_int(1);
118        let five = const_int(5);
119
120        assert!(is_zero(&zero));
121        assert!(!is_zero(&one));
122        assert!(!is_zero(&five));
123    }
124
125    #[test]
126    fn test_is_one() {
127        let zero = const_int(0);
128        let one = const_int(1);
129        let five = const_int(5);
130
131        assert!(!is_one(&zero));
132        assert!(is_one(&one));
133        assert!(!is_one(&five));
134    }
135
136    #[test]
137    fn test_is_neg_one() {
138        let zero = const_int(0);
139        let neg_one = const_int(-1);
140        let one = const_int(1);
141
142        assert!(!is_neg_one(&zero));
143        assert!(is_neg_one(&neg_one));
144        assert!(!is_neg_one(&one));
145    }
146
147    #[test]
148    fn test_is_nonzero() {
149        let zero = const_int(0);
150        let one = const_int(1);
151        let neg_five = const_int(-5);
152
153        assert!(!is_nonzero(&zero));
154        assert!(is_nonzero(&one));
155        assert!(is_nonzero(&neg_five));
156    }
157
158    #[test]
159    fn test_try_const() {
160        let five = const_int(5);
161        assert!(try_const(&five).is_some());
162
163        let one = const_int(1);
164        let two = const_int(2);
165        let add = UOp::new(Op::Binary(BinaryOp::Add, one, two), DType::Int32);
166        assert!(try_const(&add).is_none());
167    }
168
169    #[test]
170    fn test_const_matches() {
171        let five = const_int(5);
172        let zero = const_int(0);
173
174        assert!(const_matches(&five, |cv| !cv.is_zero()));
175        assert!(const_matches(&zero, |cv| cv.is_zero()));
176        assert!(!const_matches(&five, |cv| cv.is_zero()));
177    }
178}