morok_ir/pattern/
helpers.rs1use std::sync::Arc;
7
8use crate::ConstValue;
9
10use crate::{Op, UOp};
11
12#[inline]
14pub fn is_zero(uop: &Arc<UOp>) -> bool {
15 matches!(uop.op(), Op::Const(cv) if cv.0.is_zero())
16}
17
18#[inline]
20pub fn is_one(uop: &Arc<UOp>) -> bool {
21 matches!(uop.op(), Op::Const(cv) if cv.0.is_one())
22}
23
24#[inline]
26pub fn is_neg_one(uop: &Arc<UOp>) -> bool {
27 matches!(uop.op(), Op::Const(cv) if cv.0.is_neg_one())
28}
29
30#[inline]
32pub fn is_nonzero(uop: &Arc<UOp>) -> bool {
33 matches!(uop.op(), Op::Const(cv) if !cv.0.is_zero())
34}
35
36#[inline]
38pub fn try_const(uop: &Arc<UOp>) -> Option<&ConstValue> {
39 match uop.op() {
40 Op::Const(cv) => Some(&cv.0),
41 _ => None,
42 }
43}
44
45#[inline]
47pub fn is_vconst(uop: &Arc<UOp>) -> bool {
48 matches!(uop.op(), Op::VConst { .. })
49}
50
51#[inline]
58pub fn is_any_const(uop: &Arc<UOp>) -> bool {
59 match uop.op() {
60 Op::Const(_) | Op::VConst { .. } => true,
61 Op::Cast { src, .. }
62 | Op::BitCast { src, .. }
63 | Op::Reshape { src, .. }
64 | Op::Expand { src, .. }
65 | Op::Shrink { src, .. }
66 | Op::Pad { src, .. }
67 | Op::Permute { src, .. }
68 | Op::Flip { src, .. } => is_any_const(src),
69 _ => false,
70 }
71}
72
73#[inline]
75pub fn try_vconst(uop: &Arc<UOp>) -> Option<&Vec<ConstValue>> {
76 match uop.op() {
77 Op::VConst { values } => Some(values),
78 _ => None,
79 }
80}
81
82#[inline]
84pub fn try_any_const_values(uop: &Arc<UOp>) -> Option<Vec<ConstValue>> {
85 match uop.op() {
86 Op::Const(cv) => Some(vec![cv.0]),
87 Op::VConst { values } => Some(values.clone()),
88 _ => None,
89 }
90}
91
92#[inline]
94pub fn const_matches<F>(uop: &Arc<UOp>, predicate: F) -> bool
95where
96 F: FnOnce(&ConstValue) -> bool,
97{
98 match uop.op() {
99 Op::Const(cv) => predicate(&cv.0),
100 _ => false,
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::types::BinaryOp;
108 use morok_dtype::DType;
109
110 fn const_int(v: i64) -> Arc<UOp> {
111 UOp::const_(DType::Int32, ConstValue::Int(v))
112 }
113
114 #[test]
115 fn test_is_zero() {
116 let zero = const_int(0);
117 let one = const_int(1);
118 let five = const_int(5);
119
120 assert!(is_zero(&zero));
121 assert!(!is_zero(&one));
122 assert!(!is_zero(&five));
123 }
124
125 #[test]
126 fn test_is_one() {
127 let zero = const_int(0);
128 let one = const_int(1);
129 let five = const_int(5);
130
131 assert!(!is_one(&zero));
132 assert!(is_one(&one));
133 assert!(!is_one(&five));
134 }
135
136 #[test]
137 fn test_is_neg_one() {
138 let zero = const_int(0);
139 let neg_one = const_int(-1);
140 let one = const_int(1);
141
142 assert!(!is_neg_one(&zero));
143 assert!(is_neg_one(&neg_one));
144 assert!(!is_neg_one(&one));
145 }
146
147 #[test]
148 fn test_is_nonzero() {
149 let zero = const_int(0);
150 let one = const_int(1);
151 let neg_five = const_int(-5);
152
153 assert!(!is_nonzero(&zero));
154 assert!(is_nonzero(&one));
155 assert!(is_nonzero(&neg_five));
156 }
157
158 #[test]
159 fn test_try_const() {
160 let five = const_int(5);
161 assert!(try_const(&five).is_some());
162
163 let one = const_int(1);
164 let two = const_int(2);
165 let add = UOp::new(Op::Binary(BinaryOp::Add, one, two), DType::Int32);
166 assert!(try_const(&add).is_none());
167 }
168
169 #[test]
170 fn test_const_matches() {
171 let five = const_int(5);
172 let zero = const_int(0);
173
174 assert!(const_matches(&five, |cv| !cv.is_zero()));
175 assert!(const_matches(&zero, |cv| cv.is_zero()));
176 assert!(!const_matches(&five, |cv| cv.is_zero()));
177 }
178}