Skip to main content

runmat_accelerate/
fusion.rs

1#[cfg(not(target_arch = "wasm32"))]
2use std::cell::RefCell;
3use std::collections::{HashMap, HashSet};
4#[cfg(target_arch = "wasm32")]
5use std::sync::Mutex;
6use std::sync::{Arc, OnceLock, RwLock, Weak};
7
8use once_cell::sync::Lazy;
9use runmat_accelerate_api::ReductionFlavor;
10use runmat_builtins::Value;
11use serde::{Deserialize, Serialize};
12
13use crate::graph::{
14    AccelGraph, AccelNode, AccelNodeLabel, AccelOpCategory, InstrSpan, NodeId, PrimitiveOp,
15    ShapeInfo, ValueId, ValueInfo, ValueOrigin, VarBinding,
16};
17use crate::reduction_meta::{detect_reduction_signature, ReductionAxes, ReductionBehavior};
18use runmat_accelerate_api::CovNormalization;
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
21pub enum FusionKind {
22    ElementwiseChain,
23    Reduction,
24    MatmulEpilogue,
25    CenteredGram,
26    ImageNormalize,
27    PowerStepNormalize,
28    ExplainedVariance,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct FusionGroup {
33    pub id: usize,
34    pub kind: FusionKind,
35    pub nodes: Vec<NodeId>,
36    pub shape: ShapeInfo,
37    pub span: InstrSpan,
38    pub pattern: Option<FusionPattern>,
39    #[serde(default)]
40    pub stack_layout: Option<FusionStackLayout>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
44pub struct FusionStackLayout {
45    pub required_stack_operands: usize,
46    pub bindings: Vec<FusionStackValueBinding>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50pub struct FusionStackValueBinding {
51    pub value_id: ValueId,
52    pub stack_offset: usize,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub enum FusionPattern {
57    CenteredGram {
58        matrix: ValueId,
59        normalization: CovNormalization,
60    },
61    ImageNormalize(ImageNormalizePattern),
62    PowerStepNormalize {
63        lhs: ValueId,
64        rhs: ValueId,
65        epsilon: f64,
66    },
67    ExplainedVariance {
68        q: ValueId,
69        g: ValueId,
70    },
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ImageNormalizePattern {
75    pub input: ValueId,
76    pub epsilon: ImageScalar,
77    pub gain: Option<ImageScalar>,
78    pub bias: Option<ImageScalar>,
79    pub gamma: Option<ImageScalar>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub enum ImageScalar {
84    Constant(f64),
85    Value(ValueId),
86}
87
88pub fn detect_fusion_groups(graph: &AccelGraph) -> Vec<FusionGroup> {
89    if graph.nodes.is_empty() {
90        return Vec::new();
91    }
92
93    let consumer_map = build_consumer_map(graph);
94    let mut assigned: HashSet<NodeId> = HashSet::new();
95    let mut groups = Vec::new();
96    let mut group_id = 0usize;
97
98    detect_image_normalize(graph, &mut assigned, &mut groups, &mut group_id);
99    detect_explained_variance(graph, &mut assigned, &mut groups, &mut group_id);
100    detect_power_step_normalize(graph, &mut assigned, &mut groups, &mut group_id);
101    detect_centered_gram(graph, &mut assigned, &mut groups, &mut group_id);
102
103    for node in &graph.nodes {
104        // Elementwise chains
105        if assigned.contains(&node.id) {
106            continue;
107        }
108        let elementwise_like = node.is_elementwise() || is_elementwise_max_min(graph, node);
109        if !elementwise_like {
110            continue;
111        }
112        if node.outputs.is_empty() {
113            continue;
114        }
115        let mut current_shape = node_output_shape(graph, node);
116        if matches!(current_shape, ShapeInfo::Unknown | ShapeInfo::Scalar) {
117            continue;
118        }
119        let mut chain: Vec<NodeId> = Vec::new();
120        let mut frontier = node.id;
121        let mut local_seen: HashSet<NodeId> = HashSet::new();
122
123        loop {
124            if !local_seen.insert(frontier) {
125                break;
126            }
127            chain.push(frontier);
128            let next = find_next_elementwise(
129                graph,
130                frontier,
131                &assigned,
132                &local_seen,
133                &consumer_map,
134                &current_shape,
135            );
136            match next {
137                Some((next_id, next_shape)) => {
138                    frontier = next_id;
139                    current_shape = next_shape;
140                }
141                None => break,
142            }
143        }
144
145        if chain.len() > 1 {
146            expand_group_with_fanout(graph, &mut chain, &assigned, &consumer_map);
147            chain.sort_unstable_by_key(|id| {
148                graph
149                    .node(*id)
150                    .map(|node| node.span.start)
151                    .unwrap_or_default()
152            });
153            chain.dedup();
154            for id in &chain {
155                assigned.insert(*id);
156            }
157            let span = group_span(graph, &chain);
158            groups.push(FusionGroup {
159                id: group_id,
160                kind: FusionKind::ElementwiseChain,
161                nodes: chain,
162                shape: current_shape.clone(),
163                span,
164                pattern: None,
165                stack_layout: None,
166            });
167            group_id += 1;
168        }
169    }
170
171    // Reduction singletons (basic grouping; future: include eligible producers)
172    for node in &graph.nodes {
173        if assigned.contains(&node.id) {
174            continue;
175        }
176        if !node.is_reduction() || is_elementwise_max_min(graph, node) {
177            continue;
178        }
179        let span = InstrSpan {
180            start: node.span.start,
181            end: node.span.end,
182        };
183        groups.push(FusionGroup {
184            id: group_id,
185            kind: FusionKind::Reduction,
186            nodes: vec![node.id],
187            shape: node_output_shape(graph, node),
188            span,
189            pattern: None,
190            stack_layout: None,
191        });
192        group_id += 1;
193    }
194
195    // Matmul + simple epilogue (alpha/beta/row/col scale) chains
196    for node in &graph.nodes {
197        if node.category != AccelOpCategory::MatMul || assigned.contains(&node.id) {
198            continue;
199        }
200        if node.outputs.is_empty() {
201            continue;
202        }
203        // Require exactly one consumer chain and only elementwise ops we can fold
204        let mut chain: Vec<NodeId> = vec![node.id];
205        let mut frontier = node.id;
206        let mut ok = false;
207        loop {
208            // Find single consumer of the current frontier's output
209            let mut next_id_opt: Option<NodeId> = None;
210            for &out in &graph.node(frontier).unwrap().outputs {
211                if let Some(cons) = consumer_map.get(&out) {
212                    if cons.len() == 1 {
213                        next_id_opt = cons.iter().copied().next();
214                    } else {
215                        next_id_opt = None;
216                    }
217                }
218            }
219            let Some(next_id) = next_id_opt else { break };
220            let next = graph.node(next_id).unwrap();
221            if !next.is_elementwise() {
222                break;
223            }
224            // Allow only primitive elementwise ops we can fold: add/sub/mul and elementwise divide
225            let allowed = matches!(
226                next.label,
227                AccelNodeLabel::Primitive(PrimitiveOp::Add)
228                    | AccelNodeLabel::Primitive(PrimitiveOp::Sub)
229                    | AccelNodeLabel::Primitive(PrimitiveOp::Mul)
230                    | AccelNodeLabel::Primitive(PrimitiveOp::ElemMul)
231                    | AccelNodeLabel::Primitive(PrimitiveOp::ElemDiv)
232            );
233            if !allowed {
234                break;
235            }
236            chain.push(next_id);
237            frontier = next_id;
238            ok = true;
239        }
240        if ok {
241            for id in &chain {
242                assigned.insert(*id);
243            }
244            let span = group_span(graph, &chain);
245            groups.push(FusionGroup {
246                id: group_id,
247                kind: FusionKind::MatmulEpilogue,
248                nodes: chain,
249                shape: node_output_shape(graph, node),
250                span,
251                pattern: None,
252                stack_layout: None,
253            });
254            group_id += 1;
255        }
256    }
257
258    merge_downstream_fanout(graph, &mut groups, &consumer_map);
259    groups
260}
261
262fn expand_group_with_fanout(
263    graph: &AccelGraph,
264    chain: &mut Vec<NodeId>,
265    assigned: &HashSet<NodeId>,
266    consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
267) {
268    let base_start = chain
269        .iter()
270        .filter_map(|id| graph.node(*id).map(|node| node.span.start))
271        .min()
272        .unwrap_or(0);
273    let mut node_set: HashSet<NodeId> = chain.iter().copied().collect();
274    let mut changed = true;
275    while changed {
276        changed = false;
277        for node in &graph.nodes {
278            if node_set.contains(&node.id) {
279                continue;
280            }
281            if node.span.start < base_start {
282                continue;
283            }
284            if assigned.contains(&node.id) {
285                continue;
286            }
287            if !(node.is_elementwise() || is_elementwise_max_min(graph, node)) {
288                continue;
289            }
290            if node.outputs.is_empty() {
291                continue;
292            }
293            let mut feeds_group = false;
294            let mut all_consumers_ok = true;
295            for &out in &node.outputs {
296                if let Some(consumers) = consumer_map.get(&out) {
297                    let mut consumer_in_group = false;
298                    for consumer in consumers {
299                        if node_set.contains(consumer) {
300                            consumer_in_group = true;
301                        } else {
302                            all_consumers_ok = false;
303                            break;
304                        }
305                    }
306                    if !all_consumers_ok {
307                        break;
308                    }
309                    if consumer_in_group {
310                        feeds_group = true;
311                    }
312                } else {
313                    all_consumers_ok = false;
314                    break;
315                }
316            }
317            if !feeds_group || !all_consumers_ok {
318                continue;
319            }
320            let mut inputs_ok = true;
321            for &input in &node.inputs {
322                if let Some(info) = graph.value(input) {
323                    if let ValueOrigin::NodeOutput { node: producer, .. } = info.origin {
324                        if !node_set.contains(&producer) {
325                            if let Some(prod_node) = graph.node(producer) {
326                                if prod_node.span.start >= base_start {
327                                    inputs_ok = false;
328                                    break;
329                                }
330                            } else {
331                                inputs_ok = false;
332                                break;
333                            }
334                        }
335                    }
336                }
337            }
338            if inputs_ok {
339                node_set.insert(node.id);
340                chain.push(node.id);
341                changed = true;
342            }
343        }
344    }
345}
346
347fn build_consumer_map(graph: &AccelGraph) -> HashMap<ValueId, HashSet<NodeId>> {
348    let mut map: HashMap<ValueId, HashSet<NodeId>> = HashMap::new();
349    for node in &graph.nodes {
350        for &input in &node.inputs {
351            if let Some(value) = graph.value(input) {
352                if matches!(value.origin, crate::graph::ValueOrigin::NodeOutput { .. }) {
353                    map.entry(input).or_default().insert(node.id);
354                }
355            }
356        }
357    }
358    map
359}
360
361fn merge_downstream_fanout(
362    graph: &AccelGraph,
363    groups: &mut Vec<FusionGroup>,
364    consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
365) {
366    let mut changed = true;
367    while changed {
368        changed = false;
369        let mut node_group: HashMap<NodeId, usize> = HashMap::new();
370        for (idx, group) in groups.iter().enumerate() {
371            if group.kind.is_elementwise() {
372                for &node in &group.nodes {
373                    node_group.insert(node, idx);
374                }
375            }
376        }
377        'outer: for target_idx in 0..groups.len() {
378            if !groups[target_idx].kind.is_elementwise() {
379                continue;
380            }
381            let base_start = groups[target_idx].span.start;
382            let mut merge_indices: Vec<usize> = Vec::new();
383            for &node_id in &groups[target_idx].nodes {
384                let Some(node) = graph.node(node_id) else {
385                    continue;
386                };
387                for &input in &node.inputs {
388                    if let Some(info) = graph.value(input) {
389                        if let ValueOrigin::NodeOutput { node: producer, .. } = info.origin {
390                            if let Some(&source_idx) = node_group.get(&producer) {
391                                if source_idx == target_idx {
392                                    continue;
393                                }
394                                let source_group = &groups[source_idx];
395                                if !source_group.kind.is_elementwise() {
396                                    continue;
397                                }
398                                if source_group.span.start < base_start {
399                                    continue;
400                                }
401                                if !group_consumers_subset(
402                                    source_group,
403                                    target_idx,
404                                    groups,
405                                    consumer_map,
406                                    graph,
407                                ) {
408                                    continue;
409                                }
410                                merge_indices.push(source_idx);
411                            }
412                        }
413                    }
414                }
415            }
416            if merge_indices.is_empty() {
417                continue;
418            }
419            merge_indices.sort_unstable();
420            merge_indices.dedup();
421            for idx in &merge_indices {
422                let nodes = groups[*idx].nodes.clone();
423                groups[target_idx].nodes.extend(nodes);
424                groups[*idx].nodes.clear();
425            }
426            groups[target_idx]
427                .nodes
428                .sort_unstable_by_key(|id| graph.node(*id).map(|n| n.span.start).unwrap_or(0));
429            groups[target_idx].nodes.dedup();
430            groups[target_idx].span = group_span(graph, &groups[target_idx].nodes);
431            changed = true;
432            break 'outer;
433        }
434        if changed {
435            groups.retain(|group| !group.nodes.is_empty());
436        }
437    }
438}
439
440fn group_consumers_subset(
441    source_group: &FusionGroup,
442    target_idx: usize,
443    groups: &[FusionGroup],
444    consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
445    graph: &AccelGraph,
446) -> bool {
447    let target_nodes: HashSet<NodeId> = groups[target_idx].nodes.iter().copied().collect();
448    let source_nodes: HashSet<NodeId> = source_group.nodes.iter().copied().collect();
449    for &node_id in &source_group.nodes {
450        let Some(node) = graph.node(node_id) else {
451            continue;
452        };
453        for &out in &node.outputs {
454            if let Some(consumers) = consumer_map.get(&out) {
455                for consumer in consumers {
456                    if !source_nodes.contains(consumer) && !target_nodes.contains(consumer) {
457                        return false;
458                    }
459                }
460            }
461        }
462    }
463    true
464}
465
466fn node_output_shape(graph: &AccelGraph, node: &AccelNode) -> ShapeInfo {
467    let mut shape = ShapeInfo::Scalar;
468    for &output in &node.outputs {
469        if let Some(info) = graph.value(output) {
470            shape = shape.unify(&info.shape);
471        }
472    }
473    shape
474}
475
476fn find_next_elementwise(
477    graph: &AccelGraph,
478    node_id: NodeId,
479    assigned: &HashSet<NodeId>,
480    local_seen: &HashSet<NodeId>,
481    consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
482    current_shape: &ShapeInfo,
483) -> Option<(NodeId, ShapeInfo)> {
484    let node = graph.node(node_id)?;
485    let mut candidate: Option<(NodeId, ShapeInfo)> = None;
486
487    for &output in &node.outputs {
488        let consumers = consumer_map.get(&output)?;
489        if consumers.len() != 1 {
490            return None;
491        }
492        let next_id = *consumers.iter().next()?;
493        if next_id <= node_id || assigned.contains(&next_id) || local_seen.contains(&next_id) {
494            return None;
495        }
496        let next_node = graph.node(next_id)?;
497        if !(next_node.is_elementwise() || is_elementwise_max_min(graph, next_node)) {
498            return None;
499        }
500        // Ensure the edge we follow is actually used by next node
501        if !next_node.inputs.contains(&output) {
502            continue;
503        }
504        let next_shape = node_output_shape(graph, next_node);
505        if matches!(next_shape, ShapeInfo::Unknown) {
506            return None;
507        }
508        let unified = current_shape.unify(&next_shape);
509        if matches!(unified, ShapeInfo::Unknown) {
510            return None;
511        }
512        candidate = Some((next_id, unified));
513        break;
514    }
515
516    candidate
517}
518
519fn is_elementwise_max_min(graph: &AccelGraph, node: &AccelNode) -> bool {
520    match &node.label {
521        AccelNodeLabel::Builtin { name }
522            if name.eq_ignore_ascii_case("max") || name.eq_ignore_ascii_case("min") =>
523        {
524            if node.inputs.len() < 2 {
525                return false;
526            }
527            !value_is_placeholder(graph, node.inputs[1])
528        }
529        _ => false,
530    }
531}
532
533fn value_is_placeholder(graph: &AccelGraph, vid: ValueId) -> bool {
534    let Some(info) = graph.value(vid) else {
535        return false;
536    };
537    let Some(constant) = &info.constant else {
538        return false;
539    };
540    match constant {
541        Value::Tensor(t) => t.data.is_empty(),
542        Value::LogicalArray(l) => l.data.is_empty(),
543        Value::StringArray(sa) => sa.data.is_empty(),
544        Value::CharArray(ca) => ca.data.is_empty(),
545        Value::Cell(cell) => cell.data.is_empty(),
546        Value::String(s) => s.is_empty(),
547        _ => false,
548    }
549}
550
551fn group_span(graph: &AccelGraph, nodes: &[NodeId]) -> InstrSpan {
552    let mut start = usize::MAX;
553    let mut end = 0usize;
554    for &id in nodes {
555        if let Some(node) = graph.node(id) {
556            start = start.min(node.span.start);
557            end = end.max(node.span.end);
558        }
559    }
560    if start == usize::MAX {
561        start = 0;
562    }
563    InstrSpan { start, end }
564}
565
566fn merge_stack_layout_with_stack_pattern(
567    existing: Option<&FusionStackLayout>,
568    inputs: &[ValueId],
569    stack_pattern: &[usize],
570) -> Option<FusionStackLayout> {
571    if existing.is_none() && stack_pattern.is_empty() {
572        return None;
573    }
574
575    let mut bindings = existing
576        .map(|layout| layout.bindings.clone())
577        .unwrap_or_default();
578    for (stack_offset, &input_idx) in stack_pattern.iter().enumerate() {
579        let &value_id = inputs.get(input_idx)?;
580        if bindings.iter().any(|binding| binding.value_id == value_id) {
581            continue;
582        }
583        bindings.push(FusionStackValueBinding {
584            value_id,
585            stack_offset,
586        });
587    }
588
589    let required_stack_operands = existing
590        .map(|layout| layout.required_stack_operands)
591        .unwrap_or(0)
592        .max(stack_pattern.len());
593
594    Some(FusionStackLayout {
595        required_stack_operands,
596        bindings,
597    })
598}
599
600#[derive(Debug, Clone)]
601pub struct FusionPlan {
602    pub groups: Vec<FusionGroupPlan>,
603}
604
605#[derive(Debug, Clone)]
606pub struct FusionGroupPlan {
607    pub index: usize,
608    pub group: FusionGroup,
609    pub operations: Vec<FusionOp>,
610    pub inputs: Vec<ValueId>,
611    pub stack_pattern: Vec<usize>,
612    pub constants: HashMap<usize, Value>,
613    pub const_values: HashMap<ValueId, Value>,
614    pub materialized_stores: Vec<FusionStoreMaterialization>,
615    pub output: Option<ValueId>,
616    pub kernel: FusionKernelSpec,
617    // For reductions: track the ValueId of the data tensor being reduced, if identifiable
618    pub reduction_data: Option<ValueId>,
619    // For reductions: track the ValueId of the dim argument when identifiable
620    pub reduction_dim: Option<ValueId>,
621    // For reductions: flavor metadata (e.g., sum vs mean scaling)
622    pub reduction_flavor: Option<ReductionFlavor>,
623    // For reductions: axis selection metadata (e.g., explicit dims vs 'all')
624    pub reduction_axes: Option<ReductionAxes>,
625    pub pattern: Option<FusionPattern>,
626}
627
628#[derive(Debug, Clone)]
629pub struct FusionStoreMaterialization {
630    pub value_id: ValueId,
631    pub binding: VarBinding,
632}
633
634#[derive(Debug, Clone)]
635pub enum FusionOp {
636    Primitive {
637        op: PrimitiveOp,
638        inputs: Vec<ValueId>,
639        output: Option<ValueId>,
640    },
641    Builtin {
642        name: String,
643        inputs: Vec<ValueId>,
644        output: Option<ValueId>,
645    },
646}
647
648#[derive(Debug, Clone)]
649pub struct FusionKernelSpec {
650    pub kind: FusionKind,
651    pub supported: bool,
652}
653
654impl FusionKernelSpec {
655    fn new(kind: FusionKind, supported: bool) -> Self {
656        Self { kind, supported }
657    }
658}
659
660#[derive(Clone, Debug)]
661pub struct ActiveFusion {
662    pub kind: FusionKind,
663    pub span: InstrSpan,
664    pub element_count: Option<usize>,
665    pub supported: bool,
666}
667
668struct ActiveContext {
669    plan: Arc<FusionPlan>,
670    active_group: Option<usize>,
671}
672
673static PLAN_CACHE: Lazy<RwLock<HashMap<usize, Weak<FusionPlan>>>> =
674    Lazy::new(|| RwLock::new(HashMap::new()));
675
676#[cfg(not(target_arch = "wasm32"))]
677thread_local! {
678    static ACTIVE_PLAN: RefCell<Option<ActiveContext>> = const { RefCell::new(None) };
679}
680#[cfg(target_arch = "wasm32")]
681static ACTIVE_PLAN: Lazy<Mutex<Option<ActiveContext>>> = Lazy::new(|| Mutex::new(None));
682
683#[cfg(not(target_arch = "wasm32"))]
684fn with_active_context<R>(f: impl FnOnce(&mut Option<ActiveContext>) -> R) -> R {
685    ACTIVE_PLAN.with(|ctx| {
686        let mut slot = ctx.borrow_mut();
687        f(&mut slot)
688    })
689}
690
691#[cfg(target_arch = "wasm32")]
692fn with_active_context<R>(f: impl FnOnce(&mut Option<ActiveContext>) -> R) -> R {
693    let mut slot = ACTIVE_PLAN.lock().expect("active plan mutex poisoned");
694    f(&mut slot)
695}
696
697fn fusion_debug_enabled() -> bool {
698    static FLAG: OnceLock<bool> = OnceLock::new();
699    *FLAG.get_or_init(|| match std::env::var("RUNMAT_DEBUG_FUSION") {
700        Ok(v) => v == "1" || v.eq_ignore_ascii_case("true") || v.eq_ignore_ascii_case("yes"),
701        Err(_) => false,
702    })
703}
704
705pub fn prepare_fusion_plan(
706    graph: Option<&AccelGraph>,
707    groups: &[FusionGroup],
708    candidate_group_count: usize,
709) -> Option<Arc<FusionPlan>> {
710    let graph = graph?;
711    if candidate_group_count == 0 {
712        if !groups.is_empty() && fusion_debug_enabled() {
713            log::debug!(
714                "fusion plan preparation: executable bytecode fusion groups present ({}) but semantic candidate groups are absent",
715                groups.len()
716            );
717        }
718        return None;
719    }
720    if groups.is_empty() {
721        if candidate_group_count > 0 && fusion_debug_enabled() {
722            log::debug!(
723                "fusion plan preparation: semantic candidate groups present ({}) but executable bytecode fusion groups are empty",
724                candidate_group_count
725            );
726        }
727        return None;
728    }
729    let groups = sanitize_runtime_groups(graph, groups);
730    if groups.is_empty() {
731        if fusion_debug_enabled() {
732            log::debug!(
733                "fusion plan preparation: semantic-gated bytecode groups could not be reconciled against runtime accel graph nodes"
734            );
735        }
736        return None;
737    }
738    let key = graph as *const AccelGraph as usize;
739    if let Some(plan) = PLAN_CACHE
740        .read()
741        .ok()
742        .and_then(|guard| guard.get(&key).and_then(|weak| weak.upgrade()))
743    {
744        return Some(plan);
745    }
746
747    let plan = FusionPlan::from_graph(graph, &groups);
748    let plan = Arc::new(plan);
749    if let Ok(mut guard) = PLAN_CACHE.write() {
750        guard.insert(key, Arc::downgrade(&plan));
751    }
752    Some(plan)
753}
754
755fn sanitize_runtime_groups(graph: &AccelGraph, groups: &[FusionGroup]) -> Vec<FusionGroup> {
756    groups
757        .iter()
758        .filter_map(|group| {
759            let had_explicit_mapped_nodes = !group.nodes.is_empty();
760            let mut sanitized = group.clone();
761            sanitized.nodes.retain(|id| {
762                graph
763                    .node(*id)
764                    .map(|node| {
765                        node_matches_runtime_group_kind(graph, node, &sanitized.kind)
766                            && node_within_group_span(node, &sanitized.span)
767                    })
768                    .unwrap_or(false)
769            });
770            if sanitized.nodes.is_empty() && !had_explicit_mapped_nodes {
771                sanitized.nodes = graph
772                    .nodes
773                    .iter()
774                    .filter(|node| {
775                        node_matches_runtime_group_kind(graph, node, &sanitized.kind)
776                            && node_within_group_span(node, &sanitized.span)
777                    })
778                    .map(|node| node.id)
779                    .collect();
780            }
781            sanitized.nodes.sort_unstable_by_key(|node_id| {
782                graph
783                    .node(*node_id)
784                    .map(|node| (node.span.start, node.span.end, node.id))
785                    .unwrap_or((usize::MAX, usize::MAX, *node_id))
786            });
787            sanitized.nodes.dedup();
788            if sanitized.nodes.is_empty() {
789                None
790            } else {
791                Some(sanitized)
792            }
793        })
794        .collect()
795}
796
797fn node_matches_runtime_group_kind(
798    graph: &AccelGraph,
799    node: &AccelNode,
800    kind: &FusionKind,
801) -> bool {
802    match kind {
803        FusionKind::ElementwiseChain => {
804            node.is_elementwise()
805                || node.category == AccelOpCategory::Transpose
806                || is_elementwise_max_min(graph, node)
807        }
808        FusionKind::Reduction => node.is_reduction(),
809        FusionKind::MatmulEpilogue => {
810            node.category == AccelOpCategory::MatMul
811                || node.is_elementwise()
812                || node.category == AccelOpCategory::Transpose
813        }
814        FusionKind::CenteredGram
815        | FusionKind::ImageNormalize
816        | FusionKind::PowerStepNormalize
817        | FusionKind::ExplainedVariance => true,
818    }
819}
820
821fn node_within_group_span(node: &AccelNode, span: &InstrSpan) -> bool {
822    node.span.start >= span.start && node.span.end <= span.end
823}
824
825pub fn activate_fusion_plan(plan: Option<Arc<FusionPlan>>) {
826    with_active_context(|slot| {
827        *slot = plan.map(|plan| ActiveContext {
828            plan,
829            active_group: None,
830        });
831    });
832}
833
834pub fn deactivate_fusion_plan() {
835    with_active_context(|slot| {
836        slot.take();
837    });
838}
839
840pub fn set_current_pc(pc: usize) {
841    with_active_context(|slot| {
842        if let Some(context) = slot.as_mut() {
843            context.active_group = context.plan.group_for_pc(pc);
844        }
845    });
846}
847
848pub fn active_fusion() -> Option<ActiveFusion> {
849    with_active_context(|slot| {
850        slot.as_ref()
851            .and_then(|context| {
852                context
853                    .active_group
854                    .and_then(|idx| context.plan.groups.get(idx))
855            })
856            .map(|plan| ActiveFusion {
857                kind: plan.group.kind.clone(),
858                span: plan.group.span.clone(),
859                element_count: plan.element_count(),
860                supported: plan.kernel.supported,
861            })
862    })
863}
864
865pub fn active_group_plan_clone() -> Option<FusionGroupPlan> {
866    with_active_context(|slot| {
867        slot.as_ref().and_then(|context| {
868            context
869                .active_group
870                .and_then(|idx| context.plan.groups.get(idx).cloned())
871        })
872    })
873}
874
875impl FusionPlan {
876    pub fn from_graph(graph: &AccelGraph, groups: &[FusionGroup]) -> Self {
877        let plans = groups
878            .iter()
879            .enumerate()
880            .map(|(idx, group)| FusionGroupPlan::new(idx, group.clone(), graph))
881            .collect();
882        Self { groups: plans }
883    }
884
885    fn group_for_pc(&self, pc: usize) -> Option<usize> {
886        self.groups
887            .iter()
888            .find(|plan| pc >= plan.group.span.start && pc <= plan.group.span.end)
889            .map(|plan| plan.index)
890    }
891}
892
893impl From<Vec<FusionGroupPlan>> for FusionPlan {
894    fn from(groups: Vec<FusionGroupPlan>) -> Self {
895        Self { groups }
896    }
897}
898
899fn log_plan_stack_pattern(stage: &str, plan: &FusionGroupPlan, graph: &AccelGraph) {
900    if !fusion_debug_enabled() || plan.stack_pattern.is_empty() {
901        return;
902    }
903    let mut pattern_meta: Vec<String> = Vec::with_capacity(plan.stack_pattern.len());
904    for (pos, input_idx) in plan.stack_pattern.iter().enumerate() {
905        let value_id = plan.inputs.get(*input_idx).copied();
906        if let Some(vid) = value_id {
907            if let Some(info) = graph.value(vid) {
908                let node_label = match info.origin {
909                    ValueOrigin::NodeOutput { node, .. } => graph
910                        .node(node)
911                        .map(|n| format!("{:?}", n.label))
912                        .unwrap_or_else(|| "<missing-node>".to_string()),
913                    _ => String::new(),
914                };
915                pattern_meta.push(format!(
916                    "#{}:input_idx={} vid={} origin={:?} label={}",
917                    pos, input_idx, vid, info.origin, node_label
918                ));
919            } else {
920                pattern_meta.push(format!(
921                    "#{}:input_idx={} vid={} origin=<missing>",
922                    pos, input_idx, vid
923                ));
924            }
925        } else {
926            pattern_meta.push(format!("#{}:input_idx={} vid=<missing>", pos, input_idx));
927        }
928    }
929    log::trace!(
930        "fusion plan {} {} stack_pattern={:?} meta={:?}",
931        plan.index,
932        stage,
933        plan.stack_pattern,
934        pattern_meta
935    );
936}
937
938impl FusionGroupPlan {
939    fn new(index: usize, group: FusionGroup, graph: &AccelGraph) -> Self {
940        let node_set: HashSet<NodeId> = group.nodes.iter().copied().collect();
941        let mut seen_inputs: HashMap<ValueId, usize> = HashMap::new();
942        let mut inputs: Vec<ValueId> = Vec::new();
943        let mut stack_pattern: Vec<usize> = Vec::new();
944        let mut constants: HashMap<usize, Value> = HashMap::new();
945        let const_values: HashMap<ValueId, Value> = HashMap::new();
946        let mut operations = Vec::new();
947        let mut reduction_flavor: Option<ReductionFlavor> = None;
948        let mut reduction_axes: Option<ReductionAxes> = None;
949        let mut reduction_data: Option<ValueId> = None;
950        let mut reduction_dim: Option<ValueId> = None;
951        let mut output: Option<ValueId> = None;
952
953        let is_reduction_group = group.kind.is_reduction();
954        for node_id in &group.nodes {
955            let Some(node) = graph.node(*node_id) else {
956                continue;
957            };
958            for input in &node.inputs {
959                let binding = graph.var_binding(*input);
960                let (external, is_variable, maybe_constant) = match graph.value(*input) {
961                    Some(info) => match &info.origin {
962                        ValueOrigin::NodeOutput { node: origin, .. }
963                            if node_set.contains(origin) =>
964                        {
965                            (false, false, None)
966                        }
967                        ValueOrigin::Variable { .. } => (true, true, None),
968                        ValueOrigin::NodeOutput { .. } if binding.is_some() => (true, true, None),
969                        ValueOrigin::Constant => (true, false, info.constant.clone()),
970                        _ => (true, false, None),
971                    },
972                    None => (true, false, None),
973                };
974                if external {
975                    // Special handling for reductions: do NOT include constants in inputs;
976                    // only the data tensor should be an input. Constants are recorded separately.
977                    if is_reduction_group {
978                        if let Some(constant) = maybe_constant.clone() {
979                            // Assign a synthetic key for constants; keys are not positional for reductions
980                            let key = constants.len() + 1000;
981                            constants.insert(key, constant);
982                            continue;
983                        }
984                        // Only include the reduction data operand as an input
985                        if let Some(data_id) = reduction_data {
986                            if *input != data_id {
987                                // Skip non-data external inputs for reduction groups
988                                continue;
989                            }
990                        }
991                    }
992
993                    let mut newly_added = false;
994                    let input_idx = if let Some(idx) = seen_inputs.get(input) {
995                        *idx
996                    } else {
997                        let idx = inputs.len();
998                        inputs.push(*input);
999                        seen_inputs.insert(*input, idx);
1000                        newly_added = true;
1001                        idx
1002                    };
1003
1004                    if fusion_debug_enabled() {
1005                        let origin = graph.value(*input).map(|v| v.origin.clone());
1006                        log::trace!(
1007                            "fusion plan #{:?} consider input vid={} origin={:?} binding={:?} newly_added={} is_variable={} stack_candidate={}",
1008                            index,
1009                            input,
1010                            origin,
1011                            binding,
1012                            newly_added,
1013                            is_variable,
1014                            !is_variable && newly_added
1015                        );
1016                    }
1017                    if let Some(constant) = maybe_constant.clone() {
1018                        constants.insert(input_idx, constant);
1019                    } else if !is_variable && newly_added {
1020                        let allow_stack = match graph.value(*input) {
1021                            Some(info) => match info.origin {
1022                                ValueOrigin::NodeOutput { node, .. } => graph
1023                                    .node(node)
1024                                    .map(|n| n.span.start <= group.span.start)
1025                                    .unwrap_or(false),
1026                                _ => true,
1027                            },
1028                            None => true,
1029                        };
1030                        if allow_stack {
1031                            stack_pattern.push(input_idx);
1032                        } else if fusion_debug_enabled() {
1033                            log::trace!(
1034                                "fusion plan {} skipping stack candidate vid={} origin_after_span",
1035                                index,
1036                                input
1037                            );
1038                        }
1039                    } else if !is_variable
1040                        && !newly_added
1041                        && matches!(
1042                            graph.value(*input).map(|v| &v.origin),
1043                            Some(ValueOrigin::Constant)
1044                        )
1045                    {
1046                    }
1047                }
1048            }
1049
1050            let op = match &node.label {
1051                AccelNodeLabel::Primitive(p) => FusionOp::Primitive {
1052                    op: *p,
1053                    inputs: node.inputs.clone(),
1054                    output: node.outputs.first().copied(),
1055                },
1056                AccelNodeLabel::Builtin { name } => FusionOp::Builtin {
1057                    name: name.clone(),
1058                    inputs: node.inputs.clone(),
1059                    output: node.outputs.first().copied(),
1060                },
1061                AccelNodeLabel::Unknown => FusionOp::Primitive {
1062                    op: PrimitiveOp::UPlus,
1063                    inputs: node.inputs.clone(),
1064                    output: node.outputs.first().copied(),
1065                },
1066            };
1067            operations.push(op);
1068
1069            if let Some(out) = node.outputs.first().copied() {
1070                output = Some(out);
1071            }
1072            // Generic reduction signature (no name checks)
1073            if node.is_reduction() {
1074                if let Some(sig) = detect_reduction_signature(graph, node) {
1075                    reduction_data = Some(sig.data_input);
1076                    reduction_dim = sig.dim_arg;
1077                    reduction_flavor = Some(match sig.behavior {
1078                        ReductionBehavior::MeanLike => ReductionFlavor::Mean,
1079                        _ => ReductionFlavor::Sum,
1080                    });
1081                    reduction_axes = Some(sig.axes.clone());
1082                }
1083            }
1084        }
1085
1086        let kind = group.kind.clone();
1087        let pattern = group.pattern.clone();
1088        let mut plan = Self {
1089            index,
1090            group,
1091            operations,
1092            stack_pattern,
1093            constants,
1094            const_values,
1095            materialized_stores: Vec::new(),
1096            inputs,
1097            output,
1098            kernel: FusionKernelSpec::new(kind, true),
1099            reduction_data,
1100            reduction_dim,
1101            reduction_flavor,
1102            reduction_axes,
1103            pattern,
1104        };
1105
1106        log_plan_stack_pattern("initial", &plan, graph);
1107
1108        // Record constant ValueIds for all groups for easier downstream analysis
1109        for node_id in &plan.group.nodes {
1110            if let Some(node) = graph.node(*node_id) {
1111                for &inp in &node.inputs {
1112                    if let Some(info) = graph.value(inp) {
1113                        if let Some(cv) = info.constant.clone() {
1114                            plan.const_values.insert(inp, cv);
1115                        }
1116                    }
1117                }
1118            }
1119        }
1120
1121        // For reduction groups, externalize only real tensor dependencies; keep constants separate
1122        if plan.group.kind.is_reduction() {
1123            if let Some(data_vid) = plan.reduction_data {
1124                let original_inputs = plan.inputs.clone();
1125                let original_stack_pattern = plan.stack_pattern.clone();
1126                // Record constant ValueIds for codegen
1127                // Build dependency map from op outputs to inputs
1128                let mut prod: HashMap<ValueId, Vec<ValueId>> = HashMap::new();
1129                for op in &plan.operations {
1130                    match op {
1131                        FusionOp::Primitive {
1132                            inputs,
1133                            output,
1134                            op: _,
1135                        } => {
1136                            if let Some(out) = output {
1137                                prod.insert(*out, inputs.clone());
1138                            }
1139                        }
1140                        FusionOp::Builtin {
1141                            name: _,
1142                            inputs,
1143                            output,
1144                        } => {
1145                            if let Some(out) = output {
1146                                prod.insert(*out, inputs.clone());
1147                            }
1148                        }
1149                    }
1150                }
1151                let mut deps: Vec<ValueId> = Vec::new();
1152                let mut visited: HashSet<ValueId> = HashSet::new();
1153                let mut stack: Vec<ValueId> = vec![data_vid];
1154                // Track extra ops we discover outside the original group that are safe to inline
1155                let mut extra_ops: Vec<FusionOp> = Vec::new();
1156                let mut added_nodes: HashSet<ValueId> = HashSet::new();
1157                while let Some(cur) = stack.pop() {
1158                    if !visited.insert(cur) {
1159                        continue;
1160                    }
1161                    if graph.var_binding(cur).is_some() {
1162                        if !deps.contains(&cur) {
1163                            deps.push(cur);
1164                        }
1165                        continue;
1166                    }
1167                    if let Some(info) = graph.value(cur) {
1168                        if matches!(info.origin, ValueOrigin::Variable { .. }) {
1169                            if !deps.contains(&cur) {
1170                                deps.push(cur);
1171                            }
1172                            continue;
1173                        }
1174                    }
1175                    // Do not short-circuit on the reduction_data itself; expand through its producers first.
1176                    if original_inputs.contains(&cur) && cur != data_vid {
1177                        if !deps.contains(&cur) {
1178                            deps.push(cur);
1179                        }
1180                        continue;
1181                    }
1182                    if let Some(parents) = prod.get(&cur) {
1183                        for p in parents {
1184                            stack.push(*p);
1185                        }
1186                        continue;
1187                    }
1188                    // If not produced by an op in this group, try to expand through safe producer nodes
1189                    if let Some((_, node)) = node_from_value(graph, cur) {
1190                        // Only consider simple arithmetic producers we know how to fold
1191                        match &node.label {
1192                            AccelNodeLabel::Primitive(PrimitiveOp::Mul)
1193                            | AccelNodeLabel::Primitive(PrimitiveOp::ElemMul)
1194                            | AccelNodeLabel::Primitive(PrimitiveOp::ElemDiv)
1195                            | AccelNodeLabel::Primitive(PrimitiveOp::ElemLeftDiv)
1196                            | AccelNodeLabel::Primitive(PrimitiveOp::Add)
1197                            | AccelNodeLabel::Primitive(PrimitiveOp::Sub) => {
1198                                // Record op for codegen and traverse inputs
1199                                if added_nodes.insert(cur) {
1200                                    extra_ops.push(FusionOp::Primitive {
1201                                        op: match node.label {
1202                                            AccelNodeLabel::Primitive(op) => op,
1203                                            _ => PrimitiveOp::UPlus,
1204                                        },
1205                                        inputs: node.inputs.clone(),
1206                                        output: node.outputs.first().copied(),
1207                                    });
1208                                }
1209                                for &p in &node.inputs {
1210                                    stack.push(p);
1211                                }
1212                                continue;
1213                            }
1214                            AccelNodeLabel::Primitive(PrimitiveOp::ElemPow) => {
1215                                // Only accept power with constant exponent (typically 2 for squares)
1216                                if node.inputs.len() == 2 {
1217                                    if let Some(exp) = value_constant_f64(graph, node.inputs[1]) {
1218                                        if exp.is_finite() {
1219                                            if added_nodes.insert(cur) {
1220                                                extra_ops.push(FusionOp::Primitive {
1221                                                    op: PrimitiveOp::ElemPow,
1222                                                    inputs: node.inputs.clone(),
1223                                                    output: node.outputs.first().copied(),
1224                                                });
1225                                            }
1226                                            stack.push(node.inputs[0]);
1227                                            // Treat exponent as constant dependency for codegen
1228                                            stack.push(node.inputs[1]);
1229                                            continue;
1230                                        }
1231                                    }
1232                                }
1233                                // Fallback: treat as leaf dependency
1234                            }
1235                            AccelNodeLabel::Builtin { name } => {
1236                                // Allow simple casts to flow through (single/double)
1237                                if (name.eq_ignore_ascii_case("single")
1238                                    || name.eq_ignore_ascii_case("double"))
1239                                    && node.inputs.len() == 1
1240                                {
1241                                    stack.push(node.inputs[0]);
1242                                    continue;
1243                                }
1244                                // Unknown builtin: treat as leaf
1245                            }
1246                            _ => {
1247                                // Unknown producer: treat as leaf
1248                            }
1249                        }
1250                    }
1251                }
1252                // Ensure direct parents of the reduction data are materialized as inputs
1253                if let Some(parents) = prod.get(&data_vid) {
1254                    for &p in parents {
1255                        if !deps.contains(&p) {
1256                            // Skip trivial constants embedded in const_values; those are handled separately
1257                            let is_const = plan.const_values.contains_key(&p)
1258                                || graph.value(p).and_then(|vi| vi.constant.as_ref()).is_some();
1259                            if !is_const {
1260                                deps.push(p);
1261                            }
1262                        }
1263                    }
1264                }
1265                // Prepend the newly discovered ops so they are available to codegen
1266                // Keep original operations as well (the reduction op itself)
1267                if !extra_ops.is_empty() {
1268                    // Ensure a stable order: extra ops first
1269                    let mut new_ops = Vec::with_capacity(extra_ops.len() + plan.operations.len());
1270                    new_ops.extend(extra_ops);
1271                    new_ops.append(&mut plan.operations);
1272                    plan.operations = new_ops;
1273                }
1274                plan.inputs = deps;
1275                // Ensure constants referenced by any newly added operations are recorded.
1276                for op in &plan.operations {
1277                    let inputs = match op {
1278                        FusionOp::Primitive { inputs, .. } => inputs,
1279                        FusionOp::Builtin { inputs, .. } => inputs,
1280                    };
1281                    for vid in inputs {
1282                        if plan.const_values.contains_key(vid) {
1283                            continue;
1284                        }
1285                        if let Some(info) = graph.value(*vid) {
1286                            if let Some(cv) = info.constant.clone() {
1287                                plan.const_values.insert(*vid, cv);
1288                            }
1289                        }
1290                    }
1291                }
1292
1293                // Rebuild stack pattern based on the dependencies that were previously sourced
1294                // from the execution stack.
1295                let mut new_stack_pattern: Vec<usize> = Vec::new();
1296                for (new_idx, vid) in plan.inputs.iter().enumerate() {
1297                    if let Some(old_idx) = original_inputs.iter().position(|v| v == vid) {
1298                        if original_stack_pattern.contains(&old_idx) {
1299                            new_stack_pattern.push(new_idx);
1300                        }
1301                    }
1302                }
1303
1304                // Rebuild constants map using the new input ordering.
1305                let mut new_constants: HashMap<usize, Value> = HashMap::new();
1306                for (idx, vid) in plan.inputs.iter().enumerate() {
1307                    if let Some(value) = plan.const_values.get(vid) {
1308                        new_constants.insert(idx, value.clone());
1309                    } else if let Some(info) = graph.value(*vid) {
1310                        if let Some(cv) = info.constant.clone() {
1311                            new_constants.insert(idx, cv);
1312                        }
1313                    }
1314                }
1315                plan.constants = new_constants;
1316
1317                if new_stack_pattern.is_empty() {
1318                    for (idx, vid) in plan.inputs.iter().enumerate() {
1319                        if plan.constants.contains_key(&idx) {
1320                            continue;
1321                        }
1322                        if let Some(info) = graph.value(*vid) {
1323                            if matches!(
1324                                info.origin,
1325                                ValueOrigin::Variable { .. } | ValueOrigin::Constant
1326                            ) {
1327                                continue;
1328                            }
1329                        }
1330                        new_stack_pattern.push(idx);
1331                    }
1332                }
1333                plan.stack_pattern = new_stack_pattern;
1334            }
1335        }
1336
1337        // Final sanitize: for reduction groups, ensure inputs contain no constants
1338        if plan.group.kind.is_reduction() {
1339            let original_inputs = plan.inputs.clone();
1340            plan.inputs.retain(|vid| {
1341                if let Some(info) = graph.value(*vid) {
1342                    !matches!(info.origin, ValueOrigin::Constant)
1343                        && !plan.const_values.contains_key(vid)
1344                } else {
1345                    true
1346                }
1347            });
1348            if plan.inputs.len() != original_inputs.len() {
1349                let mut new_stack: Vec<usize> = Vec::new();
1350                for old_idx in &plan.stack_pattern {
1351                    if *old_idx < original_inputs.len() {
1352                        let vid = original_inputs[*old_idx];
1353                        if let Some(new_idx) = plan.inputs.iter().position(|v| *v == vid) {
1354                            new_stack.push(new_idx);
1355                        }
1356                    }
1357                }
1358                plan.stack_pattern = new_stack;
1359            }
1360        }
1361
1362        // Determine kernel support:
1363        // - Elementwise: require WGSL generation at plan time.
1364        // - Reduction: require WGSL generation at plan time as well.
1365        // - Other kinds: executed via provider paths.
1366        let supported = if plan.kernel.kind.is_elementwise() {
1367            // Keep scalar ops on the VM/runtime scalar path. Fusing scalar elementwise
1368            // spans can materialize scalar GPU handles that later leak into scalar-only
1369            // VM coercion boundaries.
1370            if scalar_shape_known_one(&plan.group.shape) {
1371                false
1372            } else {
1373                plan.generate_wgsl("f32").is_some()
1374            }
1375        } else if plan.kernel.kind.is_reduction() {
1376            plan.generate_reduction_wgsl("f32").is_some()
1377        } else {
1378            true
1379        };
1380        plan.kernel.supported = plan.kernel.supported && supported;
1381        if !plan.kernel.supported && fusion_debug_enabled() {
1382            let const_ids: Vec<ValueId> = plan.const_values.keys().copied().collect();
1383            log::debug!(
1384                "fusion plan {} unsupported: kind={:?} group_kind={:?} inputs={:?} reduction_data={:?} reduction_dim={:?} const_ids={:?}",
1385                plan.index,
1386                plan.kernel.kind,
1387                plan.group.kind,
1388                plan.inputs,
1389                plan.reduction_data,
1390                plan.reduction_dim,
1391                const_ids
1392            );
1393            if plan.kernel.kind.is_reduction() {
1394                let mut seen: HashSet<ValueId> = HashSet::new();
1395                let mut value_info: Vec<String> = Vec::new();
1396                for op in &plan.operations {
1397                    let inputs = match op {
1398                        FusionOp::Primitive { inputs, .. } => inputs,
1399                        FusionOp::Builtin { inputs, .. } => inputs,
1400                    };
1401                    for vid in inputs {
1402                        if seen.insert(*vid) {
1403                            if let Some(info) = graph.value(*vid) {
1404                                value_info.push(format!(
1405                                    "vid={} origin={:?} constant={}",
1406                                    vid,
1407                                    info.origin,
1408                                    info.constant.is_some()
1409                                ));
1410                            } else {
1411                                value_info.push(format!("vid={} origin=<missing>", vid));
1412                            }
1413                        }
1414                    }
1415                }
1416                log::debug!(
1417                    "fusion reduction plan {} value summary: [{}]",
1418                    plan.index,
1419                    value_info.join(", ")
1420                );
1421            }
1422        }
1423
1424        if matches!(plan.group.kind, FusionKind::CenteredGram) && plan.stack_pattern.is_empty() {
1425            let mut centered_stack_idxs: Vec<usize> = Vec::new();
1426            for (idx, vid) in plan.inputs.iter().enumerate() {
1427                if plan.constants.contains_key(&idx) {
1428                    continue;
1429                }
1430                if let Some(info) = graph.value(*vid) {
1431                    if matches!(info.origin, ValueOrigin::NodeOutput { .. }) {
1432                        centered_stack_idxs.push(idx);
1433                        continue;
1434                    }
1435                    if matches!(info.origin, ValueOrigin::Variable { .. }) {
1436                        continue;
1437                    }
1438                }
1439                centered_stack_idxs.push(idx);
1440            }
1441            if centered_stack_idxs.is_empty() && !plan.inputs.is_empty() {
1442                centered_stack_idxs.push(0);
1443            }
1444            plan.stack_pattern = centered_stack_idxs;
1445        }
1446
1447        if !plan.stack_pattern.is_empty() || plan.group.stack_layout.is_some() {
1448            plan.group.stack_layout = merge_stack_layout_with_stack_pattern(
1449                plan.group.stack_layout.as_ref(),
1450                &plan.inputs,
1451                &plan.stack_pattern,
1452            );
1453        }
1454
1455        if plan.group.kind.is_elementwise() {
1456            let mut stores = Vec::new();
1457            for op in &plan.operations {
1458                let output = match op {
1459                    FusionOp::Primitive { output, .. } => *output,
1460                    FusionOp::Builtin { output, .. } => *output,
1461                };
1462                let Some(value_id) = output else {
1463                    continue;
1464                };
1465                let Some(binding) = graph.var_binding(value_id).cloned() else {
1466                    continue;
1467                };
1468                stores.push(FusionStoreMaterialization { value_id, binding });
1469            }
1470            plan.materialized_stores = stores;
1471        }
1472
1473        log_plan_stack_pattern("final", &plan, graph);
1474
1475        // If the plan requires any unsupported operations, mark kernel as unsupported
1476
1477        plan
1478    }
1479
1480    pub fn reduction_data_shape(&self, graph: &AccelGraph) -> Option<Vec<usize>> {
1481        let vid = self.reduction_data?;
1482        let info = graph.value(vid)?;
1483        match &info.shape {
1484            ShapeInfo::Tensor(dims) if !dims.is_empty() && dims.iter().all(|d| d.is_some()) => {
1485                Some(dims.iter().map(|d| d.unwrap()).collect())
1486            }
1487            _ => None,
1488        }
1489    }
1490
1491    pub fn element_count(&self) -> Option<usize> {
1492        self.group.element_count()
1493    }
1494
1495    pub fn constant_shape(&self, len: usize) -> Vec<usize> {
1496        match &self.group.shape {
1497            ShapeInfo::Tensor(dims) if !dims.is_empty() && dims.iter().all(|dim| dim.is_some()) => {
1498                dims.iter().map(|dim| dim.unwrap()).collect()
1499            }
1500            _ => vec![len],
1501        }
1502    }
1503
1504    pub fn generate_wgsl(&self, scalar_ty: &str) -> Option<String> {
1505        self.generate_wgsl_for_output(self.output?, scalar_ty)
1506    }
1507
1508    /// Build the complete WGSL elementwise shader.
1509    ///
1510    /// The caller is responsible for the output-specific parts:
1511    /// - `output_bindings`: one `@group(0) @binding(…) var<storage, read_write> …` line per output.
1512    /// - `params_binding_idx`: the binding index for the uniform `Params` block
1513    ///   (`inputs.len() + num_outputs`).
1514    /// - `body`: the sequence of `let tmpN` assignment statements.
1515    /// - `final_writes`: the `output….data[g] = …;` store statements.
1516    fn build_wgsl_shader(
1517        &self,
1518        scalar_ty: &str,
1519        output_bindings: &str,
1520        params_binding_idx: usize,
1521        body: &str,
1522        final_writes: &str,
1523    ) -> String {
1524        let mut shader = String::new();
1525
1526        // ── type definitions ──────────────────────────────────────────────────
1527        shader.push_str("const MAX_RANK: u32 = 128u;\n");
1528        shader.push_str("struct PackedValue { value: u32, _pad0: u32, _pad1: u32, _pad2: u32 };\n");
1529        shader.push_str("alias PackedArray = array<PackedValue, MAX_RANK>;\n\n");
1530        shader.push_str(&format!("struct Tensor {{ data: array<{scalar_ty}>, }};\n"));
1531
1532        // Broadcast-aware Params: len, offset, rank, pad, out_shape and per-input shape/stride
1533        shader.push_str(
1534            "struct Params {\n    len: u32,\n    offset: u32,\n    rank: u32,\n    _pad: u32,\n    out_shape: PackedArray,\n",
1535        );
1536        for idx in 0..self.inputs.len() {
1537            shader.push_str(&format!("    in{}_shape: PackedArray,\n", idx));
1538            shader.push_str(&format!("    in{}_stride: PackedArray,\n", idx));
1539        }
1540        shader.push_str("}\n\n");
1541
1542        // ── portable helper stubs ─────────────────────────────────────────────
1543        // Avoid relying on backend builtins that may be missing.
1544        // hypot is not a WGSL builtin; define it explicitly.
1545        // Use the scaling form max*sqrt(1+(min/max)²) to avoid overflow when
1546        // a² or b² exceeds the representable range.
1547        // Guard against Inf inputs: Inf/Inf = NaN, so return hi early when
1548        // it is already infinite (IEEE 754 requires hypot(Inf,*) = Inf).
1549        if scalar_ty == "f32" {
1550            shader.push_str("fn isNan(x: f32) -> bool { return x != x; }\n");
1551            shader.push_str("fn isFinite(x: f32) -> bool { return (x == x) && (abs(x) < 3.4028234663852886e38); }\n");
1552            shader.push_str("fn isInf(x: f32) -> bool { return (x == x) && !(abs(x) < 3.4028234663852886e38); }\n");
1553            shader.push_str(concat!(
1554                "fn hypot(a: f32, b: f32) -> f32 {\n",
1555                "    let lo = min(abs(a), abs(b));\n",
1556                "    let hi = max(abs(a), abs(b));\n",
1557                "    if hi == 0.0 { return 0.0; }\n",
1558                "    if isInf(hi) { return hi; }\n",
1559                "    let r = lo / hi;\n",
1560                "    return hi * sqrt(1.0 + r * r);\n",
1561                "}\n\n",
1562            ));
1563        } else {
1564            shader.push_str("fn isNan(x: f64) -> bool { return x != x; }\n");
1565            shader.push_str("fn isFinite(x: f64) -> bool { return (x == x) && (abs(x) < f64(1.7976931348623157e308)); }\n");
1566            shader.push_str("fn isInf(x: f64) -> bool { return (x == x) && !(abs(x) < f64(1.7976931348623157e308)); }\n");
1567            shader.push_str(concat!(
1568                "fn hypot(a: f64, b: f64) -> f64 {\n",
1569                "    let lo = min(abs(a), abs(b));\n",
1570                "    let hi = max(abs(a), abs(b));\n",
1571                "    if hi == f64(0.0) { return f64(0.0); }\n",
1572                "    if isInf(hi) { return hi; }\n",
1573                "    let r = lo / hi;\n",
1574                "    return hi * sqrt(f64(1.0) + r * r);\n",
1575                "}\n\n",
1576            ));
1577        }
1578
1579        // ── resource bindings ─────────────────────────────────────────────────
1580        for (idx, _) in self.inputs.iter().enumerate() {
1581            shader.push_str(&format!(
1582                "@group(0) @binding({idx}) var<storage, read> input{idx}: Tensor;\n",
1583            ));
1584        }
1585        shader.push_str(output_bindings);
1586        shader.push_str(&format!(
1587            "@group(0) @binding({params_binding_idx}) var<uniform> params: Params;\n\n",
1588        ));
1589
1590        // ── compute entry point ───────────────────────────────────────────────
1591        shader.push_str(
1592            "@compute @workgroup_size(@WG@)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n",
1593        );
1594        shader.push_str("    let idx = gid.x;\n    if (idx >= params.len) { return; }\n");
1595        shader.push_str("    let g = idx + params.offset;\n");
1596
1597        // Compute N-D coordinates from global index (with chunk offset)
1598        shader.push_str(
1599            "    var coord: array<u32, MAX_RANK>;\n    var tmp: u32 = g;\n    var d: u32 = 0u;\n    loop { if d >= params.rank { break; } let dim = params.out_shape[d].value; if dim == 0u { coord[d] = 0u; } else { coord[d] = tmp % dim; tmp = tmp / dim; } d = d + 1u; }\n",
1600        );
1601
1602        // Compute broadcasted flat indices per input
1603        for (idx, _) in self.inputs.iter().enumerate() {
1604            shader.push_str(&format!(
1605                "    var i{idx}: u32 = 0u; d = 0u; loop {{ if d >= params.rank {{ break; }} let sd = params.in{idx}_shape[d].value; let st = params.in{idx}_stride[d].value; let c = select(coord[d], 0u, sd == 1u); i{idx} = i{idx} + c * st; d = d + 1u; }}\n",
1606            ));
1607        }
1608
1609        shader.push_str(body);
1610        shader.push_str(final_writes);
1611        shader.push_str("}\n");
1612        shader
1613    }
1614
1615    /// Generate a single WGSL shader that writes all `output_ids` in one compute pass,
1616    /// eliminating the O(N²) redundant dispatches that arise from calling
1617    /// `generate_wgsl_for_output` N times.
1618    ///
1619    /// Binding layout:
1620    ///   0 .. inputs.len()-1          → read-only input tensors
1621    ///   inputs.len() + k             → read_write output tensor k
1622    ///   inputs.len() + output_ids.len() → uniform Params
1623    pub fn generate_wgsl_for_outputs(
1624        &self,
1625        output_ids: &[ValueId],
1626        scalar_ty: &str,
1627    ) -> Option<String> {
1628        if output_ids.is_empty() {
1629            return None;
1630        }
1631        if output_ids.len() == 1 {
1632            return self.generate_wgsl_for_output(output_ids[0], scalar_ty);
1633        }
1634        if !self.kernel.kind.is_elementwise() {
1635            return None;
1636        }
1637        if !self.kernel.supported {
1638            return None;
1639        }
1640
1641        let mut exprs: HashMap<ValueId, String> = HashMap::new();
1642        for (idx, input_id) in self.inputs.iter().enumerate() {
1643            exprs.insert(*input_id, format!("input{idx}.data[i{idx}]"));
1644        }
1645
1646        let mut body = String::new();
1647        for (node_idx, op) in self.operations.iter().enumerate() {
1648            let tmp_name = format!("tmp{node_idx}");
1649            match op {
1650                FusionOp::Primitive { op, inputs, output } => {
1651                    let expr = primitive_expr(*op, inputs, &exprs)?;
1652                    body.push_str(&format!("    let {tmp_name}: {scalar_ty} = {expr};\n"));
1653                    if let Some(out) = output {
1654                        exprs.insert(*out, tmp_name.clone());
1655                    }
1656                }
1657                FusionOp::Builtin {
1658                    name,
1659                    inputs,
1660                    output,
1661                } => {
1662                    let expr = builtin_expr(name, inputs, &exprs, scalar_ty)?;
1663                    body.push_str(&format!("    let {tmp_name}: {scalar_ty} = {expr};\n"));
1664                    if let Some(out) = output {
1665                        exprs.insert(*out, tmp_name.clone());
1666                    }
1667                }
1668            }
1669        }
1670
1671        let mut final_exprs = Vec::with_capacity(output_ids.len());
1672        for output_id in output_ids {
1673            final_exprs.push(exprs.get(output_id)?.clone());
1674        }
1675
1676        let num_outputs = output_ids.len();
1677        let n_inputs = self.inputs.len();
1678
1679        let mut output_bindings = String::new();
1680        for k in 0..num_outputs {
1681            output_bindings.push_str(&format!(
1682                "@group(0) @binding({}) var<storage, read_write> output{k}: Tensor;\n",
1683                n_inputs + k,
1684            ));
1685        }
1686
1687        let mut final_writes = String::new();
1688        for (k, expr) in final_exprs.iter().enumerate() {
1689            final_writes.push_str(&format!("    output{k}.data[g] = {expr};\n"));
1690        }
1691
1692        Some(self.build_wgsl_shader(
1693            scalar_ty,
1694            &output_bindings,
1695            n_inputs + num_outputs,
1696            &body,
1697            &final_writes,
1698        ))
1699    }
1700
1701    pub fn generate_wgsl_for_output(&self, output_id: ValueId, scalar_ty: &str) -> Option<String> {
1702        if !self.kernel.kind.is_elementwise() {
1703            return None;
1704        }
1705        if !self.kernel.supported {
1706            return None;
1707        }
1708
1709        let mut exprs: HashMap<ValueId, String> = HashMap::new();
1710        for (idx, input_id) in self.inputs.iter().enumerate() {
1711            // Placeholder; will be resolved to the broadcasted index variable i{idx}
1712            exprs.insert(*input_id, format!("input{idx}.data[i{idx}]"));
1713        }
1714
1715        let mut body = String::new();
1716        for (node_idx, op) in self.operations.iter().enumerate() {
1717            let tmp_name = format!("tmp{node_idx}");
1718            match op {
1719                FusionOp::Primitive { op, inputs, output } => {
1720                    let expr = primitive_expr(*op, inputs, &exprs)?;
1721                    body.push_str(&format!("    let {tmp_name}: {scalar_ty} = {expr};\n"));
1722                    if let Some(out) = output {
1723                        exprs.insert(*out, tmp_name.clone());
1724                    }
1725                }
1726                FusionOp::Builtin {
1727                    name,
1728                    inputs,
1729                    output,
1730                } => {
1731                    let expr = builtin_expr(name, inputs, &exprs, scalar_ty)?;
1732                    body.push_str(&format!("    let {tmp_name}: {scalar_ty} = {expr};\n"));
1733                    if let Some(out) = output {
1734                        exprs.insert(*out, tmp_name.clone());
1735                    }
1736                }
1737            }
1738        }
1739
1740        let final_expr = exprs.get(&output_id)?.clone();
1741        let n_inputs = self.inputs.len();
1742
1743        let output_bindings =
1744            format!("@group(0) @binding({n_inputs}) var<storage, read_write> output: Tensor;\n",);
1745        let final_writes = format!("    output.data[g] = {final_expr};\n");
1746
1747        Some(self.build_wgsl_shader(
1748            scalar_ty,
1749            &output_bindings,
1750            n_inputs + 1,
1751            &body,
1752            &final_writes,
1753        ))
1754    }
1755
1756    pub fn generate_reduction_wgsl(&self, scalar_ty: &str) -> Option<String> {
1757        if !self.kernel.kind.is_reduction() {
1758            return None;
1759        }
1760        // Minimal column-major reduction kernel template (single workgroup per slice).
1761        // Supports folding simple producer expressions over multiple inputs (e.g., sum(A.*B, dim)).
1762        if self.inputs.is_empty() {
1763            return None;
1764        }
1765        // Determine axis from the reduction builtin's explicit dim argument when available.
1766        // MATLAB dim is 1-based: dim=1 reduces rows (axis=0), dim=2 reduces cols (axis=1).
1767        let mut axis = 0usize;
1768        // Support 'all' via either index-keyed constants or value-id keyed const_values
1769        let reduce_all = self
1770            .constants
1771            .values()
1772            .any(|v| matches!(v, Value::String(s) if s.eq_ignore_ascii_case("all")))
1773            || self
1774                .const_values
1775                .values()
1776                .any(|v| matches!(v, Value::String(s) if s.eq_ignore_ascii_case("all")));
1777        if reduce_all {
1778            // We'll flatten in VM by setting nrows = total and ncols = 1; axis=0 works with that.
1779            axis = 0;
1780        } else if let Some(dim_vid) = self.reduction_dim {
1781            if let Some(v) = self.const_values.get(&dim_vid) {
1782                match v {
1783                    Value::Num(n) if *n >= 1.0 => {
1784                        axis = (*n as usize).saturating_sub(1);
1785                    }
1786                    Value::Int(i) => {
1787                        let val = i.to_f64();
1788                        if val >= 1.0 {
1789                            axis = (val as usize).saturating_sub(1);
1790                        }
1791                    }
1792                    _ => {}
1793                }
1794            }
1795        } else {
1796            // Fallback: scan constant table for a plausible dim
1797            for v in self.constants.values() {
1798                match v {
1799                    Value::Num(n) if *n >= 1.0 => {
1800                        axis = (*n as usize).saturating_sub(1);
1801                        break;
1802                    }
1803                    Value::Int(i) => {
1804                        let val = i.to_f64();
1805                        if val >= 1.0 {
1806                            axis = (val as usize).saturating_sub(1);
1807                            break;
1808                        }
1809                    }
1810                    _ => {}
1811                }
1812            }
1813        }
1814
1815        // Detect omitnan constant (compile-time selection)
1816        let omitnan = self.constants.values().any(|v| match v {
1817            Value::String(s) => s.eq_ignore_ascii_case("omitnan"),
1818            _ => false,
1819        });
1820
1821        // Build reduction operand expression by folding the producer chain
1822        let data_vid = self.reduction_data?;
1823        let ext_input = self.inputs[0];
1824        let mut exprs: HashMap<ValueId, String> = HashMap::new();
1825        exprs.insert(ext_input, "v".to_string());
1826        // Map additional external inputs to v1, v2, ...
1827        for (idx, &vid) in self.inputs.iter().enumerate().skip(1) {
1828            exprs.insert(vid, format!("v{idx}"));
1829        }
1830        for (vid, val) in &self.const_values {
1831            let lit = match val {
1832                Value::Num(n) => {
1833                    if scalar_ty == "f64" {
1834                        format!("f64({})", n)
1835                    } else {
1836                        format!("{:?}", *n as f32)
1837                    }
1838                }
1839                Value::Int(i) => {
1840                    let f = i.to_f64();
1841                    if scalar_ty == "f64" {
1842                        format!("f64({})", f)
1843                    } else {
1844                        format!("{:?}", f as f32)
1845                    }
1846                }
1847                Value::Tensor(t) if t.data.len() == 1 => {
1848                    let scalar = t.data[0];
1849                    if scalar_ty == "f64" {
1850                        format!("f64({})", scalar)
1851                    } else {
1852                        format!("{:?}", scalar as f32)
1853                    }
1854                }
1855                _ => {
1856                    if scalar_ty == "f64" {
1857                        "f64(0.0)".to_string()
1858                    } else {
1859                        "0.0".to_string()
1860                    }
1861                }
1862            };
1863            exprs.insert(*vid, lit);
1864        }
1865        let mut progressed = true;
1866        while progressed {
1867            progressed = false;
1868            for op in &self.operations {
1869                match op {
1870                    FusionOp::Primitive { op, inputs, output } => {
1871                        if let Some(out) = output {
1872                            if exprs.contains_key(out) {
1873                                continue;
1874                            }
1875                            if let Some(code) = primitive_expr(*op, inputs, &exprs) {
1876                                exprs.insert(*out, code);
1877                                progressed = true;
1878                            }
1879                        }
1880                    }
1881                    FusionOp::Builtin {
1882                        name,
1883                        inputs,
1884                        output,
1885                    } => {
1886                        if let Some(out) = output {
1887                            if exprs.contains_key(out) {
1888                                continue;
1889                            }
1890                            if let Some(code) = builtin_expr(name, inputs, &exprs, scalar_ty) {
1891                                exprs.insert(*out, code);
1892                                progressed = true;
1893                            }
1894                        }
1895                    }
1896                }
1897            }
1898            if exprs.contains_key(&data_vid) {
1899                break;
1900            }
1901        }
1902        // Require a folded expression for the reduction operand; if missing, defer (no WGSL).
1903        let val_expr = match exprs.get(&data_vid) {
1904            Some(s) => s.clone(),
1905            None => {
1906                if fusion_debug_enabled() {
1907                    let expr_keys: Vec<ValueId> = exprs.keys().copied().collect();
1908                    log::debug!(
1909                        "fusion reduction WGSL: missing expression for data {:?}; inputs={:?} expr_keys={:?} ops={:?}",
1910                        data_vid,
1911                        self.inputs,
1912                        expr_keys,
1913                        self.operations
1914                    );
1915                }
1916                return None;
1917            }
1918        };
1919
1920        let mut shader = String::new();
1921        shader.push_str(&format!("struct Tensor {{ data: array<{scalar_ty}>, }};\n"));
1922        shader.push_str("struct MParams { nrows: u32, ncols: u32, ld: u32, flags: u32 }\n\n");
1923        // Bind all input tensors dynamically, followed by output and params
1924        for (idx, _) in self.inputs.iter().enumerate() {
1925            shader.push_str(&format!(
1926                "@group(0) @binding({}) var<storage, read> input{}: Tensor;\n",
1927                idx, idx
1928            ));
1929        }
1930        shader.push_str(&format!(
1931            "@group(0) @binding({}) var<storage, read_write> output: Tensor;\n",
1932            self.inputs.len()
1933        ));
1934        shader.push_str(&format!(
1935            "@group(0) @binding({}) var<uniform> params: MParams;\n\n",
1936            self.inputs.len() + 1
1937        ));
1938        // Use a small fixed workgroup tile size to avoid driver stalls on some backends
1939        shader.push_str(&format!(
1940            "var<workgroup> tile: array<{scalar_ty}, @WG@u>;\n\n"
1941        ));
1942        shader.push_str(&format!(
1943            "const OMITNAN: bool = {};\n\n",
1944            if omitnan { "true" } else { "false" }
1945        ));
1946        // Determine mean semantics from planner-populated reduction flavor
1947        let is_mean = matches!(self.reduction_flavor, Some(ReductionFlavor::Mean));
1948        let post_scale = if is_mean {
1949            let dim = if axis == 0 {
1950                "params.nrows"
1951            } else {
1952                "params.ncols"
1953            };
1954            if scalar_ty == "f64" {
1955                format!("(1.0 / f64(f32({dim})))")
1956            } else {
1957                format!("(1.0 / f32({dim}))")
1958            }
1959        } else if scalar_ty == "f64" {
1960            "f64(1.0)".to_string()
1961        } else {
1962            "1.0".to_string()
1963        };
1964        // Helper(s) at module scope
1965        shader.push_str(&format!(
1966            "fn isNanF(x: {scalar}) -> bool {{ return x != x; }}\n",
1967            scalar = scalar_ty
1968        ));
1969        if scalar_ty == "f64" {
1970            shader.push_str("fn canonicalNan() -> f64 {\n  var bits: u64 = 0x7ff8000000000000u;\n  return bitcast<f64>(bits);\n}\n\n");
1971        } else {
1972            shader.push_str("fn canonicalNan() -> f32 {\n  var bits: u32 = 0x7fc00000u;\n  return bitcast<f32>(bits);\n}\n\n");
1973        }
1974        shader.push_str("@compute @workgroup_size(@WG@)\n");
1975        if axis == 0 {
1976            // Column-wise: reduce over rows; one output per column (ncols)
1977            shader.push_str(
1978                "fn main(@builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {\n",
1979            );
1980            shader.push_str("  let col = wid.x;\n  if (col >= params.ncols) { return; }\n");
1981            shader.push_str(&format!(
1982                "  var acc: {scalar_ty} = {}0.0;\n",
1983                if scalar_ty == "f64" { "f64(" } else { "" }
1984            ));
1985            if scalar_ty == "f64" {
1986                shader.push_str("  // close cast for f64 literal\n");
1987            }
1988            // helpers are declared at module scope
1989            shader.push_str("  var saw_nan: bool = false;\n  var r = lid.x;\n");
1990            // Load row-wise values from each input and fold into expression
1991            {
1992                // Build the per-iteration loads
1993                let mut loop_body = String::new();
1994                // input0 as 'v'
1995                loop_body.push_str("    let v = input0.data[ (col * params.nrows) + r ];\n");
1996                // additional inputs as v1, v2, ...
1997                for (idx, _) in self.inputs.iter().enumerate().skip(1) {
1998                    loop_body.push_str(&format!(
1999                        "    let v{idx} = input{idx}.data[ (col * params.nrows) + r ];\n"
2000                    ));
2001                }
2002                // compute val and accumulate
2003                loop_body.push_str(&format!(
2004                    "    let val: {scalar} = {val};\n    if (OMITNAN) {{ if (!isNanF(val)) {{ acc = acc + val; }} }} else {{ if (isNanF(val)) {{ saw_nan = true; }} else {{ acc = acc + val; }} }}\n",
2005                scalar = scalar_ty,
2006                val = val_expr
2007            ));
2008                shader.push_str("  while (r < params.nrows) {\n");
2009                shader.push_str(&loop_body);
2010                shader.push_str("    r += @WG@u;\n  }\n");
2011            }
2012            shader.push_str("  if (!OMITNAN && saw_nan) { acc = canonicalNan(); }\n");
2013            shader.push_str("  tile[lid.x] = acc;\n  workgroupBarrier();\n");
2014            shader.push_str(
2015                "  var off = (@WG@u) / 2u;\n  loop { if (off == 0u) { break; } if (lid.x < off) {\n    let a = tile[lid.x]; let b = tile[lid.x + off];\n    tile[lid.x] = a + b;\n  } workgroupBarrier(); off = off / 2u; }\n",
2016            );
2017            // Final write: apply post-scale (sum=1, mean=1/rows)
2018            shader.push_str(&format!(
2019                "  if (lid.x == 0u) {{ output.data[col] = tile[0u] * {}; }}\n}}\n",
2020                post_scale
2021            ));
2022        } else {
2023            // Row-wise: reduce over cols; one output per row (nrows)
2024            shader.push_str(
2025                "fn main(@builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {\n",
2026            );
2027            shader.push_str("  let row = wid.x;\n  // For axis=1, number of output slices equals rows (params.ncols)\n  if (row >= params.ncols) { return; }\n");
2028            shader.push_str(&format!(
2029                "  var acc: {scalar_ty} = {}0.0;\n",
2030                if scalar_ty == "f64" { "f64(" } else { "" }
2031            ));
2032            if scalar_ty == "f64" {
2033                shader.push_str("  // close cast for f64 literal\n");
2034            }
2035            // helpers are declared at module scope
2036            shader.push_str("  var saw_nan: bool = false;\n  var c = lid.x;\n");
2037            {
2038                let mut loop_body = String::new();
2039                // input0 as 'v' — provider encodes rows in params.ncols for axis=1
2040                loop_body.push_str("    let v = input0.data[ row + (c * params.ncols) ];\n");
2041                // additional inputs as v1, v2, ...
2042                for (idx, _) in self.inputs.iter().enumerate().skip(1) {
2043                    loop_body.push_str(&format!(
2044                        "    let v{idx} = input{idx}.data[ row + (c * params.ncols) ];\n"
2045                    ));
2046                }
2047                loop_body.push_str(&format!(
2048                    "    let val: {scalar} = {val};\n    if (OMITNAN) {{ if (!isNanF(val)) {{ acc = acc + val; }} }} else {{ if (isNanF(val)) {{ saw_nan = true; }} else {{ acc = acc + val; }} }}\n",
2049                scalar = scalar_ty,
2050                val = val_expr
2051            ));
2052                // Iterate over reduce_len, which arrives as params.nrows when axis=1
2053                shader.push_str("  while (c < params.nrows) {\n");
2054                shader.push_str(&loop_body);
2055                shader.push_str("    c += @WG@u;\n  }\n");
2056            }
2057            shader.push_str("  if (!OMITNAN && saw_nan) { acc = canonicalNan(); }\n");
2058            shader.push_str("  tile[lid.x] = acc;\n  workgroupBarrier();\n");
2059            shader.push_str(
2060                "  var off = (@WG@u) / 2u;\n  loop { if (off == 0u) { break; } if (lid.x < off) {\n    let a = tile[lid.x]; let b = tile[lid.x + off];\n    tile[lid.x] = a + b;\n  } workgroupBarrier(); off = off / 2u; }\n",
2061            );
2062            shader.push_str(&format!(
2063                "  if (lid.x == 0u) {{ output.data[row] = tile[0u] * {}; }}\n}}\n",
2064                post_scale
2065            ));
2066        }
2067        Some(shader)
2068    }
2069}
2070
2071impl FusionGroup {
2072    pub fn element_count(&self) -> Option<usize> {
2073        match &self.shape {
2074            ShapeInfo::Scalar => Some(1),
2075            ShapeInfo::Tensor(dims) => dims
2076                .iter()
2077                .try_fold(1usize, |acc, dim| dim.and_then(|d| acc.checked_mul(d))),
2078            ShapeInfo::Unknown => None,
2079        }
2080    }
2081}
2082
2083impl FusionKind {
2084    pub fn is_elementwise(&self) -> bool {
2085        matches!(self, FusionKind::ElementwiseChain)
2086    }
2087
2088    pub fn is_reduction(&self) -> bool {
2089        matches!(self, FusionKind::Reduction)
2090    }
2091}
2092
2093fn detect_centered_gram(
2094    graph: &AccelGraph,
2095    assigned: &mut HashSet<NodeId>,
2096    groups: &mut Vec<FusionGroup>,
2097    next_group_id: &mut usize,
2098) {
2099    for div_node in &graph.nodes {
2100        if assigned.contains(&div_node.id) {
2101            continue;
2102        }
2103        let div_op = match div_node.label {
2104            AccelNodeLabel::Primitive(op) => op,
2105            _ => continue,
2106        };
2107        if div_op != PrimitiveOp::ElemDiv {
2108            continue;
2109        }
2110        if div_node.inputs.len() != 2 {
2111            continue;
2112        }
2113        let (numerator_id, denom_id) = (div_node.inputs[0], div_node.inputs[1]);
2114        let denom_info = match graph.value(denom_id) {
2115            Some(info) => info,
2116            None => continue,
2117        };
2118        let denom_const = match &denom_info.constant {
2119            Some(Value::Num(v)) => Some(*v),
2120            Some(Value::Int(i)) => Some(i.to_f64()),
2121            _ => None,
2122        };
2123        if denom_const.is_some_and(|v| v == 0.0) {
2124            continue;
2125        }
2126
2127        let mul_node_id = match graph
2128            .value(numerator_id)
2129            .and_then(|info| match &info.origin {
2130                ValueOrigin::NodeOutput { node, .. } => Some(*node),
2131                _ => None,
2132            }) {
2133            Some(id) => id,
2134            None => continue,
2135        };
2136        if assigned.contains(&mul_node_id) {
2137            continue;
2138        }
2139        let mul_node = match graph.node(mul_node_id) {
2140            Some(node) => node,
2141            None => continue,
2142        };
2143        let mul_op = match mul_node.label {
2144            AccelNodeLabel::Primitive(op) => op,
2145            _ => continue,
2146        };
2147        if mul_op != PrimitiveOp::Mul && mul_op != PrimitiveOp::ElemMul {
2148            continue;
2149        }
2150        if mul_node.inputs.len() != 2 {
2151            continue;
2152        }
2153
2154        let mut transpose_node_id: Option<NodeId> = None;
2155        let mut centered_val_id: Option<ValueId> = None;
2156        for input_vid in &mul_node.inputs {
2157            let candidate_node_id =
2158                match graph.value(*input_vid).and_then(|info| match &info.origin {
2159                    ValueOrigin::NodeOutput { node, .. } => Some(*node),
2160                    _ => None,
2161                }) {
2162                    Some(id) => id,
2163                    None => continue,
2164                };
2165            if let Some(trans_node) = graph.node(candidate_node_id) {
2166                if matches!(
2167                    trans_node.label,
2168                    AccelNodeLabel::Primitive(PrimitiveOp::Transpose)
2169                ) {
2170                    if let Some(centered) = trans_node.inputs.first().copied() {
2171                        transpose_node_id = Some(candidate_node_id);
2172                        centered_val_id = Some(centered);
2173                        break;
2174                    }
2175                }
2176            }
2177        }
2178
2179        let transpose_node_id = match transpose_node_id {
2180            Some(id) if !assigned.contains(&id) => id,
2181            _ => continue,
2182        };
2183        let centered_val_id = match centered_val_id {
2184            Some(id) => id,
2185            None => continue,
2186        };
2187
2188        if assigned.contains(&transpose_node_id) {
2189            continue;
2190        }
2191        if graph.node(transpose_node_id).is_none() {
2192            continue;
2193        }
2194
2195        let centered_node_id =
2196            match graph
2197                .value(centered_val_id)
2198                .and_then(|info| match &info.origin {
2199                    ValueOrigin::NodeOutput { node, .. } => Some(*node),
2200                    _ => None,
2201                }) {
2202                Some(id) => id,
2203                None => continue,
2204            };
2205        if assigned.contains(&centered_node_id) {
2206            continue;
2207        }
2208        let centered_node = match graph.node(centered_node_id) {
2209            Some(node) => node,
2210            None => continue,
2211        };
2212        if !matches!(
2213            centered_node.label,
2214            AccelNodeLabel::Primitive(PrimitiveOp::Sub)
2215        ) {
2216            continue;
2217        }
2218        if centered_node.inputs.len() != 2 {
2219            continue;
2220        }
2221        let matrix_val_id = centered_node.inputs[0];
2222        let mean_val_id = centered_node.inputs[1];
2223
2224        let mean_node_id = match graph
2225            .value(mean_val_id)
2226            .and_then(|info| match &info.origin {
2227                ValueOrigin::NodeOutput { node, .. } => Some(*node),
2228                _ => None,
2229            }) {
2230            Some(id) => id,
2231            None => continue,
2232        };
2233        if assigned.contains(&mean_node_id) {
2234            continue;
2235        }
2236        let mean_node = match graph.node(mean_node_id) {
2237            Some(node) => node,
2238            None => continue,
2239        };
2240        match &mean_node.label {
2241            AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mean") => {}
2242            _ => continue,
2243        }
2244        if mean_node.inputs.is_empty() || mean_node.inputs[0] != matrix_val_id {
2245            continue;
2246        }
2247
2248        let matrix_info = match graph.value(matrix_val_id) {
2249            Some(info) => info,
2250            None => continue,
2251        };
2252        let matrix_rows = match &matrix_info.shape {
2253            ShapeInfo::Tensor(dims) if !dims.is_empty() => dims[0].unwrap_or(0),
2254            _ => 0,
2255        };
2256        let normalization = if matrix_rows > 1 {
2257            if let Some(value) = denom_const {
2258                let unbiased = (matrix_rows as f64 - 1.0).max(1.0);
2259                let biased = matrix_rows as f64;
2260                if approx_eq(value, unbiased) {
2261                    CovNormalization::Unbiased
2262                } else if approx_eq(value, biased) {
2263                    CovNormalization::Biased
2264                } else {
2265                    CovNormalization::Unbiased
2266                }
2267            } else {
2268                CovNormalization::Unbiased
2269            }
2270        } else {
2271            CovNormalization::Unbiased
2272        };
2273
2274        let mut nodes = vec![
2275            mean_node_id,
2276            centered_node_id,
2277            transpose_node_id,
2278            mul_node_id,
2279            div_node.id,
2280        ];
2281        nodes.sort_by_key(|node_id| {
2282            graph
2283                .node(*node_id)
2284                .map(|node| node.span.start)
2285                .unwrap_or(usize::MAX)
2286        });
2287        let span = group_span(graph, &nodes);
2288        let shape = node_output_shape(graph, div_node);
2289
2290        groups.push(FusionGroup {
2291            id: *next_group_id,
2292            kind: FusionKind::CenteredGram,
2293            nodes: nodes.clone(),
2294            shape,
2295            span,
2296            pattern: Some(FusionPattern::CenteredGram {
2297                matrix: matrix_val_id,
2298                normalization,
2299            }),
2300            stack_layout: None,
2301        });
2302        *next_group_id += 1;
2303        for id in nodes {
2304            assigned.insert(id);
2305        }
2306    }
2307}
2308
2309fn detect_image_normalize(
2310    graph: &AccelGraph,
2311    assigned: &mut HashSet<NodeId>,
2312    groups: &mut Vec<FusionGroup>,
2313    next_group_id: &mut usize,
2314) {
2315    for pow_node in &graph.nodes {
2316        if assigned.contains(&pow_node.id) {
2317            continue;
2318        }
2319        let Some(match_info) = analyze_image_normalize(graph, pow_node.id, assigned) else {
2320            continue;
2321        };
2322
2323        let pow_node_ref = match graph.node(pow_node.id) {
2324            Some(node) => node,
2325            None => continue,
2326        };
2327
2328        let shape = node_output_shape(graph, pow_node_ref);
2329        let span = group_span(graph, &match_info.nodes);
2330
2331        let pattern = ImageNormalizePattern {
2332            input: match_info.input,
2333            epsilon: match_info.epsilon.clone(),
2334            gain: match_info.gain.clone(),
2335            bias: match_info.bias.clone(),
2336            gamma: match_info.gamma.clone(),
2337        };
2338
2339        groups.push(FusionGroup {
2340            id: *next_group_id,
2341            kind: FusionKind::ImageNormalize,
2342            nodes: match_info.nodes.clone(),
2343            shape,
2344            span: span.clone(),
2345            pattern: Some(FusionPattern::ImageNormalize(pattern)),
2346            stack_layout: None,
2347        });
2348        if fusion_debug_enabled() {
2349            log::debug!(
2350                "fusion: detected image normalize group id={} span={:?} nodes={:?}",
2351                next_group_id,
2352                span,
2353                match_info.nodes
2354            );
2355        }
2356        *next_group_id += 1;
2357        for node_id in match_info.nodes {
2358            assigned.insert(node_id);
2359        }
2360    }
2361}
2362
2363fn approx_eq(a: f64, b: f64) -> bool {
2364    let scale = a.abs().max(b.abs()).max(1.0);
2365    (a - b).abs() <= scale * 1e-6
2366}
2367
2368fn detect_power_step_normalize(
2369    graph: &AccelGraph,
2370    assigned: &mut HashSet<NodeId>,
2371    groups: &mut Vec<FusionGroup>,
2372    next_group_id: &mut usize,
2373) {
2374    'outer: for div_node in &graph.nodes {
2375        if assigned.contains(&div_node.id) {
2376            continue;
2377        }
2378        let div_op = match div_node.label {
2379            AccelNodeLabel::Primitive(op) => op,
2380            _ => continue,
2381        };
2382        if div_op != PrimitiveOp::ElemDiv {
2383            continue;
2384        }
2385        if div_node.inputs.len() != 2 {
2386            continue;
2387        }
2388        let numerator_vid = div_node.inputs[0];
2389        let denom_vid = div_node.inputs[1];
2390
2391        let (matmul_id, matmul_node) = match node_from_value(graph, numerator_vid) {
2392            Some((id, node)) => (id, node),
2393            None => continue,
2394        };
2395        if assigned.contains(&matmul_id) {
2396            continue;
2397        }
2398        match &matmul_node.label {
2399            AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mtimes") => {}
2400            _ => continue,
2401        }
2402        if matmul_node.inputs.len() != 2 {
2403            continue;
2404        }
2405
2406        let Some(denom_info) = analyze_power_step_denominator(graph, denom_vid, numerator_vid)
2407        else {
2408            continue;
2409        };
2410        if assigned.contains(&denom_info.sqrt_node) {
2411            continue;
2412        }
2413        if assigned.contains(&denom_info.sum_node) {
2414            continue;
2415        }
2416        if assigned.contains(&denom_info.pow_node) {
2417            continue;
2418        }
2419        if let Some(add_id) = denom_info.add_node {
2420            if assigned.contains(&add_id) {
2421                continue;
2422            }
2423        }
2424        if denom_info.pow_input != numerator_vid {
2425            continue;
2426        }
2427
2428        let mut nodes = vec![matmul_id, denom_info.pow_node, denom_info.sum_node];
2429        if let Some(add_id) = denom_info.add_node {
2430            nodes.push(add_id);
2431        }
2432        nodes.push(denom_info.sqrt_node);
2433        nodes.push(div_node.id);
2434
2435        for node_id in &nodes {
2436            if assigned.contains(node_id) {
2437                continue 'outer;
2438            }
2439        }
2440
2441        nodes.sort_by_key(|node_id| {
2442            graph
2443                .node(*node_id)
2444                .map(|node| node.span.start)
2445                .unwrap_or(usize::MAX)
2446        });
2447
2448        let span = group_span(graph, &nodes);
2449        let shape = node_output_shape(graph, div_node);
2450
2451        groups.push(FusionGroup {
2452            id: *next_group_id,
2453            kind: FusionKind::PowerStepNormalize,
2454            nodes: nodes.clone(),
2455            shape,
2456            span,
2457            pattern: Some(FusionPattern::PowerStepNormalize {
2458                lhs: matmul_node.inputs[0],
2459                rhs: matmul_node.inputs[1],
2460                epsilon: denom_info.epsilon,
2461            }),
2462            stack_layout: None,
2463        });
2464        *next_group_id += 1;
2465        for id in nodes {
2466            assigned.insert(id);
2467        }
2468    }
2469}
2470
2471fn detect_explained_variance(
2472    graph: &AccelGraph,
2473    assigned: &mut HashSet<NodeId>,
2474    groups: &mut Vec<FusionGroup>,
2475    next_group_id: &mut usize,
2476) {
2477    for diag_node in &graph.nodes {
2478        if assigned.contains(&diag_node.id) {
2479            continue;
2480        }
2481        match &diag_node.label {
2482            AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("diag") => {}
2483            _ => continue,
2484        }
2485        if diag_node.inputs.len() != 1 {
2486            continue;
2487        }
2488        let matmul2_vid = diag_node.inputs[0];
2489        let (matmul2_id, matmul2_node) = match node_from_value(graph, matmul2_vid) {
2490            Some(pair) => pair,
2491            None => continue,
2492        };
2493        if assigned.contains(&matmul2_id) {
2494            continue;
2495        }
2496        match &matmul2_node.label {
2497            AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mtimes") => {}
2498            _ => continue,
2499        }
2500        if matmul2_node.inputs.len() != 2 {
2501            continue;
2502        }
2503
2504        let (matmul1_id, matmul1_node, q_vid) = if let Some((mm_id, mm_node)) =
2505            node_from_value(graph, matmul2_node.inputs[0])
2506        {
2507            if matches!(mm_node.label, AccelNodeLabel::Builtin { ref name } if name.eq_ignore_ascii_case("mtimes"))
2508            {
2509                (mm_id, mm_node, matmul2_node.inputs[1])
2510            } else {
2511                continue;
2512            }
2513        } else if let Some((mm_id, mm_node)) = node_from_value(graph, matmul2_node.inputs[1]) {
2514            if matches!(mm_node.label, AccelNodeLabel::Builtin { ref name } if name.eq_ignore_ascii_case("mtimes"))
2515            {
2516                (mm_id, mm_node, matmul2_node.inputs[0])
2517            } else {
2518                continue;
2519            }
2520        } else {
2521            continue;
2522        };
2523
2524        if assigned.contains(&matmul1_id) {
2525            continue;
2526        }
2527
2528        if matmul1_node.inputs.len() != 2 {
2529            continue;
2530        }
2531
2532        let (transpose_id, transpose_input_vid, g_vid) =
2533            if let Some((t_id, src_vid)) = is_transpose_node(graph, matmul1_node.inputs[0]) {
2534                (t_id, src_vid, matmul1_node.inputs[1])
2535            } else if let Some((t_id, src_vid)) = is_transpose_node(graph, matmul1_node.inputs[1]) {
2536                (t_id, src_vid, matmul1_node.inputs[0])
2537            } else {
2538                continue;
2539            };
2540
2541        if assigned.contains(&transpose_id) {
2542            continue;
2543        }
2544
2545        if transpose_input_vid != q_vid {
2546            continue;
2547        }
2548
2549        let mut nodes = vec![diag_node.id, matmul2_id, matmul1_id, transpose_id];
2550        nodes.sort_by_key(|node_id| {
2551            graph
2552                .node(*node_id)
2553                .map(|node| node.span.start)
2554                .unwrap_or(usize::MAX)
2555        });
2556        let span = group_span(graph, &nodes);
2557        let shape = node_output_shape(graph, diag_node);
2558        groups.push(FusionGroup {
2559            id: *next_group_id,
2560            kind: FusionKind::ExplainedVariance,
2561            nodes: nodes.clone(),
2562            shape,
2563            span,
2564            pattern: Some(FusionPattern::ExplainedVariance { q: q_vid, g: g_vid }),
2565            stack_layout: None,
2566        });
2567        *next_group_id += 1;
2568        for id in nodes {
2569            assigned.insert(id);
2570        }
2571    }
2572}
2573
2574struct PowerStepDenominatorInfo {
2575    sqrt_node: NodeId,
2576    add_node: Option<NodeId>,
2577    sum_node: NodeId,
2578    pow_node: NodeId,
2579    pow_input: ValueId,
2580    epsilon: f64,
2581}
2582
2583fn analyze_power_step_denominator(
2584    graph: &AccelGraph,
2585    denom_vid: ValueId,
2586    expected_source_vid: ValueId,
2587) -> Option<PowerStepDenominatorInfo> {
2588    let (sqrt_node_id, sqrt_input_vid, add_node_opt, epsilon_from_outer) =
2589        if let Some((sqrt_id, sqrt_in)) = is_sqrt_node(graph, denom_vid) {
2590            if let Some((add_node, sum_vid, epsilon_inner)) =
2591                extract_add_with_constant(graph, sqrt_in)
2592            {
2593                (sqrt_id, sum_vid, Some(add_node), epsilon_inner)
2594            } else {
2595                (sqrt_id, sqrt_in, None, 0.0)
2596            }
2597        } else if let Some((add_node, other_vid, epsilon_inner)) =
2598            extract_add_with_constant(graph, denom_vid)
2599        {
2600            let (sqrt_id, sqrt_in) = is_sqrt_node(graph, other_vid)?;
2601            (sqrt_id, sqrt_in, Some(add_node), epsilon_inner)
2602        } else {
2603            return None;
2604        };
2605
2606    let (sum_node_id, sum_node) = node_from_value(graph, sqrt_input_vid)?;
2607    match &sum_node.label {
2608        AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("sum") => {}
2609        _ => return None,
2610    }
2611    if sum_node.inputs.is_empty() {
2612        return None;
2613    }
2614    let pow_vid = sum_node.inputs[0];
2615    let (pow_node_id, pow_node) = node_from_value(graph, pow_vid)?;
2616    let pow_input = match pow_node.label {
2617        AccelNodeLabel::Primitive(PrimitiveOp::ElemPow) => {
2618            if pow_node.inputs.len() != 2 {
2619                return None;
2620            }
2621            let base = pow_node.inputs[0];
2622            let exponent_vid = pow_node.inputs[1];
2623            let exponent = value_constant_f64(graph, exponent_vid)?;
2624            if !approx_eq(exponent, 2.0) {
2625                return None;
2626            }
2627            base
2628        }
2629        _ => return None,
2630    };
2631
2632    if pow_input != expected_source_vid {
2633        return None;
2634    }
2635
2636    let epsilon = epsilon_from_outer;
2637    Some(PowerStepDenominatorInfo {
2638        sqrt_node: sqrt_node_id,
2639        add_node: add_node_opt,
2640        sum_node: sum_node_id,
2641        pow_node: pow_node_id,
2642        pow_input,
2643        epsilon,
2644    })
2645}
2646
2647fn node_from_value(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, &AccelNode)> {
2648    let info = graph.value(vid)?;
2649    match info.origin {
2650        ValueOrigin::NodeOutput { node, .. } => graph.node(node).map(|n| (node, n)),
2651        _ => None,
2652    }
2653}
2654
2655fn is_sqrt_node(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId)> {
2656    let (node_id, node) = node_from_value(graph, vid)?;
2657    match &node.label {
2658        AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("sqrt") => {
2659            let input = node.inputs.first().copied()?;
2660            Some((node_id, input))
2661        }
2662        _ => None,
2663    }
2664}
2665
2666fn is_transpose_node(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId)> {
2667    let (node_id, node) = node_from_value(graph, vid)?;
2668    match &node.label {
2669        AccelNodeLabel::Primitive(PrimitiveOp::Transpose) => {
2670            let input = node.inputs.first().copied()?;
2671            Some((node_id, input))
2672        }
2673        _ => None,
2674    }
2675}
2676
2677fn extract_add_with_constant(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId, f64)> {
2678    let (node_id, node) = node_from_value(graph, vid)?;
2679    match node.label {
2680        AccelNodeLabel::Primitive(PrimitiveOp::Add) => {
2681            if node.inputs.len() != 2 {
2682                return None;
2683            }
2684            let lhs = node.inputs[0];
2685            let rhs = node.inputs[1];
2686            if let Some(eps) = value_constant_f64(graph, rhs) {
2687                return Some((node_id, lhs, eps));
2688            }
2689            if let Some(eps) = value_constant_f64(graph, lhs) {
2690                return Some((node_id, rhs, eps));
2691            }
2692            None
2693        }
2694        AccelNodeLabel::Primitive(PrimitiveOp::Sub) => {
2695            if node.inputs.len() != 2 {
2696                return None;
2697            }
2698            let lhs = node.inputs[0];
2699            let rhs = node.inputs[1];
2700            if let Some(eps) = value_constant_f64(graph, rhs) {
2701                return Some((node_id, lhs, -eps));
2702            }
2703            if let Some(eps) = value_constant_f64(graph, lhs) {
2704                return Some((node_id, rhs, eps));
2705            }
2706            None
2707        }
2708        _ => None,
2709    }
2710}
2711
2712struct ConstantTrace {
2713    value: f64,
2714    nodes: Vec<NodeId>,
2715}
2716
2717fn collect_scalar_constant(graph: &AccelGraph, vid: ValueId) -> Option<ConstantTrace> {
2718    let mut current = vid;
2719    let mut nodes: Vec<NodeId> = Vec::new();
2720    let mut sign = 1.0f64;
2721    let mut visited: HashSet<NodeId> = HashSet::new();
2722
2723    loop {
2724        let info = graph.value(current)?;
2725        match &info.origin {
2726            ValueOrigin::Constant => {
2727                let base = value_info_scalar(info)?;
2728                return Some(ConstantTrace {
2729                    value: sign * base,
2730                    nodes,
2731                });
2732            }
2733            ValueOrigin::NodeOutput { node, .. } => {
2734                if !visited.insert(*node) {
2735                    return None;
2736                }
2737                let node_ref = graph.node(*node)?;
2738                match &node_ref.label {
2739                    AccelNodeLabel::Builtin { name }
2740                        if name.eq_ignore_ascii_case("single")
2741                            || name.eq_ignore_ascii_case("double")
2742                            || name.eq_ignore_ascii_case("gpuarray") =>
2743                    {
2744                        if node_ref.inputs.len() != 1 {
2745                            return None;
2746                        }
2747                        nodes.push(*node);
2748                        current = node_ref.inputs[0];
2749                    }
2750                    AccelNodeLabel::Primitive(PrimitiveOp::Neg) => {
2751                        if node_ref.inputs.len() != 1 {
2752                            return None;
2753                        }
2754                        nodes.push(*node);
2755                        sign = -sign;
2756                        current = node_ref.inputs[0];
2757                    }
2758                    AccelNodeLabel::Primitive(PrimitiveOp::UPlus) => {
2759                        if node_ref.inputs.len() != 1 {
2760                            return None;
2761                        }
2762                        nodes.push(*node);
2763                        current = node_ref.inputs[0];
2764                    }
2765                    _ => return None,
2766                }
2767            }
2768            _ => return None,
2769        }
2770    }
2771}
2772
2773fn scalar_shape_known_one(shape: &ShapeInfo) -> bool {
2774    match shape {
2775        ShapeInfo::Scalar => true,
2776        ShapeInfo::Tensor(dims) => {
2777            if dims.is_empty() {
2778                return true;
2779            }
2780            dims.iter().all(|dim| matches!(dim, Some(1)))
2781        }
2782        ShapeInfo::Unknown => false,
2783    }
2784}
2785
2786fn capture_image_scalar(
2787    graph: &AccelGraph,
2788    vid: ValueId,
2789    assigned: &HashSet<NodeId>,
2790    _nodes: &mut Vec<NodeId>,
2791) -> Option<ImageScalar> {
2792    if let Some(trace) = collect_scalar_constant(graph, vid) {
2793        if trace.nodes.iter().any(|id| assigned.contains(id)) {
2794            return None;
2795        }
2796        return Some(ImageScalar::Constant(trace.value));
2797    }
2798    let info = graph.value(vid)?;
2799    if scalar_shape_known_one(&info.shape) {
2800        return Some(ImageScalar::Value(vid));
2801    }
2802    if log::log_enabled!(log::Level::Debug) {
2803        log::debug!(
2804            "capture_image_scalar: reject vid={vid:?} shape={:?} origin={:?}",
2805            info.shape,
2806            info.origin
2807        );
2808    }
2809    None
2810}
2811
2812fn peel_numeric_casts(
2813    graph: &AccelGraph,
2814    mut vid: ValueId,
2815    assigned: &HashSet<NodeId>,
2816    _nodes: &mut Vec<NodeId>,
2817) -> Option<ValueId> {
2818    loop {
2819        let info = graph.value(vid)?;
2820        match &info.origin {
2821            ValueOrigin::NodeOutput { node, .. } => {
2822                if assigned.contains(node) {
2823                    return None;
2824                }
2825                let node_ref = graph.node(*node)?;
2826                if let AccelNodeLabel::Builtin { name } = &node_ref.label {
2827                    if name.eq_ignore_ascii_case("single")
2828                        || name.eq_ignore_ascii_case("double")
2829                        || name.eq_ignore_ascii_case("gpuarray")
2830                    {
2831                        if node_ref.inputs.len() != 1 {
2832                            return None;
2833                        }
2834                        vid = node_ref.inputs[0];
2835                        continue;
2836                    }
2837                }
2838                return Some(vid);
2839            }
2840            _ => return Some(vid),
2841        }
2842    }
2843}
2844
2845fn resolve_scalar_constant(graph: &AccelGraph, vid: ValueId) -> Option<f64> {
2846    collect_scalar_constant(graph, vid).map(|trace| trace.value)
2847}
2848
2849fn value_info_scalar(info: &ValueInfo) -> Option<f64> {
2850    match &info.constant {
2851        Some(Value::Num(v)) => Some(*v),
2852        Some(Value::Int(i)) => Some(i.to_f64()),
2853        Some(Value::Tensor(t)) if t.data.len() == 1 => Some(t.data[0]),
2854        Some(Value::LogicalArray(arr)) if arr.data.len() == 1 => Some(arr.data[0] as f64),
2855        Some(Value::Bool(flag)) => Some(if *flag { 1.0 } else { 0.0 }),
2856        _ => None,
2857    }
2858}
2859
2860fn value_constant_f64(graph: &AccelGraph, vid: ValueId) -> Option<f64> {
2861    resolve_scalar_constant(graph, vid)
2862}
2863
2864fn primitive_expr(
2865    op: PrimitiveOp,
2866    inputs: &[ValueId],
2867    exprs: &HashMap<ValueId, String>,
2868) -> Option<String> {
2869    let binary = |exprs: &HashMap<ValueId, String>| -> Option<(String, String)> {
2870        let lhs = exprs.get(inputs.first()?).cloned()?;
2871        let rhs = exprs.get(inputs.get(1)?).cloned()?;
2872        Some((lhs, rhs))
2873    };
2874    match op {
2875        PrimitiveOp::Add => {
2876            let (lhs, rhs) = binary(exprs)?;
2877            Some(format!("({lhs} + {rhs})"))
2878        }
2879        PrimitiveOp::Sub => {
2880            let (lhs, rhs) = binary(exprs)?;
2881            Some(format!("({lhs} - {rhs})"))
2882        }
2883        PrimitiveOp::Mul | PrimitiveOp::ElemMul => {
2884            let (lhs, rhs) = binary(exprs)?;
2885            Some(format!("({lhs} * {rhs})"))
2886        }
2887        PrimitiveOp::ElemDiv | PrimitiveOp::ElemLeftDiv => {
2888            let (lhs, rhs) = binary(exprs)?;
2889            Some(format!("({lhs} / {rhs})"))
2890        }
2891        PrimitiveOp::Pow | PrimitiveOp::ElemPow => {
2892            let (lhs, rhs) = binary(exprs)?;
2893            Some(format!("pow({lhs}, {rhs})"))
2894        }
2895        PrimitiveOp::Neg => {
2896            let arg = exprs.get(inputs.first()?).cloned()?;
2897            Some(format!("(-{arg})"))
2898        }
2899        PrimitiveOp::UPlus => {
2900            let arg = exprs.get(inputs.first()?).cloned()?;
2901            Some(format!("(+{arg})"))
2902        }
2903        _ => None,
2904    }
2905}
2906
2907fn builtin_expr(
2908    name: &str,
2909    inputs: &[ValueId],
2910    exprs: &HashMap<ValueId, String>,
2911    scalar_ty: &str,
2912) -> Option<String> {
2913    let func = match name.to_ascii_lowercase().as_str() {
2914        "isfinite" => return builtin_unary_call("isFinite", inputs, exprs),
2915        "isinf" => return builtin_unary_call("isInf", inputs, exprs),
2916        "isnan" => return builtin_unary_call("isNan", inputs, exprs),
2917        "single" | "double" | "gpuarray" => return builtin_identity(inputs, exprs),
2918        "fix" => return builtin_unary_call("trunc", inputs, exprs),
2919        "sign" => return builtin_unary_call("sign", inputs, exprs),
2920        "mod" => {
2921            let lhs = exprs.get(inputs.first()?).cloned()?;
2922            let rhs = exprs.get(inputs.get(1)?).cloned()?;
2923            // When rhs is infinite and lhs is finite, MATLAB sign-corrects: returns lhs when
2924            // signs match, rhs (±Inf) when they differ. The general formula produces NaN here
2925            // (inf * 0 = NaN), so we must short-circuit.
2926            return Some(format!(
2927                "select(({lhs} - {rhs} * floor({lhs} / {rhs})), select({rhs}, {lhs}, ({lhs} == 0.0 || sign({lhs}) == sign({rhs}))), (isInf({rhs}) && isFinite({lhs})))"
2928            ));
2929        }
2930        "rem" => {
2931            let lhs = exprs.get(inputs.first()?).cloned()?;
2932            let rhs = exprs.get(inputs.get(1)?).cloned()?;
2933            return Some(format!(
2934                "select(({lhs} - {rhs} * trunc({lhs} / {rhs})), {lhs}, (isInf({rhs}) && isFinite({lhs})))"
2935            ));
2936        }
2937        "sin" => "sin",
2938        "cos" => "cos",
2939        "tan" => "tan",
2940        "asin" => "asin",
2941        "acos" => "acos",
2942        "atan" => "atan",
2943        "atan2" => return builtin_binary("atan2", inputs, exprs),
2944        "hypot" => return builtin_binary("hypot", inputs, exprs),
2945        "pow2" => {
2946            if inputs.len() == 1 {
2947                return builtin_unary_call("exp2", inputs, exprs);
2948            }
2949            return None;
2950        }
2951        "sinh" => "sinh",
2952        "cosh" => "cosh",
2953        "tanh" => "tanh",
2954        "exp" => "exp",
2955        "log" => "log",
2956        "log2" => "log2",
2957        "sqrt" => "sqrt",
2958        "abs" => "abs",
2959        "exp2" => "exp2",
2960        "floor" => "floor",
2961        "ceil" => "ceil",
2962        "round" => "round",
2963        "trunc" => "trunc",
2964        "asinh" => return builtin_unary_call("asinh", inputs, exprs),
2965        "acosh" => return builtin_unary_call("acosh", inputs, exprs),
2966        "atanh" => return builtin_unary_call("atanh", inputs, exprs),
2967        "max" => return builtin_binary("max", inputs, exprs),
2968        "min" => return builtin_binary("min", inputs, exprs),
2969        _ => {
2970            return match name.to_ascii_lowercase().as_str() {
2971                "log10" => {
2972                    let arg = exprs.get(inputs.first()?).cloned()?;
2973                    let constant = cast_literal(scalar_ty, "0.4342944819032518");
2974                    Some(format!("(log({arg}) * {constant})"))
2975                }
2976                "log1p" => {
2977                    let arg = exprs.get(inputs.first()?).cloned()?;
2978                    let one = cast_literal(scalar_ty, "1.0");
2979                    Some(format!("log({arg} + {one})"))
2980                }
2981                "expm1" => {
2982                    let arg = exprs.get(inputs.first()?).cloned()?;
2983                    let one = cast_literal(scalar_ty, "1.0");
2984                    Some(format!("(exp({arg}) - {one})"))
2985                }
2986                _ => None,
2987            }
2988        }
2989    };
2990    let arg = exprs.get(inputs.first()?).cloned()?;
2991    Some(format!("{func}({arg})"))
2992}
2993
2994fn builtin_binary(
2995    func: &str,
2996    inputs: &[ValueId],
2997    exprs: &HashMap<ValueId, String>,
2998) -> Option<String> {
2999    let lhs = exprs.get(inputs.first()?).cloned()?;
3000    let rhs = exprs.get(inputs.get(1)?).cloned()?;
3001    Some(format!("{func}({lhs}, {rhs})"))
3002}
3003
3004fn builtin_unary_call(
3005    func: &str,
3006    inputs: &[ValueId],
3007    exprs: &HashMap<ValueId, String>,
3008) -> Option<String> {
3009    let arg = exprs.get(inputs.first()?).cloned()?;
3010    Some(format!("{func}({arg})"))
3011}
3012
3013fn builtin_identity(inputs: &[ValueId], exprs: &HashMap<ValueId, String>) -> Option<String> {
3014    exprs.get(inputs.first()?).cloned()
3015}
3016
3017fn cast_literal(scalar_ty: &str, literal: &str) -> String {
3018    if scalar_ty == "f64" {
3019        format!("{scalar_ty}({literal})")
3020    } else {
3021        literal.to_string()
3022    }
3023}
3024
3025fn split_add_with_scalar(
3026    graph: &AccelGraph,
3027    vid: ValueId,
3028    assigned: &HashSet<NodeId>,
3029    nodes: &mut Vec<NodeId>,
3030) -> Option<(NodeId, ValueId, ImageScalar)> {
3031    let (node_id, node) = node_from_value(graph, vid)?;
3032    match node.label {
3033        AccelNodeLabel::Primitive(PrimitiveOp::Add) => {
3034            if node.inputs.len() != 2 {
3035                return None;
3036            }
3037            let lhs = node.inputs[0];
3038            let rhs = node.inputs[1];
3039            if let Some(scalar) = capture_image_scalar(graph, rhs, assigned, nodes) {
3040                return Some((node_id, lhs, scalar));
3041            }
3042            if let Some(scalar) = capture_image_scalar(graph, lhs, assigned, nodes) {
3043                return Some((node_id, rhs, scalar));
3044            }
3045            None
3046        }
3047        AccelNodeLabel::Primitive(PrimitiveOp::Sub) => {
3048            if node.inputs.len() != 2 {
3049                return None;
3050            }
3051            let lhs = node.inputs[0];
3052            let rhs = node.inputs[1];
3053            if let Some(ImageScalar::Constant(value)) =
3054                capture_image_scalar(graph, rhs, assigned, nodes)
3055            {
3056                return Some((node_id, lhs, ImageScalar::Constant(-value)));
3057            }
3058            None
3059        }
3060        _ => None,
3061    }
3062}
3063
3064fn split_mul_with_scalar(
3065    graph: &AccelGraph,
3066    vid: ValueId,
3067    assigned: &HashSet<NodeId>,
3068    nodes: &mut Vec<NodeId>,
3069) -> Option<(NodeId, ValueId, ImageScalar)> {
3070    let (node_id, node) = node_from_value(graph, vid)?;
3071    match node.label {
3072        AccelNodeLabel::Primitive(PrimitiveOp::Mul)
3073        | AccelNodeLabel::Primitive(PrimitiveOp::ElemMul) => {
3074            if node.inputs.len() != 2 {
3075                return None;
3076            }
3077            let lhs = node.inputs[0];
3078            let rhs = node.inputs[1];
3079            if let Some(scalar) = capture_image_scalar(graph, rhs, assigned, nodes) {
3080                return Some((node_id, lhs, scalar));
3081            }
3082            if let Some(scalar) = capture_image_scalar(graph, lhs, assigned, nodes) {
3083                return Some((node_id, rhs, scalar));
3084            }
3085            None
3086        }
3087        _ => None,
3088    }
3089}
3090
3091fn split_max_with_zero_scalar(
3092    graph: &AccelGraph,
3093    vid: ValueId,
3094    assigned: &HashSet<NodeId>,
3095    nodes: &mut Vec<NodeId>,
3096) -> Option<(NodeId, ValueId)> {
3097    let (node_id, node) = node_from_value(graph, vid)?;
3098    match &node.label {
3099        AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("max") => {
3100            if node.inputs.len() != 2 {
3101                if log::log_enabled!(log::Level::Debug) {
3102                    log::debug!(
3103                        "split_max_with_zero_scalar: node {node_id:?} has {} inputs",
3104                        node.inputs.len()
3105                    );
3106                }
3107                return None;
3108            }
3109            let lhs = node.inputs[0];
3110            let rhs = node.inputs[1];
3111            if let Some(ImageScalar::Constant(value)) =
3112                capture_image_scalar(graph, rhs, assigned, nodes)
3113            {
3114                if approx_eq(value, 0.0) {
3115                    if log::log_enabled!(log::Level::Debug) {
3116                        log::debug!(
3117                            "split_max_with_zero_scalar: rhs zero constant for node {node_id:?}"
3118                        );
3119                    }
3120                    return Some((node_id, lhs));
3121                }
3122            }
3123            if let Some(ImageScalar::Constant(value)) =
3124                capture_image_scalar(graph, lhs, assigned, nodes)
3125            {
3126                if approx_eq(value, 0.0) {
3127                    if log::log_enabled!(log::Level::Debug) {
3128                        log::debug!(
3129                            "split_max_with_zero_scalar: lhs zero constant for node {node_id:?}"
3130                        );
3131                    }
3132                    return Some((node_id, rhs));
3133                }
3134            }
3135            if log::log_enabled!(log::Level::Debug) {
3136                log::debug!(
3137                    "split_max_with_zero_scalar: node {node_id:?} inputs not zero constants"
3138                );
3139            }
3140            None
3141        }
3142        _ => None,
3143    }
3144}
3145
3146fn resolve_numeric_vector_constant(graph: &AccelGraph, vid: ValueId) -> Option<Vec<f64>> {
3147    if let Some(scalar) = resolve_scalar_constant(graph, vid) {
3148        return Some(vec![scalar]);
3149    }
3150    let info = graph.value(vid)?;
3151    match &info.constant {
3152        Some(Value::Tensor(tensor)) if !tensor.data.is_empty() => Some(tensor.data.clone()),
3153        Some(Value::LogicalArray(arr)) if !arr.data.is_empty() => Some(
3154            arr.data
3155                .iter()
3156                .map(|v| if *v == 0 { 0.0 } else { 1.0 })
3157                .collect(),
3158        ),
3159        Some(Value::Bool(flag)) => Some(vec![if *flag { 1.0 } else { 0.0 }]),
3160        Some(Value::Int(iv)) => Some(vec![iv.to_f64()]),
3161        Some(Value::Num(num)) => Some(vec![*num]),
3162        _ => None,
3163    }
3164}
3165
3166fn match_mean_axes(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId, Vec<f64>)> {
3167    let (node_id, node) = node_from_value(graph, vid)?;
3168    match &node.label {
3169        AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mean") => {}
3170        _ => return None,
3171    }
3172    if node.inputs.len() < 2 {
3173        return None;
3174    }
3175    let data_vid = node.inputs[0];
3176    let dims_vid = node.inputs[1];
3177    let dims = resolve_numeric_vector_constant(graph, dims_vid)?;
3178    Some((node_id, data_vid, dims))
3179}
3180
3181fn dims_match_unordered(found: &[f64], expected: &[f64]) -> bool {
3182    if found.len() != expected.len() {
3183        return false;
3184    }
3185    let mut a: Vec<i64> = found.iter().map(|d| d.round() as i64).collect();
3186    let mut b: Vec<i64> = expected.iter().map(|d| d.round() as i64).collect();
3187    a.sort_unstable();
3188    b.sort_unstable();
3189    a == b
3190}
3191
3192fn peel_mean_dims(
3193    graph: &AccelGraph,
3194    vid: ValueId,
3195    expected_dims: &[f64],
3196    assigned: &HashSet<NodeId>,
3197    nodes: &mut Vec<NodeId>,
3198) -> Option<ValueId> {
3199    if expected_dims.is_empty() {
3200        return Some(vid);
3201    }
3202    let (node_id, data_vid, dims) = match_mean_axes(graph, vid)?;
3203    if assigned.contains(&node_id) {
3204        return None;
3205    }
3206    if dims.len() == expected_dims.len() && dims_match_unordered(&dims, expected_dims) {
3207        nodes.push(node_id);
3208        return Some(data_vid);
3209    }
3210    if dims.len() == 1 && approx_eq(dims[0], expected_dims[0]) {
3211        nodes.push(node_id);
3212        return peel_mean_dims(graph, data_vid, &expected_dims[1..], assigned, nodes);
3213    }
3214    None
3215}
3216
3217struct ImageNormalizeMatch {
3218    nodes: Vec<NodeId>,
3219    input: ValueId,
3220    epsilon: ImageScalar,
3221    gain: Option<ImageScalar>,
3222    bias: Option<ImageScalar>,
3223    gamma: Option<ImageScalar>,
3224}
3225
3226fn analyze_image_normalize(
3227    graph: &AccelGraph,
3228    pow_node_id: NodeId,
3229    assigned: &HashSet<NodeId>,
3230) -> Option<ImageNormalizeMatch> {
3231    let pow_node = graph.node(pow_node_id)?;
3232    if log::log_enabled!(log::Level::Trace) {
3233        log::trace!(
3234            "image_normalize: inspect pow candidate node={pow_node_id:?} label={:?}",
3235            pow_node.label
3236        );
3237    }
3238    macro_rules! img_norm_fail {
3239        ($reason:expr) => {{
3240            if log::log_enabled!(log::Level::Trace) {
3241                log::trace!(
3242                    "image_normalize: reject node {pow_node_id:?} reason={}",
3243                    $reason
3244                );
3245            }
3246            return None;
3247        }};
3248    }
3249    if !matches!(
3250        pow_node.label,
3251        AccelNodeLabel::Primitive(PrimitiveOp::ElemPow)
3252    ) {
3253        img_norm_fail!("not elem pow");
3254    }
3255    if pow_node.inputs.len() != 2 || pow_node.outputs.len() != 1 {
3256        img_norm_fail!("unexpected pow arity");
3257    }
3258
3259    let mut nodes: Vec<NodeId> = vec![pow_node_id];
3260
3261    let gamma_scalar = capture_image_scalar(graph, pow_node.inputs[1], assigned, &mut nodes)?;
3262    if log::log_enabled!(log::Level::Trace) {
3263        log::trace!("image_normalize: node {pow_node_id:?} gamma scalar={gamma_scalar:?}");
3264    }
3265    let gamma_opt = match &gamma_scalar {
3266        ImageScalar::Constant(value) if approx_eq(*value, 1.0) => None,
3267        _ => Some(gamma_scalar),
3268    };
3269
3270    let (clamp_node_id, clamp_input_vid) =
3271        split_max_with_zero_scalar(graph, pow_node.inputs[0], assigned, &mut nodes)?;
3272    if assigned.contains(&clamp_node_id) {
3273        img_norm_fail!("clamp node already assigned");
3274    }
3275    nodes.push(clamp_node_id);
3276
3277    let pre_bias_vid = peel_numeric_casts(graph, clamp_input_vid, assigned, &mut nodes)?;
3278    let (pre_gain_vid, bias_opt) = if let Some((add_node_id, base_vid, bias_scalar)) =
3279        split_add_with_scalar(graph, pre_bias_vid, assigned, &mut nodes)
3280    {
3281        if assigned.contains(&add_node_id) {
3282            img_norm_fail!("bias add already assigned");
3283        }
3284        nodes.push(add_node_id);
3285        let bias = match &bias_scalar {
3286            ImageScalar::Constant(value) if approx_eq(*value, 0.0) => None,
3287            _ => Some(bias_scalar),
3288        };
3289        let base_vid = peel_numeric_casts(graph, base_vid, assigned, &mut nodes)?;
3290        (base_vid, bias)
3291    } else {
3292        (pre_bias_vid, None)
3293    };
3294
3295    let (mut norm_vid, gain_opt) = if let Some((mul_node_id, base_vid, gain_scalar)) =
3296        split_mul_with_scalar(graph, pre_gain_vid, assigned, &mut nodes)
3297    {
3298        if assigned.contains(&mul_node_id) {
3299            img_norm_fail!("gain mul already assigned");
3300        }
3301        nodes.push(mul_node_id);
3302        let gain = match &gain_scalar {
3303            ImageScalar::Constant(value) if approx_eq(*value, 1.0) => None,
3304            _ => Some(gain_scalar),
3305        };
3306        let base_vid = peel_numeric_casts(graph, base_vid, assigned, &mut nodes)?;
3307        (base_vid, gain)
3308    } else {
3309        (pre_gain_vid, None)
3310    };
3311
3312    norm_vid = peel_numeric_casts(graph, norm_vid, assigned, &mut nodes)?;
3313
3314    let (div_node_id, div_node) = node_from_value(graph, norm_vid)?;
3315    if assigned.contains(&div_node_id) {
3316        img_norm_fail!("div node already assigned");
3317    }
3318    match div_node.label {
3319        AccelNodeLabel::Primitive(PrimitiveOp::ElemDiv) => {}
3320        _ => img_norm_fail!("not div primitive"),
3321    }
3322    if div_node.inputs.len() != 2 {
3323        img_norm_fail!("div arity");
3324    }
3325
3326    let diff_vid = div_node.inputs[0];
3327    let sigma_vid = peel_numeric_casts(graph, div_node.inputs[1], assigned, &mut nodes)?;
3328    let (sigma_node_id, sigma_input_vid) = match is_sqrt_node(graph, sigma_vid) {
3329        Some(pair) => pair,
3330        None => img_norm_fail!("sigma not sqrt"),
3331    };
3332    if assigned.contains(&sigma_node_id) {
3333        img_norm_fail!("sqrt node already assigned");
3334    }
3335    nodes.push(div_node_id);
3336    nodes.push(sigma_node_id);
3337
3338    let (add_node_id, mean_sq_vid, epsilon_scalar) =
3339        split_add_with_scalar(graph, sigma_input_vid, assigned, &mut nodes)?;
3340    if assigned.contains(&add_node_id) {
3341        img_norm_fail!("epsilon add already assigned");
3342    }
3343    nodes.push(add_node_id);
3344    let epsilon = epsilon_scalar;
3345    let mean_sq_vid = peel_numeric_casts(graph, mean_sq_vid, assigned, &mut nodes)?;
3346
3347    let squared_diff_vid = peel_mean_dims(graph, mean_sq_vid, &[3.0, 2.0], assigned, &mut nodes)?;
3348
3349    let (square_pow_node_id, square_pow_node) = node_from_value(graph, squared_diff_vid)?;
3350    if assigned.contains(&square_pow_node_id) {
3351        img_norm_fail!("square pow already assigned");
3352    }
3353    if !matches!(
3354        square_pow_node.label,
3355        AccelNodeLabel::Primitive(PrimitiveOp::ElemPow)
3356    ) {
3357        img_norm_fail!("variance pow not elem pow");
3358    }
3359    if square_pow_node.inputs.len() != 2 {
3360        img_norm_fail!("variance pow arity");
3361    }
3362    let exponent_trace = collect_scalar_constant(graph, square_pow_node.inputs[1])?;
3363    if !approx_eq(exponent_trace.value, 2.0) {
3364        img_norm_fail!("variance exponent != 2");
3365    }
3366    if exponent_trace.nodes.iter().any(|id| assigned.contains(id)) {
3367        img_norm_fail!("variance exponent nodes already assigned");
3368    }
3369    nodes.push(square_pow_node_id);
3370    nodes.extend(exponent_trace.nodes.iter().copied());
3371
3372    let diff_var_vid = square_pow_node.inputs[0];
3373    let (diff_var_node_id, diff_var_node) = node_from_value(graph, diff_var_vid)?;
3374    if assigned.contains(&diff_var_node_id) {
3375        img_norm_fail!("diff variance node already assigned");
3376    }
3377    if !matches!(
3378        diff_var_node.label,
3379        AccelNodeLabel::Primitive(PrimitiveOp::Sub)
3380    ) {
3381        img_norm_fail!("diff variance node not sub");
3382    }
3383    if diff_var_node.inputs.len() != 2 {
3384        img_norm_fail!("diff variance arity");
3385    }
3386    let imgs_vid = diff_var_node.inputs[0];
3387    let mu_vid = peel_numeric_casts(graph, diff_var_node.inputs[1], assigned, &mut nodes)?;
3388    nodes.push(diff_var_node_id);
3389
3390    let (diff_node_id, diff_node) = node_from_value(graph, diff_vid)?;
3391    if assigned.contains(&diff_node_id) {
3392        img_norm_fail!("diff node already assigned");
3393    }
3394    if !matches!(diff_node.label, AccelNodeLabel::Primitive(PrimitiveOp::Sub)) {
3395        img_norm_fail!("diff node not sub");
3396    }
3397    if diff_node.inputs.len() != 2 {
3398        img_norm_fail!("diff node arity");
3399    }
3400    let diff_mu_vid = peel_numeric_casts(graph, diff_node.inputs[1], assigned, &mut nodes)?;
3401    if diff_node.inputs[0] != imgs_vid || diff_mu_vid != mu_vid {
3402        img_norm_fail!("diff inputs mismatch with variance pair");
3403    }
3404    nodes.push(diff_node_id);
3405
3406    let mean_mu_input_vid = peel_mean_dims(graph, mu_vid, &[3.0, 2.0], assigned, &mut nodes)?;
3407    if mean_mu_input_vid != imgs_vid {
3408        img_norm_fail!("mean mu input mismatch");
3409    }
3410
3411    let input_info = graph.value(imgs_vid)?;
3412    match &input_info.shape {
3413        ShapeInfo::Tensor(dims) if dims.len() >= 2 => {}
3414        ShapeInfo::Unknown => {}
3415        other => {
3416            if log::log_enabled!(log::Level::Debug) {
3417                log::debug!(
3418                    "image_normalize: node {pow_node_id:?} input shape {:?}",
3419                    other
3420                );
3421            }
3422            img_norm_fail!("input not 3-d tensor");
3423        }
3424    }
3425
3426    nodes.sort_unstable();
3427    nodes.dedup();
3428
3429    Some(ImageNormalizeMatch {
3430        nodes,
3431        input: imgs_vid,
3432        epsilon,
3433        gain: gain_opt,
3434        bias: bias_opt,
3435        gamma: gamma_opt,
3436    })
3437}
3438
3439#[cfg(test)]
3440mod tests {
3441    use super::*;
3442    use crate::graph::{
3443        AccelGraph, AccelGraphTag, AccelNode, AccelNodeLabel, AccelOpCategory, InstrSpan,
3444        PrimitiveOp, ValueId, ValueInfo, ValueOrigin, VarKind,
3445    };
3446    use runmat_builtins::{Type, Value};
3447    use std::collections::HashMap as StdHashMap;
3448
3449    fn simple_elementwise_graph() -> AccelGraph {
3450        let values = vec![
3451            // Value 0: input tensor
3452            ValueInfo {
3453                id: 0,
3454                origin: ValueOrigin::Variable {
3455                    kind: VarKind::Global,
3456                    index: 0,
3457                },
3458                ty: Type::tensor(),
3459                shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3460                constant: None,
3461            },
3462            // Node 0 output value (value id 1)
3463            ValueInfo {
3464                id: 1,
3465                origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3466                ty: Type::tensor(),
3467                shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3468                constant: None,
3469            },
3470            // Node 1 output value (value id 2)
3471            ValueInfo {
3472                id: 2,
3473                origin: ValueOrigin::NodeOutput { node: 1, output: 0 },
3474                ty: Type::tensor(),
3475                shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3476                constant: None,
3477            },
3478        ];
3479
3480        let node0 = AccelNode {
3481            id: 0,
3482            label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3483            category: AccelOpCategory::Elementwise,
3484            inputs: vec![0, 0],
3485            outputs: vec![1],
3486            span: InstrSpan { start: 10, end: 10 },
3487            tags: vec![AccelGraphTag::Elementwise],
3488        };
3489        let node1 = AccelNode {
3490            id: 1,
3491            label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3492            category: AccelOpCategory::Elementwise,
3493            inputs: vec![1, 0],
3494            outputs: vec![2],
3495            span: InstrSpan { start: 11, end: 11 },
3496            tags: vec![AccelGraphTag::Elementwise],
3497        };
3498
3499        AccelGraph {
3500            nodes: vec![node0, node1],
3501            values,
3502            var_bindings: StdHashMap::new(),
3503            node_bindings: StdHashMap::new(),
3504        }
3505    }
3506
3507    #[test]
3508    fn detects_chain() {
3509        let graph = simple_elementwise_graph();
3510        let groups = detect_fusion_groups(&graph);
3511        assert_eq!(groups.len(), 1);
3512        let group = &groups[0];
3513        assert_eq!(group.nodes, vec![0, 1]);
3514        assert_eq!(group.kind, FusionKind::ElementwiseChain);
3515    }
3516
3517    #[test]
3518    fn prepare_fusion_plan_requires_semantic_candidate_groups() {
3519        let graph = simple_elementwise_graph();
3520        let groups = detect_fusion_groups(&graph);
3521        assert_eq!(groups.len(), 1);
3522
3523        let plan = prepare_fusion_plan(Some(&graph), &groups, 0);
3524        assert!(
3525            plan.is_none(),
3526            "bytecode groups alone should not produce an executable fusion plan without semantic candidate evidence"
3527        );
3528    }
3529
3530    #[test]
3531    fn prepare_fusion_plan_allows_semantic_gated_groups() {
3532        let graph = simple_elementwise_graph();
3533        let groups = detect_fusion_groups(&graph);
3534        assert_eq!(groups.len(), 1);
3535
3536        let plan = prepare_fusion_plan(Some(&graph), &groups, 1);
3537        assert!(
3538            plan.is_some(),
3539            "semantic candidate evidence should allow executable fusion plan preparation"
3540        );
3541    }
3542
3543    #[test]
3544    fn prepare_fusion_plan_recovers_empty_group_nodes_from_contained_runtime_span() {
3545        let graph = simple_elementwise_graph();
3546        let groups = vec![FusionGroup {
3547            id: 0,
3548            kind: FusionKind::ElementwiseChain,
3549            nodes: Vec::new(),
3550            shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3551            span: InstrSpan { start: 10, end: 10 },
3552            pattern: None,
3553            stack_layout: None,
3554        }];
3555
3556        let plan = prepare_fusion_plan(Some(&graph), &groups, 1)
3557            .expect("runtime group sanitization should recover contained elementwise nodes");
3558        assert_eq!(plan.groups.len(), 1);
3559        assert_eq!(
3560            plan.groups[0].group.nodes,
3561            vec![0],
3562            "runtime sanitization should recover a compatible contained node for empty group mapping"
3563        );
3564    }
3565
3566    #[test]
3567    fn prepare_fusion_plan_rejects_empty_group_nodes_when_runtime_graph_is_too_far() {
3568        let graph = simple_elementwise_graph();
3569        let groups = vec![FusionGroup {
3570            id: 0,
3571            kind: FusionKind::ElementwiseChain,
3572            nodes: Vec::new(),
3573            shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3574            span: InstrSpan { start: 20, end: 20 },
3575            pattern: None,
3576            stack_layout: None,
3577        }];
3578
3579        let plan = prepare_fusion_plan(Some(&graph), &groups, 1);
3580        assert!(
3581            plan.is_none(),
3582            "runtime sanitization should reject empty group mapping when no compatible nearby nodes exist"
3583        );
3584    }
3585
3586    #[test]
3587    fn prepare_fusion_plan_rejects_empty_group_nodes_when_runtime_node_covers_group_span() {
3588        let values = vec![
3589            ValueInfo {
3590                id: 0,
3591                origin: ValueOrigin::Variable {
3592                    kind: VarKind::Global,
3593                    index: 0,
3594                },
3595                ty: Type::tensor(),
3596                shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3597                constant: None,
3598            },
3599            ValueInfo {
3600                id: 1,
3601                origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3602                ty: Type::tensor(),
3603                shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3604                constant: None,
3605            },
3606        ];
3607        let graph = AccelGraph {
3608            nodes: vec![AccelNode {
3609                id: 0,
3610                label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3611                category: AccelOpCategory::Elementwise,
3612                inputs: vec![0, 0],
3613                outputs: vec![1],
3614                span: InstrSpan { start: 10, end: 12 },
3615                tags: vec![AccelGraphTag::Elementwise],
3616            }],
3617            values,
3618            var_bindings: StdHashMap::new(),
3619            node_bindings: StdHashMap::new(),
3620        };
3621        let groups = vec![FusionGroup {
3622            id: 0,
3623            kind: FusionKind::ElementwiseChain,
3624            nodes: Vec::new(),
3625            shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3626            span: InstrSpan { start: 11, end: 11 },
3627            pattern: None,
3628            stack_layout: None,
3629        }];
3630
3631        let plan = prepare_fusion_plan(Some(&graph), &groups, 1);
3632        assert!(
3633            plan.is_none(),
3634            "runtime sanitization should reject covering runtime-node spans when semantic group spans are narrower"
3635        );
3636    }
3637
3638    #[test]
3639    fn prepare_fusion_plan_rejects_stale_mapped_nodes_without_runtime_remap() {
3640        let graph = simple_elementwise_graph();
3641        let groups = vec![FusionGroup {
3642            id: 0,
3643            kind: FusionKind::ElementwiseChain,
3644            nodes: vec![1],
3645            shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3646            span: InstrSpan { start: 10, end: 10 },
3647            pattern: None,
3648            stack_layout: None,
3649        }];
3650
3651        let plan = prepare_fusion_plan(Some(&graph), &groups, 1);
3652        assert!(
3653            plan.is_none(),
3654            "runtime sanitization should reject stale mapped nodes instead of remapping from runtime graph scan"
3655        );
3656    }
3657
3658    #[test]
3659    fn builds_plan_and_template() {
3660        let graph = simple_elementwise_graph();
3661        let groups = detect_fusion_groups(&graph);
3662        let plan = FusionPlan::from_graph(&graph, &groups);
3663        assert_eq!(plan.groups.len(), 1);
3664        let group_plan = &plan.groups[0];
3665        assert!(group_plan.kernel.supported);
3666        let wgsl = group_plan.generate_wgsl("f32").expect("wgsl");
3667        assert!(wgsl.contains("@compute"));
3668        assert!(group_plan.group.element_count().is_some());
3669    }
3670
3671    #[test]
3672    fn stack_pattern_tracks_repeated_constants() {
3673        let values = vec![
3674            ValueInfo {
3675                id: 0,
3676                origin: ValueOrigin::Variable {
3677                    kind: VarKind::Global,
3678                    index: 0,
3679                },
3680                ty: Type::tensor(),
3681                shape: ShapeInfo::Tensor(vec![Some(4)]),
3682                constant: None,
3683            },
3684            ValueInfo {
3685                id: 1,
3686                origin: ValueOrigin::Constant,
3687                ty: Type::tensor(),
3688                shape: ShapeInfo::Tensor(vec![Some(4)]),
3689                constant: Some(Value::Num(1.0)),
3690            },
3691            ValueInfo {
3692                id: 2,
3693                origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3694                ty: Type::tensor(),
3695                shape: ShapeInfo::Tensor(vec![Some(4)]),
3696                constant: None,
3697            },
3698            ValueInfo {
3699                id: 3,
3700                origin: ValueOrigin::NodeOutput { node: 1, output: 0 },
3701                ty: Type::tensor(),
3702                shape: ShapeInfo::Tensor(vec![Some(4)]),
3703                constant: None,
3704            },
3705        ];
3706
3707        let node0 = AccelNode {
3708            id: 0,
3709            label: AccelNodeLabel::Primitive(PrimitiveOp::Add),
3710            category: AccelOpCategory::Elementwise,
3711            inputs: vec![0, 1],
3712            outputs: vec![2],
3713            span: InstrSpan { start: 5, end: 5 },
3714            tags: vec![AccelGraphTag::Elementwise],
3715        };
3716        let node1 = AccelNode {
3717            id: 1,
3718            label: AccelNodeLabel::Primitive(PrimitiveOp::Add),
3719            category: AccelOpCategory::Elementwise,
3720            inputs: vec![2, 1],
3721            outputs: vec![3],
3722            span: InstrSpan { start: 6, end: 6 },
3723            tags: vec![AccelGraphTag::Elementwise],
3724        };
3725
3726        let graph = AccelGraph {
3727            nodes: vec![node0, node1],
3728            values,
3729            var_bindings: StdHashMap::new(),
3730            node_bindings: StdHashMap::new(),
3731        };
3732
3733        let groups = detect_fusion_groups(&graph);
3734        assert_eq!(groups.len(), 1);
3735        let plan = FusionPlan::from_graph(&graph, &groups);
3736        let group_plan = &plan.groups[0];
3737        assert_eq!(group_plan.inputs.len(), 2);
3738        assert!(group_plan.stack_pattern.is_empty());
3739        assert!(group_plan.constants.contains_key(&1));
3740        assert!(group_plan.const_values.contains_key(&1));
3741    }
3742
3743    #[test]
3744    fn builtin_expr_supports_extended_set() {
3745        let mut exprs: StdHashMap<ValueId, String> = StdHashMap::new();
3746        exprs.insert(0, "v0".to_string());
3747        exprs.insert(1, "v1".to_string());
3748
3749        let log1p = super::builtin_expr("log1p", &[0], &exprs, "f32");
3750        assert!(log1p.is_some());
3751
3752        let log10 = super::builtin_expr("log10", &[0], &exprs, "f64");
3753        assert!(log10.unwrap().contains("log"));
3754
3755        let expm1 = super::builtin_expr("expm1", &[0], &exprs, "f32");
3756        assert!(expm1.unwrap().contains("exp"));
3757
3758        let floor = super::builtin_expr("floor", &[0], &exprs, "f32");
3759        assert_eq!(floor.unwrap(), "floor(v0)");
3760
3761        let atan2 = super::builtin_expr("atan2", &[0, 1], &exprs, "f32");
3762        assert_eq!(atan2.unwrap(), "atan2(v0, v1)");
3763
3764        let asinh = super::builtin_expr("asinh", &[0], &exprs, "f32");
3765        assert_eq!(asinh.unwrap(), "asinh(v0)");
3766
3767        let acosh = super::builtin_expr("acosh", &[0], &exprs, "f32");
3768        assert_eq!(acosh.unwrap(), "acosh(v0)");
3769
3770        let atanh = super::builtin_expr("atanh", &[0], &exprs, "f32");
3771        assert_eq!(atanh.unwrap(), "atanh(v0)");
3772
3773        let hypot = super::builtin_expr("hypot", &[0, 1], &exprs, "f32");
3774        assert_eq!(hypot.unwrap(), "hypot(v0, v1)");
3775
3776        let sign = super::builtin_expr("sign", &[0], &exprs, "f32");
3777        assert_eq!(sign.unwrap(), "sign(v0)");
3778
3779        let fix = super::builtin_expr("fix", &[0], &exprs, "f32");
3780        assert_eq!(fix.unwrap(), "trunc(v0)");
3781
3782        let modulo = super::builtin_expr("mod", &[0, 1], &exprs, "f32");
3783        let modulo = modulo.unwrap();
3784        assert!(modulo.contains("floor"));
3785        assert!(modulo.contains("isInf"));
3786
3787        let rem = super::builtin_expr("rem", &[0, 1], &exprs, "f32");
3788        let rem = rem.unwrap();
3789        assert!(rem.contains("trunc"));
3790        assert!(rem.contains("isInf"));
3791
3792        let pow2 = super::builtin_expr("pow2", &[0], &exprs, "f32");
3793        assert_eq!(pow2.unwrap(), "exp2(v0)");
3794
3795        let single = super::builtin_expr("single", &[0], &exprs, "f32");
3796        assert_eq!(single.unwrap(), "v0");
3797
3798        let double = super::builtin_expr("double", &[0], &exprs, "f64");
3799        assert_eq!(double.unwrap(), "v0");
3800    }
3801
3802    #[test]
3803    fn fanout_chain_with_casts_supported() {
3804        let values = vec![
3805            // Base input tensor
3806            ValueInfo {
3807                id: 0,
3808                origin: ValueOrigin::Variable {
3809                    kind: VarKind::Global,
3810                    index: 0,
3811                },
3812                ty: Type::tensor(),
3813                shape: ShapeInfo::Tensor(vec![Some(8)]),
3814                constant: None,
3815            },
3816            // tanh(x) output
3817            ValueInfo {
3818                id: 1,
3819                origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3820                ty: Type::tensor(),
3821                shape: ShapeInfo::Tensor(vec![Some(8)]),
3822                constant: None,
3823            },
3824            // constant scale before casting
3825            ValueInfo {
3826                id: 2,
3827                origin: ValueOrigin::Constant,
3828                ty: Type::Num,
3829                shape: ShapeInfo::Scalar,
3830                constant: Some(Value::Num(0.1)),
3831            },
3832            // single(0.1) output
3833            ValueInfo {
3834                id: 3,
3835                origin: ValueOrigin::NodeOutput { node: 1, output: 0 },
3836                ty: Type::Num,
3837                shape: ShapeInfo::Scalar,
3838                constant: None,
3839            },
3840            // scaled branch output
3841            ValueInfo {
3842                id: 4,
3843                origin: ValueOrigin::NodeOutput { node: 2, output: 0 },
3844                ty: Type::tensor(),
3845                shape: ShapeInfo::Tensor(vec![Some(8)]),
3846                constant: None,
3847            },
3848            // final add output
3849            ValueInfo {
3850                id: 5,
3851                origin: ValueOrigin::NodeOutput { node: 3, output: 0 },
3852                ty: Type::tensor(),
3853                shape: ShapeInfo::Tensor(vec![Some(8)]),
3854                constant: None,
3855            },
3856        ];
3857
3858        let tanh_node = AccelNode {
3859            id: 0,
3860            label: AccelNodeLabel::Builtin {
3861                name: "tanh".to_string(),
3862            },
3863            category: AccelOpCategory::Elementwise,
3864            inputs: vec![0],
3865            outputs: vec![1],
3866            span: InstrSpan { start: 10, end: 10 },
3867            tags: vec![AccelGraphTag::Elementwise],
3868        };
3869        let single_node = AccelNode {
3870            id: 1,
3871            label: AccelNodeLabel::Builtin {
3872                name: "single".to_string(),
3873            },
3874            category: AccelOpCategory::Elementwise,
3875            inputs: vec![2],
3876            outputs: vec![3],
3877            span: InstrSpan { start: 11, end: 11 },
3878            tags: vec![AccelGraphTag::Elementwise],
3879        };
3880        let mul_node = AccelNode {
3881            id: 2,
3882            label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3883            category: AccelOpCategory::Elementwise,
3884            inputs: vec![3, 0],
3885            outputs: vec![4],
3886            span: InstrSpan { start: 12, end: 12 },
3887            tags: vec![AccelGraphTag::Elementwise],
3888        };
3889        let add_node = AccelNode {
3890            id: 3,
3891            label: AccelNodeLabel::Primitive(PrimitiveOp::Add),
3892            category: AccelOpCategory::Elementwise,
3893            inputs: vec![1, 4],
3894            outputs: vec![5],
3895            span: InstrSpan { start: 13, end: 13 },
3896            tags: vec![AccelGraphTag::Elementwise],
3897        };
3898
3899        let graph = AccelGraph {
3900            nodes: vec![tanh_node, single_node, mul_node, add_node],
3901            values,
3902            var_bindings: StdHashMap::new(),
3903            node_bindings: StdHashMap::new(),
3904        };
3905
3906        let groups = detect_fusion_groups(&graph);
3907        assert_eq!(groups.len(), 1);
3908
3909        let plan = FusionPlan::from_graph(&graph, &groups);
3910        let group_plan = &plan.groups[0];
3911        assert!(group_plan.kernel.supported);
3912        let shader = group_plan.generate_wgsl("f32");
3913        assert!(shader
3914            .as_ref()
3915            .map(|wgsl| wgsl.contains("tanh") && wgsl.contains("output.data"))
3916            .unwrap_or(false));
3917    }
3918}