runmat_accelerate/
fusion.rs

1use std::cell::RefCell;
2use std::collections::{HashMap, HashSet};
3use std::sync::{Arc, OnceLock, RwLock, Weak};
4
5use once_cell::sync::Lazy;
6use runmat_accelerate_api::ReductionFlavor;
7use runmat_builtins::Value;
8use serde::{Deserialize, Serialize};
9
10use crate::graph::{
11    AccelGraph, AccelNode, AccelNodeLabel, AccelOpCategory, InstrSpan, NodeId, PrimitiveOp,
12    ShapeInfo, ValueId, ValueInfo, ValueOrigin,
13};
14use crate::reduction_meta::{detect_reduction_signature, ReductionAxes, ReductionBehavior};
15use runmat_accelerate_api::CovNormalization;
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
18pub enum FusionKind {
19    ElementwiseChain,
20    Reduction,
21    MatmulEpilogue,
22    CenteredGram,
23    ImageNormalize,
24    PowerStepNormalize,
25    ExplainedVariance,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct FusionGroup {
30    pub id: usize,
31    pub kind: FusionKind,
32    pub nodes: Vec<NodeId>,
33    pub shape: ShapeInfo,
34    pub span: InstrSpan,
35    pub pattern: Option<FusionPattern>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub enum FusionPattern {
40    CenteredGram {
41        matrix: ValueId,
42        normalization: CovNormalization,
43    },
44    ImageNormalize(ImageNormalizePattern),
45    PowerStepNormalize {
46        lhs: ValueId,
47        rhs: ValueId,
48        epsilon: f64,
49    },
50    ExplainedVariance {
51        q: ValueId,
52        g: ValueId,
53    },
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ImageNormalizePattern {
58    pub input: ValueId,
59    pub epsilon: ImageScalar,
60    pub gain: Option<ImageScalar>,
61    pub bias: Option<ImageScalar>,
62    pub gamma: Option<ImageScalar>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum ImageScalar {
67    Constant(f64),
68    Value(ValueId),
69}
70
71pub fn detect_fusion_groups(graph: &AccelGraph) -> Vec<FusionGroup> {
72    if graph.nodes.is_empty() {
73        return Vec::new();
74    }
75
76    let consumer_map = build_consumer_map(graph);
77    let mut assigned: HashSet<NodeId> = HashSet::new();
78    let mut groups = Vec::new();
79    let mut group_id = 0usize;
80
81    detect_image_normalize(graph, &mut assigned, &mut groups, &mut group_id);
82    detect_explained_variance(graph, &mut assigned, &mut groups, &mut group_id);
83    detect_power_step_normalize(graph, &mut assigned, &mut groups, &mut group_id);
84    detect_centered_gram(graph, &mut assigned, &mut groups, &mut group_id);
85
86    for node in &graph.nodes {
87        // Elementwise chains
88        if assigned.contains(&node.id) {
89            continue;
90        }
91        let elementwise_like = node.is_elementwise() || is_elementwise_max_min(graph, node);
92        if !elementwise_like {
93            continue;
94        }
95        if node.outputs.is_empty() {
96            continue;
97        }
98        let mut current_shape = node_output_shape(graph, node);
99        if matches!(current_shape, ShapeInfo::Unknown) {
100            continue;
101        }
102        let mut chain: Vec<NodeId> = Vec::new();
103        let mut frontier = node.id;
104        let mut local_seen: HashSet<NodeId> = HashSet::new();
105
106        loop {
107            if !local_seen.insert(frontier) {
108                break;
109            }
110            chain.push(frontier);
111            let next = find_next_elementwise(
112                graph,
113                frontier,
114                &assigned,
115                &local_seen,
116                &consumer_map,
117                &current_shape,
118            );
119            match next {
120                Some((next_id, next_shape)) => {
121                    frontier = next_id;
122                    current_shape = next_shape;
123                }
124                None => break,
125            }
126        }
127
128        if chain.len() > 1 {
129            expand_group_with_fanout(graph, &mut chain, &assigned, &consumer_map);
130            chain.sort_unstable_by_key(|id| {
131                graph
132                    .node(*id)
133                    .map(|node| node.span.start)
134                    .unwrap_or_default()
135            });
136            chain.dedup();
137            for id in &chain {
138                assigned.insert(*id);
139            }
140            let span = group_span(graph, &chain);
141            groups.push(FusionGroup {
142                id: group_id,
143                kind: FusionKind::ElementwiseChain,
144                nodes: chain,
145                shape: current_shape.clone(),
146                span,
147                pattern: None,
148            });
149            group_id += 1;
150        }
151    }
152
153    // Reduction singletons (basic grouping; future: include eligible producers)
154    for node in &graph.nodes {
155        if assigned.contains(&node.id) {
156            continue;
157        }
158        if !node.is_reduction() || is_elementwise_max_min(graph, node) {
159            continue;
160        }
161        let span = InstrSpan {
162            start: node.span.start,
163            end: node.span.end,
164        };
165        groups.push(FusionGroup {
166            id: group_id,
167            kind: FusionKind::Reduction,
168            nodes: vec![node.id],
169            shape: node_output_shape(graph, node),
170            span,
171            pattern: None,
172        });
173        group_id += 1;
174    }
175
176    // Matmul + simple epilogue (alpha/beta/row/col scale) chains
177    for node in &graph.nodes {
178        if node.category != AccelOpCategory::MatMul || assigned.contains(&node.id) {
179            continue;
180        }
181        if node.outputs.is_empty() {
182            continue;
183        }
184        // Require exactly one consumer chain and only elementwise ops we can fold
185        let mut chain: Vec<NodeId> = vec![node.id];
186        let mut frontier = node.id;
187        let mut ok = false;
188        loop {
189            // Find single consumer of the current frontier's output
190            let mut next_id_opt: Option<NodeId> = None;
191            for &out in &graph.node(frontier).unwrap().outputs {
192                if let Some(cons) = consumer_map.get(&out) {
193                    if cons.len() == 1 {
194                        next_id_opt = cons.iter().copied().next();
195                    } else {
196                        next_id_opt = None;
197                    }
198                }
199            }
200            let Some(next_id) = next_id_opt else { break };
201            let next = graph.node(next_id).unwrap();
202            if !next.is_elementwise() {
203                break;
204            }
205            // Allow only primitive elementwise ops we can fold: add/sub/mul/div/elem variants
206            let allowed = matches!(
207                next.label,
208                AccelNodeLabel::Primitive(PrimitiveOp::Add)
209                    | AccelNodeLabel::Primitive(PrimitiveOp::Sub)
210                    | AccelNodeLabel::Primitive(PrimitiveOp::Mul)
211                    | AccelNodeLabel::Primitive(PrimitiveOp::ElemMul)
212                    | AccelNodeLabel::Primitive(PrimitiveOp::Div)
213                    | AccelNodeLabel::Primitive(PrimitiveOp::ElemDiv)
214            );
215            if !allowed {
216                break;
217            }
218            chain.push(next_id);
219            frontier = next_id;
220            ok = true;
221        }
222        if ok {
223            for id in &chain {
224                assigned.insert(*id);
225            }
226            let span = group_span(graph, &chain);
227            groups.push(FusionGroup {
228                id: group_id,
229                kind: FusionKind::MatmulEpilogue,
230                nodes: chain,
231                shape: node_output_shape(graph, node),
232                span,
233                pattern: None,
234            });
235            group_id += 1;
236        }
237    }
238
239    merge_downstream_fanout(graph, &mut groups, &consumer_map);
240    groups
241}
242
243fn expand_group_with_fanout(
244    graph: &AccelGraph,
245    chain: &mut Vec<NodeId>,
246    assigned: &HashSet<NodeId>,
247    consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
248) {
249    let base_start = chain
250        .iter()
251        .filter_map(|id| graph.node(*id).map(|node| node.span.start))
252        .min()
253        .unwrap_or(0);
254    let mut node_set: HashSet<NodeId> = chain.iter().copied().collect();
255    let mut changed = true;
256    while changed {
257        changed = false;
258        for node in &graph.nodes {
259            if node_set.contains(&node.id) {
260                continue;
261            }
262            if node.span.start < base_start {
263                continue;
264            }
265            if assigned.contains(&node.id) {
266                continue;
267            }
268            if !(node.is_elementwise() || is_elementwise_max_min(graph, node)) {
269                continue;
270            }
271            if node.outputs.is_empty() {
272                continue;
273            }
274            let mut feeds_group = false;
275            let mut all_consumers_ok = true;
276            for &out in &node.outputs {
277                if let Some(consumers) = consumer_map.get(&out) {
278                    let mut consumer_in_group = false;
279                    for consumer in consumers {
280                        if node_set.contains(consumer) {
281                            consumer_in_group = true;
282                        } else {
283                            all_consumers_ok = false;
284                            break;
285                        }
286                    }
287                    if !all_consumers_ok {
288                        break;
289                    }
290                    if consumer_in_group {
291                        feeds_group = true;
292                    }
293                } else {
294                    all_consumers_ok = false;
295                    break;
296                }
297            }
298            if !feeds_group || !all_consumers_ok {
299                continue;
300            }
301            let mut inputs_ok = true;
302            for &input in &node.inputs {
303                if let Some(info) = graph.value(input) {
304                    if let ValueOrigin::NodeOutput { node: producer, .. } = info.origin {
305                        if !node_set.contains(&producer) {
306                            if let Some(prod_node) = graph.node(producer) {
307                                if prod_node.span.start >= base_start {
308                                    inputs_ok = false;
309                                    break;
310                                }
311                            } else {
312                                inputs_ok = false;
313                                break;
314                            }
315                        }
316                    }
317                }
318            }
319            if inputs_ok {
320                node_set.insert(node.id);
321                chain.push(node.id);
322                changed = true;
323            }
324        }
325    }
326}
327
328fn build_consumer_map(graph: &AccelGraph) -> HashMap<ValueId, HashSet<NodeId>> {
329    let mut map: HashMap<ValueId, HashSet<NodeId>> = HashMap::new();
330    for node in &graph.nodes {
331        for &input in &node.inputs {
332            if let Some(value) = graph.value(input) {
333                if matches!(value.origin, crate::graph::ValueOrigin::NodeOutput { .. }) {
334                    map.entry(input).or_default().insert(node.id);
335                }
336            }
337        }
338    }
339    map
340}
341
342fn merge_downstream_fanout(
343    graph: &AccelGraph,
344    groups: &mut Vec<FusionGroup>,
345    consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
346) {
347    let mut changed = true;
348    while changed {
349        changed = false;
350        let mut node_group: HashMap<NodeId, usize> = HashMap::new();
351        for (idx, group) in groups.iter().enumerate() {
352            if group.kind.is_elementwise() {
353                for &node in &group.nodes {
354                    node_group.insert(node, idx);
355                }
356            }
357        }
358        'outer: for target_idx in 0..groups.len() {
359            if !groups[target_idx].kind.is_elementwise() {
360                continue;
361            }
362            let base_start = groups[target_idx].span.start;
363            let mut merge_indices: Vec<usize> = Vec::new();
364            for &node_id in &groups[target_idx].nodes {
365                let Some(node) = graph.node(node_id) else {
366                    continue;
367                };
368                for &input in &node.inputs {
369                    if let Some(info) = graph.value(input) {
370                        if let ValueOrigin::NodeOutput { node: producer, .. } = info.origin {
371                            if let Some(&source_idx) = node_group.get(&producer) {
372                                if source_idx == target_idx {
373                                    continue;
374                                }
375                                let source_group = &groups[source_idx];
376                                if !source_group.kind.is_elementwise() {
377                                    continue;
378                                }
379                                if source_group.span.start < base_start {
380                                    continue;
381                                }
382                                if !group_consumers_subset(
383                                    source_group,
384                                    target_idx,
385                                    groups,
386                                    consumer_map,
387                                    graph,
388                                ) {
389                                    continue;
390                                }
391                                merge_indices.push(source_idx);
392                            }
393                        }
394                    }
395                }
396            }
397            if merge_indices.is_empty() {
398                continue;
399            }
400            merge_indices.sort_unstable();
401            merge_indices.dedup();
402            for idx in &merge_indices {
403                let nodes = groups[*idx].nodes.clone();
404                groups[target_idx].nodes.extend(nodes);
405                groups[*idx].nodes.clear();
406            }
407            groups[target_idx]
408                .nodes
409                .sort_unstable_by_key(|id| graph.node(*id).map(|n| n.span.start).unwrap_or(0));
410            groups[target_idx].nodes.dedup();
411            groups[target_idx].span = group_span(graph, &groups[target_idx].nodes);
412            changed = true;
413            break 'outer;
414        }
415        if changed {
416            groups.retain(|group| !group.nodes.is_empty());
417        }
418    }
419}
420
421fn group_consumers_subset(
422    source_group: &FusionGroup,
423    target_idx: usize,
424    groups: &[FusionGroup],
425    consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
426    graph: &AccelGraph,
427) -> bool {
428    let target_nodes: HashSet<NodeId> = groups[target_idx].nodes.iter().copied().collect();
429    let source_nodes: HashSet<NodeId> = source_group.nodes.iter().copied().collect();
430    for &node_id in &source_group.nodes {
431        let Some(node) = graph.node(node_id) else {
432            continue;
433        };
434        for &out in &node.outputs {
435            if let Some(consumers) = consumer_map.get(&out) {
436                for consumer in consumers {
437                    if !source_nodes.contains(consumer) && !target_nodes.contains(consumer) {
438                        return false;
439                    }
440                }
441            }
442        }
443    }
444    true
445}
446
447fn node_output_shape(graph: &AccelGraph, node: &AccelNode) -> ShapeInfo {
448    let mut shape = ShapeInfo::Scalar;
449    for &output in &node.outputs {
450        if let Some(info) = graph.value(output) {
451            shape = shape.unify(&info.shape);
452        }
453    }
454    shape
455}
456
457fn find_next_elementwise(
458    graph: &AccelGraph,
459    node_id: NodeId,
460    assigned: &HashSet<NodeId>,
461    local_seen: &HashSet<NodeId>,
462    consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
463    current_shape: &ShapeInfo,
464) -> Option<(NodeId, ShapeInfo)> {
465    let node = graph.node(node_id)?;
466    let mut candidate: Option<(NodeId, ShapeInfo)> = None;
467
468    for &output in &node.outputs {
469        let consumers = consumer_map.get(&output)?;
470        if consumers.len() != 1 {
471            return None;
472        }
473        let next_id = *consumers.iter().next()?;
474        if next_id <= node_id || assigned.contains(&next_id) || local_seen.contains(&next_id) {
475            return None;
476        }
477        let next_node = graph.node(next_id)?;
478        if !(next_node.is_elementwise() || is_elementwise_max_min(graph, next_node)) {
479            return None;
480        }
481        // Ensure the edge we follow is actually used by next node
482        if !next_node.inputs.contains(&output) {
483            continue;
484        }
485        let next_shape = node_output_shape(graph, next_node);
486        if matches!(next_shape, ShapeInfo::Unknown) {
487            return None;
488        }
489        let unified = current_shape.unify(&next_shape);
490        if matches!(unified, ShapeInfo::Unknown) {
491            return None;
492        }
493        candidate = Some((next_id, unified));
494        break;
495    }
496
497    candidate
498}
499
500fn is_elementwise_max_min(graph: &AccelGraph, node: &AccelNode) -> bool {
501    match &node.label {
502        AccelNodeLabel::Builtin { name }
503            if name.eq_ignore_ascii_case("max") || name.eq_ignore_ascii_case("min") =>
504        {
505            if node.inputs.len() < 2 {
506                return false;
507            }
508            !value_is_placeholder(graph, node.inputs[1])
509        }
510        _ => false,
511    }
512}
513
514fn value_is_placeholder(graph: &AccelGraph, vid: ValueId) -> bool {
515    let Some(info) = graph.value(vid) else {
516        return false;
517    };
518    let Some(constant) = &info.constant else {
519        return false;
520    };
521    match constant {
522        Value::Tensor(t) => t.data.is_empty(),
523        Value::LogicalArray(l) => l.data.is_empty(),
524        Value::StringArray(sa) => sa.data.is_empty(),
525        Value::CharArray(ca) => ca.data.is_empty(),
526        Value::Cell(cell) => cell.data.is_empty(),
527        Value::String(s) => s.is_empty(),
528        _ => false,
529    }
530}
531
532fn group_span(graph: &AccelGraph, nodes: &[NodeId]) -> InstrSpan {
533    let mut start = usize::MAX;
534    let mut end = 0usize;
535    for &id in nodes {
536        if let Some(node) = graph.node(id) {
537            start = start.min(node.span.start);
538            end = end.max(node.span.end);
539        }
540    }
541    if start == usize::MAX {
542        start = 0;
543    }
544    InstrSpan { start, end }
545}
546
547#[derive(Debug, Clone)]
548pub struct FusionPlan {
549    pub groups: Vec<FusionGroupPlan>,
550}
551
552#[derive(Debug, Clone)]
553pub struct FusionGroupPlan {
554    pub index: usize,
555    pub group: FusionGroup,
556    pub operations: Vec<FusionOp>,
557    pub inputs: Vec<ValueId>,
558    pub stack_pattern: Vec<usize>,
559    pub constants: HashMap<usize, Value>,
560    pub const_values: HashMap<ValueId, Value>,
561    pub output: Option<ValueId>,
562    pub kernel: FusionKernelSpec,
563    // For reductions: track the ValueId of the data tensor being reduced, if identifiable
564    pub reduction_data: Option<ValueId>,
565    // For reductions: track the ValueId of the dim argument when identifiable
566    pub reduction_dim: Option<ValueId>,
567    // For reductions: flavor metadata (e.g., sum vs mean scaling)
568    pub reduction_flavor: Option<ReductionFlavor>,
569    // For reductions: axis selection metadata (e.g., explicit dims vs 'all')
570    pub reduction_axes: Option<ReductionAxes>,
571    pub pattern: Option<FusionPattern>,
572}
573
574#[derive(Debug, Clone)]
575pub enum FusionOp {
576    Primitive {
577        op: PrimitiveOp,
578        inputs: Vec<ValueId>,
579        output: Option<ValueId>,
580    },
581    Builtin {
582        name: String,
583        inputs: Vec<ValueId>,
584        output: Option<ValueId>,
585    },
586}
587
588#[derive(Debug, Clone)]
589pub struct FusionKernelSpec {
590    pub kind: FusionKind,
591    pub supported: bool,
592}
593
594impl FusionKernelSpec {
595    fn new(kind: FusionKind, supported: bool) -> Self {
596        Self { kind, supported }
597    }
598}
599
600#[derive(Clone, Debug)]
601pub struct ActiveFusion {
602    pub kind: FusionKind,
603    pub span: InstrSpan,
604    pub element_count: Option<usize>,
605    pub supported: bool,
606}
607
608struct ActiveContext {
609    plan: Arc<FusionPlan>,
610    active_group: Option<usize>,
611}
612
613static PLAN_CACHE: Lazy<RwLock<HashMap<usize, Weak<FusionPlan>>>> =
614    Lazy::new(|| RwLock::new(HashMap::new()));
615
616thread_local! {
617    static ACTIVE_PLAN: RefCell<Option<ActiveContext>> = const { RefCell::new(None) };
618}
619
620fn fusion_debug_enabled() -> bool {
621    static FLAG: OnceLock<bool> = OnceLock::new();
622    *FLAG.get_or_init(|| match std::env::var("RUNMAT_DEBUG_FUSION") {
623        Ok(v) => v == "1" || v.eq_ignore_ascii_case("true") || v.eq_ignore_ascii_case("yes"),
624        Err(_) => false,
625    })
626}
627
628pub fn prepare_fusion_plan(
629    graph: Option<&AccelGraph>,
630    groups: &[FusionGroup],
631) -> Option<Arc<FusionPlan>> {
632    let graph = graph?;
633    if groups.is_empty() {
634        return None;
635    }
636    let key = graph as *const AccelGraph as usize;
637    if let Some(plan) = PLAN_CACHE
638        .read()
639        .ok()
640        .and_then(|guard| guard.get(&key).and_then(|weak| weak.upgrade()))
641    {
642        return Some(plan);
643    }
644
645    let plan = FusionPlan::from_graph(graph, groups);
646    let plan = Arc::new(plan);
647    if let Ok(mut guard) = PLAN_CACHE.write() {
648        guard.insert(key, Arc::downgrade(&plan));
649    }
650    Some(plan)
651}
652
653pub fn activate_fusion_plan(plan: Option<Arc<FusionPlan>>) {
654    ACTIVE_PLAN.with(|ctx| {
655        let mut slot = ctx.borrow_mut();
656        *slot = plan.map(|plan| ActiveContext {
657            plan,
658            active_group: None,
659        });
660    });
661}
662
663pub fn deactivate_fusion_plan() {
664    ACTIVE_PLAN.with(|ctx| {
665        ctx.borrow_mut().take();
666    });
667}
668
669pub fn set_current_pc(pc: usize) {
670    ACTIVE_PLAN.with(|ctx| {
671        if let Some(context) = ctx.borrow_mut().as_mut() {
672            context.active_group = context.plan.group_for_pc(pc);
673        }
674    });
675}
676
677pub fn active_fusion() -> Option<ActiveFusion> {
678    ACTIVE_PLAN.with(|ctx| {
679        ctx.borrow()
680            .as_ref()
681            .and_then(|context| {
682                context
683                    .active_group
684                    .and_then(|idx| context.plan.groups.get(idx))
685            })
686            .map(|plan| ActiveFusion {
687                kind: plan.group.kind.clone(),
688                span: plan.group.span.clone(),
689                element_count: plan.element_count(),
690                supported: plan.kernel.supported,
691            })
692    })
693}
694
695pub fn active_group_plan_clone() -> Option<FusionGroupPlan> {
696    ACTIVE_PLAN.with(|ctx| {
697        ctx.borrow().as_ref().and_then(|context| {
698            context
699                .active_group
700                .and_then(|idx| context.plan.groups.get(idx).cloned())
701        })
702    })
703}
704
705impl FusionPlan {
706    pub fn from_graph(graph: &AccelGraph, groups: &[FusionGroup]) -> Self {
707        let plans = groups
708            .iter()
709            .enumerate()
710            .map(|(idx, group)| FusionGroupPlan::new(idx, group.clone(), graph))
711            .collect();
712        Self { groups: plans }
713    }
714
715    fn group_for_pc(&self, pc: usize) -> Option<usize> {
716        self.groups
717            .iter()
718            .find(|plan| pc >= plan.group.span.start && pc <= plan.group.span.end)
719            .map(|plan| plan.index)
720    }
721}
722
723impl From<Vec<FusionGroupPlan>> for FusionPlan {
724    fn from(groups: Vec<FusionGroupPlan>) -> Self {
725        Self { groups }
726    }
727}
728
729fn log_plan_stack_pattern(stage: &str, plan: &FusionGroupPlan, graph: &AccelGraph) {
730    if !fusion_debug_enabled() || plan.stack_pattern.is_empty() {
731        return;
732    }
733    let mut pattern_meta: Vec<String> = Vec::with_capacity(plan.stack_pattern.len());
734    for (pos, input_idx) in plan.stack_pattern.iter().enumerate() {
735        let value_id = plan.inputs.get(*input_idx).copied();
736        if let Some(vid) = value_id {
737            if let Some(info) = graph.value(vid) {
738                let node_label = match info.origin {
739                    ValueOrigin::NodeOutput { node, .. } => graph
740                        .node(node)
741                        .map(|n| format!("{:?}", n.label))
742                        .unwrap_or_else(|| "<missing-node>".to_string()),
743                    _ => String::new(),
744                };
745                pattern_meta.push(format!(
746                    "#{}:input_idx={} vid={} origin={:?} label={}",
747                    pos, input_idx, vid, info.origin, node_label
748                ));
749            } else {
750                pattern_meta.push(format!(
751                    "#{}:input_idx={} vid={} origin=<missing>",
752                    pos, input_idx, vid
753                ));
754            }
755        } else {
756            pattern_meta.push(format!("#{}:input_idx={} vid=<missing>", pos, input_idx));
757        }
758    }
759    log::debug!(
760        "fusion plan {} {} stack_pattern={:?} meta={:?}",
761        plan.index,
762        stage,
763        plan.stack_pattern,
764        pattern_meta
765    );
766}
767
768impl FusionGroupPlan {
769    fn new(index: usize, group: FusionGroup, graph: &AccelGraph) -> Self {
770        let node_set: HashSet<NodeId> = group.nodes.iter().copied().collect();
771        let mut seen_inputs: HashMap<ValueId, usize> = HashMap::new();
772        let mut inputs: Vec<ValueId> = Vec::new();
773        let mut stack_pattern: Vec<usize> = Vec::new();
774        let mut constants: HashMap<usize, Value> = HashMap::new();
775        let const_values: HashMap<ValueId, Value> = HashMap::new();
776        let mut operations = Vec::new();
777        let mut reduction_flavor: Option<ReductionFlavor> = None;
778        let mut reduction_axes: Option<ReductionAxes> = None;
779        let mut reduction_data: Option<ValueId> = None;
780        let mut reduction_dim: Option<ValueId> = None;
781        let mut output: Option<ValueId> = None;
782
783        let is_reduction_group = group.kind.is_reduction();
784        for node_id in &group.nodes {
785            let Some(node) = graph.node(*node_id) else {
786                continue;
787            };
788            for input in &node.inputs {
789                let binding = graph.var_binding(*input);
790                let (external, is_variable, maybe_constant) = match graph.value(*input) {
791                    Some(info) => match &info.origin {
792                        ValueOrigin::NodeOutput { node: origin, .. }
793                            if node_set.contains(origin) =>
794                        {
795                            (false, false, None)
796                        }
797                        ValueOrigin::Variable { .. } => (true, true, None),
798                        ValueOrigin::NodeOutput { .. } if binding.is_some() => (true, true, None),
799                        ValueOrigin::Constant => (true, false, info.constant.clone()),
800                        _ => (true, false, None),
801                    },
802                    None => (true, false, None),
803                };
804                if external {
805                    // Special handling for reductions: do NOT include constants in inputs;
806                    // only the data tensor should be an input. Constants are recorded separately.
807                    if is_reduction_group {
808                        if let Some(constant) = maybe_constant.clone() {
809                            // Assign a synthetic key for constants; keys are not positional for reductions
810                            let key = constants.len() + 1000;
811                            constants.insert(key, constant);
812                            continue;
813                        }
814                        // Only include the reduction data operand as an input
815                        if let Some(data_id) = reduction_data {
816                            if *input != data_id {
817                                // Skip non-data external inputs for reduction groups
818                                continue;
819                            }
820                        }
821                    }
822
823                    let mut newly_added = false;
824                    let input_idx = if let Some(idx) = seen_inputs.get(input) {
825                        *idx
826                    } else {
827                        let idx = inputs.len();
828                        inputs.push(*input);
829                        seen_inputs.insert(*input, idx);
830                        newly_added = true;
831                        idx
832                    };
833
834                    if fusion_debug_enabled() {
835                        let origin = graph.value(*input).map(|v| v.origin.clone());
836                        log::debug!(
837                            "fusion plan #{:?} consider input vid={} origin={:?} binding={:?} newly_added={} is_variable={} stack_candidate={}",
838                            index,
839                            input,
840                            origin,
841                            binding,
842                            newly_added,
843                            is_variable,
844                            !is_variable && newly_added
845                        );
846                    }
847                    if let Some(constant) = maybe_constant.clone() {
848                        constants.insert(input_idx, constant);
849                    } else if !is_variable && newly_added {
850                        let allow_stack = match graph.value(*input) {
851                            Some(info) => match info.origin {
852                                ValueOrigin::NodeOutput { node, .. } => graph
853                                    .node(node)
854                                    .map(|n| n.span.start <= group.span.start)
855                                    .unwrap_or(false),
856                                _ => true,
857                            },
858                            None => true,
859                        };
860                        if allow_stack {
861                            stack_pattern.push(input_idx);
862                        } else if fusion_debug_enabled() {
863                            log::debug!(
864                                "fusion plan {} skipping stack candidate vid={} origin_after_span",
865                                index,
866                                input
867                            );
868                        }
869                    } else if !is_variable
870                        && !newly_added
871                        && matches!(
872                            graph.value(*input).map(|v| &v.origin),
873                            Some(ValueOrigin::Constant)
874                        )
875                    {
876                    }
877                }
878            }
879
880            let op = match &node.label {
881                AccelNodeLabel::Primitive(p) => FusionOp::Primitive {
882                    op: *p,
883                    inputs: node.inputs.clone(),
884                    output: node.outputs.first().copied(),
885                },
886                AccelNodeLabel::Builtin { name } => FusionOp::Builtin {
887                    name: name.clone(),
888                    inputs: node.inputs.clone(),
889                    output: node.outputs.first().copied(),
890                },
891                AccelNodeLabel::Unknown => FusionOp::Primitive {
892                    op: PrimitiveOp::UPlus,
893                    inputs: node.inputs.clone(),
894                    output: node.outputs.first().copied(),
895                },
896            };
897            operations.push(op);
898
899            if let Some(out) = node.outputs.first().copied() {
900                output = Some(out);
901            }
902            // Generic reduction signature (no name checks)
903            if node.is_reduction() {
904                if let Some(sig) = detect_reduction_signature(graph, node) {
905                    reduction_data = Some(sig.data_input);
906                    reduction_dim = sig.dim_arg;
907                    reduction_flavor = Some(match sig.behavior {
908                        ReductionBehavior::MeanLike => ReductionFlavor::Mean,
909                        _ => ReductionFlavor::Sum,
910                    });
911                    reduction_axes = Some(sig.axes.clone());
912                }
913            }
914        }
915
916        let kind = group.kind.clone();
917        let pattern = group.pattern.clone();
918        let mut plan = Self {
919            index,
920            group,
921            operations,
922            stack_pattern,
923            constants,
924            const_values,
925            inputs,
926            output,
927            kernel: FusionKernelSpec::new(kind, true),
928            reduction_data,
929            reduction_dim,
930            reduction_flavor,
931            reduction_axes,
932            pattern,
933        };
934
935        log_plan_stack_pattern("initial", &plan, graph);
936
937        // Record constant ValueIds for all groups for easier downstream analysis
938        for node_id in &plan.group.nodes {
939            if let Some(node) = graph.node(*node_id) {
940                for &inp in &node.inputs {
941                    if let Some(info) = graph.value(inp) {
942                        if let Some(cv) = info.constant.clone() {
943                            plan.const_values.insert(inp, cv);
944                        }
945                    }
946                }
947            }
948        }
949
950        // For reduction groups, externalize only real tensor dependencies; keep constants separate
951        if plan.group.kind.is_reduction() {
952            if let Some(data_vid) = plan.reduction_data {
953                let original_inputs = plan.inputs.clone();
954                let original_stack_pattern = plan.stack_pattern.clone();
955                // Record constant ValueIds for codegen
956                // Build dependency map from op outputs to inputs
957                let mut prod: HashMap<ValueId, Vec<ValueId>> = HashMap::new();
958                for op in &plan.operations {
959                    match op {
960                        FusionOp::Primitive {
961                            inputs,
962                            output,
963                            op: _,
964                        } => {
965                            if let Some(out) = output {
966                                prod.insert(*out, inputs.clone());
967                            }
968                        }
969                        FusionOp::Builtin {
970                            name: _,
971                            inputs,
972                            output,
973                        } => {
974                            if let Some(out) = output {
975                                prod.insert(*out, inputs.clone());
976                            }
977                        }
978                    }
979                }
980                let mut deps: Vec<ValueId> = Vec::new();
981                let mut visited: HashSet<ValueId> = HashSet::new();
982                let mut stack: Vec<ValueId> = vec![data_vid];
983                // Track extra ops we discover outside the original group that are safe to inline
984                let mut extra_ops: Vec<FusionOp> = Vec::new();
985                let mut added_nodes: HashSet<ValueId> = HashSet::new();
986                while let Some(cur) = stack.pop() {
987                    if !visited.insert(cur) {
988                        continue;
989                    }
990                    if graph.var_binding(cur).is_some() {
991                        if !deps.contains(&cur) {
992                            deps.push(cur);
993                        }
994                        continue;
995                    }
996                    if let Some(info) = graph.value(cur) {
997                        if matches!(info.origin, ValueOrigin::Variable { .. }) {
998                            if !deps.contains(&cur) {
999                                deps.push(cur);
1000                            }
1001                            continue;
1002                        }
1003                    }
1004                    // Do not short-circuit on the reduction_data itself; expand through its producers first.
1005                    if original_inputs.contains(&cur) && cur != data_vid {
1006                        if !deps.contains(&cur) {
1007                            deps.push(cur);
1008                        }
1009                        continue;
1010                    }
1011                    if let Some(parents) = prod.get(&cur) {
1012                        for p in parents {
1013                            stack.push(*p);
1014                        }
1015                        continue;
1016                    }
1017                    // If not produced by an op in this group, try to expand through safe producer nodes
1018                    if let Some((_, node)) = node_from_value(graph, cur) {
1019                        // Only consider simple arithmetic producers we know how to fold
1020                        match &node.label {
1021                            AccelNodeLabel::Primitive(PrimitiveOp::Mul)
1022                            | AccelNodeLabel::Primitive(PrimitiveOp::ElemMul)
1023                            | AccelNodeLabel::Primitive(PrimitiveOp::Div)
1024                            | AccelNodeLabel::Primitive(PrimitiveOp::ElemDiv)
1025                            | AccelNodeLabel::Primitive(PrimitiveOp::ElemLeftDiv)
1026                            | AccelNodeLabel::Primitive(PrimitiveOp::Add)
1027                            | AccelNodeLabel::Primitive(PrimitiveOp::Sub) => {
1028                                // Record op for codegen and traverse inputs
1029                                if added_nodes.insert(cur) {
1030                                    extra_ops.push(FusionOp::Primitive {
1031                                        op: match node.label {
1032                                            AccelNodeLabel::Primitive(op) => op,
1033                                            _ => PrimitiveOp::UPlus,
1034                                        },
1035                                        inputs: node.inputs.clone(),
1036                                        output: node.outputs.first().copied(),
1037                                    });
1038                                }
1039                                for &p in &node.inputs {
1040                                    stack.push(p);
1041                                }
1042                                continue;
1043                            }
1044                            AccelNodeLabel::Primitive(PrimitiveOp::ElemPow) => {
1045                                // Only accept power with constant exponent (typically 2 for squares)
1046                                if node.inputs.len() == 2 {
1047                                    if let Some(exp) = value_constant_f64(graph, node.inputs[1]) {
1048                                        if exp.is_finite() {
1049                                            if added_nodes.insert(cur) {
1050                                                extra_ops.push(FusionOp::Primitive {
1051                                                    op: PrimitiveOp::ElemPow,
1052                                                    inputs: node.inputs.clone(),
1053                                                    output: node.outputs.first().copied(),
1054                                                });
1055                                            }
1056                                            stack.push(node.inputs[0]);
1057                                            // Treat exponent as constant dependency for codegen
1058                                            stack.push(node.inputs[1]);
1059                                            continue;
1060                                        }
1061                                    }
1062                                }
1063                                // Fallback: treat as leaf dependency
1064                            }
1065                            AccelNodeLabel::Builtin { name } => {
1066                                // Allow simple casts to flow through (single/double)
1067                                if (name.eq_ignore_ascii_case("single")
1068                                    || name.eq_ignore_ascii_case("double"))
1069                                    && node.inputs.len() == 1
1070                                {
1071                                    stack.push(node.inputs[0]);
1072                                    continue;
1073                                }
1074                                // Unknown builtin: treat as leaf
1075                            }
1076                            _ => {
1077                                // Unknown producer: treat as leaf
1078                            }
1079                        }
1080                    }
1081                }
1082                // Ensure direct parents of the reduction data are materialized as inputs
1083                if let Some(parents) = prod.get(&data_vid) {
1084                    for &p in parents {
1085                        if !deps.contains(&p) {
1086                            // Skip trivial constants embedded in const_values; those are handled separately
1087                            let is_const = plan.const_values.contains_key(&p)
1088                                || graph.value(p).and_then(|vi| vi.constant.as_ref()).is_some();
1089                            if !is_const {
1090                                deps.push(p);
1091                            }
1092                        }
1093                    }
1094                }
1095                // Prepend the newly discovered ops so they are available to codegen
1096                // Keep original operations as well (the reduction op itself)
1097                if !extra_ops.is_empty() {
1098                    // Ensure a stable order: extra ops first
1099                    let mut new_ops = Vec::with_capacity(extra_ops.len() + plan.operations.len());
1100                    new_ops.extend(extra_ops);
1101                    new_ops.append(&mut plan.operations);
1102                    plan.operations = new_ops;
1103                }
1104                plan.inputs = deps;
1105                // Ensure constants referenced by any newly added operations are recorded.
1106                for op in &plan.operations {
1107                    let inputs = match op {
1108                        FusionOp::Primitive { inputs, .. } => inputs,
1109                        FusionOp::Builtin { inputs, .. } => inputs,
1110                    };
1111                    for vid in inputs {
1112                        if plan.const_values.contains_key(vid) {
1113                            continue;
1114                        }
1115                        if let Some(info) = graph.value(*vid) {
1116                            if let Some(cv) = info.constant.clone() {
1117                                plan.const_values.insert(*vid, cv);
1118                            }
1119                        }
1120                    }
1121                }
1122
1123                // Rebuild stack pattern based on the dependencies that were previously sourced
1124                // from the execution stack.
1125                let mut new_stack_pattern: Vec<usize> = Vec::new();
1126                for (new_idx, vid) in plan.inputs.iter().enumerate() {
1127                    if let Some(old_idx) = original_inputs.iter().position(|v| v == vid) {
1128                        if original_stack_pattern.contains(&old_idx) {
1129                            new_stack_pattern.push(new_idx);
1130                        }
1131                    }
1132                }
1133
1134                // Rebuild constants map using the new input ordering.
1135                let mut new_constants: HashMap<usize, Value> = HashMap::new();
1136                for (idx, vid) in plan.inputs.iter().enumerate() {
1137                    if let Some(value) = plan.const_values.get(vid) {
1138                        new_constants.insert(idx, value.clone());
1139                    } else if let Some(info) = graph.value(*vid) {
1140                        if let Some(cv) = info.constant.clone() {
1141                            new_constants.insert(idx, cv);
1142                        }
1143                    }
1144                }
1145                plan.constants = new_constants;
1146
1147                if new_stack_pattern.is_empty() {
1148                    for (idx, vid) in plan.inputs.iter().enumerate() {
1149                        if plan.constants.contains_key(&idx) {
1150                            continue;
1151                        }
1152                        if let Some(info) = graph.value(*vid) {
1153                            if matches!(
1154                                info.origin,
1155                                ValueOrigin::Variable { .. } | ValueOrigin::Constant
1156                            ) {
1157                                continue;
1158                            }
1159                        }
1160                        new_stack_pattern.push(idx);
1161                    }
1162                }
1163                plan.stack_pattern = new_stack_pattern;
1164            }
1165        }
1166
1167        // Final sanitize: for reduction groups, ensure inputs contain no constants
1168        if plan.group.kind.is_reduction() {
1169            let original_inputs = plan.inputs.clone();
1170            plan.inputs.retain(|vid| {
1171                if let Some(info) = graph.value(*vid) {
1172                    !matches!(info.origin, ValueOrigin::Constant)
1173                        && !plan.const_values.contains_key(vid)
1174                } else {
1175                    true
1176                }
1177            });
1178            if plan.inputs.len() != original_inputs.len() {
1179                let mut new_stack: Vec<usize> = Vec::new();
1180                for old_idx in &plan.stack_pattern {
1181                    if *old_idx < original_inputs.len() {
1182                        let vid = original_inputs[*old_idx];
1183                        if let Some(new_idx) = plan.inputs.iter().position(|v| *v == vid) {
1184                            new_stack.push(new_idx);
1185                        }
1186                    }
1187                }
1188                plan.stack_pattern = new_stack;
1189            }
1190        }
1191
1192        // Determine kernel support:
1193        // - Elementwise: require WGSL generation at plan time.
1194        // - Reduction: require WGSL generation at plan time as well.
1195        // - Other kinds: executed via provider paths.
1196        let supported = if plan.kernel.kind.is_elementwise() {
1197            plan.generate_wgsl("f32").is_some()
1198        } else if plan.kernel.kind.is_reduction() {
1199            plan.generate_reduction_wgsl("f32").is_some()
1200        } else {
1201            true
1202        };
1203        plan.kernel.supported = plan.kernel.supported && supported;
1204        if !plan.kernel.supported && fusion_debug_enabled() {
1205            let const_ids: Vec<ValueId> = plan.const_values.keys().copied().collect();
1206            log::debug!(
1207                "fusion plan {} unsupported: kind={:?} group_kind={:?} inputs={:?} reduction_data={:?} reduction_dim={:?} const_ids={:?}",
1208                plan.index,
1209                plan.kernel.kind,
1210                plan.group.kind,
1211                plan.inputs,
1212                plan.reduction_data,
1213                plan.reduction_dim,
1214                const_ids
1215            );
1216            if plan.kernel.kind.is_reduction() {
1217                let mut seen: HashSet<ValueId> = HashSet::new();
1218                let mut value_info: Vec<String> = Vec::new();
1219                for op in &plan.operations {
1220                    let inputs = match op {
1221                        FusionOp::Primitive { inputs, .. } => inputs,
1222                        FusionOp::Builtin { inputs, .. } => inputs,
1223                    };
1224                    for vid in inputs {
1225                        if seen.insert(*vid) {
1226                            if let Some(info) = graph.value(*vid) {
1227                                value_info.push(format!(
1228                                    "vid={} origin={:?} constant={}",
1229                                    vid,
1230                                    info.origin,
1231                                    info.constant.is_some()
1232                                ));
1233                            } else {
1234                                value_info.push(format!("vid={} origin=<missing>", vid));
1235                            }
1236                        }
1237                    }
1238                }
1239                log::debug!(
1240                    "fusion reduction plan {} value summary: [{}]",
1241                    plan.index,
1242                    value_info.join(", ")
1243                );
1244            }
1245        }
1246
1247        if matches!(plan.group.kind, FusionKind::CenteredGram) && plan.stack_pattern.is_empty() {
1248            let mut centered_stack_idxs: Vec<usize> = Vec::new();
1249            for (idx, vid) in plan.inputs.iter().enumerate() {
1250                if plan.constants.contains_key(&idx) {
1251                    continue;
1252                }
1253                if let Some(info) = graph.value(*vid) {
1254                    if matches!(info.origin, ValueOrigin::NodeOutput { .. }) {
1255                        centered_stack_idxs.push(idx);
1256                        continue;
1257                    }
1258                    if matches!(info.origin, ValueOrigin::Variable { .. }) {
1259                        continue;
1260                    }
1261                }
1262                centered_stack_idxs.push(idx);
1263            }
1264            if centered_stack_idxs.is_empty() && !plan.inputs.is_empty() {
1265                centered_stack_idxs.push(0);
1266            }
1267            plan.stack_pattern = centered_stack_idxs;
1268        }
1269
1270        log_plan_stack_pattern("final", &plan, graph);
1271
1272        // If the plan requires any unsupported operations, mark kernel as unsupported
1273
1274        plan
1275    }
1276
1277    pub fn reduction_data_shape(&self, graph: &AccelGraph) -> Option<Vec<usize>> {
1278        let vid = self.reduction_data?;
1279        let info = graph.value(vid)?;
1280        match &info.shape {
1281            ShapeInfo::Tensor(dims) if !dims.is_empty() && dims.iter().all(|d| d.is_some()) => {
1282                Some(dims.iter().map(|d| d.unwrap()).collect())
1283            }
1284            _ => None,
1285        }
1286    }
1287
1288    pub fn element_count(&self) -> Option<usize> {
1289        self.group.element_count()
1290    }
1291
1292    pub fn constant_shape(&self, len: usize) -> Vec<usize> {
1293        match &self.group.shape {
1294            ShapeInfo::Tensor(dims) if !dims.is_empty() && dims.iter().all(|dim| dim.is_some()) => {
1295                dims.iter().map(|dim| dim.unwrap()).collect()
1296            }
1297            _ => vec![len],
1298        }
1299    }
1300
1301    pub fn generate_wgsl(&self, scalar_ty: &str) -> Option<String> {
1302        if !self.kernel.kind.is_elementwise() {
1303            return None;
1304        }
1305        if !self.kernel.supported {
1306            return None;
1307        }
1308        let output_id = self.output?;
1309        let mut exprs: HashMap<ValueId, String> = HashMap::new();
1310        for (idx, input_id) in self.inputs.iter().enumerate() {
1311            // Placeholder; will be replaced by broadcasted index variable i{idx}
1312            exprs.insert(*input_id, format!("input{idx}.data[i{idx}]"));
1313        }
1314
1315        let mut body = String::new();
1316        for (node_idx, op) in self.operations.iter().enumerate() {
1317            let tmp_name = format!("tmp{node_idx}");
1318            match op {
1319                FusionOp::Primitive { op, inputs, output } => {
1320                    let expr = primitive_expr(*op, inputs, &exprs)?;
1321                    body.push_str(&format!("    let {tmp_name}: {scalar_ty} = {expr};\n"));
1322                    if let Some(out) = output {
1323                        exprs.insert(*out, tmp_name.clone());
1324                    }
1325                }
1326                FusionOp::Builtin {
1327                    name,
1328                    inputs,
1329                    output,
1330                } => {
1331                    let expr = builtin_expr(name, inputs, &exprs, scalar_ty)?;
1332                    body.push_str(&format!("    let {tmp_name}: {scalar_ty} = {expr};\n"));
1333                    if let Some(out) = output {
1334                        exprs.insert(*out, tmp_name.clone());
1335                    }
1336                }
1337            }
1338        }
1339
1340        let final_expr = exprs.get(&output_id)?.clone();
1341
1342        let mut shader = String::new();
1343        shader.push_str("const MAX_RANK: u32 = 128u;\n");
1344        shader.push_str("struct PackedValue { value: u32, _pad0: u32, _pad1: u32, _pad2: u32 };\n");
1345        shader.push_str("alias PackedArray = array<PackedValue, MAX_RANK>;\n\n");
1346        shader.push_str(&format!("struct Tensor {{ data: array<{scalar_ty}>, }};\n"));
1347        // Broadcast-aware Params: len, offset, rank, pad, out_shape and per-input shape/stride
1348        shader.push_str("struct Params {\n    len: u32,\n    offset: u32,\n    rank: u32,\n    _pad: u32,\n    out_shape: PackedArray,\n");
1349        for idx in 0..self.inputs.len() {
1350            shader.push_str(&format!("    in{}_shape: PackedArray,\n", idx));
1351            shader.push_str(&format!("    in{}_stride: PackedArray,\n", idx));
1352        }
1353        shader.push_str("}\n\n");
1354        // Provide portable stubs; avoid relying on backend builtins that may be missing
1355        if scalar_ty == "f32" {
1356            shader.push_str("fn isNan(x: f32) -> bool { return x != x; }\n");
1357            shader.push_str("fn isFinite(x: f32) -> bool { return (x == x) && (abs(x) < 3.4028234663852886e38); }\n");
1358            shader.push_str("fn isInf(x: f32) -> bool { return (x == x) && !(abs(x) < 3.4028234663852886e38); }\n\n");
1359        } else {
1360            shader.push_str("fn isNan(x: f64) -> bool { return x != x; }\n");
1361            shader.push_str("fn isFinite(x: f64) -> bool { return (x == x) && (abs(x) < f64(1.7976931348623157e308)); }\n");
1362            shader.push_str("fn isInf(x: f64) -> bool { return (x == x) && !(abs(x) < f64(1.7976931348623157e308)); }\n\n");
1363        }
1364        for (idx, _) in self.inputs.iter().enumerate() {
1365            shader.push_str(&format!(
1366                "@group(0) @binding({}) var<storage, read> input{}: Tensor;\n",
1367                idx, idx
1368            ));
1369        }
1370        shader.push_str(&format!(
1371            "@group(0) @binding({}) var<storage, read_write> output: Tensor;\n",
1372            self.inputs.len()
1373        ));
1374        shader.push_str(&format!(
1375            "@group(0) @binding({}) var<uniform> params: Params;\n\n",
1376            self.inputs.len() + 1
1377        ));
1378        shader.push_str("@compute @workgroup_size(@WG@)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n");
1379        shader.push_str("    let idx = gid.x;\n    if (idx >= params.len) { return; }\n");
1380        shader.push_str("    let g = idx + params.offset;\n");
1381        shader.push_str("    // Compute N-D coordinates from global index (with chunk offset)\n    var coord: array<u32, MAX_RANK>;\n    var tmp: u32 = g;\n    var d: u32 = 0u;\n    loop { if d >= params.rank { break; } let dim = params.out_shape[d].value; if dim == 0u { coord[d] = 0u; } else { coord[d] = tmp % dim; tmp = tmp / dim; } d = d + 1u; }\n");
1382        // Compute broadcasted indices per input
1383        for (idx, _) in self.inputs.iter().enumerate() {
1384            shader.push_str(&format!(
1385                "    var i{}: u32 = 0u; d = 0u; loop {{ if d >= params.rank {{ break; }} let sd = params.in{}_shape[d].value; let st = params.in{}_stride[d].value; let c = select(coord[d], 0u, sd == 1u); i{} = i{} + c * st; d = d + 1u; }}\n",
1386                idx, idx, idx, idx, idx
1387            ));
1388        }
1389        shader.push_str(&body);
1390        shader.push_str(&format!("    output.data[g] = {final_expr};\n}}\n"));
1391        Some(shader)
1392    }
1393
1394    pub fn generate_reduction_wgsl(&self, scalar_ty: &str) -> Option<String> {
1395        if !self.kernel.kind.is_reduction() {
1396            return None;
1397        }
1398        // Minimal column-major reduction kernel template (single workgroup per slice).
1399        // Supports folding simple producer expressions over multiple inputs (e.g., sum(A.*B, dim)).
1400        if self.inputs.is_empty() {
1401            return None;
1402        }
1403        // Determine axis from the reduction builtin's explicit dim argument when available.
1404        // MATLAB dim is 1-based: dim=1 reduces rows (axis=0), dim=2 reduces cols (axis=1).
1405        let mut axis = 0usize;
1406        // Support 'all' via either index-keyed constants or value-id keyed const_values
1407        let reduce_all = self
1408            .constants
1409            .values()
1410            .any(|v| matches!(v, Value::String(s) if s.eq_ignore_ascii_case("all")))
1411            || self
1412                .const_values
1413                .values()
1414                .any(|v| matches!(v, Value::String(s) if s.eq_ignore_ascii_case("all")));
1415        if reduce_all {
1416            // We'll flatten in VM by setting nrows = total and ncols = 1; axis=0 works with that.
1417            axis = 0;
1418        } else if let Some(dim_vid) = self.reduction_dim {
1419            if let Some(v) = self.const_values.get(&dim_vid) {
1420                match v {
1421                    Value::Num(n) if *n >= 1.0 => {
1422                        axis = (*n as usize).saturating_sub(1);
1423                    }
1424                    Value::Int(i) => {
1425                        let val = i.to_f64();
1426                        if val >= 1.0 {
1427                            axis = (val as usize).saturating_sub(1);
1428                        }
1429                    }
1430                    _ => {}
1431                }
1432            }
1433        } else {
1434            // Fallback: scan constant table for a plausible dim
1435            for v in self.constants.values() {
1436                match v {
1437                    Value::Num(n) if *n >= 1.0 => {
1438                        axis = (*n as usize).saturating_sub(1);
1439                        break;
1440                    }
1441                    Value::Int(i) => {
1442                        let val = i.to_f64();
1443                        if val >= 1.0 {
1444                            axis = (val as usize).saturating_sub(1);
1445                            break;
1446                        }
1447                    }
1448                    _ => {}
1449                }
1450            }
1451        }
1452
1453        // Detect omitnan constant (compile-time selection)
1454        let omitnan = self.constants.values().any(|v| match v {
1455            Value::String(s) => s.eq_ignore_ascii_case("omitnan"),
1456            _ => false,
1457        });
1458
1459        // Build reduction operand expression by folding the producer chain
1460        let data_vid = self.reduction_data?;
1461        let ext_input = self.inputs[0];
1462        let mut exprs: HashMap<ValueId, String> = HashMap::new();
1463        exprs.insert(ext_input, "v".to_string());
1464        // Map additional external inputs to v1, v2, ...
1465        for (idx, &vid) in self.inputs.iter().enumerate().skip(1) {
1466            exprs.insert(vid, format!("v{idx}"));
1467        }
1468        for (vid, val) in &self.const_values {
1469            let lit = match val {
1470                Value::Num(n) => {
1471                    if scalar_ty == "f64" {
1472                        format!("f64({})", n)
1473                    } else {
1474                        format!("{:?}", *n as f32)
1475                    }
1476                }
1477                Value::Int(i) => {
1478                    let f = i.to_f64();
1479                    if scalar_ty == "f64" {
1480                        format!("f64({})", f)
1481                    } else {
1482                        format!("{:?}", f as f32)
1483                    }
1484                }
1485                Value::Tensor(t) if t.data.len() == 1 => {
1486                    let scalar = t.data[0];
1487                    if scalar_ty == "f64" {
1488                        format!("f64({})", scalar)
1489                    } else {
1490                        format!("{:?}", scalar as f32)
1491                    }
1492                }
1493                _ => {
1494                    if scalar_ty == "f64" {
1495                        "f64(0.0)".to_string()
1496                    } else {
1497                        "0.0".to_string()
1498                    }
1499                }
1500            };
1501            exprs.insert(*vid, lit);
1502        }
1503        let mut progressed = true;
1504        while progressed {
1505            progressed = false;
1506            for op in &self.operations {
1507                match op {
1508                    FusionOp::Primitive { op, inputs, output } => {
1509                        if let Some(out) = output {
1510                            if exprs.contains_key(out) {
1511                                continue;
1512                            }
1513                            if let Some(code) = primitive_expr(*op, inputs, &exprs) {
1514                                exprs.insert(*out, code);
1515                                progressed = true;
1516                            }
1517                        }
1518                    }
1519                    FusionOp::Builtin {
1520                        name,
1521                        inputs,
1522                        output,
1523                    } => {
1524                        if let Some(out) = output {
1525                            if exprs.contains_key(out) {
1526                                continue;
1527                            }
1528                            if let Some(code) = builtin_expr(name, inputs, &exprs, scalar_ty) {
1529                                exprs.insert(*out, code);
1530                                progressed = true;
1531                            }
1532                        }
1533                    }
1534                }
1535            }
1536            if exprs.contains_key(&data_vid) {
1537                break;
1538            }
1539        }
1540        // Require a folded expression for the reduction operand; if missing, defer (no WGSL).
1541        let val_expr = match exprs.get(&data_vid) {
1542            Some(s) => s.clone(),
1543            None => {
1544                if fusion_debug_enabled() {
1545                    let expr_keys: Vec<ValueId> = exprs.keys().copied().collect();
1546                    log::debug!(
1547                        "fusion reduction WGSL: missing expression for data {:?}; inputs={:?} expr_keys={:?} ops={:?}",
1548                        data_vid,
1549                        self.inputs,
1550                        expr_keys,
1551                        self.operations
1552                    );
1553                }
1554                return None;
1555            }
1556        };
1557
1558        let mut shader = String::new();
1559        shader.push_str(&format!("struct Tensor {{ data: array<{scalar_ty}>, }};\n"));
1560        shader.push_str("struct MParams { nrows: u32, ncols: u32, ld: u32, flags: u32 }\n\n");
1561        // Bind all input tensors dynamically, followed by output and params
1562        for (idx, _) in self.inputs.iter().enumerate() {
1563            shader.push_str(&format!(
1564                "@group(0) @binding({}) var<storage, read> input{}: Tensor;\n",
1565                idx, idx
1566            ));
1567        }
1568        shader.push_str(&format!(
1569            "@group(0) @binding({}) var<storage, read_write> output: Tensor;\n",
1570            self.inputs.len()
1571        ));
1572        shader.push_str(&format!(
1573            "@group(0) @binding({}) var<uniform> params: MParams;\n\n",
1574            self.inputs.len() + 1
1575        ));
1576        // Use a small fixed workgroup tile size to avoid driver stalls on some backends
1577        shader.push_str(&format!(
1578            "var<workgroup> tile: array<{scalar_ty}, @WG@u>;\n\n"
1579        ));
1580        shader.push_str(&format!(
1581            "const OMITNAN: bool = {};\n\n",
1582            if omitnan { "true" } else { "false" }
1583        ));
1584        // Determine mean semantics from planner-populated reduction flavor
1585        let is_mean = matches!(self.reduction_flavor, Some(ReductionFlavor::Mean));
1586        let post_scale = if is_mean {
1587            let dim = if axis == 0 {
1588                "params.nrows"
1589            } else {
1590                "params.ncols"
1591            };
1592            if scalar_ty == "f64" {
1593                format!("(1.0 / f64(f32({dim})))")
1594            } else {
1595                format!("(1.0 / f32({dim}))")
1596            }
1597        } else if scalar_ty == "f64" {
1598            "f64(1.0)".to_string()
1599        } else {
1600            "1.0".to_string()
1601        };
1602        // Helper(s) at module scope
1603        shader.push_str(&format!(
1604            "fn isNanF(x: {scalar}) -> bool {{ return x != x; }}\n\n",
1605            scalar = scalar_ty
1606        ));
1607        shader.push_str("@compute @workgroup_size(@WG@)\n");
1608        if axis == 0 {
1609            // Column-wise: reduce over rows; one output per column (ncols)
1610            shader.push_str(
1611                "fn main(@builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {\n",
1612            );
1613            shader.push_str("  let col = wid.x;\n  if (col >= params.ncols) { return; }\n");
1614            shader.push_str(&format!(
1615                "  var acc: {scalar_ty} = {}0.0;\n",
1616                if scalar_ty == "f64" { "f64(" } else { "" }
1617            ));
1618            if scalar_ty == "f64" {
1619                shader.push_str("  // close cast for f64 literal\n");
1620            }
1621            // helpers are declared at module scope
1622            shader.push_str("  var saw_nan: bool = false;\n  var r = lid.x;\n");
1623            // Load row-wise values from each input and fold into expression
1624            {
1625                // Build the per-iteration loads
1626                let mut loop_body = String::new();
1627                // input0 as 'v'
1628                loop_body.push_str("    let v = input0.data[ (col * params.nrows) + r ];\n");
1629                // additional inputs as v1, v2, ...
1630                for (idx, _) in self.inputs.iter().enumerate().skip(1) {
1631                    loop_body.push_str(&format!(
1632                        "    let v{idx} = input{idx}.data[ (col * params.nrows) + r ];\n"
1633                    ));
1634                }
1635                // compute val and accumulate
1636                loop_body.push_str(&format!(
1637                    "    let val: {scalar} = {val};\n    if (OMITNAN) {{ if (!isNanF(val)) {{ acc = acc + val; }} }} else {{ if (isNanF(val)) {{ saw_nan = true; }} else {{ acc = acc + val; }} }}\n",
1638                scalar = scalar_ty,
1639                val = val_expr
1640            ));
1641                shader.push_str("  while (r < params.nrows) {\n");
1642                shader.push_str(&loop_body);
1643                shader.push_str("    r += @WG@u;\n  }\n");
1644            }
1645            if scalar_ty == "f64" {
1646                shader.push_str(
1647                    "  if (!OMITNAN && saw_nan) { acc = bitcast<f64>(0x7ff8000000000000u); }\n",
1648                );
1649            } else {
1650                shader
1651                    .push_str("  if (!OMITNAN && saw_nan) { acc = bitcast<f32>(0x7fc00000u); }\n");
1652            }
1653            shader.push_str("  tile[lid.x] = acc;\n  workgroupBarrier();\n");
1654            shader.push_str(
1655                "  var off = (@WG@u) / 2u;\n  loop { if (off == 0u) { break; } if (lid.x < off) {\n    let a = tile[lid.x]; let b = tile[lid.x + off];\n    tile[lid.x] = a + b;\n  } workgroupBarrier(); off = off / 2u; }\n",
1656            );
1657            // Final write: apply post-scale (sum=1, mean=1/rows)
1658            shader.push_str(&format!(
1659                "  if (lid.x == 0u) {{ output.data[col] = tile[0u] * {}; }}\n}}\n",
1660                post_scale
1661            ));
1662        } else {
1663            // Row-wise: reduce over cols; one output per row (nrows)
1664            shader.push_str(
1665                "fn main(@builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {\n",
1666            );
1667            shader.push_str("  let row = wid.x;\n  // For axis=1, number of output slices equals rows (params.ncols)\n  if (row >= params.ncols) { return; }\n");
1668            shader.push_str(&format!(
1669                "  var acc: {scalar_ty} = {}0.0;\n",
1670                if scalar_ty == "f64" { "f64(" } else { "" }
1671            ));
1672            if scalar_ty == "f64" {
1673                shader.push_str("  // close cast for f64 literal\n");
1674            }
1675            // helpers are declared at module scope
1676            shader.push_str("  var saw_nan: bool = false;\n  var c = lid.x;\n");
1677            {
1678                let mut loop_body = String::new();
1679                // input0 as 'v' — provider encodes rows in params.ncols for axis=1
1680                loop_body.push_str("    let v = input0.data[ row + (c * params.ncols) ];\n");
1681                // additional inputs as v1, v2, ...
1682                for (idx, _) in self.inputs.iter().enumerate().skip(1) {
1683                    loop_body.push_str(&format!(
1684                        "    let v{idx} = input{idx}.data[ row + (c * params.ncols) ];\n"
1685                    ));
1686                }
1687                loop_body.push_str(&format!(
1688                    "    let val: {scalar} = {val};\n    if (OMITNAN) {{ if (!isNanF(val)) {{ acc = acc + val; }} }} else {{ if (isNanF(val)) {{ saw_nan = true; }} else {{ acc = acc + val; }} }}\n",
1689                scalar = scalar_ty,
1690                val = val_expr
1691            ));
1692                // Iterate over reduce_len, which arrives as params.nrows when axis=1
1693                shader.push_str("  while (c < params.nrows) {\n");
1694                shader.push_str(&loop_body);
1695                shader.push_str("    c += @WG@u;\n  }\n");
1696            }
1697            if scalar_ty == "f64" {
1698                shader.push_str(
1699                    "  if (!OMITNAN && saw_nan) { acc = bitcast<f64>(0x7ff8000000000000u); }\n",
1700                );
1701            } else {
1702                shader
1703                    .push_str("  if (!OMITNAN && saw_nan) { acc = bitcast<f32>(0x7fc00000u); }\n");
1704            }
1705            shader.push_str("  tile[lid.x] = acc;\n  workgroupBarrier();\n");
1706            shader.push_str(
1707                "  var off = (@WG@u) / 2u;\n  loop { if (off == 0u) { break; } if (lid.x < off) {\n    let a = tile[lid.x]; let b = tile[lid.x + off];\n    tile[lid.x] = a + b;\n  } workgroupBarrier(); off = off / 2u; }\n",
1708            );
1709            shader.push_str(&format!(
1710                "  if (lid.x == 0u) {{ output.data[row] = tile[0u] * {}; }}\n}}\n",
1711                post_scale
1712            ));
1713        }
1714        Some(shader)
1715    }
1716}
1717
1718impl FusionGroup {
1719    pub fn element_count(&self) -> Option<usize> {
1720        match &self.shape {
1721            ShapeInfo::Scalar => Some(1),
1722            ShapeInfo::Tensor(dims) => dims
1723                .iter()
1724                .try_fold(1usize, |acc, dim| dim.and_then(|d| acc.checked_mul(d))),
1725            ShapeInfo::Unknown => None,
1726        }
1727    }
1728}
1729
1730impl FusionKind {
1731    pub fn is_elementwise(&self) -> bool {
1732        matches!(self, FusionKind::ElementwiseChain)
1733    }
1734
1735    pub fn is_reduction(&self) -> bool {
1736        matches!(self, FusionKind::Reduction)
1737    }
1738}
1739
1740fn detect_centered_gram(
1741    graph: &AccelGraph,
1742    assigned: &mut HashSet<NodeId>,
1743    groups: &mut Vec<FusionGroup>,
1744    next_group_id: &mut usize,
1745) {
1746    for div_node in &graph.nodes {
1747        if assigned.contains(&div_node.id) {
1748            continue;
1749        }
1750        let div_op = match div_node.label {
1751            AccelNodeLabel::Primitive(op) => op,
1752            _ => continue,
1753        };
1754        if div_op != PrimitiveOp::Div && div_op != PrimitiveOp::ElemDiv {
1755            continue;
1756        }
1757        if div_node.inputs.len() != 2 {
1758            continue;
1759        }
1760        let (numerator_id, denom_id) = (div_node.inputs[0], div_node.inputs[1]);
1761        let denom_info = match graph.value(denom_id) {
1762            Some(info) => info,
1763            None => continue,
1764        };
1765        let denom_const = match &denom_info.constant {
1766            Some(Value::Num(v)) => Some(*v),
1767            Some(Value::Int(i)) => Some(i.to_f64()),
1768            _ => None,
1769        };
1770        if denom_const.is_some_and(|v| v == 0.0) {
1771            continue;
1772        }
1773
1774        let mul_node_id = match graph
1775            .value(numerator_id)
1776            .and_then(|info| match &info.origin {
1777                ValueOrigin::NodeOutput { node, .. } => Some(*node),
1778                _ => None,
1779            }) {
1780            Some(id) => id,
1781            None => continue,
1782        };
1783        if assigned.contains(&mul_node_id) {
1784            continue;
1785        }
1786        let mul_node = match graph.node(mul_node_id) {
1787            Some(node) => node,
1788            None => continue,
1789        };
1790        let mul_op = match mul_node.label {
1791            AccelNodeLabel::Primitive(op) => op,
1792            _ => continue,
1793        };
1794        if mul_op != PrimitiveOp::Mul && mul_op != PrimitiveOp::ElemMul {
1795            continue;
1796        }
1797        if mul_node.inputs.len() != 2 {
1798            continue;
1799        }
1800
1801        let mut transpose_node_id: Option<NodeId> = None;
1802        let mut centered_val_id: Option<ValueId> = None;
1803        for input_vid in &mul_node.inputs {
1804            let candidate_node_id =
1805                match graph.value(*input_vid).and_then(|info| match &info.origin {
1806                    ValueOrigin::NodeOutput { node, .. } => Some(*node),
1807                    _ => None,
1808                }) {
1809                    Some(id) => id,
1810                    None => continue,
1811                };
1812            if let Some(trans_node) = graph.node(candidate_node_id) {
1813                if matches!(
1814                    trans_node.label,
1815                    AccelNodeLabel::Primitive(PrimitiveOp::Transpose)
1816                ) {
1817                    if let Some(centered) = trans_node.inputs.first().copied() {
1818                        transpose_node_id = Some(candidate_node_id);
1819                        centered_val_id = Some(centered);
1820                        break;
1821                    }
1822                }
1823            }
1824        }
1825
1826        let transpose_node_id = match transpose_node_id {
1827            Some(id) if !assigned.contains(&id) => id,
1828            _ => continue,
1829        };
1830        let centered_val_id = match centered_val_id {
1831            Some(id) => id,
1832            None => continue,
1833        };
1834
1835        if assigned.contains(&transpose_node_id) {
1836            continue;
1837        }
1838        if graph.node(transpose_node_id).is_none() {
1839            continue;
1840        }
1841
1842        let centered_node_id =
1843            match graph
1844                .value(centered_val_id)
1845                .and_then(|info| match &info.origin {
1846                    ValueOrigin::NodeOutput { node, .. } => Some(*node),
1847                    _ => None,
1848                }) {
1849                Some(id) => id,
1850                None => continue,
1851            };
1852        if assigned.contains(&centered_node_id) {
1853            continue;
1854        }
1855        let centered_node = match graph.node(centered_node_id) {
1856            Some(node) => node,
1857            None => continue,
1858        };
1859        if !matches!(
1860            centered_node.label,
1861            AccelNodeLabel::Primitive(PrimitiveOp::Sub)
1862        ) {
1863            continue;
1864        }
1865        if centered_node.inputs.len() != 2 {
1866            continue;
1867        }
1868        let matrix_val_id = centered_node.inputs[0];
1869        let mean_val_id = centered_node.inputs[1];
1870
1871        let mean_node_id = match graph
1872            .value(mean_val_id)
1873            .and_then(|info| match &info.origin {
1874                ValueOrigin::NodeOutput { node, .. } => Some(*node),
1875                _ => None,
1876            }) {
1877            Some(id) => id,
1878            None => continue,
1879        };
1880        if assigned.contains(&mean_node_id) {
1881            continue;
1882        }
1883        let mean_node = match graph.node(mean_node_id) {
1884            Some(node) => node,
1885            None => continue,
1886        };
1887        match &mean_node.label {
1888            AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mean") => {}
1889            _ => continue,
1890        }
1891        if mean_node.inputs.is_empty() || mean_node.inputs[0] != matrix_val_id {
1892            continue;
1893        }
1894
1895        let matrix_info = match graph.value(matrix_val_id) {
1896            Some(info) => info,
1897            None => continue,
1898        };
1899        let matrix_rows = match &matrix_info.shape {
1900            ShapeInfo::Tensor(dims) if !dims.is_empty() => dims[0].unwrap_or(0),
1901            _ => 0,
1902        };
1903        let normalization = if matrix_rows > 1 {
1904            if let Some(value) = denom_const {
1905                let unbiased = (matrix_rows as f64 - 1.0).max(1.0);
1906                let biased = matrix_rows as f64;
1907                if approx_eq(value, unbiased) {
1908                    CovNormalization::Unbiased
1909                } else if approx_eq(value, biased) {
1910                    CovNormalization::Biased
1911                } else {
1912                    CovNormalization::Unbiased
1913                }
1914            } else {
1915                CovNormalization::Unbiased
1916            }
1917        } else {
1918            CovNormalization::Unbiased
1919        };
1920
1921        let mut nodes = vec![
1922            mean_node_id,
1923            centered_node_id,
1924            transpose_node_id,
1925            mul_node_id,
1926            div_node.id,
1927        ];
1928        nodes.sort_by_key(|node_id| {
1929            graph
1930                .node(*node_id)
1931                .map(|node| node.span.start)
1932                .unwrap_or(usize::MAX)
1933        });
1934        let span = group_span(graph, &nodes);
1935        let shape = node_output_shape(graph, div_node);
1936
1937        groups.push(FusionGroup {
1938            id: *next_group_id,
1939            kind: FusionKind::CenteredGram,
1940            nodes: nodes.clone(),
1941            shape,
1942            span,
1943            pattern: Some(FusionPattern::CenteredGram {
1944                matrix: matrix_val_id,
1945                normalization,
1946            }),
1947        });
1948        *next_group_id += 1;
1949        for id in nodes {
1950            assigned.insert(id);
1951        }
1952    }
1953}
1954
1955fn detect_image_normalize(
1956    graph: &AccelGraph,
1957    assigned: &mut HashSet<NodeId>,
1958    groups: &mut Vec<FusionGroup>,
1959    next_group_id: &mut usize,
1960) {
1961    for pow_node in &graph.nodes {
1962        if assigned.contains(&pow_node.id) {
1963            continue;
1964        }
1965        let Some(match_info) = analyze_image_normalize(graph, pow_node.id, assigned) else {
1966            continue;
1967        };
1968
1969        let pow_node_ref = match graph.node(pow_node.id) {
1970            Some(node) => node,
1971            None => continue,
1972        };
1973
1974        let shape = node_output_shape(graph, pow_node_ref);
1975        let span = group_span(graph, &match_info.nodes);
1976
1977        let pattern = ImageNormalizePattern {
1978            input: match_info.input,
1979            epsilon: match_info.epsilon.clone(),
1980            gain: match_info.gain.clone(),
1981            bias: match_info.bias.clone(),
1982            gamma: match_info.gamma.clone(),
1983        };
1984
1985        groups.push(FusionGroup {
1986            id: *next_group_id,
1987            kind: FusionKind::ImageNormalize,
1988            nodes: match_info.nodes.clone(),
1989            shape,
1990            span: span.clone(),
1991            pattern: Some(FusionPattern::ImageNormalize(pattern)),
1992        });
1993        if fusion_debug_enabled() {
1994            log::debug!(
1995                "fusion: detected image normalize group id={} span={:?} nodes={:?}",
1996                next_group_id,
1997                span,
1998                match_info.nodes
1999            );
2000        }
2001        *next_group_id += 1;
2002        for node_id in match_info.nodes {
2003            assigned.insert(node_id);
2004        }
2005    }
2006}
2007
2008fn approx_eq(a: f64, b: f64) -> bool {
2009    let scale = a.abs().max(b.abs()).max(1.0);
2010    (a - b).abs() <= scale * 1e-6
2011}
2012
2013fn detect_power_step_normalize(
2014    graph: &AccelGraph,
2015    assigned: &mut HashSet<NodeId>,
2016    groups: &mut Vec<FusionGroup>,
2017    next_group_id: &mut usize,
2018) {
2019    'outer: for div_node in &graph.nodes {
2020        if assigned.contains(&div_node.id) {
2021            continue;
2022        }
2023        let div_op = match div_node.label {
2024            AccelNodeLabel::Primitive(op) => op,
2025            _ => continue,
2026        };
2027        if div_op != PrimitiveOp::Div && div_op != PrimitiveOp::ElemDiv {
2028            continue;
2029        }
2030        if div_node.inputs.len() != 2 {
2031            continue;
2032        }
2033        let numerator_vid = div_node.inputs[0];
2034        let denom_vid = div_node.inputs[1];
2035
2036        let (matmul_id, matmul_node) = match node_from_value(graph, numerator_vid) {
2037            Some((id, node)) => (id, node),
2038            None => continue,
2039        };
2040        if assigned.contains(&matmul_id) {
2041            continue;
2042        }
2043        match &matmul_node.label {
2044            AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mtimes") => {}
2045            _ => continue,
2046        }
2047        if matmul_node.inputs.len() != 2 {
2048            continue;
2049        }
2050
2051        let Some(denom_info) = analyze_power_step_denominator(graph, denom_vid, numerator_vid)
2052        else {
2053            continue;
2054        };
2055        if assigned.contains(&denom_info.sqrt_node) {
2056            continue;
2057        }
2058        if assigned.contains(&denom_info.sum_node) {
2059            continue;
2060        }
2061        if assigned.contains(&denom_info.pow_node) {
2062            continue;
2063        }
2064        if let Some(add_id) = denom_info.add_node {
2065            if assigned.contains(&add_id) {
2066                continue;
2067            }
2068        }
2069        if denom_info.pow_input != numerator_vid {
2070            continue;
2071        }
2072
2073        let mut nodes = vec![matmul_id, denom_info.pow_node, denom_info.sum_node];
2074        if let Some(add_id) = denom_info.add_node {
2075            nodes.push(add_id);
2076        }
2077        nodes.push(denom_info.sqrt_node);
2078        nodes.push(div_node.id);
2079
2080        for node_id in &nodes {
2081            if assigned.contains(node_id) {
2082                continue 'outer;
2083            }
2084        }
2085
2086        nodes.sort_by_key(|node_id| {
2087            graph
2088                .node(*node_id)
2089                .map(|node| node.span.start)
2090                .unwrap_or(usize::MAX)
2091        });
2092
2093        let span = group_span(graph, &nodes);
2094        let shape = node_output_shape(graph, div_node);
2095
2096        groups.push(FusionGroup {
2097            id: *next_group_id,
2098            kind: FusionKind::PowerStepNormalize,
2099            nodes: nodes.clone(),
2100            shape,
2101            span,
2102            pattern: Some(FusionPattern::PowerStepNormalize {
2103                lhs: matmul_node.inputs[0],
2104                rhs: matmul_node.inputs[1],
2105                epsilon: denom_info.epsilon,
2106            }),
2107        });
2108        *next_group_id += 1;
2109        for id in nodes {
2110            assigned.insert(id);
2111        }
2112    }
2113}
2114
2115fn detect_explained_variance(
2116    graph: &AccelGraph,
2117    assigned: &mut HashSet<NodeId>,
2118    groups: &mut Vec<FusionGroup>,
2119    next_group_id: &mut usize,
2120) {
2121    for diag_node in &graph.nodes {
2122        if assigned.contains(&diag_node.id) {
2123            continue;
2124        }
2125        match &diag_node.label {
2126            AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("diag") => {}
2127            _ => continue,
2128        }
2129        if diag_node.inputs.len() != 1 {
2130            continue;
2131        }
2132        let matmul2_vid = diag_node.inputs[0];
2133        let (matmul2_id, matmul2_node) = match node_from_value(graph, matmul2_vid) {
2134            Some(pair) => pair,
2135            None => continue,
2136        };
2137        if assigned.contains(&matmul2_id) {
2138            continue;
2139        }
2140        match &matmul2_node.label {
2141            AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mtimes") => {}
2142            _ => continue,
2143        }
2144        if matmul2_node.inputs.len() != 2 {
2145            continue;
2146        }
2147
2148        let (matmul1_id, matmul1_node, q_vid) = if let Some((mm_id, mm_node)) =
2149            node_from_value(graph, matmul2_node.inputs[0])
2150        {
2151            if matches!(mm_node.label, AccelNodeLabel::Builtin { ref name } if name.eq_ignore_ascii_case("mtimes"))
2152            {
2153                (mm_id, mm_node, matmul2_node.inputs[1])
2154            } else {
2155                continue;
2156            }
2157        } else if let Some((mm_id, mm_node)) = node_from_value(graph, matmul2_node.inputs[1]) {
2158            if matches!(mm_node.label, AccelNodeLabel::Builtin { ref name } if name.eq_ignore_ascii_case("mtimes"))
2159            {
2160                (mm_id, mm_node, matmul2_node.inputs[0])
2161            } else {
2162                continue;
2163            }
2164        } else {
2165            continue;
2166        };
2167
2168        if assigned.contains(&matmul1_id) {
2169            continue;
2170        }
2171
2172        if matmul1_node.inputs.len() != 2 {
2173            continue;
2174        }
2175
2176        let (transpose_id, transpose_input_vid, g_vid) =
2177            if let Some((t_id, src_vid)) = is_transpose_node(graph, matmul1_node.inputs[0]) {
2178                (t_id, src_vid, matmul1_node.inputs[1])
2179            } else if let Some((t_id, src_vid)) = is_transpose_node(graph, matmul1_node.inputs[1]) {
2180                (t_id, src_vid, matmul1_node.inputs[0])
2181            } else {
2182                continue;
2183            };
2184
2185        if assigned.contains(&transpose_id) {
2186            continue;
2187        }
2188
2189        if transpose_input_vid != q_vid {
2190            continue;
2191        }
2192
2193        let mut nodes = vec![diag_node.id, matmul2_id, matmul1_id, transpose_id];
2194        nodes.sort_by_key(|node_id| {
2195            graph
2196                .node(*node_id)
2197                .map(|node| node.span.start)
2198                .unwrap_or(usize::MAX)
2199        });
2200        let span = group_span(graph, &nodes);
2201        let shape = node_output_shape(graph, diag_node);
2202        groups.push(FusionGroup {
2203            id: *next_group_id,
2204            kind: FusionKind::ExplainedVariance,
2205            nodes: nodes.clone(),
2206            shape,
2207            span,
2208            pattern: Some(FusionPattern::ExplainedVariance { q: q_vid, g: g_vid }),
2209        });
2210        *next_group_id += 1;
2211        for id in nodes {
2212            assigned.insert(id);
2213        }
2214    }
2215}
2216
2217struct PowerStepDenominatorInfo {
2218    sqrt_node: NodeId,
2219    add_node: Option<NodeId>,
2220    sum_node: NodeId,
2221    pow_node: NodeId,
2222    pow_input: ValueId,
2223    epsilon: f64,
2224}
2225
2226fn analyze_power_step_denominator(
2227    graph: &AccelGraph,
2228    denom_vid: ValueId,
2229    expected_source_vid: ValueId,
2230) -> Option<PowerStepDenominatorInfo> {
2231    let (sqrt_node_id, sqrt_input_vid, add_node_opt, epsilon_from_outer) =
2232        if let Some((sqrt_id, sqrt_in)) = is_sqrt_node(graph, denom_vid) {
2233            if let Some((add_node, sum_vid, epsilon_inner)) =
2234                extract_add_with_constant(graph, sqrt_in)
2235            {
2236                (sqrt_id, sum_vid, Some(add_node), epsilon_inner)
2237            } else {
2238                (sqrt_id, sqrt_in, None, 0.0)
2239            }
2240        } else if let Some((add_node, other_vid, epsilon_inner)) =
2241            extract_add_with_constant(graph, denom_vid)
2242        {
2243            let (sqrt_id, sqrt_in) = is_sqrt_node(graph, other_vid)?;
2244            (sqrt_id, sqrt_in, Some(add_node), epsilon_inner)
2245        } else {
2246            return None;
2247        };
2248
2249    let (sum_node_id, sum_node) = node_from_value(graph, sqrt_input_vid)?;
2250    match &sum_node.label {
2251        AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("sum") => {}
2252        _ => return None,
2253    }
2254    if sum_node.inputs.is_empty() {
2255        return None;
2256    }
2257    let pow_vid = sum_node.inputs[0];
2258    let (pow_node_id, pow_node) = node_from_value(graph, pow_vid)?;
2259    let pow_input = match pow_node.label {
2260        AccelNodeLabel::Primitive(PrimitiveOp::ElemPow) => {
2261            if pow_node.inputs.len() != 2 {
2262                return None;
2263            }
2264            let base = pow_node.inputs[0];
2265            let exponent_vid = pow_node.inputs[1];
2266            let exponent = value_constant_f64(graph, exponent_vid)?;
2267            if !approx_eq(exponent, 2.0) {
2268                return None;
2269            }
2270            base
2271        }
2272        _ => return None,
2273    };
2274
2275    if pow_input != expected_source_vid {
2276        return None;
2277    }
2278
2279    let epsilon = epsilon_from_outer;
2280    Some(PowerStepDenominatorInfo {
2281        sqrt_node: sqrt_node_id,
2282        add_node: add_node_opt,
2283        sum_node: sum_node_id,
2284        pow_node: pow_node_id,
2285        pow_input,
2286        epsilon,
2287    })
2288}
2289
2290fn node_from_value(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, &AccelNode)> {
2291    let info = graph.value(vid)?;
2292    match info.origin {
2293        ValueOrigin::NodeOutput { node, .. } => graph.node(node).map(|n| (node, n)),
2294        _ => None,
2295    }
2296}
2297
2298fn is_sqrt_node(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId)> {
2299    let (node_id, node) = node_from_value(graph, vid)?;
2300    match &node.label {
2301        AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("sqrt") => {
2302            let input = node.inputs.first().copied()?;
2303            Some((node_id, input))
2304        }
2305        _ => None,
2306    }
2307}
2308
2309fn is_transpose_node(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId)> {
2310    let (node_id, node) = node_from_value(graph, vid)?;
2311    match &node.label {
2312        AccelNodeLabel::Primitive(PrimitiveOp::Transpose) => {
2313            let input = node.inputs.first().copied()?;
2314            Some((node_id, input))
2315        }
2316        _ => None,
2317    }
2318}
2319
2320fn extract_add_with_constant(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId, f64)> {
2321    let (node_id, node) = node_from_value(graph, vid)?;
2322    match node.label {
2323        AccelNodeLabel::Primitive(PrimitiveOp::Add) => {
2324            if node.inputs.len() != 2 {
2325                return None;
2326            }
2327            let lhs = node.inputs[0];
2328            let rhs = node.inputs[1];
2329            if let Some(eps) = value_constant_f64(graph, rhs) {
2330                return Some((node_id, lhs, eps));
2331            }
2332            if let Some(eps) = value_constant_f64(graph, lhs) {
2333                return Some((node_id, rhs, eps));
2334            }
2335            None
2336        }
2337        AccelNodeLabel::Primitive(PrimitiveOp::Sub) => {
2338            if node.inputs.len() != 2 {
2339                return None;
2340            }
2341            let lhs = node.inputs[0];
2342            let rhs = node.inputs[1];
2343            if let Some(eps) = value_constant_f64(graph, rhs) {
2344                return Some((node_id, lhs, -eps));
2345            }
2346            if let Some(eps) = value_constant_f64(graph, lhs) {
2347                return Some((node_id, rhs, eps));
2348            }
2349            None
2350        }
2351        _ => None,
2352    }
2353}
2354
2355struct ConstantTrace {
2356    value: f64,
2357    nodes: Vec<NodeId>,
2358}
2359
2360fn collect_scalar_constant(graph: &AccelGraph, vid: ValueId) -> Option<ConstantTrace> {
2361    let mut current = vid;
2362    let mut nodes: Vec<NodeId> = Vec::new();
2363    let mut sign = 1.0f64;
2364    let mut visited: HashSet<NodeId> = HashSet::new();
2365
2366    loop {
2367        let info = graph.value(current)?;
2368        match &info.origin {
2369            ValueOrigin::Constant => {
2370                let base = value_info_scalar(info)?;
2371                return Some(ConstantTrace {
2372                    value: sign * base,
2373                    nodes,
2374                });
2375            }
2376            ValueOrigin::NodeOutput { node, .. } => {
2377                if !visited.insert(*node) {
2378                    return None;
2379                }
2380                let node_ref = graph.node(*node)?;
2381                match &node_ref.label {
2382                    AccelNodeLabel::Builtin { name }
2383                        if name.eq_ignore_ascii_case("single")
2384                            || name.eq_ignore_ascii_case("double")
2385                            || name.eq_ignore_ascii_case("gpuarray") =>
2386                    {
2387                        if node_ref.inputs.len() != 1 {
2388                            return None;
2389                        }
2390                        nodes.push(*node);
2391                        current = node_ref.inputs[0];
2392                    }
2393                    AccelNodeLabel::Primitive(PrimitiveOp::Neg) => {
2394                        if node_ref.inputs.len() != 1 {
2395                            return None;
2396                        }
2397                        nodes.push(*node);
2398                        sign = -sign;
2399                        current = node_ref.inputs[0];
2400                    }
2401                    AccelNodeLabel::Primitive(PrimitiveOp::UPlus) => {
2402                        if node_ref.inputs.len() != 1 {
2403                            return None;
2404                        }
2405                        nodes.push(*node);
2406                        current = node_ref.inputs[0];
2407                    }
2408                    _ => return None,
2409                }
2410            }
2411            _ => return None,
2412        }
2413    }
2414}
2415
2416fn scalar_shape_known_one(shape: &ShapeInfo) -> bool {
2417    match shape {
2418        ShapeInfo::Scalar => true,
2419        ShapeInfo::Tensor(dims) => {
2420            if dims.is_empty() {
2421                return true;
2422            }
2423            dims.iter().all(|dim| matches!(dim, Some(1)))
2424        }
2425        ShapeInfo::Unknown => false,
2426    }
2427}
2428
2429fn capture_image_scalar(
2430    graph: &AccelGraph,
2431    vid: ValueId,
2432    assigned: &HashSet<NodeId>,
2433    _nodes: &mut Vec<NodeId>,
2434) -> Option<ImageScalar> {
2435    if let Some(trace) = collect_scalar_constant(graph, vid) {
2436        if trace.nodes.iter().any(|id| assigned.contains(id)) {
2437            return None;
2438        }
2439        return Some(ImageScalar::Constant(trace.value));
2440    }
2441    let info = graph.value(vid)?;
2442    if scalar_shape_known_one(&info.shape) {
2443        return Some(ImageScalar::Value(vid));
2444    }
2445    if log::log_enabled!(log::Level::Debug) {
2446        log::debug!(
2447            "capture_image_scalar: reject vid={vid:?} shape={:?} origin={:?}",
2448            info.shape,
2449            info.origin
2450        );
2451    }
2452    None
2453}
2454
2455fn peel_numeric_casts(
2456    graph: &AccelGraph,
2457    mut vid: ValueId,
2458    assigned: &HashSet<NodeId>,
2459    _nodes: &mut Vec<NodeId>,
2460) -> Option<ValueId> {
2461    loop {
2462        let info = graph.value(vid)?;
2463        match &info.origin {
2464            ValueOrigin::NodeOutput { node, .. } => {
2465                if assigned.contains(node) {
2466                    return None;
2467                }
2468                let node_ref = graph.node(*node)?;
2469                if let AccelNodeLabel::Builtin { name } = &node_ref.label {
2470                    if name.eq_ignore_ascii_case("single")
2471                        || name.eq_ignore_ascii_case("double")
2472                        || name.eq_ignore_ascii_case("gpuarray")
2473                    {
2474                        if node_ref.inputs.len() != 1 {
2475                            return None;
2476                        }
2477                        vid = node_ref.inputs[0];
2478                        continue;
2479                    }
2480                }
2481                return Some(vid);
2482            }
2483            _ => return Some(vid),
2484        }
2485    }
2486}
2487
2488fn resolve_scalar_constant(graph: &AccelGraph, vid: ValueId) -> Option<f64> {
2489    collect_scalar_constant(graph, vid).map(|trace| trace.value)
2490}
2491
2492fn value_info_scalar(info: &ValueInfo) -> Option<f64> {
2493    match &info.constant {
2494        Some(Value::Num(v)) => Some(*v),
2495        Some(Value::Int(i)) => Some(i.to_f64()),
2496        Some(Value::Tensor(t)) if t.data.len() == 1 => Some(t.data[0]),
2497        Some(Value::LogicalArray(arr)) if arr.data.len() == 1 => Some(arr.data[0] as f64),
2498        Some(Value::Bool(flag)) => Some(if *flag { 1.0 } else { 0.0 }),
2499        _ => None,
2500    }
2501}
2502
2503fn value_constant_f64(graph: &AccelGraph, vid: ValueId) -> Option<f64> {
2504    resolve_scalar_constant(graph, vid)
2505}
2506
2507fn primitive_expr(
2508    op: PrimitiveOp,
2509    inputs: &[ValueId],
2510    exprs: &HashMap<ValueId, String>,
2511) -> Option<String> {
2512    let binary = |exprs: &HashMap<ValueId, String>| -> Option<(String, String)> {
2513        let lhs = exprs.get(inputs.first()?).cloned()?;
2514        let rhs = exprs.get(inputs.get(1)?).cloned()?;
2515        Some((lhs, rhs))
2516    };
2517    match op {
2518        PrimitiveOp::Add => {
2519            let (lhs, rhs) = binary(exprs)?;
2520            Some(format!("({lhs} + {rhs})"))
2521        }
2522        PrimitiveOp::Sub => {
2523            let (lhs, rhs) = binary(exprs)?;
2524            Some(format!("({lhs} - {rhs})"))
2525        }
2526        PrimitiveOp::Mul | PrimitiveOp::ElemMul => {
2527            let (lhs, rhs) = binary(exprs)?;
2528            Some(format!("({lhs} * {rhs})"))
2529        }
2530        PrimitiveOp::Div | PrimitiveOp::ElemDiv | PrimitiveOp::ElemLeftDiv => {
2531            let (lhs, rhs) = binary(exprs)?;
2532            Some(format!("({lhs} / {rhs})"))
2533        }
2534        PrimitiveOp::Pow | PrimitiveOp::ElemPow => {
2535            let (lhs, rhs) = binary(exprs)?;
2536            Some(format!("pow({lhs}, {rhs})"))
2537        }
2538        PrimitiveOp::Neg => {
2539            let arg = exprs.get(inputs.first()?).cloned()?;
2540            Some(format!("(-{arg})"))
2541        }
2542        PrimitiveOp::UPlus => {
2543            let arg = exprs.get(inputs.first()?).cloned()?;
2544            Some(format!("(+{arg})"))
2545        }
2546        _ => None,
2547    }
2548}
2549
2550fn builtin_expr(
2551    name: &str,
2552    inputs: &[ValueId],
2553    exprs: &HashMap<ValueId, String>,
2554    scalar_ty: &str,
2555) -> Option<String> {
2556    let func = match name.to_ascii_lowercase().as_str() {
2557        "isfinite" => return builtin_unary_call("isFinite", inputs, exprs),
2558        "isinf" => return builtin_unary_call("isInf", inputs, exprs),
2559        "isnan" => return builtin_unary_call("isNan", inputs, exprs),
2560        "single" | "double" | "gpuarray" => return builtin_identity(inputs, exprs),
2561        "sin" => "sin",
2562        "cos" => "cos",
2563        "tan" => "tan",
2564        "asin" => "asin",
2565        "acos" => "acos",
2566        "atan" => "atan",
2567        "atan2" => return builtin_binary("atan2", inputs, exprs),
2568        "sinh" => "sinh",
2569        "cosh" => "cosh",
2570        "tanh" => "tanh",
2571        "exp" => "exp",
2572        "log" => "log",
2573        "log2" => "log2",
2574        "sqrt" => "sqrt",
2575        "abs" => "abs",
2576        "exp2" => "exp2",
2577        "floor" => "floor",
2578        "ceil" => "ceil",
2579        "round" => "round",
2580        "trunc" => "trunc",
2581        "max" => return builtin_binary("max", inputs, exprs),
2582        "min" => return builtin_binary("min", inputs, exprs),
2583        _ => {
2584            return match name.to_ascii_lowercase().as_str() {
2585                "log10" => {
2586                    let arg = exprs.get(inputs.first()?).cloned()?;
2587                    let constant = cast_literal(scalar_ty, "0.4342944819032518");
2588                    Some(format!("(log({arg}) * {constant})"))
2589                }
2590                "log1p" => {
2591                    let arg = exprs.get(inputs.first()?).cloned()?;
2592                    let one = cast_literal(scalar_ty, "1.0");
2593                    Some(format!("log({arg} + {one})"))
2594                }
2595                "expm1" => {
2596                    let arg = exprs.get(inputs.first()?).cloned()?;
2597                    let one = cast_literal(scalar_ty, "1.0");
2598                    Some(format!("(exp({arg}) - {one})"))
2599                }
2600                _ => None,
2601            }
2602        }
2603    };
2604    let arg = exprs.get(inputs.first()?).cloned()?;
2605    Some(format!("{func}({arg})"))
2606}
2607
2608fn builtin_binary(
2609    func: &str,
2610    inputs: &[ValueId],
2611    exprs: &HashMap<ValueId, String>,
2612) -> Option<String> {
2613    let lhs = exprs.get(inputs.first()?).cloned()?;
2614    let rhs = exprs.get(inputs.get(1)?).cloned()?;
2615    Some(format!("{func}({lhs}, {rhs})"))
2616}
2617
2618fn builtin_unary_call(
2619    func: &str,
2620    inputs: &[ValueId],
2621    exprs: &HashMap<ValueId, String>,
2622) -> Option<String> {
2623    let arg = exprs.get(inputs.first()?).cloned()?;
2624    Some(format!("{func}({arg})"))
2625}
2626
2627fn builtin_identity(inputs: &[ValueId], exprs: &HashMap<ValueId, String>) -> Option<String> {
2628    exprs.get(inputs.first()?).cloned()
2629}
2630
2631fn cast_literal(scalar_ty: &str, literal: &str) -> String {
2632    if scalar_ty == "f64" {
2633        format!("{scalar_ty}({literal})")
2634    } else {
2635        literal.to_string()
2636    }
2637}
2638
2639fn split_add_with_scalar(
2640    graph: &AccelGraph,
2641    vid: ValueId,
2642    assigned: &HashSet<NodeId>,
2643    nodes: &mut Vec<NodeId>,
2644) -> Option<(NodeId, ValueId, ImageScalar)> {
2645    let (node_id, node) = node_from_value(graph, vid)?;
2646    match node.label {
2647        AccelNodeLabel::Primitive(PrimitiveOp::Add) => {
2648            if node.inputs.len() != 2 {
2649                return None;
2650            }
2651            let lhs = node.inputs[0];
2652            let rhs = node.inputs[1];
2653            if let Some(scalar) = capture_image_scalar(graph, rhs, assigned, nodes) {
2654                return Some((node_id, lhs, scalar));
2655            }
2656            if let Some(scalar) = capture_image_scalar(graph, lhs, assigned, nodes) {
2657                return Some((node_id, rhs, scalar));
2658            }
2659            None
2660        }
2661        AccelNodeLabel::Primitive(PrimitiveOp::Sub) => {
2662            if node.inputs.len() != 2 {
2663                return None;
2664            }
2665            let lhs = node.inputs[0];
2666            let rhs = node.inputs[1];
2667            if let Some(ImageScalar::Constant(value)) =
2668                capture_image_scalar(graph, rhs, assigned, nodes)
2669            {
2670                return Some((node_id, lhs, ImageScalar::Constant(-value)));
2671            }
2672            None
2673        }
2674        _ => None,
2675    }
2676}
2677
2678fn split_mul_with_scalar(
2679    graph: &AccelGraph,
2680    vid: ValueId,
2681    assigned: &HashSet<NodeId>,
2682    nodes: &mut Vec<NodeId>,
2683) -> Option<(NodeId, ValueId, ImageScalar)> {
2684    let (node_id, node) = node_from_value(graph, vid)?;
2685    match node.label {
2686        AccelNodeLabel::Primitive(PrimitiveOp::Mul)
2687        | AccelNodeLabel::Primitive(PrimitiveOp::ElemMul) => {
2688            if node.inputs.len() != 2 {
2689                return None;
2690            }
2691            let lhs = node.inputs[0];
2692            let rhs = node.inputs[1];
2693            if let Some(scalar) = capture_image_scalar(graph, rhs, assigned, nodes) {
2694                return Some((node_id, lhs, scalar));
2695            }
2696            if let Some(scalar) = capture_image_scalar(graph, lhs, assigned, nodes) {
2697                return Some((node_id, rhs, scalar));
2698            }
2699            None
2700        }
2701        _ => None,
2702    }
2703}
2704
2705fn split_max_with_zero_scalar(
2706    graph: &AccelGraph,
2707    vid: ValueId,
2708    assigned: &HashSet<NodeId>,
2709    nodes: &mut Vec<NodeId>,
2710) -> Option<(NodeId, ValueId)> {
2711    let (node_id, node) = node_from_value(graph, vid)?;
2712    match &node.label {
2713        AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("max") => {
2714            if node.inputs.len() != 2 {
2715                if log::log_enabled!(log::Level::Debug) {
2716                    log::debug!(
2717                        "split_max_with_zero_scalar: node {node_id:?} has {} inputs",
2718                        node.inputs.len()
2719                    );
2720                }
2721                return None;
2722            }
2723            let lhs = node.inputs[0];
2724            let rhs = node.inputs[1];
2725            if let Some(ImageScalar::Constant(value)) =
2726                capture_image_scalar(graph, rhs, assigned, nodes)
2727            {
2728                if approx_eq(value, 0.0) {
2729                    if log::log_enabled!(log::Level::Debug) {
2730                        log::debug!(
2731                            "split_max_with_zero_scalar: rhs zero constant for node {node_id:?}"
2732                        );
2733                    }
2734                    return Some((node_id, lhs));
2735                }
2736            }
2737            if let Some(ImageScalar::Constant(value)) =
2738                capture_image_scalar(graph, lhs, assigned, nodes)
2739            {
2740                if approx_eq(value, 0.0) {
2741                    if log::log_enabled!(log::Level::Debug) {
2742                        log::debug!(
2743                            "split_max_with_zero_scalar: lhs zero constant for node {node_id:?}"
2744                        );
2745                    }
2746                    return Some((node_id, rhs));
2747                }
2748            }
2749            if log::log_enabled!(log::Level::Debug) {
2750                log::debug!(
2751                    "split_max_with_zero_scalar: node {node_id:?} inputs not zero constants"
2752                );
2753            }
2754            None
2755        }
2756        _ => None,
2757    }
2758}
2759
2760fn resolve_numeric_vector_constant(graph: &AccelGraph, vid: ValueId) -> Option<Vec<f64>> {
2761    if let Some(scalar) = resolve_scalar_constant(graph, vid) {
2762        return Some(vec![scalar]);
2763    }
2764    let info = graph.value(vid)?;
2765    match &info.constant {
2766        Some(Value::Tensor(tensor)) if !tensor.data.is_empty() => Some(tensor.data.clone()),
2767        Some(Value::LogicalArray(arr)) if !arr.data.is_empty() => Some(
2768            arr.data
2769                .iter()
2770                .map(|v| if *v == 0 { 0.0 } else { 1.0 })
2771                .collect(),
2772        ),
2773        Some(Value::Bool(flag)) => Some(vec![if *flag { 1.0 } else { 0.0 }]),
2774        Some(Value::Int(iv)) => Some(vec![iv.to_f64()]),
2775        Some(Value::Num(num)) => Some(vec![*num]),
2776        _ => None,
2777    }
2778}
2779
2780fn match_mean_axes(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId, Vec<f64>)> {
2781    let (node_id, node) = node_from_value(graph, vid)?;
2782    match &node.label {
2783        AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mean") => {}
2784        _ => return None,
2785    }
2786    if node.inputs.len() < 2 {
2787        return None;
2788    }
2789    let data_vid = node.inputs[0];
2790    let dims_vid = node.inputs[1];
2791    let dims = resolve_numeric_vector_constant(graph, dims_vid)?;
2792    Some((node_id, data_vid, dims))
2793}
2794
2795fn dims_match_unordered(found: &[f64], expected: &[f64]) -> bool {
2796    if found.len() != expected.len() {
2797        return false;
2798    }
2799    let mut a: Vec<i64> = found.iter().map(|d| d.round() as i64).collect();
2800    let mut b: Vec<i64> = expected.iter().map(|d| d.round() as i64).collect();
2801    a.sort_unstable();
2802    b.sort_unstable();
2803    a == b
2804}
2805
2806fn peel_mean_dims(
2807    graph: &AccelGraph,
2808    vid: ValueId,
2809    expected_dims: &[f64],
2810    assigned: &HashSet<NodeId>,
2811    nodes: &mut Vec<NodeId>,
2812) -> Option<ValueId> {
2813    if expected_dims.is_empty() {
2814        return Some(vid);
2815    }
2816    let (node_id, data_vid, dims) = match_mean_axes(graph, vid)?;
2817    if assigned.contains(&node_id) {
2818        return None;
2819    }
2820    if dims.len() == expected_dims.len() && dims_match_unordered(&dims, expected_dims) {
2821        nodes.push(node_id);
2822        return Some(data_vid);
2823    }
2824    if dims.len() == 1 && approx_eq(dims[0], expected_dims[0]) {
2825        nodes.push(node_id);
2826        return peel_mean_dims(graph, data_vid, &expected_dims[1..], assigned, nodes);
2827    }
2828    None
2829}
2830
2831struct ImageNormalizeMatch {
2832    nodes: Vec<NodeId>,
2833    input: ValueId,
2834    epsilon: ImageScalar,
2835    gain: Option<ImageScalar>,
2836    bias: Option<ImageScalar>,
2837    gamma: Option<ImageScalar>,
2838}
2839
2840fn analyze_image_normalize(
2841    graph: &AccelGraph,
2842    pow_node_id: NodeId,
2843    assigned: &HashSet<NodeId>,
2844) -> Option<ImageNormalizeMatch> {
2845    let pow_node = graph.node(pow_node_id)?;
2846    if log::log_enabled!(log::Level::Debug) {
2847        log::debug!(
2848            "image_normalize: inspect pow candidate node={pow_node_id:?} label={:?}",
2849            pow_node.label
2850        );
2851    }
2852    macro_rules! img_norm_fail {
2853        ($reason:expr) => {{
2854            if log::log_enabled!(log::Level::Debug) {
2855                log::debug!(
2856                    "image_normalize: reject node {pow_node_id:?} reason={}",
2857                    $reason
2858                );
2859            }
2860            return None;
2861        }};
2862    }
2863    if !matches!(
2864        pow_node.label,
2865        AccelNodeLabel::Primitive(PrimitiveOp::ElemPow)
2866    ) {
2867        img_norm_fail!("not elem pow");
2868    }
2869    if pow_node.inputs.len() != 2 || pow_node.outputs.len() != 1 {
2870        img_norm_fail!("unexpected pow arity");
2871    }
2872
2873    let mut nodes: Vec<NodeId> = vec![pow_node_id];
2874
2875    let gamma_scalar = capture_image_scalar(graph, pow_node.inputs[1], assigned, &mut nodes)?;
2876    if log::log_enabled!(log::Level::Debug) {
2877        log::debug!("image_normalize: node {pow_node_id:?} gamma scalar={gamma_scalar:?}");
2878    }
2879    let gamma_opt = match &gamma_scalar {
2880        ImageScalar::Constant(value) if approx_eq(*value, 1.0) => None,
2881        _ => Some(gamma_scalar),
2882    };
2883
2884    let (clamp_node_id, clamp_input_vid) =
2885        split_max_with_zero_scalar(graph, pow_node.inputs[0], assigned, &mut nodes)?;
2886    if assigned.contains(&clamp_node_id) {
2887        img_norm_fail!("clamp node already assigned");
2888    }
2889    nodes.push(clamp_node_id);
2890
2891    let pre_bias_vid = peel_numeric_casts(graph, clamp_input_vid, assigned, &mut nodes)?;
2892    let (pre_gain_vid, bias_opt) = if let Some((add_node_id, base_vid, bias_scalar)) =
2893        split_add_with_scalar(graph, pre_bias_vid, assigned, &mut nodes)
2894    {
2895        if assigned.contains(&add_node_id) {
2896            img_norm_fail!("bias add already assigned");
2897        }
2898        nodes.push(add_node_id);
2899        let bias = match &bias_scalar {
2900            ImageScalar::Constant(value) if approx_eq(*value, 0.0) => None,
2901            _ => Some(bias_scalar),
2902        };
2903        let base_vid = peel_numeric_casts(graph, base_vid, assigned, &mut nodes)?;
2904        (base_vid, bias)
2905    } else {
2906        (pre_bias_vid, None)
2907    };
2908
2909    let (mut norm_vid, gain_opt) = if let Some((mul_node_id, base_vid, gain_scalar)) =
2910        split_mul_with_scalar(graph, pre_gain_vid, assigned, &mut nodes)
2911    {
2912        if assigned.contains(&mul_node_id) {
2913            img_norm_fail!("gain mul already assigned");
2914        }
2915        nodes.push(mul_node_id);
2916        let gain = match &gain_scalar {
2917            ImageScalar::Constant(value) if approx_eq(*value, 1.0) => None,
2918            _ => Some(gain_scalar),
2919        };
2920        let base_vid = peel_numeric_casts(graph, base_vid, assigned, &mut nodes)?;
2921        (base_vid, gain)
2922    } else {
2923        (pre_gain_vid, None)
2924    };
2925
2926    norm_vid = peel_numeric_casts(graph, norm_vid, assigned, &mut nodes)?;
2927
2928    let (div_node_id, div_node) = node_from_value(graph, norm_vid)?;
2929    if assigned.contains(&div_node_id) {
2930        img_norm_fail!("div node already assigned");
2931    }
2932    match div_node.label {
2933        AccelNodeLabel::Primitive(PrimitiveOp::ElemDiv)
2934        | AccelNodeLabel::Primitive(PrimitiveOp::Div) => {}
2935        _ => img_norm_fail!("not div primitive"),
2936    }
2937    if div_node.inputs.len() != 2 {
2938        img_norm_fail!("div arity");
2939    }
2940
2941    let diff_vid = div_node.inputs[0];
2942    let sigma_vid = peel_numeric_casts(graph, div_node.inputs[1], assigned, &mut nodes)?;
2943    let (sigma_node_id, sigma_input_vid) = match is_sqrt_node(graph, sigma_vid) {
2944        Some(pair) => pair,
2945        None => img_norm_fail!("sigma not sqrt"),
2946    };
2947    if assigned.contains(&sigma_node_id) {
2948        img_norm_fail!("sqrt node already assigned");
2949    }
2950    nodes.push(div_node_id);
2951    nodes.push(sigma_node_id);
2952
2953    let (add_node_id, mean_sq_vid, epsilon_scalar) =
2954        split_add_with_scalar(graph, sigma_input_vid, assigned, &mut nodes)?;
2955    if assigned.contains(&add_node_id) {
2956        img_norm_fail!("epsilon add already assigned");
2957    }
2958    nodes.push(add_node_id);
2959    let epsilon = epsilon_scalar;
2960    let mean_sq_vid = peel_numeric_casts(graph, mean_sq_vid, assigned, &mut nodes)?;
2961
2962    let squared_diff_vid = peel_mean_dims(graph, mean_sq_vid, &[3.0, 2.0], assigned, &mut nodes)?;
2963
2964    let (square_pow_node_id, square_pow_node) = node_from_value(graph, squared_diff_vid)?;
2965    if assigned.contains(&square_pow_node_id) {
2966        img_norm_fail!("square pow already assigned");
2967    }
2968    if !matches!(
2969        square_pow_node.label,
2970        AccelNodeLabel::Primitive(PrimitiveOp::ElemPow)
2971    ) {
2972        img_norm_fail!("variance pow not elem pow");
2973    }
2974    if square_pow_node.inputs.len() != 2 {
2975        img_norm_fail!("variance pow arity");
2976    }
2977    let exponent_trace = collect_scalar_constant(graph, square_pow_node.inputs[1])?;
2978    if !approx_eq(exponent_trace.value, 2.0) {
2979        img_norm_fail!("variance exponent != 2");
2980    }
2981    if exponent_trace.nodes.iter().any(|id| assigned.contains(id)) {
2982        img_norm_fail!("variance exponent nodes already assigned");
2983    }
2984    nodes.push(square_pow_node_id);
2985    nodes.extend(exponent_trace.nodes.iter().copied());
2986
2987    let diff_var_vid = square_pow_node.inputs[0];
2988    let (diff_var_node_id, diff_var_node) = node_from_value(graph, diff_var_vid)?;
2989    if assigned.contains(&diff_var_node_id) {
2990        img_norm_fail!("diff variance node already assigned");
2991    }
2992    if !matches!(
2993        diff_var_node.label,
2994        AccelNodeLabel::Primitive(PrimitiveOp::Sub)
2995    ) {
2996        img_norm_fail!("diff variance node not sub");
2997    }
2998    if diff_var_node.inputs.len() != 2 {
2999        img_norm_fail!("diff variance arity");
3000    }
3001    let imgs_vid = diff_var_node.inputs[0];
3002    let mu_vid = peel_numeric_casts(graph, diff_var_node.inputs[1], assigned, &mut nodes)?;
3003    nodes.push(diff_var_node_id);
3004
3005    let (diff_node_id, diff_node) = node_from_value(graph, diff_vid)?;
3006    if assigned.contains(&diff_node_id) {
3007        img_norm_fail!("diff node already assigned");
3008    }
3009    if !matches!(diff_node.label, AccelNodeLabel::Primitive(PrimitiveOp::Sub)) {
3010        img_norm_fail!("diff node not sub");
3011    }
3012    if diff_node.inputs.len() != 2 {
3013        img_norm_fail!("diff node arity");
3014    }
3015    let diff_mu_vid = peel_numeric_casts(graph, diff_node.inputs[1], assigned, &mut nodes)?;
3016    if diff_node.inputs[0] != imgs_vid || diff_mu_vid != mu_vid {
3017        img_norm_fail!("diff inputs mismatch with variance pair");
3018    }
3019    nodes.push(diff_node_id);
3020
3021    let mean_mu_input_vid = peel_mean_dims(graph, mu_vid, &[3.0, 2.0], assigned, &mut nodes)?;
3022    if mean_mu_input_vid != imgs_vid {
3023        img_norm_fail!("mean mu input mismatch");
3024    }
3025
3026    let input_info = graph.value(imgs_vid)?;
3027    match &input_info.shape {
3028        ShapeInfo::Tensor(dims) if dims.len() >= 2 => {}
3029        ShapeInfo::Unknown => {}
3030        other => {
3031            if log::log_enabled!(log::Level::Debug) {
3032                log::debug!(
3033                    "image_normalize: node {pow_node_id:?} input shape {:?}",
3034                    other
3035                );
3036            }
3037            img_norm_fail!("input not 3-d tensor");
3038        }
3039    }
3040
3041    nodes.sort_unstable();
3042    nodes.dedup();
3043
3044    Some(ImageNormalizeMatch {
3045        nodes,
3046        input: imgs_vid,
3047        epsilon,
3048        gain: gain_opt,
3049        bias: bias_opt,
3050        gamma: gamma_opt,
3051    })
3052}
3053
3054#[cfg(test)]
3055mod tests {
3056    use super::*;
3057    use crate::graph::{
3058        AccelGraph, AccelGraphTag, AccelNode, AccelNodeLabel, AccelOpCategory, InstrSpan,
3059        PrimitiveOp, ValueId, ValueInfo, ValueOrigin, VarKind,
3060    };
3061    use runmat_builtins::{Type, Value};
3062    use std::collections::HashMap as StdHashMap;
3063
3064    fn simple_elementwise_graph() -> AccelGraph {
3065        let values = vec![
3066            // Value 0: input tensor
3067            ValueInfo {
3068                id: 0,
3069                origin: ValueOrigin::Variable {
3070                    kind: VarKind::Global,
3071                    index: 0,
3072                },
3073                ty: Type::tensor(),
3074                shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3075                constant: None,
3076            },
3077            // Node 0 output value (value id 1)
3078            ValueInfo {
3079                id: 1,
3080                origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3081                ty: Type::tensor(),
3082                shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3083                constant: None,
3084            },
3085            // Node 1 output value (value id 2)
3086            ValueInfo {
3087                id: 2,
3088                origin: ValueOrigin::NodeOutput { node: 1, output: 0 },
3089                ty: Type::tensor(),
3090                shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3091                constant: None,
3092            },
3093        ];
3094
3095        let node0 = AccelNode {
3096            id: 0,
3097            label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3098            category: AccelOpCategory::Elementwise,
3099            inputs: vec![0, 0],
3100            outputs: vec![1],
3101            span: InstrSpan { start: 10, end: 10 },
3102            tags: vec![AccelGraphTag::Elementwise],
3103        };
3104        let node1 = AccelNode {
3105            id: 1,
3106            label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3107            category: AccelOpCategory::Elementwise,
3108            inputs: vec![1, 0],
3109            outputs: vec![2],
3110            span: InstrSpan { start: 11, end: 11 },
3111            tags: vec![AccelGraphTag::Elementwise],
3112        };
3113
3114        AccelGraph {
3115            nodes: vec![node0, node1],
3116            values,
3117            var_bindings: StdHashMap::new(),
3118            node_bindings: StdHashMap::new(),
3119        }
3120    }
3121
3122    #[test]
3123    fn detects_chain() {
3124        let graph = simple_elementwise_graph();
3125        let groups = detect_fusion_groups(&graph);
3126        assert_eq!(groups.len(), 1);
3127        let group = &groups[0];
3128        assert_eq!(group.nodes, vec![0, 1]);
3129        assert_eq!(group.kind, FusionKind::ElementwiseChain);
3130    }
3131
3132    #[test]
3133    fn builds_plan_and_template() {
3134        let graph = simple_elementwise_graph();
3135        let groups = detect_fusion_groups(&graph);
3136        let plan = FusionPlan::from_graph(&graph, &groups);
3137        assert_eq!(plan.groups.len(), 1);
3138        let group_plan = &plan.groups[0];
3139        assert!(group_plan.kernel.supported);
3140        let wgsl = group_plan.generate_wgsl("f32").expect("wgsl");
3141        assert!(wgsl.contains("@compute"));
3142        assert!(group_plan.group.element_count().is_some());
3143    }
3144
3145    #[test]
3146    fn stack_pattern_tracks_repeated_constants() {
3147        let values = vec![
3148            ValueInfo {
3149                id: 0,
3150                origin: ValueOrigin::Variable {
3151                    kind: VarKind::Global,
3152                    index: 0,
3153                },
3154                ty: Type::tensor(),
3155                shape: ShapeInfo::Tensor(vec![Some(4)]),
3156                constant: None,
3157            },
3158            ValueInfo {
3159                id: 1,
3160                origin: ValueOrigin::Constant,
3161                ty: Type::tensor(),
3162                shape: ShapeInfo::Tensor(vec![Some(4)]),
3163                constant: Some(Value::Num(1.0)),
3164            },
3165            ValueInfo {
3166                id: 2,
3167                origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3168                ty: Type::tensor(),
3169                shape: ShapeInfo::Tensor(vec![Some(4)]),
3170                constant: None,
3171            },
3172            ValueInfo {
3173                id: 3,
3174                origin: ValueOrigin::NodeOutput { node: 1, output: 0 },
3175                ty: Type::tensor(),
3176                shape: ShapeInfo::Tensor(vec![Some(4)]),
3177                constant: None,
3178            },
3179        ];
3180
3181        let node0 = AccelNode {
3182            id: 0,
3183            label: AccelNodeLabel::Primitive(PrimitiveOp::Add),
3184            category: AccelOpCategory::Elementwise,
3185            inputs: vec![0, 1],
3186            outputs: vec![2],
3187            span: InstrSpan { start: 5, end: 5 },
3188            tags: vec![AccelGraphTag::Elementwise],
3189        };
3190        let node1 = AccelNode {
3191            id: 1,
3192            label: AccelNodeLabel::Primitive(PrimitiveOp::Add),
3193            category: AccelOpCategory::Elementwise,
3194            inputs: vec![2, 1],
3195            outputs: vec![3],
3196            span: InstrSpan { start: 6, end: 6 },
3197            tags: vec![AccelGraphTag::Elementwise],
3198        };
3199
3200        let graph = AccelGraph {
3201            nodes: vec![node0, node1],
3202            values,
3203            var_bindings: StdHashMap::new(),
3204            node_bindings: StdHashMap::new(),
3205        };
3206
3207        let groups = detect_fusion_groups(&graph);
3208        assert_eq!(groups.len(), 1);
3209        let plan = FusionPlan::from_graph(&graph, &groups);
3210        let group_plan = &plan.groups[0];
3211        assert_eq!(group_plan.inputs.len(), 2);
3212        assert!(group_plan.stack_pattern.is_empty());
3213        assert!(group_plan.constants.contains_key(&1));
3214        assert!(group_plan.const_values.contains_key(&1));
3215    }
3216
3217    #[test]
3218    fn builtin_expr_supports_extended_set() {
3219        let mut exprs: StdHashMap<ValueId, String> = StdHashMap::new();
3220        exprs.insert(0, "v0".to_string());
3221        exprs.insert(1, "v1".to_string());
3222
3223        let log1p = super::builtin_expr("log1p", &[0], &exprs, "f32");
3224        assert!(log1p.is_some());
3225
3226        let log10 = super::builtin_expr("log10", &[0], &exprs, "f64");
3227        assert!(log10.unwrap().contains("log"));
3228
3229        let expm1 = super::builtin_expr("expm1", &[0], &exprs, "f32");
3230        assert!(expm1.unwrap().contains("exp"));
3231
3232        let floor = super::builtin_expr("floor", &[0], &exprs, "f32");
3233        assert_eq!(floor.unwrap(), "floor(v0)");
3234
3235        let atan2 = super::builtin_expr("atan2", &[0, 1], &exprs, "f32");
3236        assert_eq!(atan2.unwrap(), "atan2(v0, v1)");
3237
3238        let single = super::builtin_expr("single", &[0], &exprs, "f32");
3239        assert_eq!(single.unwrap(), "v0");
3240
3241        let double = super::builtin_expr("double", &[0], &exprs, "f64");
3242        assert_eq!(double.unwrap(), "v0");
3243    }
3244
3245    #[test]
3246    fn fanout_chain_with_casts_supported() {
3247        let values = vec![
3248            // Base input tensor
3249            ValueInfo {
3250                id: 0,
3251                origin: ValueOrigin::Variable {
3252                    kind: VarKind::Global,
3253                    index: 0,
3254                },
3255                ty: Type::tensor(),
3256                shape: ShapeInfo::Tensor(vec![Some(8)]),
3257                constant: None,
3258            },
3259            // tanh(x) output
3260            ValueInfo {
3261                id: 1,
3262                origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3263                ty: Type::tensor(),
3264                shape: ShapeInfo::Tensor(vec![Some(8)]),
3265                constant: None,
3266            },
3267            // constant scale before casting
3268            ValueInfo {
3269                id: 2,
3270                origin: ValueOrigin::Constant,
3271                ty: Type::Num,
3272                shape: ShapeInfo::Scalar,
3273                constant: Some(Value::Num(0.1)),
3274            },
3275            // single(0.1) output
3276            ValueInfo {
3277                id: 3,
3278                origin: ValueOrigin::NodeOutput { node: 1, output: 0 },
3279                ty: Type::Num,
3280                shape: ShapeInfo::Scalar,
3281                constant: None,
3282            },
3283            // scaled branch output
3284            ValueInfo {
3285                id: 4,
3286                origin: ValueOrigin::NodeOutput { node: 2, output: 0 },
3287                ty: Type::tensor(),
3288                shape: ShapeInfo::Tensor(vec![Some(8)]),
3289                constant: None,
3290            },
3291            // final add output
3292            ValueInfo {
3293                id: 5,
3294                origin: ValueOrigin::NodeOutput { node: 3, output: 0 },
3295                ty: Type::tensor(),
3296                shape: ShapeInfo::Tensor(vec![Some(8)]),
3297                constant: None,
3298            },
3299        ];
3300
3301        let tanh_node = AccelNode {
3302            id: 0,
3303            label: AccelNodeLabel::Builtin {
3304                name: "tanh".to_string(),
3305            },
3306            category: AccelOpCategory::Elementwise,
3307            inputs: vec![0],
3308            outputs: vec![1],
3309            span: InstrSpan { start: 10, end: 10 },
3310            tags: vec![AccelGraphTag::Elementwise],
3311        };
3312        let single_node = AccelNode {
3313            id: 1,
3314            label: AccelNodeLabel::Builtin {
3315                name: "single".to_string(),
3316            },
3317            category: AccelOpCategory::Elementwise,
3318            inputs: vec![2],
3319            outputs: vec![3],
3320            span: InstrSpan { start: 11, end: 11 },
3321            tags: vec![AccelGraphTag::Elementwise],
3322        };
3323        let mul_node = AccelNode {
3324            id: 2,
3325            label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3326            category: AccelOpCategory::Elementwise,
3327            inputs: vec![3, 0],
3328            outputs: vec![4],
3329            span: InstrSpan { start: 12, end: 12 },
3330            tags: vec![AccelGraphTag::Elementwise],
3331        };
3332        let add_node = AccelNode {
3333            id: 3,
3334            label: AccelNodeLabel::Primitive(PrimitiveOp::Add),
3335            category: AccelOpCategory::Elementwise,
3336            inputs: vec![1, 4],
3337            outputs: vec![5],
3338            span: InstrSpan { start: 13, end: 13 },
3339            tags: vec![AccelGraphTag::Elementwise],
3340        };
3341
3342        let graph = AccelGraph {
3343            nodes: vec![tanh_node, single_node, mul_node, add_node],
3344            values,
3345            var_bindings: StdHashMap::new(),
3346            node_bindings: StdHashMap::new(),
3347        };
3348
3349        let groups = detect_fusion_groups(&graph);
3350        assert_eq!(groups.len(), 1);
3351
3352        let plan = FusionPlan::from_graph(&graph, &groups);
3353        let group_plan = &plan.groups[0];
3354        assert!(group_plan.kernel.supported);
3355        let shader = group_plan.generate_wgsl("f32");
3356        assert!(shader
3357            .as_ref()
3358            .map(|wgsl| wgsl.contains("tanh") && wgsl.contains("output.data"))
3359            .unwrap_or(false));
3360    }
3361}