#[derive(Debug, Clone)]
pub enum FusedOp {
ReluAdd,
AddMul,
MulAdd,
SigmoidMul,
TanhScale,
AddReluMul,
BatchNormFused,
LayerNormFused,
}
#[derive(Debug)]
pub struct OpSequence {
pub operations: Vec<FusedOp>,
pub input_count: usize,
pub output_count: usize,
}
impl OpSequence {
pub fn new() -> Self {
Self {
operations: Vec::new(),
input_count: 0,
output_count: 0,
}
}
pub fn add_operation(&mut self, op: FusedOp) {
self.operations.push(op);
}
pub fn is_fusible(&self) -> bool {
!self.operations.is_empty() && self.operations.len() <= 4
}
pub fn fusion_benefit_estimate(&self) -> f32 {
if self.operations.is_empty() {
return 0.0;
}
let memory_benefit = (self.operations.len() - 1) as f32 * 0.3;
let simd_benefit = self
.operations
.iter()
.map(|op| match op {
FusedOp::ReluAdd | FusedOp::AddMul | FusedOp::MulAdd => 0.2,
FusedOp::SigmoidMul | FusedOp::TanhScale => 0.15,
_ => 0.1,
})
.sum::<f32>();
memory_benefit + simd_benefit
}
}
impl Default for OpSequence {
fn default() -> Self {
Self::new()
}
}