1use std::collections::HashMap;
9use std::fmt;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct RuleId(pub u64);
14
15impl fmt::Display for RuleId {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 write!(f, "Rule({})", self.0)
18 }
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum PatternKind {
24 SingleNode {
26 filter_type: String,
28 },
29 Chain {
31 first: String,
33 second: String,
35 },
36 WithProperty {
38 filter_type: String,
40 property_key: String,
42 property_value: String,
44 },
45}
46
47impl fmt::Display for PatternKind {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 match self {
50 Self::SingleNode { filter_type } => write!(f, "Single({filter_type})"),
51 Self::Chain { first, second } => write!(f, "Chain({first} -> {second})"),
52 Self::WithProperty {
53 filter_type,
54 property_key,
55 property_value,
56 } => {
57 write!(f, "{filter_type}[{property_key}={property_value}]")
58 }
59 }
60 }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
65pub enum RewriteAction {
66 Remove,
68 ReplaceWith {
70 filter_type: String,
72 properties: HashMap<String, String>,
74 },
75 Fuse {
77 fused_type: String,
79 },
80 Swap,
82}
83
84impl fmt::Display for RewriteAction {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 match self {
87 Self::Remove => write!(f, "Remove"),
88 Self::ReplaceWith { filter_type, .. } => write!(f, "ReplaceWith({filter_type})"),
89 Self::Fuse { fused_type } => write!(f, "Fuse({fused_type})"),
90 Self::Swap => write!(f, "Swap"),
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct RewriteRule {
98 pub id: RuleId,
100 pub name: String,
102 pub pattern: PatternKind,
104 pub action: RewriteAction,
106 pub priority: i32,
108 pub enabled: bool,
110}
111
112impl RewriteRule {
113 pub fn new(id: RuleId, name: &str, pattern: PatternKind, action: RewriteAction) -> Self {
115 Self {
116 id,
117 name: name.to_string(),
118 pattern,
119 action,
120 priority: 0,
121 enabled: true,
122 }
123 }
124
125 pub fn with_priority(mut self, priority: i32) -> Self {
127 self.priority = priority;
128 self
129 }
130
131 pub fn set_enabled(&mut self, enabled: bool) {
133 self.enabled = enabled;
134 }
135
136 pub fn matches_node(&self, filter_type: &str, properties: &HashMap<String, String>) -> bool {
138 if !self.enabled {
139 return false;
140 }
141 match &self.pattern {
142 PatternKind::SingleNode { filter_type: ft } => ft == filter_type,
143 PatternKind::WithProperty {
144 filter_type: ft,
145 property_key,
146 property_value,
147 } => {
148 ft == filter_type
149 && properties
150 .get(property_key)
151 .map_or(false, |v| v == property_value)
152 }
153 PatternKind::Chain { first, .. } => first == filter_type,
154 }
155 }
156
157 pub fn matches_chain(&self, first_type: &str, second_type: &str) -> bool {
159 if !self.enabled {
160 return false;
161 }
162 match &self.pattern {
163 PatternKind::Chain { first, second } => first == first_type && second == second_type,
164 _ => false,
165 }
166 }
167}
168
169impl fmt::Display for RewriteRule {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 write!(
172 f,
173 "{}[{}]: {} -> {}",
174 self.name, self.id, self.pattern, self.action
175 )
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct RewriteEvent {
182 pub rule_id: RuleId,
184 pub rule_name: String,
186 pub matched: String,
188 pub action: String,
190}
191
192pub struct RewriteEngine {
194 rules: Vec<RewriteRule>,
196 history: Vec<RewriteEvent>,
198 max_passes: u32,
200}
201
202impl RewriteEngine {
203 pub fn new() -> Self {
205 Self {
206 rules: Vec::new(),
207 history: Vec::new(),
208 max_passes: 100,
209 }
210 }
211
212 pub fn set_max_passes(&mut self, max: u32) {
214 self.max_passes = max;
215 }
216
217 pub fn max_passes(&self) -> u32 {
219 self.max_passes
220 }
221
222 pub fn add_rule(&mut self, rule: RewriteRule) {
224 self.rules.push(rule);
225 self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
226 }
227
228 pub fn rule_count(&self) -> usize {
230 self.rules.len()
231 }
232
233 pub fn get_rule(&self, id: RuleId) -> Option<&RewriteRule> {
235 self.rules.iter().find(|r| r.id == id)
236 }
237
238 pub fn get_rule_mut(&mut self, id: RuleId) -> Option<&mut RewriteRule> {
240 self.rules.iter_mut().find(|r| r.id == id)
241 }
242
243 pub fn find_matches(
245 &self,
246 filter_type: &str,
247 properties: &HashMap<String, String>,
248 ) -> Vec<&RewriteRule> {
249 self.rules
250 .iter()
251 .filter(|r| r.matches_node(filter_type, properties))
252 .collect()
253 }
254
255 pub fn find_chain_matches(&self, first_type: &str, second_type: &str) -> Vec<&RewriteRule> {
257 self.rules
258 .iter()
259 .filter(|r| r.matches_chain(first_type, second_type))
260 .collect()
261 }
262
263 pub fn record_event(&mut self, rule: &RewriteRule, matched: &str) {
265 self.history.push(RewriteEvent {
266 rule_id: rule.id,
267 rule_name: rule.name.clone(),
268 matched: matched.to_string(),
269 action: format!("{}", rule.action),
270 });
271 }
272
273 pub fn history(&self) -> &[RewriteEvent] {
275 &self.history
276 }
277
278 pub fn clear_history(&mut self) {
280 self.history.clear();
281 }
282
283 pub fn remove_rule(&mut self, id: RuleId) -> bool {
285 let len_before = self.rules.len();
286 self.rules.retain(|r| r.id != id);
287 self.rules.len() < len_before
288 }
289
290 pub fn enable_all(&mut self) {
292 for rule in &mut self.rules {
293 rule.enabled = true;
294 }
295 }
296
297 pub fn disable_all(&mut self) {
299 for rule in &mut self.rules {
300 rule.enabled = false;
301 }
302 }
303}
304
305impl Default for RewriteEngine {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311pub fn standard_rules() -> Vec<RewriteRule> {
313 vec![
314 RewriteRule::new(
316 RuleId(1),
317 "identity_scale",
318 PatternKind::WithProperty {
319 filter_type: "scale".to_string(),
320 property_key: "factor".to_string(),
321 property_value: "1.0".to_string(),
322 },
323 RewriteAction::Remove,
324 )
325 .with_priority(100),
326 RewriteRule::new(
328 RuleId(2),
329 "scale_fusion",
330 PatternKind::Chain {
331 first: "scale".to_string(),
332 second: "scale".to_string(),
333 },
334 RewriteAction::Fuse {
335 fused_type: "scale".to_string(),
336 },
337 )
338 .with_priority(90),
339 RewriteRule::new(
341 RuleId(3),
342 "crop_fusion",
343 PatternKind::Chain {
344 first: "crop".to_string(),
345 second: "crop".to_string(),
346 },
347 RewriteAction::Fuse {
348 fused_type: "crop".to_string(),
349 },
350 )
351 .with_priority(90),
352 ]
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_rule_id_display() {
361 assert_eq!(format!("{}", RuleId(42)), "Rule(42)");
362 }
363
364 #[test]
365 fn test_pattern_kind_display() {
366 let p = PatternKind::SingleNode {
367 filter_type: "scale".to_string(),
368 };
369 assert_eq!(format!("{p}"), "Single(scale)");
370 }
371
372 #[test]
373 fn test_chain_pattern_display() {
374 let p = PatternKind::Chain {
375 first: "a".to_string(),
376 second: "b".to_string(),
377 };
378 assert_eq!(format!("{p}"), "Chain(a -> b)");
379 }
380
381 #[test]
382 fn test_rewrite_action_display() {
383 assert_eq!(format!("{}", RewriteAction::Remove), "Remove");
384 assert_eq!(format!("{}", RewriteAction::Swap), "Swap");
385 assert_eq!(
386 format!(
387 "{}",
388 RewriteAction::Fuse {
389 fused_type: "x".to_string()
390 }
391 ),
392 "Fuse(x)"
393 );
394 }
395
396 #[test]
397 fn test_rewrite_rule_new() {
398 let rule = RewriteRule::new(
399 RuleId(1),
400 "test",
401 PatternKind::SingleNode {
402 filter_type: "scale".to_string(),
403 },
404 RewriteAction::Remove,
405 );
406 assert_eq!(rule.id, RuleId(1));
407 assert_eq!(rule.name, "test");
408 assert_eq!(rule.priority, 0);
409 assert!(rule.enabled);
410 }
411
412 #[test]
413 fn test_rule_matches_single_node() {
414 let rule = RewriteRule::new(
415 RuleId(1),
416 "test",
417 PatternKind::SingleNode {
418 filter_type: "scale".to_string(),
419 },
420 RewriteAction::Remove,
421 );
422 let props = HashMap::new();
423 assert!(rule.matches_node("scale", &props));
424 assert!(!rule.matches_node("crop", &props));
425 }
426
427 #[test]
428 fn test_rule_matches_with_property() {
429 let rule = RewriteRule::new(
430 RuleId(1),
431 "identity_scale",
432 PatternKind::WithProperty {
433 filter_type: "scale".to_string(),
434 property_key: "factor".to_string(),
435 property_value: "1.0".to_string(),
436 },
437 RewriteAction::Remove,
438 );
439 let mut props = HashMap::new();
440 props.insert("factor".to_string(), "1.0".to_string());
441 assert!(rule.matches_node("scale", &props));
442
443 props.insert("factor".to_string(), "2.0".to_string());
444 assert!(!rule.matches_node("scale", &props));
445 }
446
447 #[test]
448 fn test_rule_matches_chain() {
449 let rule = RewriteRule::new(
450 RuleId(2),
451 "scale_fusion",
452 PatternKind::Chain {
453 first: "scale".to_string(),
454 second: "scale".to_string(),
455 },
456 RewriteAction::Fuse {
457 fused_type: "scale".to_string(),
458 },
459 );
460 assert!(rule.matches_chain("scale", "scale"));
461 assert!(!rule.matches_chain("scale", "crop"));
462 }
463
464 #[test]
465 fn test_disabled_rule_no_match() {
466 let mut rule = RewriteRule::new(
467 RuleId(1),
468 "test",
469 PatternKind::SingleNode {
470 filter_type: "scale".to_string(),
471 },
472 RewriteAction::Remove,
473 );
474 rule.set_enabled(false);
475 assert!(!rule.matches_node("scale", &HashMap::new()));
476 assert!(!rule.matches_chain("scale", "scale"));
477 }
478
479 #[test]
480 fn test_engine_add_and_count() {
481 let mut engine = RewriteEngine::new();
482 engine.add_rule(RewriteRule::new(
483 RuleId(1),
484 "r1",
485 PatternKind::SingleNode {
486 filter_type: "a".to_string(),
487 },
488 RewriteAction::Remove,
489 ));
490 assert_eq!(engine.rule_count(), 1);
491 }
492
493 #[test]
494 fn test_engine_priority_ordering() {
495 let mut engine = RewriteEngine::new();
496 engine.add_rule(
497 RewriteRule::new(
498 RuleId(1),
499 "low",
500 PatternKind::SingleNode {
501 filter_type: "a".to_string(),
502 },
503 RewriteAction::Remove,
504 )
505 .with_priority(10),
506 );
507 engine.add_rule(
508 RewriteRule::new(
509 RuleId(2),
510 "high",
511 PatternKind::SingleNode {
512 filter_type: "b".to_string(),
513 },
514 RewriteAction::Remove,
515 )
516 .with_priority(100),
517 );
518 assert_eq!(
520 engine
521 .get_rule(RuleId(2))
522 .expect("value should be valid")
523 .name,
524 "high"
525 );
526 let _ = engine.find_matches("a", &HashMap::new());
528 }
529
530 #[test]
531 fn test_engine_find_matches() {
532 let mut engine = RewriteEngine::new();
533 engine.add_rule(RewriteRule::new(
534 RuleId(1),
535 "r1",
536 PatternKind::SingleNode {
537 filter_type: "scale".to_string(),
538 },
539 RewriteAction::Remove,
540 ));
541 let matches = engine.find_matches("scale", &HashMap::new());
542 assert_eq!(matches.len(), 1);
543 assert!(engine.find_matches("crop", &HashMap::new()).is_empty());
544 }
545
546 #[test]
547 fn test_engine_record_and_clear_history() {
548 let mut engine = RewriteEngine::new();
549 let rule = RewriteRule::new(
550 RuleId(1),
551 "test_rule",
552 PatternKind::SingleNode {
553 filter_type: "a".to_string(),
554 },
555 RewriteAction::Remove,
556 );
557 engine.record_event(&rule, "node_42");
558 assert_eq!(engine.history().len(), 1);
559 assert_eq!(engine.history()[0].rule_name, "test_rule");
560 engine.clear_history();
561 assert!(engine.history().is_empty());
562 }
563
564 #[test]
565 fn test_engine_remove_rule() {
566 let mut engine = RewriteEngine::new();
567 engine.add_rule(RewriteRule::new(
568 RuleId(1),
569 "r1",
570 PatternKind::SingleNode {
571 filter_type: "a".to_string(),
572 },
573 RewriteAction::Remove,
574 ));
575 assert!(engine.remove_rule(RuleId(1)));
576 assert!(!engine.remove_rule(RuleId(1)));
577 assert_eq!(engine.rule_count(), 0);
578 }
579
580 #[test]
581 fn test_standard_rules() {
582 let rules = standard_rules();
583 assert_eq!(rules.len(), 3);
584 assert_eq!(rules[0].name, "identity_scale");
585 }
586
587 #[test]
588 fn test_engine_enable_disable_all() {
589 let mut engine = RewriteEngine::new();
590 for i in 0..3 {
591 engine.add_rule(RewriteRule::new(
592 RuleId(i),
593 &format!("r{i}"),
594 PatternKind::SingleNode {
595 filter_type: "a".to_string(),
596 },
597 RewriteAction::Remove,
598 ));
599 }
600 engine.disable_all();
601 assert!(engine.find_matches("a", &HashMap::new()).is_empty());
602 engine.enable_all();
603 assert_eq!(engine.find_matches("a", &HashMap::new()).len(), 3);
604 }
605}