use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RuleId(pub u64);
impl fmt::Display for RuleId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Rule({})", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PatternKind {
SingleNode {
filter_type: String,
},
Chain {
first: String,
second: String,
},
WithProperty {
filter_type: String,
property_key: String,
property_value: String,
},
}
impl fmt::Display for PatternKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::SingleNode { filter_type } => write!(f, "Single({filter_type})"),
Self::Chain { first, second } => write!(f, "Chain({first} -> {second})"),
Self::WithProperty {
filter_type,
property_key,
property_value,
} => {
write!(f, "{filter_type}[{property_key}={property_value}]")
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RewriteAction {
Remove,
ReplaceWith {
filter_type: String,
properties: HashMap<String, String>,
},
Fuse {
fused_type: String,
},
Swap,
}
impl fmt::Display for RewriteAction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Remove => write!(f, "Remove"),
Self::ReplaceWith { filter_type, .. } => write!(f, "ReplaceWith({filter_type})"),
Self::Fuse { fused_type } => write!(f, "Fuse({fused_type})"),
Self::Swap => write!(f, "Swap"),
}
}
}
#[derive(Debug, Clone)]
pub struct RewriteRule {
pub id: RuleId,
pub name: String,
pub pattern: PatternKind,
pub action: RewriteAction,
pub priority: i32,
pub enabled: bool,
}
impl RewriteRule {
pub fn new(id: RuleId, name: &str, pattern: PatternKind, action: RewriteAction) -> Self {
Self {
id,
name: name.to_string(),
pattern,
action,
priority: 0,
enabled: true,
}
}
pub fn with_priority(mut self, priority: i32) -> Self {
self.priority = priority;
self
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn matches_node(&self, filter_type: &str, properties: &HashMap<String, String>) -> bool {
if !self.enabled {
return false;
}
match &self.pattern {
PatternKind::SingleNode { filter_type: ft } => ft == filter_type,
PatternKind::WithProperty {
filter_type: ft,
property_key,
property_value,
} => {
ft == filter_type
&& properties
.get(property_key)
.map_or(false, |v| v == property_value)
}
PatternKind::Chain { first, .. } => first == filter_type,
}
}
pub fn matches_chain(&self, first_type: &str, second_type: &str) -> bool {
if !self.enabled {
return false;
}
match &self.pattern {
PatternKind::Chain { first, second } => first == first_type && second == second_type,
_ => false,
}
}
}
impl fmt::Display for RewriteRule {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}[{}]: {} -> {}",
self.name, self.id, self.pattern, self.action
)
}
}
#[derive(Debug, Clone)]
pub struct RewriteEvent {
pub rule_id: RuleId,
pub rule_name: String,
pub matched: String,
pub action: String,
}
pub struct RewriteEngine {
rules: Vec<RewriteRule>,
history: Vec<RewriteEvent>,
max_passes: u32,
}
impl RewriteEngine {
pub fn new() -> Self {
Self {
rules: Vec::new(),
history: Vec::new(),
max_passes: 100,
}
}
pub fn set_max_passes(&mut self, max: u32) {
self.max_passes = max;
}
pub fn max_passes(&self) -> u32 {
self.max_passes
}
pub fn add_rule(&mut self, rule: RewriteRule) {
self.rules.push(rule);
self.rules.sort_by(|a, b| b.priority.cmp(&a.priority));
}
pub fn rule_count(&self) -> usize {
self.rules.len()
}
pub fn get_rule(&self, id: RuleId) -> Option<&RewriteRule> {
self.rules.iter().find(|r| r.id == id)
}
pub fn get_rule_mut(&mut self, id: RuleId) -> Option<&mut RewriteRule> {
self.rules.iter_mut().find(|r| r.id == id)
}
pub fn find_matches(
&self,
filter_type: &str,
properties: &HashMap<String, String>,
) -> Vec<&RewriteRule> {
self.rules
.iter()
.filter(|r| r.matches_node(filter_type, properties))
.collect()
}
pub fn find_chain_matches(&self, first_type: &str, second_type: &str) -> Vec<&RewriteRule> {
self.rules
.iter()
.filter(|r| r.matches_chain(first_type, second_type))
.collect()
}
pub fn record_event(&mut self, rule: &RewriteRule, matched: &str) {
self.history.push(RewriteEvent {
rule_id: rule.id,
rule_name: rule.name.clone(),
matched: matched.to_string(),
action: format!("{}", rule.action),
});
}
pub fn history(&self) -> &[RewriteEvent] {
&self.history
}
pub fn clear_history(&mut self) {
self.history.clear();
}
pub fn remove_rule(&mut self, id: RuleId) -> bool {
let len_before = self.rules.len();
self.rules.retain(|r| r.id != id);
self.rules.len() < len_before
}
pub fn enable_all(&mut self) {
for rule in &mut self.rules {
rule.enabled = true;
}
}
pub fn disable_all(&mut self) {
for rule in &mut self.rules {
rule.enabled = false;
}
}
}
impl Default for RewriteEngine {
fn default() -> Self {
Self::new()
}
}
pub fn standard_rules() -> Vec<RewriteRule> {
vec![
RewriteRule::new(
RuleId(1),
"identity_scale",
PatternKind::WithProperty {
filter_type: "scale".to_string(),
property_key: "factor".to_string(),
property_value: "1.0".to_string(),
},
RewriteAction::Remove,
)
.with_priority(100),
RewriteRule::new(
RuleId(2),
"scale_fusion",
PatternKind::Chain {
first: "scale".to_string(),
second: "scale".to_string(),
},
RewriteAction::Fuse {
fused_type: "scale".to_string(),
},
)
.with_priority(90),
RewriteRule::new(
RuleId(3),
"crop_fusion",
PatternKind::Chain {
first: "crop".to_string(),
second: "crop".to_string(),
},
RewriteAction::Fuse {
fused_type: "crop".to_string(),
},
)
.with_priority(90),
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rule_id_display() {
assert_eq!(format!("{}", RuleId(42)), "Rule(42)");
}
#[test]
fn test_pattern_kind_display() {
let p = PatternKind::SingleNode {
filter_type: "scale".to_string(),
};
assert_eq!(format!("{p}"), "Single(scale)");
}
#[test]
fn test_chain_pattern_display() {
let p = PatternKind::Chain {
first: "a".to_string(),
second: "b".to_string(),
};
assert_eq!(format!("{p}"), "Chain(a -> b)");
}
#[test]
fn test_rewrite_action_display() {
assert_eq!(format!("{}", RewriteAction::Remove), "Remove");
assert_eq!(format!("{}", RewriteAction::Swap), "Swap");
assert_eq!(
format!(
"{}",
RewriteAction::Fuse {
fused_type: "x".to_string()
}
),
"Fuse(x)"
);
}
#[test]
fn test_rewrite_rule_new() {
let rule = RewriteRule::new(
RuleId(1),
"test",
PatternKind::SingleNode {
filter_type: "scale".to_string(),
},
RewriteAction::Remove,
);
assert_eq!(rule.id, RuleId(1));
assert_eq!(rule.name, "test");
assert_eq!(rule.priority, 0);
assert!(rule.enabled);
}
#[test]
fn test_rule_matches_single_node() {
let rule = RewriteRule::new(
RuleId(1),
"test",
PatternKind::SingleNode {
filter_type: "scale".to_string(),
},
RewriteAction::Remove,
);
let props = HashMap::new();
assert!(rule.matches_node("scale", &props));
assert!(!rule.matches_node("crop", &props));
}
#[test]
fn test_rule_matches_with_property() {
let rule = RewriteRule::new(
RuleId(1),
"identity_scale",
PatternKind::WithProperty {
filter_type: "scale".to_string(),
property_key: "factor".to_string(),
property_value: "1.0".to_string(),
},
RewriteAction::Remove,
);
let mut props = HashMap::new();
props.insert("factor".to_string(), "1.0".to_string());
assert!(rule.matches_node("scale", &props));
props.insert("factor".to_string(), "2.0".to_string());
assert!(!rule.matches_node("scale", &props));
}
#[test]
fn test_rule_matches_chain() {
let rule = RewriteRule::new(
RuleId(2),
"scale_fusion",
PatternKind::Chain {
first: "scale".to_string(),
second: "scale".to_string(),
},
RewriteAction::Fuse {
fused_type: "scale".to_string(),
},
);
assert!(rule.matches_chain("scale", "scale"));
assert!(!rule.matches_chain("scale", "crop"));
}
#[test]
fn test_disabled_rule_no_match() {
let mut rule = RewriteRule::new(
RuleId(1),
"test",
PatternKind::SingleNode {
filter_type: "scale".to_string(),
},
RewriteAction::Remove,
);
rule.set_enabled(false);
assert!(!rule.matches_node("scale", &HashMap::new()));
assert!(!rule.matches_chain("scale", "scale"));
}
#[test]
fn test_engine_add_and_count() {
let mut engine = RewriteEngine::new();
engine.add_rule(RewriteRule::new(
RuleId(1),
"r1",
PatternKind::SingleNode {
filter_type: "a".to_string(),
},
RewriteAction::Remove,
));
assert_eq!(engine.rule_count(), 1);
}
#[test]
fn test_engine_priority_ordering() {
let mut engine = RewriteEngine::new();
engine.add_rule(
RewriteRule::new(
RuleId(1),
"low",
PatternKind::SingleNode {
filter_type: "a".to_string(),
},
RewriteAction::Remove,
)
.with_priority(10),
);
engine.add_rule(
RewriteRule::new(
RuleId(2),
"high",
PatternKind::SingleNode {
filter_type: "b".to_string(),
},
RewriteAction::Remove,
)
.with_priority(100),
);
assert_eq!(
engine
.get_rule(RuleId(2))
.expect("value should be valid")
.name,
"high"
);
let _ = engine.find_matches("a", &HashMap::new());
}
#[test]
fn test_engine_find_matches() {
let mut engine = RewriteEngine::new();
engine.add_rule(RewriteRule::new(
RuleId(1),
"r1",
PatternKind::SingleNode {
filter_type: "scale".to_string(),
},
RewriteAction::Remove,
));
let matches = engine.find_matches("scale", &HashMap::new());
assert_eq!(matches.len(), 1);
assert!(engine.find_matches("crop", &HashMap::new()).is_empty());
}
#[test]
fn test_engine_record_and_clear_history() {
let mut engine = RewriteEngine::new();
let rule = RewriteRule::new(
RuleId(1),
"test_rule",
PatternKind::SingleNode {
filter_type: "a".to_string(),
},
RewriteAction::Remove,
);
engine.record_event(&rule, "node_42");
assert_eq!(engine.history().len(), 1);
assert_eq!(engine.history()[0].rule_name, "test_rule");
engine.clear_history();
assert!(engine.history().is_empty());
}
#[test]
fn test_engine_remove_rule() {
let mut engine = RewriteEngine::new();
engine.add_rule(RewriteRule::new(
RuleId(1),
"r1",
PatternKind::SingleNode {
filter_type: "a".to_string(),
},
RewriteAction::Remove,
));
assert!(engine.remove_rule(RuleId(1)));
assert!(!engine.remove_rule(RuleId(1)));
assert_eq!(engine.rule_count(), 0);
}
#[test]
fn test_standard_rules() {
let rules = standard_rules();
assert_eq!(rules.len(), 3);
assert_eq!(rules[0].name, "identity_scale");
}
#[test]
fn test_engine_enable_disable_all() {
let mut engine = RewriteEngine::new();
for i in 0..3 {
engine.add_rule(RewriteRule::new(
RuleId(i),
&format!("r{i}"),
PatternKind::SingleNode {
filter_type: "a".to_string(),
},
RewriteAction::Remove,
));
}
engine.disable_all();
assert!(engine.find_matches("a", &HashMap::new()).is_empty());
engine.enable_all();
assert_eq!(engine.find_matches("a", &HashMap::new()).len(), 3);
}
}