use super::core::FusedOp;
pub struct OpFusionEngine {
pub enabled: bool,
pub fusion_threshold: usize,
}
impl OpFusionEngine {
pub fn new() -> Self {
Self {
enabled: true,
fusion_threshold: 2, }
}
pub fn analyze_sequence(&self, ops: &[&str]) -> Vec<FusedOp> {
let mut fused_ops = Vec::new();
for window in ops.windows(2) {
match window {
["add", "relu"] => fused_ops.push(FusedOp::ReluAdd),
["mul", "add"] => fused_ops.push(FusedOp::MulAdd),
["add", "mul"] => fused_ops.push(FusedOp::AddMul),
["sigmoid", "mul"] => fused_ops.push(FusedOp::SigmoidMul),
["tanh", "scale"] => fused_ops.push(FusedOp::TanhScale),
_ => {}
}
}
for window in ops.windows(3) {
match window {
["add", "relu", "mul"] => {
fused_ops.retain(|op| !matches!(op, FusedOp::ReluAdd));
fused_ops.push(FusedOp::AddReluMul);
}
_ => {}
}
}
fused_ops
}
pub fn should_fuse(&self, ops: &[&str]) -> bool {
self.enabled && ops.len() >= self.fusion_threshold
}
}
impl Default for OpFusionEngine {
fn default() -> Self {
Self::new()
}
}
pub fn detect_fusible_patterns(operations: &[&str]) -> Vec<(usize, FusedOp)> {
let mut patterns = Vec::new();
for (i, window) in operations.windows(2).enumerate() {
match window {
["add", "relu"] => patterns.push((i, FusedOp::ReluAdd)),
["mul", "add"] => patterns.push((i, FusedOp::MulAdd)),
["add", "mul"] => patterns.push((i, FusedOp::AddMul)),
["sigmoid", "mul"] => patterns.push((i, FusedOp::SigmoidMul)),
["tanh", "scale"] => patterns.push((i, FusedOp::TanhScale)),
_ => {}
}
}
for (i, window) in operations.windows(3).enumerate() {
match window {
["add", "relu", "mul"] => {
patterns.retain(|(pos, op)| {
!(*pos == i && matches!(op, FusedOp::ReluAdd))
&& !(*pos == i + 1 && matches!(op, FusedOp::AddMul))
});
patterns.push((i, FusedOp::AddReluMul));
}
_ => {}
}
}
patterns
}