Skip to main content

oximedia_graph/
graph_rewrite.rs

1//! Graph rewriting and transformation rules.
2//!
3//! This module provides a rule-based system for transforming filter graphs.
4//! Rewrite rules can match patterns in the graph and replace them with
5//! optimized equivalents, enabling algebraic simplification, constant
6//! folding, and operator fusion.
7
8use std::collections::HashMap;
9use std::fmt;
10
11/// A unique identifier for a rewrite rule.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct RuleId(pub u64);
14
15impl fmt::Display for RuleId {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        write!(f, "Rule({})", self.0)
18    }
19}
20
21/// Describes the kind of pattern a rewrite rule matches.
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum PatternKind {
24    /// Matches a single node by its filter type name.
25    SingleNode {
26        /// The filter type to match (e.g. "scale", "crop").
27        filter_type: String,
28    },
29    /// Matches a chain of two consecutive nodes.
30    Chain {
31        /// First node's filter type.
32        first: String,
33        /// Second node's filter type.
34        second: String,
35    },
36    /// Matches a node with a specific property constraint.
37    WithProperty {
38        /// The filter type to match.
39        filter_type: String,
40        /// Property key that must be present.
41        property_key: String,
42        /// Expected property value.
43        property_value: String,
44    },
45}
46
47impl fmt::Display for PatternKind {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        match self {
50            Self::SingleNode { filter_type } => write!(f, "Single({filter_type})"),
51            Self::Chain { first, second } => write!(f, "Chain({first} -> {second})"),
52            Self::WithProperty {
53                filter_type,
54                property_key,
55                property_value,
56            } => {
57                write!(f, "{filter_type}[{property_key}={property_value}]")
58            }
59        }
60    }
61}
62
63/// The action to take when a rule matches.
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub enum RewriteAction {
66    /// Remove the matched node(s) entirely (identity elimination).
67    Remove,
68    /// Replace with a single node of the given filter type and properties.
69    ReplaceWith {
70        /// The replacement filter type.
71        filter_type: String,
72        /// Properties for the replacement node.
73        properties: HashMap<String, String>,
74    },
75    /// Fuse two matched nodes into one with combined properties.
76    Fuse {
77        /// The fused filter type name.
78        fused_type: String,
79    },
80    /// Swap the order of two matched nodes (commutativity).
81    Swap,
82}
83
84impl fmt::Display for RewriteAction {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        match self {
87            Self::Remove => write!(f, "Remove"),
88            Self::ReplaceWith { filter_type, .. } => write!(f, "ReplaceWith({filter_type})"),
89            Self::Fuse { fused_type } => write!(f, "Fuse({fused_type})"),
90            Self::Swap => write!(f, "Swap"),
91        }
92    }
93}
94
95/// A single graph rewrite rule.
96#[derive(Debug, Clone)]
97pub struct RewriteRule {
98    /// Unique identifier.
99    pub id: RuleId,
100    /// Human-readable name for this rule.
101    pub name: String,
102    /// The pattern to match.
103    pub pattern: PatternKind,
104    /// The action to take on match.
105    pub action: RewriteAction,
106    /// Priority (higher = applied first).
107    pub priority: i32,
108    /// Whether this rule is currently enabled.
109    pub enabled: bool,
110}
111
112impl RewriteRule {
113    /// Create a new rewrite rule.
114    pub fn new(id: RuleId, name: &str, pattern: PatternKind, action: RewriteAction) -> Self {
115        Self {
116            id,
117            name: name.to_string(),
118            pattern,
119            action,
120            priority: 0,
121            enabled: true,
122        }
123    }
124
125    /// Set the priority of this rule.
126    pub fn with_priority(mut self, priority: i32) -> Self {
127        self.priority = priority;
128        self
129    }
130
131    /// Enable or disable this rule.
132    pub fn set_enabled(&mut self, enabled: bool) {
133        self.enabled = enabled;
134    }
135
136    /// Check whether this rule matches a given node description.
137    pub fn matches_node(&self, filter_type: &str, properties: &HashMap<String, String>) -> bool {
138        if !self.enabled {
139            return false;
140        }
141        match &self.pattern {
142            PatternKind::SingleNode { filter_type: ft } => ft == filter_type,
143            PatternKind::WithProperty {
144                filter_type: ft,
145                property_key,
146                property_value,
147            } => {
148                ft == filter_type
149                    && properties
150                        .get(property_key)
151                        .map_or(false, |v| v == property_value)
152            }
153            PatternKind::Chain { first, .. } => first == filter_type,
154        }
155    }
156
157    /// Check whether this rule matches a chain of two nodes.
158    pub fn matches_chain(&self, first_type: &str, second_type: &str) -> bool {
159        if !self.enabled {
160            return false;
161        }
162        match &self.pattern {
163            PatternKind::Chain { first, second } => first == first_type && second == second_type,
164            _ => false,
165        }
166    }
167}
168
169impl fmt::Display for RewriteRule {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        write!(
172            f,
173            "{}[{}]: {} -> {}",
174            self.name, self.id, self.pattern, self.action
175        )
176    }
177}
178
179/// A record of a single rule application.
180#[derive(Debug, Clone)]
181pub struct RewriteEvent {
182    /// The rule that was applied.
183    pub rule_id: RuleId,
184    /// The rule name.
185    pub rule_name: String,
186    /// Description of what was matched.
187    pub matched: String,
188    /// The action that was taken.
189    pub action: String,
190}
191
192/// A collection of rewrite rules with application tracking.
193pub struct RewriteEngine {
194    /// Registered rules, sorted by priority.
195    rules: Vec<RewriteRule>,
196    /// History of applied rewrites.
197    history: Vec<RewriteEvent>,
198    /// Maximum number of rewrite passes to prevent infinite loops.
199    max_passes: u32,
200}
201
202impl RewriteEngine {
203    /// Create a new rewrite engine with default settings.
204    pub fn new() -> Self {
205        Self {
206            rules: Vec::new(),
207            history: Vec::new(),
208            max_passes: 100,
209        }
210    }
211
212    /// Set the maximum number of rewrite passes.
213    pub fn set_max_passes(&mut self, max: u32) {
214        self.max_passes = max;
215    }
216
217    /// Get the maximum number of rewrite passes.
218    pub fn max_passes(&self) -> u32 {
219        self.max_passes
220    }
221
222    /// Add a rewrite rule.
223    pub fn add_rule(&mut self, rule: RewriteRule) {
224        self.rules.push(rule);
225        self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
226    }
227
228    /// Get the number of registered rules.
229    pub fn rule_count(&self) -> usize {
230        self.rules.len()
231    }
232
233    /// Get a rule by its ID.
234    pub fn get_rule(&self, id: RuleId) -> Option<&RewriteRule> {
235        self.rules.iter().find(|r| r.id == id)
236    }
237
238    /// Get a mutable reference to a rule by its ID.
239    pub fn get_rule_mut(&mut self, id: RuleId) -> Option<&mut RewriteRule> {
240        self.rules.iter_mut().find(|r| r.id == id)
241    }
242
243    /// Find matching rules for a single node.
244    pub fn find_matches(
245        &self,
246        filter_type: &str,
247        properties: &HashMap<String, String>,
248    ) -> Vec<&RewriteRule> {
249        self.rules
250            .iter()
251            .filter(|r| r.matches_node(filter_type, properties))
252            .collect()
253    }
254
255    /// Find matching rules for a chain of two nodes.
256    pub fn find_chain_matches(&self, first_type: &str, second_type: &str) -> Vec<&RewriteRule> {
257        self.rules
258            .iter()
259            .filter(|r| r.matches_chain(first_type, second_type))
260            .collect()
261    }
262
263    /// Record a rewrite event.
264    pub fn record_event(&mut self, rule: &RewriteRule, matched: &str) {
265        self.history.push(RewriteEvent {
266            rule_id: rule.id,
267            rule_name: rule.name.clone(),
268            matched: matched.to_string(),
269            action: format!("{}", rule.action),
270        });
271    }
272
273    /// Get the history of applied rewrites.
274    pub fn history(&self) -> &[RewriteEvent] {
275        &self.history
276    }
277
278    /// Clear the rewrite history.
279    pub fn clear_history(&mut self) {
280        self.history.clear();
281    }
282
283    /// Remove a rule by its ID.
284    pub fn remove_rule(&mut self, id: RuleId) -> bool {
285        let len_before = self.rules.len();
286        self.rules.retain(|r| r.id != id);
287        self.rules.len() < len_before
288    }
289
290    /// Enable all rules.
291    pub fn enable_all(&mut self) {
292        for rule in &mut self.rules {
293            rule.enabled = true;
294        }
295    }
296
297    /// Disable all rules.
298    pub fn disable_all(&mut self) {
299        for rule in &mut self.rules {
300            rule.enabled = false;
301        }
302    }
303}
304
305impl Default for RewriteEngine {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311/// Create a standard set of common rewrite rules.
312pub fn standard_rules() -> Vec<RewriteRule> {
313    vec![
314        // Identity scale removal (scale 1:1 is a no-op)
315        RewriteRule::new(
316            RuleId(1),
317            "identity_scale",
318            PatternKind::WithProperty {
319                filter_type: "scale".to_string(),
320                property_key: "factor".to_string(),
321                property_value: "1.0".to_string(),
322            },
323            RewriteAction::Remove,
324        )
325        .with_priority(100),
326        // Consecutive scale fusion
327        RewriteRule::new(
328            RuleId(2),
329            "scale_fusion",
330            PatternKind::Chain {
331                first: "scale".to_string(),
332                second: "scale".to_string(),
333            },
334            RewriteAction::Fuse {
335                fused_type: "scale".to_string(),
336            },
337        )
338        .with_priority(90),
339        // Consecutive crop fusion
340        RewriteRule::new(
341            RuleId(3),
342            "crop_fusion",
343            PatternKind::Chain {
344                first: "crop".to_string(),
345                second: "crop".to_string(),
346            },
347            RewriteAction::Fuse {
348                fused_type: "crop".to_string(),
349            },
350        )
351        .with_priority(90),
352    ]
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_rule_id_display() {
361        assert_eq!(format!("{}", RuleId(42)), "Rule(42)");
362    }
363
364    #[test]
365    fn test_pattern_kind_display() {
366        let p = PatternKind::SingleNode {
367            filter_type: "scale".to_string(),
368        };
369        assert_eq!(format!("{p}"), "Single(scale)");
370    }
371
372    #[test]
373    fn test_chain_pattern_display() {
374        let p = PatternKind::Chain {
375            first: "a".to_string(),
376            second: "b".to_string(),
377        };
378        assert_eq!(format!("{p}"), "Chain(a -> b)");
379    }
380
381    #[test]
382    fn test_rewrite_action_display() {
383        assert_eq!(format!("{}", RewriteAction::Remove), "Remove");
384        assert_eq!(format!("{}", RewriteAction::Swap), "Swap");
385        assert_eq!(
386            format!(
387                "{}",
388                RewriteAction::Fuse {
389                    fused_type: "x".to_string()
390                }
391            ),
392            "Fuse(x)"
393        );
394    }
395
396    #[test]
397    fn test_rewrite_rule_new() {
398        let rule = RewriteRule::new(
399            RuleId(1),
400            "test",
401            PatternKind::SingleNode {
402                filter_type: "scale".to_string(),
403            },
404            RewriteAction::Remove,
405        );
406        assert_eq!(rule.id, RuleId(1));
407        assert_eq!(rule.name, "test");
408        assert_eq!(rule.priority, 0);
409        assert!(rule.enabled);
410    }
411
412    #[test]
413    fn test_rule_matches_single_node() {
414        let rule = RewriteRule::new(
415            RuleId(1),
416            "test",
417            PatternKind::SingleNode {
418                filter_type: "scale".to_string(),
419            },
420            RewriteAction::Remove,
421        );
422        let props = HashMap::new();
423        assert!(rule.matches_node("scale", &props));
424        assert!(!rule.matches_node("crop", &props));
425    }
426
427    #[test]
428    fn test_rule_matches_with_property() {
429        let rule = RewriteRule::new(
430            RuleId(1),
431            "identity_scale",
432            PatternKind::WithProperty {
433                filter_type: "scale".to_string(),
434                property_key: "factor".to_string(),
435                property_value: "1.0".to_string(),
436            },
437            RewriteAction::Remove,
438        );
439        let mut props = HashMap::new();
440        props.insert("factor".to_string(), "1.0".to_string());
441        assert!(rule.matches_node("scale", &props));
442
443        props.insert("factor".to_string(), "2.0".to_string());
444        assert!(!rule.matches_node("scale", &props));
445    }
446
447    #[test]
448    fn test_rule_matches_chain() {
449        let rule = RewriteRule::new(
450            RuleId(2),
451            "scale_fusion",
452            PatternKind::Chain {
453                first: "scale".to_string(),
454                second: "scale".to_string(),
455            },
456            RewriteAction::Fuse {
457                fused_type: "scale".to_string(),
458            },
459        );
460        assert!(rule.matches_chain("scale", "scale"));
461        assert!(!rule.matches_chain("scale", "crop"));
462    }
463
464    #[test]
465    fn test_disabled_rule_no_match() {
466        let mut rule = RewriteRule::new(
467            RuleId(1),
468            "test",
469            PatternKind::SingleNode {
470                filter_type: "scale".to_string(),
471            },
472            RewriteAction::Remove,
473        );
474        rule.set_enabled(false);
475        assert!(!rule.matches_node("scale", &HashMap::new()));
476        assert!(!rule.matches_chain("scale", "scale"));
477    }
478
479    #[test]
480    fn test_engine_add_and_count() {
481        let mut engine = RewriteEngine::new();
482        engine.add_rule(RewriteRule::new(
483            RuleId(1),
484            "r1",
485            PatternKind::SingleNode {
486                filter_type: "a".to_string(),
487            },
488            RewriteAction::Remove,
489        ));
490        assert_eq!(engine.rule_count(), 1);
491    }
492
493    #[test]
494    fn test_engine_priority_ordering() {
495        let mut engine = RewriteEngine::new();
496        engine.add_rule(
497            RewriteRule::new(
498                RuleId(1),
499                "low",
500                PatternKind::SingleNode {
501                    filter_type: "a".to_string(),
502                },
503                RewriteAction::Remove,
504            )
505            .with_priority(10),
506        );
507        engine.add_rule(
508            RewriteRule::new(
509                RuleId(2),
510                "high",
511                PatternKind::SingleNode {
512                    filter_type: "b".to_string(),
513                },
514                RewriteAction::Remove,
515            )
516            .with_priority(100),
517        );
518        // Both might match different types, but internal ordering is by priority
519        assert_eq!(
520            engine
521                .get_rule(RuleId(2))
522                .expect("value should be valid")
523                .name,
524            "high"
525        );
526        // Verify find_matches doesn't panic (result not used in this test)
527        let _ = engine.find_matches("a", &HashMap::new());
528    }
529
530    #[test]
531    fn test_engine_find_matches() {
532        let mut engine = RewriteEngine::new();
533        engine.add_rule(RewriteRule::new(
534            RuleId(1),
535            "r1",
536            PatternKind::SingleNode {
537                filter_type: "scale".to_string(),
538            },
539            RewriteAction::Remove,
540        ));
541        let matches = engine.find_matches("scale", &HashMap::new());
542        assert_eq!(matches.len(), 1);
543        assert!(engine.find_matches("crop", &HashMap::new()).is_empty());
544    }
545
546    #[test]
547    fn test_engine_record_and_clear_history() {
548        let mut engine = RewriteEngine::new();
549        let rule = RewriteRule::new(
550            RuleId(1),
551            "test_rule",
552            PatternKind::SingleNode {
553                filter_type: "a".to_string(),
554            },
555            RewriteAction::Remove,
556        );
557        engine.record_event(&rule, "node_42");
558        assert_eq!(engine.history().len(), 1);
559        assert_eq!(engine.history()[0].rule_name, "test_rule");
560        engine.clear_history();
561        assert!(engine.history().is_empty());
562    }
563
564    #[test]
565    fn test_engine_remove_rule() {
566        let mut engine = RewriteEngine::new();
567        engine.add_rule(RewriteRule::new(
568            RuleId(1),
569            "r1",
570            PatternKind::SingleNode {
571                filter_type: "a".to_string(),
572            },
573            RewriteAction::Remove,
574        ));
575        assert!(engine.remove_rule(RuleId(1)));
576        assert!(!engine.remove_rule(RuleId(1)));
577        assert_eq!(engine.rule_count(), 0);
578    }
579
580    #[test]
581    fn test_standard_rules() {
582        let rules = standard_rules();
583        assert_eq!(rules.len(), 3);
584        assert_eq!(rules[0].name, "identity_scale");
585    }
586
587    #[test]
588    fn test_engine_enable_disable_all() {
589        let mut engine = RewriteEngine::new();
590        for i in 0..3 {
591            engine.add_rule(RewriteRule::new(
592                RuleId(i),
593                &format!("r{i}"),
594                PatternKind::SingleNode {
595                    filter_type: "a".to_string(),
596                },
597                RewriteAction::Remove,
598            ));
599        }
600        engine.disable_all();
601        assert!(engine.find_matches("a", &HashMap::new()).is_empty());
602        engine.enable_all();
603        assert_eq!(engine.find_matches("a", &HashMap::new()).len(), 3);
604    }
605}