use crate::error::AutogradError;
use crate::{Float, NdArray, Result};
use scirs2_core::ndarray::{ArrayD, IxDyn};
use std::collections::{HashMap, HashSet};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum JitOp {
Input,
Add,
Sub,
Mul,
Div,
MatMul,
Relu,
Gelu,
Sigmoid,
Tanh,
Softmax,
Exp,
Log,
Square,
Sqrt,
Neg,
Scale,
BiasAdd,
Transpose,
ReduceSum,
ReduceMean,
}
impl JitOp {
pub fn is_elementwise(self) -> bool {
matches!(
self,
JitOp::Add
| JitOp::Sub
| JitOp::Mul
| JitOp::Div
| JitOp::Relu
| JitOp::Gelu
| JitOp::Sigmoid
| JitOp::Tanh
| JitOp::Exp
| JitOp::Log
| JitOp::Square
| JitOp::Sqrt
| JitOp::Neg
| JitOp::Scale
| JitOp::BiasAdd
)
}
pub fn is_activation(self) -> bool {
matches!(
self,
JitOp::Relu | JitOp::Gelu | JitOp::Sigmoid | JitOp::Tanh
)
}
pub fn is_unary(self) -> bool {
matches!(
self,
JitOp::Relu
| JitOp::Gelu
| JitOp::Sigmoid
| JitOp::Tanh
| JitOp::Exp
| JitOp::Log
| JitOp::Square
| JitOp::Sqrt
| JitOp::Neg
| JitOp::Scale
| JitOp::Transpose
)
}
}
impl fmt::Display for JitOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
#[derive(Debug, Clone)]
pub struct JitNode {
pub id: usize,
pub op: JitOp,
pub inputs: Vec<usize>,
pub shape: Option<Vec<usize>>,
pub num_consumers: usize,
}
impl JitNode {
pub fn new(id: usize, op: JitOp, inputs: Vec<usize>) -> Self {
Self {
id,
op,
inputs,
shape: None,
num_consumers: 0,
}
}
pub fn with_shape(id: usize, op: JitOp, inputs: Vec<usize>, shape: Vec<usize>) -> Self {
Self {
id,
op,
inputs,
shape: Some(shape),
num_consumers: 0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FusionKindJit {
Fma,
LinearBias,
LinearBiasActivation,
FusedAttention,
ElementWiseChain,
ResidualAdd,
Affine,
ReduceMean,
}
impl fmt::Display for FusionKindJit {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FusionKindJit::Fma => write!(f, "FMA"),
FusionKindJit::LinearBias => write!(f, "LinearBias"),
FusionKindJit::LinearBiasActivation => write!(f, "LinearBiasActivation"),
FusionKindJit::FusedAttention => write!(f, "FusedAttention"),
FusionKindJit::ElementWiseChain => write!(f, "ElementWiseChain"),
FusionKindJit::ResidualAdd => write!(f, "ResidualAdd"),
FusionKindJit::Affine => write!(f, "Affine"),
FusionKindJit::ReduceMean => write!(f, "ReduceMean"),
}
}
}
#[derive(Debug, Clone)]
pub struct JitFusionCandidate {
pub kind: FusionKindJit,
pub node_ids: Vec<usize>,
pub benefit: FusionBenefit,
pub root_id: usize,
}
impl JitFusionCandidate {
pub fn is_beneficial(&self) -> bool {
self.benefit.is_beneficial()
}
}
#[derive(Debug, Clone)]
pub struct FusionBenefit {
pub memory_saved_bytes: usize,
pub speedup_estimate: f64,
pub intermediates_eliminated: usize,
pub kernel_launches_saved: usize,
}
impl FusionBenefit {
pub fn is_beneficial(&self) -> bool {
self.speedup_estimate > 1.05 || self.intermediates_eliminated > 0
}
pub fn score(&self) -> f64 {
self.speedup_estimate * (1.0 + self.intermediates_eliminated as f64 * 0.5)
}
}
#[derive(Debug, Clone)]
pub struct FusionConfig {
pub enable_fma: bool,
pub enable_linear_fusion: bool,
pub enable_attention_fusion: bool,
pub enable_elementwise_chain: bool,
pub min_chain_length: usize,
pub max_chain_length: usize,
pub min_benefit_score: f64,
pub bytes_per_element: usize,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
enable_fma: true,
enable_linear_fusion: true,
enable_attention_fusion: true,
enable_elementwise_chain: true,
min_chain_length: 2,
max_chain_length: 16,
min_benefit_score: 1.0,
bytes_per_element: 4, }
}
}
pub struct JitFusionEngine {
config: FusionConfig,
}
impl JitFusionEngine {
pub fn new(config: FusionConfig) -> Self {
Self { config }
}
pub fn detect_fusions(&self, graph: &[JitNode]) -> Vec<JitFusionCandidate> {
let mut candidates = Vec::new();
let consumer_map = build_consumer_map(graph);
if self.config.enable_fma {
self.detect_fma(graph, &consumer_map, &mut candidates);
}
if self.config.enable_linear_fusion {
self.detect_linear_bias_activation(graph, &consumer_map, &mut candidates);
}
if self.config.enable_attention_fusion {
self.detect_fused_attention(graph, &consumer_map, &mut candidates);
}
if self.config.enable_elementwise_chain {
self.detect_elementwise_chains(graph, &consumer_map, &mut candidates);
}
candidates.retain(|c| c.benefit.score() >= self.config.min_benefit_score);
candidates.sort_by(|a, b| {
b.benefit
.score()
.partial_cmp(&a.benefit.score())
.unwrap_or(std::cmp::Ordering::Equal)
});
remove_overlapping(&mut candidates);
candidates
}
fn detect_fma(
&self,
graph: &[JitNode],
consumer_map: &HashMap<usize, Vec<usize>>,
candidates: &mut Vec<JitFusionCandidate>,
) {
let node_map: HashMap<usize, &JitNode> = graph.iter().map(|n| (n.id, n)).collect();
for node in graph {
if node.op == JitOp::Add && node.inputs.len() == 2 {
for (mul_idx, other_idx) in [(0, 1), (1, 0)] {
let mul_input_id = node.inputs[mul_idx];
if let Some(mul_node) = node_map.get(&mul_input_id) {
if mul_node.op == JitOp::Mul && mul_node.inputs.len() == 2 {
let consumers = consumer_map.get(&mul_input_id);
let single_consumer = consumers.map_or(false, |c| c.len() <= 1);
if single_consumer {
let shape_elements = node
.shape
.as_ref()
.map(|s| s.iter().product::<usize>())
.unwrap_or(1024);
let benefit = FusionBenefit {
memory_saved_bytes: shape_elements
* self.config.bytes_per_element,
speedup_estimate: 1.3,
intermediates_eliminated: 1,
kernel_launches_saved: 1,
};
candidates.push(JitFusionCandidate {
kind: FusionKindJit::Fma,
node_ids: vec![mul_node.id, node.id],
benefit,
root_id: node.id,
});
}
}
}
}
}
}
}
fn detect_linear_bias_activation(
&self,
graph: &[JitNode],
consumer_map: &HashMap<usize, Vec<usize>>,
candidates: &mut Vec<JitFusionCandidate>,
) {
let node_map: HashMap<usize, &JitNode> = graph.iter().map(|n| (n.id, n)).collect();
for node in graph {
if node.op != JitOp::MatMul {
continue;
}
let matmul_id = node.id;
let matmul_consumers = consumer_map.get(&matmul_id);
if let Some(consumers) = matmul_consumers {
if consumers.len() != 1 {
continue;
}
let bias_add_id = consumers[0];
if let Some(bias_node) = node_map.get(&bias_add_id) {
if bias_node.op != JitOp::Add && bias_node.op != JitOp::BiasAdd {
continue;
}
let bias_consumers = consumer_map.get(&bias_add_id);
if let Some(bias_cons) = bias_consumers {
if bias_cons.len() == 1 {
if let Some(act_node) = node_map.get(&bias_cons[0]) {
if act_node.op.is_activation() {
let shape_elements = node
.shape
.as_ref()
.map(|s| s.iter().product::<usize>())
.unwrap_or(4096);
let benefit = FusionBenefit {
memory_saved_bytes: shape_elements
* self.config.bytes_per_element
* 2,
speedup_estimate: 1.5,
intermediates_eliminated: 2,
kernel_launches_saved: 2,
};
candidates.push(JitFusionCandidate {
kind: FusionKindJit::LinearBiasActivation,
node_ids: vec![matmul_id, bias_add_id, act_node.id],
benefit,
root_id: act_node.id,
});
continue;
}
}
}
}
let shape_elements = node
.shape
.as_ref()
.map(|s| s.iter().product::<usize>())
.unwrap_or(4096);
let benefit = FusionBenefit {
memory_saved_bytes: shape_elements * self.config.bytes_per_element,
speedup_estimate: 1.2,
intermediates_eliminated: 1,
kernel_launches_saved: 1,
};
candidates.push(JitFusionCandidate {
kind: FusionKindJit::LinearBias,
node_ids: vec![matmul_id, bias_add_id],
benefit,
root_id: bias_add_id,
});
}
}
}
}
fn detect_fused_attention(
&self,
graph: &[JitNode],
consumer_map: &HashMap<usize, Vec<usize>>,
candidates: &mut Vec<JitFusionCandidate>,
) {
let node_map: HashMap<usize, &JitNode> = graph.iter().map(|n| (n.id, n)).collect();
for node in graph {
if node.op != JitOp::Softmax {
continue;
}
let softmax_id = node.id;
if node.inputs.is_empty() {
continue;
}
let pre_softmax_id = node.inputs[0];
let pre_softmax_node = match node_map.get(&pre_softmax_id) {
Some(n) => n,
None => continue,
};
let (scale_or_div, qk_matmul_id) =
if pre_softmax_node.op == JitOp::Scale || pre_softmax_node.op == JitOp::Div {
if pre_softmax_node.inputs.is_empty() {
continue;
}
(true, pre_softmax_node.inputs[0])
} else if pre_softmax_node.op == JitOp::MatMul {
(false, pre_softmax_id)
} else {
continue;
};
let qk_node = match node_map.get(&qk_matmul_id) {
Some(n) if n.op == JitOp::MatMul => n,
_ => continue,
};
let softmax_consumers = consumer_map.get(&softmax_id);
if let Some(consumers) = softmax_consumers {
for &consumer_id in consumers {
if let Some(final_matmul) = node_map.get(&consumer_id) {
if final_matmul.op == JitOp::MatMul {
let mut node_ids = vec![qk_matmul_id];
if scale_or_div {
node_ids.push(pre_softmax_id);
}
node_ids.push(softmax_id);
node_ids.push(consumer_id);
let shape_elements = qk_node
.shape
.as_ref()
.map(|s| s.iter().product::<usize>())
.unwrap_or(65536);
let benefit = FusionBenefit {
memory_saved_bytes: shape_elements
* self.config.bytes_per_element
* 3,
speedup_estimate: 2.0,
intermediates_eliminated: 3,
kernel_launches_saved: 3,
};
candidates.push(JitFusionCandidate {
kind: FusionKindJit::FusedAttention,
node_ids,
benefit,
root_id: consumer_id,
});
}
}
}
}
}
}
fn detect_elementwise_chains(
&self,
graph: &[JitNode],
consumer_map: &HashMap<usize, Vec<usize>>,
candidates: &mut Vec<JitFusionCandidate>,
) {
let node_map: HashMap<usize, &JitNode> = graph.iter().map(|n| (n.id, n)).collect();
let mut visited = HashSet::new();
for node in graph {
if !node.op.is_elementwise() || node.op == JitOp::Add || node.op == JitOp::Mul {
if !node.op.is_unary() {
continue;
}
}
if visited.contains(&node.id) {
continue;
}
let mut chain = vec![node.id];
let mut current = node;
loop {
let consumers = consumer_map.get(¤t.id);
if let Some(cons) = consumers {
if cons.len() == 1 {
if let Some(next) = node_map.get(&cons[0]) {
if next.op.is_elementwise()
&& next.op.is_unary()
&& chain.len() < self.config.max_chain_length
{
chain.push(next.id);
current = next;
continue;
}
}
}
}
break;
}
if chain.len() >= self.config.min_chain_length {
for &id in &chain {
visited.insert(id);
}
let shape_elements = node
.shape
.as_ref()
.map(|s| s.iter().product::<usize>())
.unwrap_or(1024);
let benefit = FusionBenefit {
memory_saved_bytes: shape_elements
* self.config.bytes_per_element
* (chain.len() - 1),
speedup_estimate: 1.0 + (chain.len() - 1) as f64 * 0.15,
intermediates_eliminated: chain.len() - 1,
kernel_launches_saved: chain.len() - 1,
};
let root_id = *chain.last().unwrap_or(&node.id);
candidates.push(JitFusionCandidate {
kind: FusionKindJit::ElementWiseChain,
node_ids: chain,
benefit,
root_id,
});
}
}
}
pub fn config(&self) -> &FusionConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct JitCompiledOp {
pub kind: FusionKindJit,
pub ops: Vec<JitOp>,
pub name: String,
}
impl JitCompiledOp {
pub fn from_candidate(candidate: &JitFusionCandidate, graph: &[JitNode]) -> Self {
let node_map: HashMap<usize, &JitNode> = graph.iter().map(|n| (n.id, n)).collect();
let ops: Vec<JitOp> = candidate
.node_ids
.iter()
.filter_map(|id| node_map.get(id).map(|n| n.op))
.collect();
let name = format!("fused_{}", candidate.kind);
Self {
kind: candidate.kind,
ops,
name,
}
}
pub fn execute_fma<F: Float>(
a: &NdArray<F>,
b: &NdArray<F>,
c: &NdArray<F>,
) -> Result<NdArray<F>> {
if a.shape() != b.shape() || a.shape() != c.shape() {
if c.len() == 1 {
let cv = *c
.iter()
.next()
.ok_or_else(|| AutogradError::compute_error("Empty scalar".into()))?;
let mut result = a.clone();
result.zip_mut_with(b, |av, bv| {
*av = *av * *bv + cv;
});
return Ok(result);
}
if a.shape() != b.shape() {
return Err(AutogradError::ShapeMismatch(format!(
"FMA shape mismatch: {:?} vs {:?} vs {:?}",
a.shape(),
b.shape(),
c.shape()
)));
}
}
let mut result = a.clone();
let b_slice = b
.as_slice()
.ok_or_else(|| AutogradError::compute_error("FMA: non-contiguous tensor b".into()))?;
let c_slice = c
.as_slice()
.ok_or_else(|| AutogradError::compute_error("FMA: non-contiguous tensor c".into()))?;
let result_slice = result
.as_slice_mut()
.ok_or_else(|| AutogradError::compute_error("FMA: non-contiguous result".into()))?;
for i in 0..result_slice.len() {
result_slice[i] = result_slice[i] * b_slice[i] + c_slice[i];
}
Ok(result)
}
pub fn execute_linear_bias_activation<F: Float>(
x: &NdArray<F>,
w: &NdArray<F>,
bias: &NdArray<F>,
activation: Option<JitOp>,
) -> Result<NdArray<F>> {
let x_shape = x.shape();
let w_shape = w.shape();
if x_shape.len() < 2 || w_shape.len() < 2 {
return Err(AutogradError::ShapeMismatch(
"Linear requires 2-D inputs".into(),
));
}
let m = x_shape[0];
let k = x_shape[1];
let n = w_shape[1];
if k != w_shape[0] {
return Err(AutogradError::ShapeMismatch(format!(
"Linear inner dim mismatch: {} vs {}",
k, w_shape[0]
)));
}
let x_slice = x
.as_slice()
.ok_or_else(|| AutogradError::compute_error("Non-contiguous input x".into()))?;
let w_slice = w
.as_slice()
.ok_or_else(|| AutogradError::compute_error("Non-contiguous weight w".into()))?;
let bias_flat: Vec<F> = bias.iter().copied().collect();
if bias_flat.len() != n {
return Err(AutogradError::ShapeMismatch(format!(
"Bias length {} != out_features {}",
bias_flat.len(),
n
)));
}
let mut result = ArrayD::<F>::zeros(IxDyn(&[m, n]));
let result_slice = result
.as_slice_mut()
.ok_or_else(|| AutogradError::compute_error("Non-contiguous result".into()))?;
for i in 0..m {
for j in 0..n {
let mut acc = F::zero();
for p in 0..k {
acc = acc + x_slice[i * k + p] * w_slice[p * n + j];
}
acc = acc + bias_flat[j];
acc = match activation {
Some(JitOp::Relu) => {
if acc > F::zero() {
acc
} else {
F::zero()
}
}
Some(JitOp::Sigmoid) => F::one() / (F::one() + (-acc).exp()),
Some(JitOp::Tanh) => acc.tanh(),
Some(JitOp::Gelu) => {
let sqrt_2_pi = F::from(0.7978845608028654).unwrap_or(F::one());
let coeff = F::from(0.044715).unwrap_or(F::zero());
let half = F::from(0.5).unwrap_or(F::one());
let inner = sqrt_2_pi * (acc + coeff * acc * acc * acc);
half * acc * (F::one() + inner.tanh())
}
_ => acc,
};
result_slice[i * n + j] = acc;
}
}
Ok(result)
}
pub fn execute_elementwise_chain<F: Float>(
input: &NdArray<F>,
ops: &[JitOp],
) -> Result<NdArray<F>> {
let mut current = input.clone();
for op in ops {
current = match op {
JitOp::Relu => current.mapv(|v| if v > F::zero() { v } else { F::zero() }),
JitOp::Sigmoid => current.mapv(|v| F::one() / (F::one() + (-v).exp())),
JitOp::Tanh => current.mapv(|v| v.tanh()),
JitOp::Gelu => {
let sqrt_2_pi = F::from(0.7978845608028654).unwrap_or(F::one());
let coeff = F::from(0.044715).unwrap_or(F::zero());
let half = F::from(0.5).unwrap_or(F::one());
current.mapv(|x| {
let inner = sqrt_2_pi * (x + coeff * x * x * x);
half * x * (F::one() + inner.tanh())
})
}
JitOp::Exp => current.mapv(|v| v.exp()),
JitOp::Log => current.mapv(|v| v.ln()),
JitOp::Neg => current.mapv(|v| -v),
JitOp::Square => current.mapv(|v| v * v),
JitOp::Sqrt => current.mapv(|v| v.sqrt()),
_ => current,
};
}
Ok(current)
}
}
pub struct FusionBenefitEstimator {
bytes_per_element: usize,
memory_bandwidth: f64,
compute_throughput: f64,
}
impl FusionBenefitEstimator {
pub fn new() -> Self {
Self {
bytes_per_element: 4,
memory_bandwidth: 50.0e9, compute_throughput: 1.0e12, }
}
pub fn with_params(
bytes_per_element: usize,
memory_bandwidth: f64,
compute_throughput: f64,
) -> Self {
Self {
bytes_per_element,
memory_bandwidth,
compute_throughput,
}
}
pub fn elementwise_chain_savings(
&self,
num_elements: usize,
chain_length: usize,
) -> FusionBenefit {
let unfused_bytes = num_elements * self.bytes_per_element * 2 * chain_length;
let fused_bytes = num_elements * self.bytes_per_element * 2;
let saved = unfused_bytes.saturating_sub(fused_bytes);
let speedup = if fused_bytes > 0 {
unfused_bytes as f64 / fused_bytes as f64
} else {
1.0
};
FusionBenefit {
memory_saved_bytes: saved,
speedup_estimate: speedup.min(chain_length as f64),
intermediates_eliminated: chain_length - 1,
kernel_launches_saved: chain_length - 1,
}
}
pub fn linear_bias_savings(&self, m: usize, n: usize) -> FusionBenefit {
let output_bytes = m * n * self.bytes_per_element;
FusionBenefit {
memory_saved_bytes: output_bytes, speedup_estimate: 1.2,
intermediates_eliminated: 1,
kernel_launches_saved: 1,
}
}
pub fn linear_bias_activation_savings(&self, m: usize, n: usize) -> FusionBenefit {
let output_bytes = m * n * self.bytes_per_element;
FusionBenefit {
memory_saved_bytes: output_bytes * 2,
speedup_estimate: 1.4,
intermediates_eliminated: 2,
kernel_launches_saved: 2,
}
}
pub fn attention_savings(&self, seq_len: usize, d_model: usize) -> FusionBenefit {
let qk_size = seq_len * seq_len * self.bytes_per_element;
let attn_size = seq_len * d_model * self.bytes_per_element;
FusionBenefit {
memory_saved_bytes: qk_size * 2 + attn_size,
speedup_estimate: 2.0,
intermediates_eliminated: 3,
kernel_launches_saved: 3,
}
}
pub fn fma_savings(&self, num_elements: usize) -> FusionBenefit {
let bytes = num_elements * self.bytes_per_element;
FusionBenefit {
memory_saved_bytes: bytes,
speedup_estimate: 1.3,
intermediates_eliminated: 1,
kernel_launches_saved: 1,
}
}
}
impl Default for FusionBenefitEstimator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct FusionRule {
pub name: String,
pub pattern: Vec<JitOp>,
pub kind: FusionKindJit,
pub speedup: f64,
}
pub struct PatternRegistry {
rules: Vec<FusionRule>,
}
impl PatternRegistry {
pub fn new() -> Self {
Self { rules: Vec::new() }
}
pub fn with_defaults() -> Self {
let mut reg = Self::new();
reg.add_rule(FusionRule {
name: "fma".into(),
pattern: vec![JitOp::Add, JitOp::Mul],
kind: FusionKindJit::Fma,
speedup: 1.3,
});
reg.add_rule(FusionRule {
name: "linear_bias".into(),
pattern: vec![JitOp::Add, JitOp::MatMul],
kind: FusionKindJit::LinearBias,
speedup: 1.2,
});
reg.add_rule(FusionRule {
name: "linear_bias_relu".into(),
pattern: vec![JitOp::Relu, JitOp::Add, JitOp::MatMul],
kind: FusionKindJit::LinearBiasActivation,
speedup: 1.5,
});
reg.add_rule(FusionRule {
name: "linear_bias_gelu".into(),
pattern: vec![JitOp::Gelu, JitOp::Add, JitOp::MatMul],
kind: FusionKindJit::LinearBiasActivation,
speedup: 1.5,
});
reg
}
pub fn add_rule(&mut self, rule: FusionRule) {
self.rules.push(rule);
}
pub fn rules(&self) -> &[FusionRule] {
&self.rules
}
pub fn len(&self) -> usize {
self.rules.len()
}
pub fn is_empty(&self) -> bool {
self.rules.is_empty()
}
pub fn match_rule(&self, ops: &[JitOp]) -> Option<&FusionRule> {
self.rules.iter().find(|rule| {
if rule.pattern.len() != ops.len() {
return false;
}
rule.pattern.iter().zip(ops).all(|(r, o)| r == o)
})
}
}
impl Default for PatternRegistry {
fn default() -> Self {
Self::with_defaults()
}
}
fn build_consumer_map(graph: &[JitNode]) -> HashMap<usize, Vec<usize>> {
let mut map: HashMap<usize, Vec<usize>> = HashMap::new();
for node in graph {
for &input_id in &node.inputs {
map.entry(input_id).or_default().push(node.id);
}
}
map
}
fn remove_overlapping(candidates: &mut Vec<JitFusionCandidate>) {
let mut used_nodes = HashSet::new();
candidates.retain(|c| {
let overlaps = c.node_ids.iter().any(|id| used_nodes.contains(id));
if overlaps {
false
} else {
for &id in &c.node_ids {
used_nodes.insert(id);
}
true
}
});
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_jit_op_is_elementwise() {
assert!(JitOp::Add.is_elementwise());
assert!(JitOp::Relu.is_elementwise());
assert!(!JitOp::MatMul.is_elementwise());
assert!(!JitOp::Softmax.is_elementwise());
}
#[test]
fn test_jit_op_is_activation() {
assert!(JitOp::Relu.is_activation());
assert!(JitOp::Gelu.is_activation());
assert!(!JitOp::Add.is_activation());
}
#[test]
fn test_jit_op_display() {
assert_eq!(format!("{}", JitOp::MatMul), "MatMul");
}
#[test]
fn test_detect_fma() {
let engine = JitFusionEngine::new(FusionConfig::default());
let graph = vec![
JitNode::new(0, JitOp::Input, vec![]),
JitNode::new(1, JitOp::Input, vec![]),
JitNode::new(2, JitOp::Mul, vec![0, 1]),
JitNode::new(3, JitOp::Input, vec![]),
JitNode::new(4, JitOp::Add, vec![2, 3]),
];
let fusions = engine.detect_fusions(&graph);
assert!(fusions.iter().any(|f| f.kind == FusionKindJit::Fma));
}
#[test]
fn test_fma_not_detected_when_mul_has_multiple_consumers() {
let engine = JitFusionEngine::new(FusionConfig::default());
let graph = vec![
JitNode::new(0, JitOp::Input, vec![]),
JitNode::new(1, JitOp::Input, vec![]),
JitNode::new(2, JitOp::Mul, vec![0, 1]),
JitNode::new(3, JitOp::Input, vec![]),
JitNode::new(4, JitOp::Add, vec![2, 3]),
JitNode::new(5, JitOp::Neg, vec![2]), ];
let fusions = engine.detect_fusions(&graph);
assert!(!fusions.iter().any(|f| f.kind == FusionKindJit::Fma));
}
#[test]
fn test_detect_linear_bias() {
let engine = JitFusionEngine::new(FusionConfig::default());
let graph = vec![
JitNode::new(0, JitOp::Input, vec![]),
JitNode::new(1, JitOp::Input, vec![]),
JitNode::new(2, JitOp::MatMul, vec![0, 1]),
JitNode::new(3, JitOp::Input, vec![]),
JitNode::new(4, JitOp::Add, vec![2, 3]),
];
let fusions = engine.detect_fusions(&graph);
let has_linear = fusions.iter().any(|f| f.kind == FusionKindJit::LinearBias);
assert!(has_linear);
}
#[test]
fn test_detect_linear_bias_activation() {
let engine = JitFusionEngine::new(FusionConfig::default());
let graph = vec![
JitNode::new(0, JitOp::Input, vec![]),
JitNode::new(1, JitOp::Input, vec![]),
JitNode::new(2, JitOp::MatMul, vec![0, 1]),
JitNode::new(3, JitOp::Input, vec![]),
JitNode::new(4, JitOp::Add, vec![2, 3]),
JitNode::new(5, JitOp::Relu, vec![4]),
];
let fusions = engine.detect_fusions(&graph);
let has_lba = fusions
.iter()
.any(|f| f.kind == FusionKindJit::LinearBiasActivation);
assert!(has_lba);
}
#[test]
fn test_detect_fused_attention() {
let engine = JitFusionEngine::new(FusionConfig::default());
let graph = vec![
JitNode::new(0, JitOp::Input, vec![]), JitNode::new(1, JitOp::Input, vec![]), JitNode::new(2, JitOp::MatMul, vec![0, 1]), JitNode::new(3, JitOp::Scale, vec![2]), JitNode::new(4, JitOp::Softmax, vec![3]), JitNode::new(5, JitOp::Input, vec![]), JitNode::new(6, JitOp::MatMul, vec![4, 5]), ];
let fusions = engine.detect_fusions(&graph);
let has_attn = fusions
.iter()
.any(|f| f.kind == FusionKindJit::FusedAttention);
assert!(has_attn);
}
#[test]
fn test_detect_elementwise_chain() {
let engine = JitFusionEngine::new(FusionConfig::default());
let graph = vec![
JitNode::new(0, JitOp::Input, vec![]),
JitNode::new(1, JitOp::Relu, vec![0]),
JitNode::new(2, JitOp::Sigmoid, vec![1]),
JitNode::new(3, JitOp::Neg, vec![2]),
];
let fusions = engine.detect_fusions(&graph);
let has_chain = fusions
.iter()
.any(|f| f.kind == FusionKindJit::ElementWiseChain);
assert!(has_chain);
}
#[test]
fn test_no_chain_for_single_op() {
let engine = JitFusionEngine::new(FusionConfig::default());
let graph = vec![
JitNode::new(0, JitOp::Input, vec![]),
JitNode::new(1, JitOp::Relu, vec![0]),
];
let fusions = engine.detect_fusions(&graph);
let has_chain = fusions
.iter()
.any(|f| f.kind == FusionKindJit::ElementWiseChain);
assert!(!has_chain);
}
#[test]
fn test_execute_fma() {
let a = Array1::from(vec![2.0_f64, 3.0, 4.0]).into_dyn();
let b = Array1::from(vec![5.0_f64, 6.0, 7.0]).into_dyn();
let c = Array1::from(vec![1.0_f64, 1.0, 1.0]).into_dyn();
let result = JitCompiledOp::execute_fma(&a, &b, &c).expect("fma");
assert!((result[[0]] - 11.0).abs() < 1e-10); assert!((result[[1]] - 19.0).abs() < 1e-10); assert!((result[[2]] - 29.0).abs() < 1e-10); }
#[test]
fn test_execute_linear_bias() {
let x = ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![1.0_f64, 0.0, 0.0, 0.0, 1.0, 0.0])
.expect("x");
let w =
ArrayD::from_shape_vec(IxDyn(&[3, 2]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("w");
let bias = ArrayD::from_shape_vec(IxDyn(&[2]), vec![0.1, 0.2]).expect("bias");
let result =
JitCompiledOp::execute_linear_bias_activation(&x, &w, &bias, None).expect("linear");
assert!((result[IxDyn(&[0, 0])] - 1.1).abs() < 1e-10);
assert!((result[IxDyn(&[0, 1])] - 2.2).abs() < 1e-10);
}
#[test]
fn test_execute_linear_bias_relu() {
let x = ArrayD::from_shape_vec(IxDyn(&[1, 2]), vec![-1.0_f64, 1.0]).expect("x");
let w = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![1.0, 0.0, 0.0, 1.0]).expect("w");
let bias = ArrayD::from_shape_vec(IxDyn(&[2]), vec![0.0, 0.0]).expect("bias");
let result =
JitCompiledOp::execute_linear_bias_activation(&x, &w, &bias, Some(JitOp::Relu))
.expect("linear+relu");
assert!((result[IxDyn(&[0, 0])]).abs() < 1e-10);
assert!((result[IxDyn(&[0, 1])] - 1.0).abs() < 1e-10);
}
#[test]
fn test_execute_elementwise_chain() {
let input = Array1::from(vec![-2.0_f64, -1.0, 0.0, 1.0, 2.0]).into_dyn();
let ops = vec![JitOp::Relu, JitOp::Square];
let result = JitCompiledOp::execute_elementwise_chain(&input, &ops).expect("chain");
assert!((result[[0]]).abs() < 1e-10);
assert!((result[[3]] - 1.0).abs() < 1e-10);
assert!((result[[4]] - 4.0).abs() < 1e-10);
}
#[test]
fn test_estimator_elementwise_chain() {
let estimator = FusionBenefitEstimator::new();
let benefit = estimator.elementwise_chain_savings(1000, 5);
assert!(benefit.memory_saved_bytes > 0);
assert!(benefit.speedup_estimate > 1.0);
assert_eq!(benefit.intermediates_eliminated, 4);
assert_eq!(benefit.kernel_launches_saved, 4);
}
#[test]
fn test_estimator_linear_bias() {
let estimator = FusionBenefitEstimator::new();
let benefit = estimator.linear_bias_savings(32, 64);
assert!(benefit.memory_saved_bytes > 0);
}
#[test]
fn test_estimator_attention() {
let estimator = FusionBenefitEstimator::new();
let benefit = estimator.attention_savings(128, 512);
assert!(benefit.speedup_estimate >= 2.0);
assert_eq!(benefit.intermediates_eliminated, 3);
}
#[test]
fn test_estimator_fma() {
let estimator = FusionBenefitEstimator::new();
let benefit = estimator.fma_savings(1024);
assert!(benefit.memory_saved_bytes > 0);
}
#[test]
fn test_pattern_registry_defaults() {
let reg = PatternRegistry::with_defaults();
assert!(reg.len() >= 4);
}
#[test]
fn test_pattern_registry_match() {
let reg = PatternRegistry::with_defaults();
let matched = reg.match_rule(&[JitOp::Add, JitOp::Mul]);
assert!(matched.is_some());
assert_eq!(matched.map(|r| r.kind), Some(FusionKindJit::Fma));
}
#[test]
fn test_pattern_registry_no_match() {
let reg = PatternRegistry::with_defaults();
let matched = reg.match_rule(&[JitOp::Softmax, JitOp::Sqrt]);
assert!(matched.is_none());
}
#[test]
fn test_pattern_registry_add_custom() {
let mut reg = PatternRegistry::new();
assert!(reg.is_empty());
reg.add_rule(FusionRule {
name: "custom".into(),
pattern: vec![JitOp::Neg, JitOp::Exp],
kind: FusionKindJit::ElementWiseChain,
speedup: 1.1,
});
assert_eq!(reg.len(), 1);
}
#[test]
fn test_benefit_is_beneficial() {
let benefit = FusionBenefit {
memory_saved_bytes: 1000,
speedup_estimate: 1.1,
intermediates_eliminated: 1,
kernel_launches_saved: 1,
};
assert!(benefit.is_beneficial());
}
#[test]
fn test_benefit_not_beneficial() {
let benefit = FusionBenefit {
memory_saved_bytes: 0,
speedup_estimate: 1.0,
intermediates_eliminated: 0,
kernel_launches_saved: 0,
};
assert!(!benefit.is_beneficial());
}
#[test]
fn test_benefit_score() {
let benefit = FusionBenefit {
memory_saved_bytes: 100,
speedup_estimate: 2.0,
intermediates_eliminated: 2,
kernel_launches_saved: 2,
};
assert!((benefit.score() - 4.0).abs() < 1e-10);
}
#[test]
fn test_default_config() {
let config = FusionConfig::default();
assert!(config.enable_fma);
assert!(config.enable_linear_fusion);
assert!(config.enable_attention_fusion);
assert!(config.enable_elementwise_chain);
assert_eq!(config.min_chain_length, 2);
}
#[test]
fn test_compiled_op_from_candidate() {
let graph = vec![
JitNode::new(0, JitOp::Mul, vec![]),
JitNode::new(1, JitOp::Add, vec![0]),
];
let candidate = JitFusionCandidate {
kind: FusionKindJit::Fma,
node_ids: vec![0, 1],
benefit: FusionBenefit {
memory_saved_bytes: 100,
speedup_estimate: 1.3,
intermediates_eliminated: 1,
kernel_launches_saved: 1,
},
root_id: 1,
};
let compiled = JitCompiledOp::from_candidate(&candidate, &graph);
assert_eq!(compiled.kind, FusionKindJit::Fma);
assert_eq!(compiled.ops.len(), 2);
}
#[test]
fn test_overlapping_removal() {
let engine = JitFusionEngine::new(FusionConfig::default());
let graph = vec![
JitNode::new(0, JitOp::Input, vec![]),
JitNode::new(1, JitOp::Relu, vec![0]),
JitNode::new(2, JitOp::Sigmoid, vec![1]),
JitNode::new(3, JitOp::Neg, vec![2]),
];
let fusions = engine.detect_fusions(&graph);
let chain_count = fusions
.iter()
.filter(|f| f.kind == FusionKindJit::ElementWiseChain)
.count();
assert!(chain_count <= 1);
}
#[test]
fn test_linear_shape_mismatch() {
let x = ArrayD::from_shape_vec(IxDyn(&[2, 3]), vec![1.0_f64; 6]).expect("x");
let w = ArrayD::from_shape_vec(IxDyn(&[4, 2]), vec![1.0_f64; 8]).expect("w");
let bias = ArrayD::from_shape_vec(IxDyn(&[2]), vec![0.0_f64; 2]).expect("bias");
let result = JitCompiledOp::execute_linear_bias_activation(&x, &w, &bias, None);
assert!(result.is_err());
}
#[test]
fn test_fma_scalar_broadcast() {
let a = Array1::from(vec![2.0_f64, 3.0]).into_dyn();
let b = Array1::from(vec![4.0_f64, 5.0]).into_dyn();
let c = ArrayD::from_elem(IxDyn(&[1]), 10.0_f64);
let result = JitCompiledOp::execute_fma(&a, &b, &c).expect("fma broadcast");
assert!((result[[0]] - 18.0).abs() < 1e-10); assert!((result[[1]] - 25.0).abs() < 1e-10); }
}