use crate::QuantConfig;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PatternNode {
pub op_type: String,
pub attributes: HashMap<String, String>,
pub optional: bool,
pub constraints: Vec<PatternConstraint>,
}
impl PatternNode {
pub fn new(op_type: String) -> Self {
Self {
op_type,
attributes: HashMap::new(),
optional: false,
constraints: Vec::new(),
}
}
pub fn with_attribute(mut self, key: String, value: String) -> Self {
self.attributes.insert(key, value);
self
}
pub fn optional(mut self) -> Self {
self.optional = true;
self
}
pub fn with_constraint(mut self, constraint: PatternConstraint) -> Self {
self.constraints.push(constraint);
self
}
pub fn matches(&self, node_op_type: &str, node_attributes: &HashMap<String, String>) -> bool {
if self.op_type != "*" && self.op_type != node_op_type {
return false;
}
for (key, expected_value) in &self.attributes {
match node_attributes.get(key) {
Some(actual_value) if actual_value == expected_value => continue,
_ => return false,
}
}
for constraint in &self.constraints {
if !constraint.evaluate(node_op_type, node_attributes) {
return false;
}
}
true
}
pub fn description(&self) -> String {
let mut desc = self.op_type.clone();
if !self.attributes.is_empty() {
let attrs: Vec<String> = self
.attributes
.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect();
desc.push_str(&format!("[{}]", attrs.join(", ")));
}
if self.optional {
desc.push_str("?");
}
desc
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PatternConstraint {
AttributeEquals { key: String, value: String },
AttributeExists { key: String },
AttributeMatches { key: String, pattern: String },
Custom { name: String, description: String },
}
impl PatternConstraint {
pub fn evaluate(&self, _op_type: &str, attributes: &HashMap<String, String>) -> bool {
match self {
PatternConstraint::AttributeEquals { key, value } => {
attributes.get(key).map_or(false, |v| v == value)
}
PatternConstraint::AttributeExists { key } => attributes.contains_key(key),
PatternConstraint::AttributeMatches { key, pattern } => {
if let Some(attr_value) = attributes.get(key) {
attr_value.contains(pattern)
} else {
false
}
}
PatternConstraint::Custom { .. } => {
true
}
}
}
pub fn description(&self) -> String {
match self {
PatternConstraint::AttributeEquals { key, value } => {
format!("{} == {}", key, value)
}
PatternConstraint::AttributeExists { key } => {
format!("{} exists", key)
}
PatternConstraint::AttributeMatches { key, pattern } => {
format!("{} matches {}", key, pattern)
}
PatternConstraint::Custom { name, description } => {
format!("{}: {}", name, description)
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphPattern {
pub name: String,
pub description: String,
pub nodes: Vec<PatternNode>,
pub edges: Vec<(usize, usize)>, pub qconfig: Option<QuantConfig>,
pub priority: i32,
pub iterative: bool,
}
impl GraphPattern {
pub fn new(name: String, description: String) -> Self {
Self {
name,
description,
nodes: Vec::new(),
edges: Vec::new(),
qconfig: None,
priority: 0,
iterative: false,
}
}
pub fn add_node(mut self, node: PatternNode) -> Self {
self.nodes.push(node);
self
}
pub fn add_edge(mut self, from_index: usize, to_index: usize) -> Self {
self.edges.push((from_index, to_index));
self
}
pub fn with_qconfig(mut self, qconfig: QuantConfig) -> Self {
self.qconfig = Some(qconfig);
self
}
pub fn with_priority(mut self, priority: i32) -> Self {
self.priority = priority;
self
}
pub fn iterative(mut self) -> Self {
self.iterative = true;
self
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn is_valid(&self) -> bool {
let node_count = self.nodes.len();
self.edges
.iter()
.all(|(from, to)| *from < node_count && *to < node_count)
}
pub fn to_string(&self) -> String {
let mut result = format!("Pattern: {} ({})\n", self.name, self.description);
result.push_str("Nodes:\n");
for (i, node) in self.nodes.iter().enumerate() {
result.push_str(&format!(" {}: {}\n", i, node.description()));
}
if !self.edges.is_empty() {
result.push_str("Edges:\n");
for (from, to) in &self.edges {
result.push_str(&format!(" {} -> {}\n", from, to));
}
}
if let Some(ref qconfig) = self.qconfig {
result.push_str(&format!("Quantization Config: {:?}\n", qconfig));
}
result.push_str(&format!("Priority: {}\n", self.priority));
result.push_str(&format!("Iterative: {}\n", self.iterative));
result
}
}
pub struct CommonPatterns;
impl CommonPatterns {
pub fn conv_batch_norm() -> GraphPattern {
GraphPattern::new(
"conv_bn".to_string(),
"Conv2D followed by BatchNorm fusion".to_string(),
)
.add_node(PatternNode::new("conv2d".to_string()))
.add_node(PatternNode::new("batch_norm".to_string()))
.add_edge(0, 1)
.with_priority(10)
}
pub fn conv_batch_norm_relu() -> GraphPattern {
GraphPattern::new(
"conv_bn_relu".to_string(),
"Conv2D + BatchNorm + ReLU fusion".to_string(),
)
.add_node(PatternNode::new("conv2d".to_string()))
.add_node(PatternNode::new("batch_norm".to_string()))
.add_node(PatternNode::new("relu".to_string()))
.add_edge(0, 1)
.add_edge(1, 2)
.with_priority(15)
}
pub fn conv_relu() -> GraphPattern {
GraphPattern::new("conv_relu".to_string(), "Conv2D + ReLU fusion".to_string())
.add_node(PatternNode::new("conv2d".to_string()))
.add_node(PatternNode::new("relu".to_string()))
.add_edge(0, 1)
.with_priority(8)
}
pub fn linear_relu() -> GraphPattern {
GraphPattern::new(
"linear_relu".to_string(),
"Linear + ReLU fusion".to_string(),
)
.add_node(PatternNode::new("linear".to_string()))
.add_node(PatternNode::new("relu".to_string()))
.add_edge(0, 1)
.with_priority(8)
}
pub fn quant_dequant() -> GraphPattern {
GraphPattern::new(
"quant_dequant".to_string(),
"Quantize followed by Dequantize elimination".to_string(),
)
.add_node(PatternNode::new("quantize".to_string()))
.add_node(PatternNode::new("dequantize".to_string()))
.add_edge(0, 1)
.with_priority(20)
.iterative()
}
pub fn add_relu() -> GraphPattern {
GraphPattern::new("add_relu".to_string(), "Add + ReLU fusion".to_string())
.add_node(PatternNode::new("add".to_string()))
.add_node(PatternNode::new("relu".to_string()))
.add_edge(0, 1)
.with_priority(5)
}
pub fn transpose_transpose() -> GraphPattern {
GraphPattern::new(
"transpose_transpose".to_string(),
"Consecutive transpose operations elimination".to_string(),
)
.add_node(PatternNode::new("transpose".to_string()))
.add_node(PatternNode::new("transpose".to_string()))
.add_edge(0, 1)
.with_priority(15)
.iterative()
}
pub fn matmul_add() -> GraphPattern {
GraphPattern::new(
"matmul_add".to_string(),
"MatMul + Add (bias) fusion".to_string(),
)
.add_node(PatternNode::new("matmul".to_string()))
.add_node(PatternNode::new("add".to_string()))
.add_edge(0, 1)
.with_priority(12)
}
pub fn squeeze_unsqueeze() -> GraphPattern {
GraphPattern::new(
"squeeze_unsqueeze".to_string(),
"Squeeze followed by Unsqueeze elimination".to_string(),
)
.add_node(PatternNode::new("squeeze".to_string()))
.add_node(PatternNode::new("unsqueeze".to_string()))
.add_edge(0, 1)
.with_priority(10)
.iterative()
}
pub fn reshape_reshape() -> GraphPattern {
GraphPattern::new(
"reshape_reshape".to_string(),
"Consecutive reshape operations elimination".to_string(),
)
.add_node(PatternNode::new("reshape".to_string()))
.add_node(PatternNode::new("reshape".to_string()))
.add_edge(0, 1)
.with_priority(10)
.iterative()
}
pub fn all_patterns() -> Vec<GraphPattern> {
vec![
Self::conv_batch_norm_relu(),
Self::conv_batch_norm(),
Self::conv_relu(),
Self::linear_relu(),
Self::quant_dequant(),
Self::add_relu(),
Self::transpose_transpose(),
Self::matmul_add(),
Self::squeeze_unsqueeze(),
Self::reshape_reshape(),
]
}
pub fn fusion_patterns() -> Vec<GraphPattern> {
vec![
Self::conv_batch_norm_relu(),
Self::conv_batch_norm(),
Self::conv_relu(),
Self::linear_relu(),
Self::add_relu(),
Self::matmul_add(),
]
}
pub fn elimination_patterns() -> Vec<GraphPattern> {
vec![
Self::quant_dequant(),
Self::transpose_transpose(),
Self::squeeze_unsqueeze(),
Self::reshape_reshape(),
]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PatternCollection {
pub patterns: Vec<GraphPattern>,
pub name: String,
pub description: String,
}
impl PatternCollection {
pub fn new(name: String, description: String) -> Self {
Self {
patterns: Vec::new(),
name,
description,
}
}
pub fn add_pattern(mut self, pattern: GraphPattern) -> Self {
self.patterns.push(pattern);
self
}
pub fn get_by_priority(&self) -> Vec<&GraphPattern> {
let mut sorted: Vec<&GraphPattern> = self.patterns.iter().collect();
sorted.sort_by(|a, b| b.priority.cmp(&a.priority));
sorted
}
pub fn get_by_name(&self, name: &str) -> Option<&GraphPattern> {
self.patterns.iter().find(|p| p.name == name)
}
pub fn get_fusion_patterns(&self) -> Vec<&GraphPattern> {
self.patterns
.iter()
.filter(|p| {
p.name.contains("fusion")
|| p.description.to_lowercase().contains("fusion")
|| p.nodes.len() > 1
})
.collect()
}
pub fn get_elimination_patterns(&self) -> Vec<&GraphPattern> {
self.patterns
.iter()
.filter(|p| {
p.name.contains("elimination")
|| p.description.to_lowercase().contains("elimination")
|| p.iterative
})
.collect()
}
pub fn common() -> Self {
let mut collection = Self::new(
"Common Patterns".to_string(),
"Common quantization and optimization patterns".to_string(),
);
for pattern in CommonPatterns::all_patterns() {
collection = collection.add_pattern(pattern);
}
collection
}
pub fn fusion_only() -> Self {
let mut collection = Self::new(
"Fusion Patterns".to_string(),
"Operation fusion patterns for optimization".to_string(),
);
for pattern in CommonPatterns::fusion_patterns() {
collection = collection.add_pattern(pattern);
}
collection
}
pub fn elimination_only() -> Self {
let mut collection = Self::new(
"Elimination Patterns".to_string(),
"Dead code and redundant operation elimination patterns".to_string(),
);
for pattern in CommonPatterns::elimination_patterns() {
collection = collection.add_pattern(pattern);
}
collection
}
}
impl Default for PatternCollection {
fn default() -> Self {
Self::common()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_node_creation() {
let node = PatternNode::new("conv2d".to_string());
assert_eq!(node.op_type, "conv2d");
assert!(node.attributes.is_empty());
assert!(!node.optional);
assert!(node.constraints.is_empty());
}
#[test]
fn test_pattern_node_with_attributes() {
let node = PatternNode::new("conv2d".to_string())
.with_attribute("kernel_size".to_string(), "3x3".to_string())
.optional();
assert!(node.optional);
assert!(node.has_attribute("kernel_size", "3x3"));
}
#[test]
fn test_pattern_node_matching() {
let mut attributes = HashMap::new();
attributes.insert("kernel_size".to_string(), "3x3".to_string());
let node = PatternNode::new("conv2d".to_string())
.with_attribute("kernel_size".to_string(), "3x3".to_string());
assert!(node.matches("conv2d", &attributes));
assert!(!node.matches("relu", &attributes));
let empty_attrs = HashMap::new();
assert!(!node.matches("conv2d", &empty_attrs));
let wildcard = PatternNode::new("*".to_string());
assert!(wildcard.matches("conv2d", &attributes));
assert!(wildcard.matches("relu", &attributes));
}
#[test]
fn test_pattern_constraints() {
let constraint = PatternConstraint::AttributeExists {
key: "kernel_size".to_string(),
};
let mut attrs = HashMap::new();
attrs.insert("kernel_size".to_string(), "3x3".to_string());
assert!(constraint.evaluate("conv2d", &attrs));
let empty_attrs = HashMap::new();
assert!(!constraint.evaluate("conv2d", &empty_attrs));
}
#[test]
fn test_graph_pattern_creation() {
let pattern = GraphPattern::new("test_pattern".to_string(), "Test pattern".to_string())
.add_node(PatternNode::new("conv2d".to_string()))
.add_node(PatternNode::new("relu".to_string()))
.add_edge(0, 1);
assert_eq!(pattern.name, "test_pattern");
assert_eq!(pattern.node_count(), 2);
assert_eq!(pattern.edge_count(), 1);
assert!(pattern.is_valid());
}
#[test]
fn test_common_patterns() {
let conv_bn = CommonPatterns::conv_batch_norm();
assert_eq!(conv_bn.name, "conv_bn");
assert_eq!(conv_bn.node_count(), 2);
assert_eq!(conv_bn.edge_count(), 1);
let patterns = CommonPatterns::all_patterns();
assert!(!patterns.is_empty());
let fusion = CommonPatterns::fusion_patterns();
let elimination = CommonPatterns::elimination_patterns();
assert!(!fusion.is_empty());
assert!(!elimination.is_empty());
}
#[test]
fn test_pattern_collection() {
let collection = PatternCollection::common();
assert!(!collection.patterns.is_empty());
let by_priority = collection.get_by_priority();
assert!(!by_priority.is_empty());
for i in 1..by_priority.len() {
assert!(by_priority[i - 1].priority >= by_priority[i].priority);
}
let fusion_collection = PatternCollection::fusion_only();
let elimination_collection = PatternCollection::elimination_only();
assert!(!fusion_collection.patterns.is_empty());
assert!(!elimination_collection.patterns.is_empty());
}
#[test]
fn test_pattern_validation() {
let mut pattern = GraphPattern::new(
"invalid_pattern".to_string(),
"Pattern with invalid edge".to_string(),
);
pattern = pattern.add_node(PatternNode::new("conv2d".to_string()));
assert!(pattern.is_valid());
pattern = pattern.add_edge(0, 5);
assert!(!pattern.is_valid());
}
}