1use crate::ast::{AstNode, Node, NodeInner};
4use crate::pred::Predicate;
5use serde::{Deserialize, Deserializer};
6
7mod pred_spec;
8pub use pred_spec::PredSpec;
9
10#[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
13#[serde(deny_unknown_fields)]
14pub struct TestSpec {
15 pub from: String,
16 pub to: String,
17}
18
19impl TryFrom<TestSpec> for (Node, Node) {
20 type Error = String;
21
22 fn try_from(s: TestSpec) -> Result<Self, Self::Error> {
23 Ok((s.from.as_str().try_into()?, s.to.as_str().try_into()?))
24 }
25}
26
27fn deserialize_ast_str<'de, D>(deserializer: D) -> Result<Node, D::Error>
28where
29 D: Deserializer<'de>,
30{
31 let buf = String::deserialize(deserializer)?;
32
33 Node::try_from(buf.as_str()).map_err(serde::de::Error::custom)
34}
35
36#[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
38#[serde(deny_unknown_fields, untagged)]
39pub enum RuleActionSpec {
40 Replace {
42 from: Vec<usize>,
46 to: Vec<usize>,
50 },
51 Swap { swap: (Vec<usize>, Vec<usize>) },
53 Rewrite {
56 #[serde(deserialize_with = "deserialize_ast_str")]
57 base: Node,
58 replace: Vec<(Vec<usize>, Vec<usize>)>,
59 #[serde(default)]
60 fold: Vec<Vec<usize>>,
61 },
62 Multi(Vec<Self>),
64}
65
66impl RuleActionSpec {
67 pub fn apply<N: AstNode + From<Node>>(&self, n: &mut N) -> Result<(), ()> {
70 match self {
71 Self::Replace { from, to } => {
72 let from = n.get(from.iter().map(|i| *i)).ok_or(())?.clone();
73 let to = n.get_mut(to.iter().map(|i| *i)).ok_or(())?;
74 *to = from;
75 }
76 Self::Swap { swap: (a, b) } => {
77 let mut tmp = NodeInner::new_const(false);
78 let a_mut = n.get_mut(a.iter().map(|i| *i)).ok_or(())?;
79 std::mem::swap(&mut tmp, a_mut);
80
81 let b_mut = n.get_mut(b.iter().map(|i| *i)).ok_or(())?;
82 std::mem::swap(&mut tmp, b_mut);
83
84 let a_mut = n.get_mut(a.iter().map(|i| *i)).ok_or(())?;
85 std::mem::swap(&mut tmp, a_mut);
86 }
87 Self::Rewrite {
88 base,
89 replace,
90 fold,
91 } => {
92 let mut out: N = N::from(base.clone());
93
94 for (from, to) in replace {
95 let from = n.get(from.iter().map(|i| *i)).ok_or(())?.clone();
96 let to = out.get_mut(to.iter().map(|i| *i)).ok_or(())?;
97 *to = from;
98 }
99 for path in fold {
100 let target = out.get_mut(path.iter().map(|i| *i)).ok_or(())?;
101 crate::ast::fold(target).map_err(|_| ())?;
102 }
103
104 *n = out;
105 }
106 Self::Multi(v) => {
107 for r in v.iter() {
108 r.apply(n)?;
109 }
110 }
111 }
112 Ok(())
113 }
114}
115
116#[derive(Deserialize, Debug, Default, Clone, PartialEq, Eq)]
118#[serde(deny_unknown_fields)]
119pub struct MetaSpec {
120 #[serde(default)]
121 pub is_simplify: bool,
122 #[serde(default)]
123 pub alt_form: bool,
127}
128
129#[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
131#[serde(deny_unknown_fields)]
132pub struct RuleSpec {
133 #[serde(default)]
134 pub meta: MetaSpec,
135
136 #[serde(alias = "match")]
137 pub predicate: PredSpec,
138
139 #[serde(
140 alias = "replace",
141 alias = "actions",
142 alias = "swap",
143 alias = "rewrite"
144 )]
145 pub action: RuleActionSpec,
146
147 #[serde(default)]
148 pub tests: Vec<TestSpec>,
149}
150
151#[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
153#[serde(deny_unknown_fields)]
154pub struct RuleFileSpec {
155 pub rules: Vec<RuleSpec>,
156}
157
158#[derive(Debug, Clone, PartialEq, Eq)]
160pub struct Rule {
161 pub meta: MetaSpec,
162 pred: Predicate,
163 action: RuleActionSpec,
164 tests: Vec<(Node, Node)>,
165}
166
167impl Rule {
168 pub fn eval<N: AstNode + From<Node>>(&self, n: &mut N) -> Result<bool, ()> {
171 if self.pred.matches(n) {
172 self.action.apply(n)?;
173 Ok(true)
174 } else {
175 Ok(false)
176 }
177 }
178
179 pub fn self_test(&self) -> Result<(), (usize, String)> {
181 for (i, (from, to)) in self.tests.iter().enumerate() {
182 let mut n = from.clone();
183 match self.eval(&mut n) {
184 Ok(true) => {}
185 Ok(false) => return Err((i, "predicate did not match".to_string())),
186 Err(()) => return Err((i, "evaluation failed".to_string())),
187 }
188 if &n != to {
189 return Err((i, format!("got result {}", n)));
190 }
191 }
192 Ok(())
193 }
194}
195
196impl TryFrom<RuleSpec> for Rule {
197 type Error = String;
198
199 fn try_from(s: RuleSpec) -> Result<Self, Self::Error> {
200 Ok(Self {
201 meta: s.meta,
202 pred: s.predicate.try_into()?,
203 action: s.action,
204 tests: s
205 .tests
206 .into_iter()
207 .map(|s| <TestSpec as TryInto<(Node, Node)>>::try_into(s))
208 .collect::<Result<Vec<(Node, Node)>, String>>()?,
209 })
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::ast::BinaryOp;
217 use crate::pred::PredicateOp;
218 use crate::TyValue;
219 use toml::de;
220
221 #[test]
222 fn parse_spec() {
223 assert_eq!(
224 de::from_str::<RuleSpec>(
225 r#"
226 match.op = '/'
227 match.rhs = {const = '1'}
228
229 replace.from = [0] # using the numerator (first lhs)
230 replace.to = [] # overwrite self
231
232 tests = [
233 {from = '12 / 1', to = '12'},
234 ]
235 "#
236 ),
237 Ok(RuleSpec {
238 meta: MetaSpec::default(),
239 predicate: PredSpec {
240 op: Some("/".to_string()),
241 rhs: Some(Box::new(PredSpec {
242 const_val: Some("1".to_string()),
243 ..PredSpec::default()
244 })),
245 ..PredSpec::default()
246 },
247 action: RuleActionSpec::Replace {
248 from: vec![0],
249 to: vec![],
250 },
251 tests: vec![TestSpec {
252 from: "12 / 1".to_string(),
253 to: "12".to_string(),
254 }],
255 })
256 );
257 }
258
259 #[test]
260 fn convert_spec() {
261 assert_eq!(
262 de::from_str::<RuleSpec>(
263 r#"
264 match.op = '/'
265 match.rhs = {const = '1'}
266
267 replace.from = [0] # using the numerator (first lhs)
268 replace.to = [] # overwrite self
269
270 tests = [
271 {from = '12 / 1', to = '12'},
272 ]
273 "#
274 )
275 .unwrap()
276 .try_into(),
277 Ok(Rule {
278 meta: MetaSpec::default(),
279 pred: Predicate {
280 op: Some(PredicateOp::Binary(BinaryOp::Div)),
281 children: vec![
282 None,
283 Some(Predicate {
284 const_value: Some(TyValue::from(1)),
285 ..Predicate::op(PredicateOp::Const)
286 })
287 ],
288 ..Predicate::default()
289 },
290 action: RuleActionSpec::Replace {
291 from: vec![0],
292 to: vec![],
293 },
294 tests: vec![(
295 Node::try_from("12 / 1").unwrap(),
296 Node::try_from("12").unwrap()
297 ),],
298 })
299 );
300 }
301
302 #[test]
303 fn apply_replace() {
304 let rule: Rule = de::from_str::<RuleSpec>(
305 r#"
306 match.op = '/'
307 match.rhs = {const = '1'}
308
309 replace.from = [0] # using the numerator (first lhs)
310 replace.to = [] # overwrite self
311 "#,
312 )
313 .unwrap()
314 .try_into()
315 .unwrap();
316
317 let mut ast = Node::try_from("12 / 1").unwrap();
318 assert_eq!(rule.eval(&mut ast), Ok(true));
319 assert_eq!(ast, Node::try_from("12").unwrap());
320 }
321
322 #[test]
323 fn apply_swap() {
324 let rule: Rule = de::from_str::<RuleSpec>(
325 r#"
326 match.op = '*'
327 match.lhs = {not_op = 'const'}
328 match.rhs = {op = 'const'}
329
330 action.swap = [[0], [1]]
331 "#,
332 )
333 .unwrap()
334 .try_into()
335 .unwrap();
336
337 let mut ast = Node::try_from("x * 2").unwrap();
338 assert_eq!(rule.eval(&mut ast), Ok(true));
339 assert_eq!(ast, Node::try_from("2x").unwrap());
340 }
341
342 #[test]
343 fn apply_rewrite() {
344 let rule: Rule = de::from_str::<RuleSpec>(
345 r#"
346 match.op = '-'
347 match.lhs = {op = 'neg'}
348 match.rhs = {not_op = 'neg'}
349
350 [rewrite]
351 base = "-(1 + 1)"
352 replace = [
353 [[0, 0], [0, 0]],
354 [[1], [0, 1]],
355 ]
356 "#,
357 )
358 .unwrap()
359 .try_into()
360 .unwrap();
361
362 let mut ast = Node::try_from("-12 - 2").unwrap();
363 assert_eq!(rule.eval(&mut ast), Ok(true));
364 assert_eq!(ast, Node::try_from("-(12 + 2)").unwrap());
365 }
366
367 #[test]
368 fn selftest() {
369 let mut rule: Rule = de::from_str::<RuleSpec>(
370 r#"
371 match.op = '/'
372 match.rhs = {const = '1'}
373
374 actions = [
375 {from = [0], to = []},
376 ]
377
378 tests = [
379 {from = '12 / 1', to = '12'},
380 ]
381 "#,
382 )
383 .unwrap()
384 .try_into()
385 .unwrap();
386
387 assert_eq!(rule.self_test(), Ok(()));
388
389 rule.tests[0].1 = Node::try_from("4x").unwrap();
391 assert_eq!(rule.self_test(), Err((0, "got result 12".to_string())));
392 }
393}