use std::cmp::Reverse;
use std::collections::{HashMap, HashSet};
use super::{EinsumGraph, EinsumNode, OpType};
use crate::error::IrError;
#[derive(Debug, Clone, PartialEq)]
pub enum GraphPattern {
AnyNode,
OpType(OpType),
Sequence(Vec<GraphPattern>),
Choice(Vec<GraphPattern>),
WithInputs(usize),
WithOutputs(usize),
Capture(String, Box<GraphPattern>),
ZeroOrMore(Box<GraphPattern>),
OneOrMore(Box<GraphPattern>),
}
#[derive(Debug, Clone)]
pub struct PatternMatch {
pub matched_nodes: Vec<usize>,
pub captures: HashMap<String, Vec<usize>>,
pub matched_tensors: HashSet<usize>,
}
impl PatternMatch {
pub fn new() -> Self {
Self {
matched_nodes: Vec::new(),
captures: HashMap::new(),
matched_tensors: HashSet::new(),
}
}
pub fn add_node(&mut self, node_idx: usize) {
self.matched_nodes.push(node_idx);
}
pub fn add_capture(&mut self, name: String, node_idx: usize) {
self.captures.entry(name).or_default().push(node_idx);
}
pub fn get_capture(&self, name: &str) -> Option<&[usize]> {
self.captures.get(name).map(|v| v.as_slice())
}
}
impl Default for PatternMatch {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct GraphRewriteRule {
pub name: String,
pub pattern: GraphPattern,
pub rewriter: fn(&EinsumGraph, &PatternMatch) -> Result<Vec<EinsumNode>, IrError>,
pub priority: i32,
}
impl GraphRewriteRule {
pub fn new(
name: impl Into<String>,
pattern: GraphPattern,
rewriter: fn(&EinsumGraph, &PatternMatch) -> Result<Vec<EinsumNode>, IrError>,
) -> Self {
Self {
name: name.into(),
pattern,
rewriter,
priority: 0,
}
}
pub fn with_priority(mut self, priority: i32) -> Self {
self.priority = priority;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct RewriteStats {
pub patterns_matched: usize,
pub rewrites_applied: usize,
pub nodes_before: usize,
pub nodes_after: usize,
pub nodes_eliminated: usize,
}
impl RewriteStats {
pub fn new() -> Self {
Self::default()
}
pub fn reduction_percentage(&self) -> f64 {
if self.nodes_before == 0 {
return 0.0;
}
(self.nodes_eliminated as f64 / self.nodes_before as f64) * 100.0
}
}
pub struct PatternMatcher {
rules: Vec<GraphRewriteRule>,
}
impl PatternMatcher {
pub fn new() -> Self {
Self { rules: Vec::new() }
}
pub fn add_rule(&mut self, rule: GraphRewriteRule) {
self.rules.push(rule);
self.rules.sort_by_key(|r| Reverse(r.priority));
}
pub fn find_matches(&self, graph: &EinsumGraph, pattern: &GraphPattern) -> Vec<PatternMatch> {
let mut matches = Vec::new();
for start_idx in 0..graph.nodes.len() {
if let Some(m) = self.try_match_from(graph, pattern, start_idx, &HashSet::new()) {
matches.push(m);
}
}
matches
}
fn try_match_from(
&self,
graph: &EinsumGraph,
pattern: &GraphPattern,
start_idx: usize,
visited: &HashSet<usize>,
) -> Option<PatternMatch> {
if start_idx >= graph.nodes.len() || visited.contains(&start_idx) {
return None;
}
match pattern {
GraphPattern::AnyNode => {
let mut m = PatternMatch::new();
m.add_node(start_idx);
Some(m)
}
GraphPattern::OpType(expected_op) => {
let node = &graph.nodes[start_idx];
if Self::op_matches(&node.op, expected_op) {
let mut m = PatternMatch::new();
m.add_node(start_idx);
Some(m)
} else {
None
}
}
GraphPattern::WithInputs(count) => {
let node = &graph.nodes[start_idx];
if node.inputs.len() == *count {
let mut m = PatternMatch::new();
m.add_node(start_idx);
Some(m)
} else {
None
}
}
GraphPattern::WithOutputs(count) => {
let node = &graph.nodes[start_idx];
if node.outputs.len() == *count {
let mut m = PatternMatch::new();
m.add_node(start_idx);
Some(m)
} else {
None
}
}
GraphPattern::Capture(name, sub_pattern) => {
if let Some(mut m) = self.try_match_from(graph, sub_pattern, start_idx, visited) {
m.add_capture(name.clone(), start_idx);
Some(m)
} else {
None
}
}
GraphPattern::Sequence(patterns) => {
self.match_sequence(graph, patterns, start_idx, visited)
}
GraphPattern::Choice(patterns) => {
for pat in patterns {
if let Some(m) = self.try_match_from(graph, pat, start_idx, visited) {
return Some(m);
}
}
None
}
GraphPattern::OneOrMore(sub_pattern) => {
self.match_one_or_more(graph, sub_pattern, start_idx, visited)
}
GraphPattern::ZeroOrMore(sub_pattern) => {
if let Some(m) = self.match_one_or_more(graph, sub_pattern, start_idx, visited) {
Some(m)
} else {
Some(PatternMatch::new())
}
}
}
}
fn match_sequence(
&self,
graph: &EinsumGraph,
patterns: &[GraphPattern],
start_idx: usize,
visited: &HashSet<usize>,
) -> Option<PatternMatch> {
if patterns.is_empty() {
return Some(PatternMatch::new());
}
let mut result = PatternMatch::new();
let mut current_visited = visited.clone();
let mut current_idx = start_idx;
for pattern in patterns {
if let Some(m) = self.try_match_from(graph, pattern, current_idx, ¤t_visited) {
for &node in &m.matched_nodes {
result.add_node(node);
current_visited.insert(node);
}
for (name, nodes) in m.captures {
for node in nodes {
result.add_capture(name.clone(), node);
}
}
if let Some(&last_node) = m.matched_nodes.last() {
if let Some(next) = self.find_successor(graph, last_node) {
current_idx = next;
} else {
return None; }
}
} else {
return None;
}
}
Some(result)
}
fn match_one_or_more(
&self,
graph: &EinsumGraph,
pattern: &GraphPattern,
start_idx: usize,
visited: &HashSet<usize>,
) -> Option<PatternMatch> {
let mut result = PatternMatch::new();
let mut current_visited = visited.clone();
let mut current_idx = start_idx;
let mut matched_any = false;
loop {
if let Some(m) = self.try_match_from(graph, pattern, current_idx, ¤t_visited) {
matched_any = true;
for &node in &m.matched_nodes {
result.add_node(node);
current_visited.insert(node);
}
if let Some(&last_node) = m.matched_nodes.last() {
if let Some(next) = self.find_successor(graph, last_node) {
current_idx = next;
continue;
}
}
}
break;
}
if matched_any {
Some(result)
} else {
None
}
}
fn find_successor(&self, graph: &EinsumGraph, node_idx: usize) -> Option<usize> {
let node = &graph.nodes[node_idx];
for &output_tensor in &node.outputs {
for (idx, other_node) in graph.nodes.iter().enumerate() {
if other_node.inputs.contains(&output_tensor) {
return Some(idx);
}
}
}
None
}
fn op_matches(actual: &OpType, expected: &OpType) -> bool {
match (actual, expected) {
(OpType::Einsum { .. }, OpType::Einsum { .. }) => true,
(OpType::ElemUnary { op: a }, OpType::ElemUnary { op: b }) => a == b,
(OpType::ElemBinary { op: a }, OpType::ElemBinary { op: b }) => a == b,
(OpType::Reduce { op: a, .. }, OpType::Reduce { op: b, .. }) => a == b,
_ => false,
}
}
pub fn apply_rules(&self, graph: &mut EinsumGraph) -> Result<RewriteStats, IrError> {
let mut stats = RewriteStats::new();
stats.nodes_before = graph.nodes.len();
let mut modified = true;
let mut iterations = 0;
const MAX_ITERATIONS: usize = 100;
while modified && iterations < MAX_ITERATIONS {
modified = false;
iterations += 1;
for rule in &self.rules {
let matches = self.find_matches(graph, &rule.pattern);
for m in matches {
stats.patterns_matched += 1;
if let Ok(new_nodes) = (rule.rewriter)(graph, &m) {
if self.apply_rewrite(graph, &m, new_nodes)? {
stats.rewrites_applied += 1;
modified = true;
}
}
}
}
}
stats.nodes_after = graph.nodes.len();
stats.nodes_eliminated = stats.nodes_before.saturating_sub(stats.nodes_after);
Ok(stats)
}
fn apply_rewrite(
&self,
_graph: &mut EinsumGraph,
_pattern_match: &PatternMatch,
_new_nodes: Vec<EinsumNode>,
) -> Result<bool, IrError> {
Ok(false)
}
}
impl Default for PatternMatcher {
fn default() -> Self {
Self::new()
}
}
pub mod patterns {
use super::*;
#[allow(dead_code)]
pub fn elementwise_chain(min_length: usize) -> GraphPattern {
let elem_op = GraphPattern::Choice(vec![
GraphPattern::OpType(OpType::ElemUnary { op: String::new() }),
GraphPattern::OpType(OpType::ElemBinary { op: String::new() }),
]);
if min_length == 1 {
GraphPattern::OneOrMore(Box::new(elem_op))
} else {
let mut sequence = Vec::new();
for _ in 0..min_length {
sequence.push(elem_op.clone());
}
GraphPattern::Sequence(sequence)
}
}
#[allow(dead_code)]
pub fn einsum_reduce() -> GraphPattern {
GraphPattern::Sequence(vec![
GraphPattern::OpType(OpType::Einsum {
spec: String::new(),
}),
GraphPattern::OpType(OpType::Reduce {
op: String::new(),
axes: Vec::new(),
}),
])
}
#[allow(dead_code)]
pub fn map_reduce() -> GraphPattern {
GraphPattern::Sequence(vec![
GraphPattern::Capture(
"map".to_string(),
Box::new(GraphPattern::OpType(OpType::ElemUnary {
op: String::new(),
})),
),
GraphPattern::Capture(
"reduce".to_string(),
Box::new(GraphPattern::OpType(OpType::Reduce {
op: String::new(),
axes: Vec::new(),
})),
),
])
}
#[allow(dead_code)]
pub fn broadcast_elementwise() -> GraphPattern {
GraphPattern::Sequence(vec![
GraphPattern::OpType(OpType::ElemBinary {
op: "broadcast".to_string(),
}),
GraphPattern::OpType(OpType::ElemBinary { op: String::new() }),
])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_match_creation() {
let m = PatternMatch::new();
assert!(m.matched_nodes.is_empty());
assert!(m.captures.is_empty());
}
#[test]
fn test_pattern_match_add_node() {
let mut m = PatternMatch::new();
m.add_node(0);
m.add_node(1);
assert_eq!(m.matched_nodes, vec![0, 1]);
}
#[test]
fn test_pattern_match_capture() {
let mut m = PatternMatch::new();
m.add_capture("test".to_string(), 5);
assert_eq!(m.get_capture("test"), Some(&[5][..]));
assert_eq!(m.get_capture("nonexistent"), None);
}
#[test]
fn test_rewrite_stats_default() {
let stats = RewriteStats::default();
assert_eq!(stats.patterns_matched, 0);
assert_eq!(stats.rewrites_applied, 0);
}
#[test]
fn test_rewrite_stats_reduction() {
let stats = RewriteStats {
nodes_before: 100,
nodes_after: 80,
nodes_eliminated: 20,
..Default::default()
};
assert_eq!(stats.reduction_percentage(), 20.0);
}
#[test]
fn test_pattern_matcher_creation() {
let matcher = PatternMatcher::new();
assert_eq!(matcher.rules.len(), 0);
}
#[test]
fn test_pattern_matcher_add_rule() {
let mut matcher = PatternMatcher::new();
fn dummy_rewriter(
_graph: &EinsumGraph,
_m: &PatternMatch,
) -> Result<Vec<EinsumNode>, IrError> {
Ok(Vec::new())
}
let rule = GraphRewriteRule::new("test", GraphPattern::AnyNode, dummy_rewriter);
matcher.add_rule(rule);
assert_eq!(matcher.rules.len(), 1);
}
#[test]
fn test_rule_priority_ordering() {
let mut matcher = PatternMatcher::new();
fn dummy_rewriter(
_graph: &EinsumGraph,
_m: &PatternMatch,
) -> Result<Vec<EinsumNode>, IrError> {
Ok(Vec::new())
}
let rule1 =
GraphRewriteRule::new("low", GraphPattern::AnyNode, dummy_rewriter).with_priority(1);
let rule2 =
GraphRewriteRule::new("high", GraphPattern::AnyNode, dummy_rewriter).with_priority(10);
matcher.add_rule(rule1);
matcher.add_rule(rule2);
assert_eq!(matcher.rules[0].name, "high");
assert_eq!(matcher.rules[1].name, "low");
}
#[test]
fn test_op_matches_einsum() {
let op1 = OpType::Einsum {
spec: "ij,jk->ik".to_string(),
};
let op2 = OpType::Einsum {
spec: "ik,kl->il".to_string(),
};
assert!(PatternMatcher::op_matches(&op1, &op2));
}
#[test]
fn test_op_matches_elem_unary() {
let op1 = OpType::ElemUnary {
op: "relu".to_string(),
};
let op2 = OpType::ElemUnary {
op: "relu".to_string(),
};
assert!(PatternMatcher::op_matches(&op1, &op2));
}
#[test]
fn test_op_not_matches_different_types() {
let op1 = OpType::ElemUnary {
op: "relu".to_string(),
};
let op2 = OpType::ElemBinary {
op: "add".to_string(),
};
assert!(!PatternMatcher::op_matches(&op1, &op2));
}
#[test]
fn test_patterns_elementwise_chain() {
let pattern = patterns::elementwise_chain(1);
match pattern {
GraphPattern::OneOrMore(_) => (),
_ => panic!("Expected OneOrMore pattern"),
}
}
#[test]
fn test_patterns_map_reduce() {
let pattern = patterns::map_reduce();
match pattern {
GraphPattern::Sequence(seq) => {
assert_eq!(seq.len(), 2);
}
_ => panic!("Expected Sequence pattern"),
}
}
}