1use std::cell::RefCell;
2use std::collections::{HashMap, HashSet};
3use std::sync::{Arc, OnceLock, RwLock, Weak};
4
5use once_cell::sync::Lazy;
6use runmat_accelerate_api::ReductionFlavor;
7use runmat_builtins::Value;
8use serde::{Deserialize, Serialize};
9
10use crate::graph::{
11 AccelGraph, AccelNode, AccelNodeLabel, AccelOpCategory, InstrSpan, NodeId, PrimitiveOp,
12 ShapeInfo, ValueId, ValueInfo, ValueOrigin,
13};
14use crate::reduction_meta::{detect_reduction_signature, ReductionAxes, ReductionBehavior};
15use runmat_accelerate_api::CovNormalization;
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
18pub enum FusionKind {
19 ElementwiseChain,
20 Reduction,
21 MatmulEpilogue,
22 CenteredGram,
23 ImageNormalize,
24 PowerStepNormalize,
25 ExplainedVariance,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct FusionGroup {
30 pub id: usize,
31 pub kind: FusionKind,
32 pub nodes: Vec<NodeId>,
33 pub shape: ShapeInfo,
34 pub span: InstrSpan,
35 pub pattern: Option<FusionPattern>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub enum FusionPattern {
40 CenteredGram {
41 matrix: ValueId,
42 normalization: CovNormalization,
43 },
44 ImageNormalize(ImageNormalizePattern),
45 PowerStepNormalize {
46 lhs: ValueId,
47 rhs: ValueId,
48 epsilon: f64,
49 },
50 ExplainedVariance {
51 q: ValueId,
52 g: ValueId,
53 },
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ImageNormalizePattern {
58 pub input: ValueId,
59 pub epsilon: ImageScalar,
60 pub gain: Option<ImageScalar>,
61 pub bias: Option<ImageScalar>,
62 pub gamma: Option<ImageScalar>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum ImageScalar {
67 Constant(f64),
68 Value(ValueId),
69}
70
71pub fn detect_fusion_groups(graph: &AccelGraph) -> Vec<FusionGroup> {
72 if graph.nodes.is_empty() {
73 return Vec::new();
74 }
75
76 let consumer_map = build_consumer_map(graph);
77 let mut assigned: HashSet<NodeId> = HashSet::new();
78 let mut groups = Vec::new();
79 let mut group_id = 0usize;
80
81 detect_image_normalize(graph, &mut assigned, &mut groups, &mut group_id);
82 detect_explained_variance(graph, &mut assigned, &mut groups, &mut group_id);
83 detect_power_step_normalize(graph, &mut assigned, &mut groups, &mut group_id);
84 detect_centered_gram(graph, &mut assigned, &mut groups, &mut group_id);
85
86 for node in &graph.nodes {
87 if assigned.contains(&node.id) {
89 continue;
90 }
91 let elementwise_like = node.is_elementwise() || is_elementwise_max_min(graph, node);
92 if !elementwise_like {
93 continue;
94 }
95 if node.outputs.is_empty() {
96 continue;
97 }
98 let mut current_shape = node_output_shape(graph, node);
99 if matches!(current_shape, ShapeInfo::Unknown) {
100 continue;
101 }
102 let mut chain: Vec<NodeId> = Vec::new();
103 let mut frontier = node.id;
104 let mut local_seen: HashSet<NodeId> = HashSet::new();
105
106 loop {
107 if !local_seen.insert(frontier) {
108 break;
109 }
110 chain.push(frontier);
111 let next = find_next_elementwise(
112 graph,
113 frontier,
114 &assigned,
115 &local_seen,
116 &consumer_map,
117 ¤t_shape,
118 );
119 match next {
120 Some((next_id, next_shape)) => {
121 frontier = next_id;
122 current_shape = next_shape;
123 }
124 None => break,
125 }
126 }
127
128 if chain.len() > 1 {
129 expand_group_with_fanout(graph, &mut chain, &assigned, &consumer_map);
130 chain.sort_unstable_by_key(|id| {
131 graph
132 .node(*id)
133 .map(|node| node.span.start)
134 .unwrap_or_default()
135 });
136 chain.dedup();
137 for id in &chain {
138 assigned.insert(*id);
139 }
140 let span = group_span(graph, &chain);
141 groups.push(FusionGroup {
142 id: group_id,
143 kind: FusionKind::ElementwiseChain,
144 nodes: chain,
145 shape: current_shape.clone(),
146 span,
147 pattern: None,
148 });
149 group_id += 1;
150 }
151 }
152
153 for node in &graph.nodes {
155 if assigned.contains(&node.id) {
156 continue;
157 }
158 if !node.is_reduction() || is_elementwise_max_min(graph, node) {
159 continue;
160 }
161 let span = InstrSpan {
162 start: node.span.start,
163 end: node.span.end,
164 };
165 groups.push(FusionGroup {
166 id: group_id,
167 kind: FusionKind::Reduction,
168 nodes: vec![node.id],
169 shape: node_output_shape(graph, node),
170 span,
171 pattern: None,
172 });
173 group_id += 1;
174 }
175
176 for node in &graph.nodes {
178 if node.category != AccelOpCategory::MatMul || assigned.contains(&node.id) {
179 continue;
180 }
181 if node.outputs.is_empty() {
182 continue;
183 }
184 let mut chain: Vec<NodeId> = vec![node.id];
186 let mut frontier = node.id;
187 let mut ok = false;
188 loop {
189 let mut next_id_opt: Option<NodeId> = None;
191 for &out in &graph.node(frontier).unwrap().outputs {
192 if let Some(cons) = consumer_map.get(&out) {
193 if cons.len() == 1 {
194 next_id_opt = cons.iter().copied().next();
195 } else {
196 next_id_opt = None;
197 }
198 }
199 }
200 let Some(next_id) = next_id_opt else { break };
201 let next = graph.node(next_id).unwrap();
202 if !next.is_elementwise() {
203 break;
204 }
205 let allowed = matches!(
207 next.label,
208 AccelNodeLabel::Primitive(PrimitiveOp::Add)
209 | AccelNodeLabel::Primitive(PrimitiveOp::Sub)
210 | AccelNodeLabel::Primitive(PrimitiveOp::Mul)
211 | AccelNodeLabel::Primitive(PrimitiveOp::ElemMul)
212 | AccelNodeLabel::Primitive(PrimitiveOp::Div)
213 | AccelNodeLabel::Primitive(PrimitiveOp::ElemDiv)
214 );
215 if !allowed {
216 break;
217 }
218 chain.push(next_id);
219 frontier = next_id;
220 ok = true;
221 }
222 if ok {
223 for id in &chain {
224 assigned.insert(*id);
225 }
226 let span = group_span(graph, &chain);
227 groups.push(FusionGroup {
228 id: group_id,
229 kind: FusionKind::MatmulEpilogue,
230 nodes: chain,
231 shape: node_output_shape(graph, node),
232 span,
233 pattern: None,
234 });
235 group_id += 1;
236 }
237 }
238
239 merge_downstream_fanout(graph, &mut groups, &consumer_map);
240 groups
241}
242
243fn expand_group_with_fanout(
244 graph: &AccelGraph,
245 chain: &mut Vec<NodeId>,
246 assigned: &HashSet<NodeId>,
247 consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
248) {
249 let base_start = chain
250 .iter()
251 .filter_map(|id| graph.node(*id).map(|node| node.span.start))
252 .min()
253 .unwrap_or(0);
254 let mut node_set: HashSet<NodeId> = chain.iter().copied().collect();
255 let mut changed = true;
256 while changed {
257 changed = false;
258 for node in &graph.nodes {
259 if node_set.contains(&node.id) {
260 continue;
261 }
262 if node.span.start < base_start {
263 continue;
264 }
265 if assigned.contains(&node.id) {
266 continue;
267 }
268 if !(node.is_elementwise() || is_elementwise_max_min(graph, node)) {
269 continue;
270 }
271 if node.outputs.is_empty() {
272 continue;
273 }
274 let mut feeds_group = false;
275 let mut all_consumers_ok = true;
276 for &out in &node.outputs {
277 if let Some(consumers) = consumer_map.get(&out) {
278 let mut consumer_in_group = false;
279 for consumer in consumers {
280 if node_set.contains(consumer) {
281 consumer_in_group = true;
282 } else {
283 all_consumers_ok = false;
284 break;
285 }
286 }
287 if !all_consumers_ok {
288 break;
289 }
290 if consumer_in_group {
291 feeds_group = true;
292 }
293 } else {
294 all_consumers_ok = false;
295 break;
296 }
297 }
298 if !feeds_group || !all_consumers_ok {
299 continue;
300 }
301 let mut inputs_ok = true;
302 for &input in &node.inputs {
303 if let Some(info) = graph.value(input) {
304 if let ValueOrigin::NodeOutput { node: producer, .. } = info.origin {
305 if !node_set.contains(&producer) {
306 if let Some(prod_node) = graph.node(producer) {
307 if prod_node.span.start >= base_start {
308 inputs_ok = false;
309 break;
310 }
311 } else {
312 inputs_ok = false;
313 break;
314 }
315 }
316 }
317 }
318 }
319 if inputs_ok {
320 node_set.insert(node.id);
321 chain.push(node.id);
322 changed = true;
323 }
324 }
325 }
326}
327
328fn build_consumer_map(graph: &AccelGraph) -> HashMap<ValueId, HashSet<NodeId>> {
329 let mut map: HashMap<ValueId, HashSet<NodeId>> = HashMap::new();
330 for node in &graph.nodes {
331 for &input in &node.inputs {
332 if let Some(value) = graph.value(input) {
333 if matches!(value.origin, crate::graph::ValueOrigin::NodeOutput { .. }) {
334 map.entry(input).or_default().insert(node.id);
335 }
336 }
337 }
338 }
339 map
340}
341
342fn merge_downstream_fanout(
343 graph: &AccelGraph,
344 groups: &mut Vec<FusionGroup>,
345 consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
346) {
347 let mut changed = true;
348 while changed {
349 changed = false;
350 let mut node_group: HashMap<NodeId, usize> = HashMap::new();
351 for (idx, group) in groups.iter().enumerate() {
352 if group.kind.is_elementwise() {
353 for &node in &group.nodes {
354 node_group.insert(node, idx);
355 }
356 }
357 }
358 'outer: for target_idx in 0..groups.len() {
359 if !groups[target_idx].kind.is_elementwise() {
360 continue;
361 }
362 let base_start = groups[target_idx].span.start;
363 let mut merge_indices: Vec<usize> = Vec::new();
364 for &node_id in &groups[target_idx].nodes {
365 let Some(node) = graph.node(node_id) else {
366 continue;
367 };
368 for &input in &node.inputs {
369 if let Some(info) = graph.value(input) {
370 if let ValueOrigin::NodeOutput { node: producer, .. } = info.origin {
371 if let Some(&source_idx) = node_group.get(&producer) {
372 if source_idx == target_idx {
373 continue;
374 }
375 let source_group = &groups[source_idx];
376 if !source_group.kind.is_elementwise() {
377 continue;
378 }
379 if source_group.span.start < base_start {
380 continue;
381 }
382 if !group_consumers_subset(
383 source_group,
384 target_idx,
385 groups,
386 consumer_map,
387 graph,
388 ) {
389 continue;
390 }
391 merge_indices.push(source_idx);
392 }
393 }
394 }
395 }
396 }
397 if merge_indices.is_empty() {
398 continue;
399 }
400 merge_indices.sort_unstable();
401 merge_indices.dedup();
402 for idx in &merge_indices {
403 let nodes = groups[*idx].nodes.clone();
404 groups[target_idx].nodes.extend(nodes);
405 groups[*idx].nodes.clear();
406 }
407 groups[target_idx]
408 .nodes
409 .sort_unstable_by_key(|id| graph.node(*id).map(|n| n.span.start).unwrap_or(0));
410 groups[target_idx].nodes.dedup();
411 groups[target_idx].span = group_span(graph, &groups[target_idx].nodes);
412 changed = true;
413 break 'outer;
414 }
415 if changed {
416 groups.retain(|group| !group.nodes.is_empty());
417 }
418 }
419}
420
421fn group_consumers_subset(
422 source_group: &FusionGroup,
423 target_idx: usize,
424 groups: &[FusionGroup],
425 consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
426 graph: &AccelGraph,
427) -> bool {
428 let target_nodes: HashSet<NodeId> = groups[target_idx].nodes.iter().copied().collect();
429 let source_nodes: HashSet<NodeId> = source_group.nodes.iter().copied().collect();
430 for &node_id in &source_group.nodes {
431 let Some(node) = graph.node(node_id) else {
432 continue;
433 };
434 for &out in &node.outputs {
435 if let Some(consumers) = consumer_map.get(&out) {
436 for consumer in consumers {
437 if !source_nodes.contains(consumer) && !target_nodes.contains(consumer) {
438 return false;
439 }
440 }
441 }
442 }
443 }
444 true
445}
446
447fn node_output_shape(graph: &AccelGraph, node: &AccelNode) -> ShapeInfo {
448 let mut shape = ShapeInfo::Scalar;
449 for &output in &node.outputs {
450 if let Some(info) = graph.value(output) {
451 shape = shape.unify(&info.shape);
452 }
453 }
454 shape
455}
456
457fn find_next_elementwise(
458 graph: &AccelGraph,
459 node_id: NodeId,
460 assigned: &HashSet<NodeId>,
461 local_seen: &HashSet<NodeId>,
462 consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
463 current_shape: &ShapeInfo,
464) -> Option<(NodeId, ShapeInfo)> {
465 let node = graph.node(node_id)?;
466 let mut candidate: Option<(NodeId, ShapeInfo)> = None;
467
468 for &output in &node.outputs {
469 let consumers = consumer_map.get(&output)?;
470 if consumers.len() != 1 {
471 return None;
472 }
473 let next_id = *consumers.iter().next()?;
474 if next_id <= node_id || assigned.contains(&next_id) || local_seen.contains(&next_id) {
475 return None;
476 }
477 let next_node = graph.node(next_id)?;
478 if !(next_node.is_elementwise() || is_elementwise_max_min(graph, next_node)) {
479 return None;
480 }
481 if !next_node.inputs.contains(&output) {
483 continue;
484 }
485 let next_shape = node_output_shape(graph, next_node);
486 if matches!(next_shape, ShapeInfo::Unknown) {
487 return None;
488 }
489 let unified = current_shape.unify(&next_shape);
490 if matches!(unified, ShapeInfo::Unknown) {
491 return None;
492 }
493 candidate = Some((next_id, unified));
494 break;
495 }
496
497 candidate
498}
499
500fn is_elementwise_max_min(graph: &AccelGraph, node: &AccelNode) -> bool {
501 match &node.label {
502 AccelNodeLabel::Builtin { name }
503 if name.eq_ignore_ascii_case("max") || name.eq_ignore_ascii_case("min") =>
504 {
505 if node.inputs.len() < 2 {
506 return false;
507 }
508 !value_is_placeholder(graph, node.inputs[1])
509 }
510 _ => false,
511 }
512}
513
514fn value_is_placeholder(graph: &AccelGraph, vid: ValueId) -> bool {
515 let Some(info) = graph.value(vid) else {
516 return false;
517 };
518 let Some(constant) = &info.constant else {
519 return false;
520 };
521 match constant {
522 Value::Tensor(t) => t.data.is_empty(),
523 Value::LogicalArray(l) => l.data.is_empty(),
524 Value::StringArray(sa) => sa.data.is_empty(),
525 Value::CharArray(ca) => ca.data.is_empty(),
526 Value::Cell(cell) => cell.data.is_empty(),
527 Value::String(s) => s.is_empty(),
528 _ => false,
529 }
530}
531
532fn group_span(graph: &AccelGraph, nodes: &[NodeId]) -> InstrSpan {
533 let mut start = usize::MAX;
534 let mut end = 0usize;
535 for &id in nodes {
536 if let Some(node) = graph.node(id) {
537 start = start.min(node.span.start);
538 end = end.max(node.span.end);
539 }
540 }
541 if start == usize::MAX {
542 start = 0;
543 }
544 InstrSpan { start, end }
545}
546
547#[derive(Debug, Clone)]
548pub struct FusionPlan {
549 pub groups: Vec<FusionGroupPlan>,
550}
551
552#[derive(Debug, Clone)]
553pub struct FusionGroupPlan {
554 pub index: usize,
555 pub group: FusionGroup,
556 pub operations: Vec<FusionOp>,
557 pub inputs: Vec<ValueId>,
558 pub stack_pattern: Vec<usize>,
559 pub constants: HashMap<usize, Value>,
560 pub const_values: HashMap<ValueId, Value>,
561 pub output: Option<ValueId>,
562 pub kernel: FusionKernelSpec,
563 pub reduction_data: Option<ValueId>,
565 pub reduction_dim: Option<ValueId>,
567 pub reduction_flavor: Option<ReductionFlavor>,
569 pub reduction_axes: Option<ReductionAxes>,
571 pub pattern: Option<FusionPattern>,
572}
573
574#[derive(Debug, Clone)]
575pub enum FusionOp {
576 Primitive {
577 op: PrimitiveOp,
578 inputs: Vec<ValueId>,
579 output: Option<ValueId>,
580 },
581 Builtin {
582 name: String,
583 inputs: Vec<ValueId>,
584 output: Option<ValueId>,
585 },
586}
587
588#[derive(Debug, Clone)]
589pub struct FusionKernelSpec {
590 pub kind: FusionKind,
591 pub supported: bool,
592}
593
594impl FusionKernelSpec {
595 fn new(kind: FusionKind, supported: bool) -> Self {
596 Self { kind, supported }
597 }
598}
599
600#[derive(Clone, Debug)]
601pub struct ActiveFusion {
602 pub kind: FusionKind,
603 pub span: InstrSpan,
604 pub element_count: Option<usize>,
605 pub supported: bool,
606}
607
608struct ActiveContext {
609 plan: Arc<FusionPlan>,
610 active_group: Option<usize>,
611}
612
613static PLAN_CACHE: Lazy<RwLock<HashMap<usize, Weak<FusionPlan>>>> =
614 Lazy::new(|| RwLock::new(HashMap::new()));
615
616thread_local! {
617 static ACTIVE_PLAN: RefCell<Option<ActiveContext>> = const { RefCell::new(None) };
618}
619
620fn fusion_debug_enabled() -> bool {
621 static FLAG: OnceLock<bool> = OnceLock::new();
622 *FLAG.get_or_init(|| match std::env::var("RUNMAT_DEBUG_FUSION") {
623 Ok(v) => v == "1" || v.eq_ignore_ascii_case("true") || v.eq_ignore_ascii_case("yes"),
624 Err(_) => false,
625 })
626}
627
628pub fn prepare_fusion_plan(
629 graph: Option<&AccelGraph>,
630 groups: &[FusionGroup],
631) -> Option<Arc<FusionPlan>> {
632 let graph = graph?;
633 if groups.is_empty() {
634 return None;
635 }
636 let key = graph as *const AccelGraph as usize;
637 if let Some(plan) = PLAN_CACHE
638 .read()
639 .ok()
640 .and_then(|guard| guard.get(&key).and_then(|weak| weak.upgrade()))
641 {
642 return Some(plan);
643 }
644
645 let plan = FusionPlan::from_graph(graph, groups);
646 let plan = Arc::new(plan);
647 if let Ok(mut guard) = PLAN_CACHE.write() {
648 guard.insert(key, Arc::downgrade(&plan));
649 }
650 Some(plan)
651}
652
653pub fn activate_fusion_plan(plan: Option<Arc<FusionPlan>>) {
654 ACTIVE_PLAN.with(|ctx| {
655 let mut slot = ctx.borrow_mut();
656 *slot = plan.map(|plan| ActiveContext {
657 plan,
658 active_group: None,
659 });
660 });
661}
662
663pub fn deactivate_fusion_plan() {
664 ACTIVE_PLAN.with(|ctx| {
665 ctx.borrow_mut().take();
666 });
667}
668
669pub fn set_current_pc(pc: usize) {
670 ACTIVE_PLAN.with(|ctx| {
671 if let Some(context) = ctx.borrow_mut().as_mut() {
672 context.active_group = context.plan.group_for_pc(pc);
673 }
674 });
675}
676
677pub fn active_fusion() -> Option<ActiveFusion> {
678 ACTIVE_PLAN.with(|ctx| {
679 ctx.borrow()
680 .as_ref()
681 .and_then(|context| {
682 context
683 .active_group
684 .and_then(|idx| context.plan.groups.get(idx))
685 })
686 .map(|plan| ActiveFusion {
687 kind: plan.group.kind.clone(),
688 span: plan.group.span.clone(),
689 element_count: plan.element_count(),
690 supported: plan.kernel.supported,
691 })
692 })
693}
694
695pub fn active_group_plan_clone() -> Option<FusionGroupPlan> {
696 ACTIVE_PLAN.with(|ctx| {
697 ctx.borrow().as_ref().and_then(|context| {
698 context
699 .active_group
700 .and_then(|idx| context.plan.groups.get(idx).cloned())
701 })
702 })
703}
704
705impl FusionPlan {
706 pub fn from_graph(graph: &AccelGraph, groups: &[FusionGroup]) -> Self {
707 let plans = groups
708 .iter()
709 .enumerate()
710 .map(|(idx, group)| FusionGroupPlan::new(idx, group.clone(), graph))
711 .collect();
712 Self { groups: plans }
713 }
714
715 fn group_for_pc(&self, pc: usize) -> Option<usize> {
716 self.groups
717 .iter()
718 .find(|plan| pc >= plan.group.span.start && pc <= plan.group.span.end)
719 .map(|plan| plan.index)
720 }
721}
722
723impl From<Vec<FusionGroupPlan>> for FusionPlan {
724 fn from(groups: Vec<FusionGroupPlan>) -> Self {
725 Self { groups }
726 }
727}
728
729fn log_plan_stack_pattern(stage: &str, plan: &FusionGroupPlan, graph: &AccelGraph) {
730 if !fusion_debug_enabled() || plan.stack_pattern.is_empty() {
731 return;
732 }
733 let mut pattern_meta: Vec<String> = Vec::with_capacity(plan.stack_pattern.len());
734 for (pos, input_idx) in plan.stack_pattern.iter().enumerate() {
735 let value_id = plan.inputs.get(*input_idx).copied();
736 if let Some(vid) = value_id {
737 if let Some(info) = graph.value(vid) {
738 let node_label = match info.origin {
739 ValueOrigin::NodeOutput { node, .. } => graph
740 .node(node)
741 .map(|n| format!("{:?}", n.label))
742 .unwrap_or_else(|| "<missing-node>".to_string()),
743 _ => String::new(),
744 };
745 pattern_meta.push(format!(
746 "#{}:input_idx={} vid={} origin={:?} label={}",
747 pos, input_idx, vid, info.origin, node_label
748 ));
749 } else {
750 pattern_meta.push(format!(
751 "#{}:input_idx={} vid={} origin=<missing>",
752 pos, input_idx, vid
753 ));
754 }
755 } else {
756 pattern_meta.push(format!("#{}:input_idx={} vid=<missing>", pos, input_idx));
757 }
758 }
759 log::debug!(
760 "fusion plan {} {} stack_pattern={:?} meta={:?}",
761 plan.index,
762 stage,
763 plan.stack_pattern,
764 pattern_meta
765 );
766}
767
768impl FusionGroupPlan {
769 fn new(index: usize, group: FusionGroup, graph: &AccelGraph) -> Self {
770 let node_set: HashSet<NodeId> = group.nodes.iter().copied().collect();
771 let mut seen_inputs: HashMap<ValueId, usize> = HashMap::new();
772 let mut inputs: Vec<ValueId> = Vec::new();
773 let mut stack_pattern: Vec<usize> = Vec::new();
774 let mut constants: HashMap<usize, Value> = HashMap::new();
775 let const_values: HashMap<ValueId, Value> = HashMap::new();
776 let mut operations = Vec::new();
777 let mut reduction_flavor: Option<ReductionFlavor> = None;
778 let mut reduction_axes: Option<ReductionAxes> = None;
779 let mut reduction_data: Option<ValueId> = None;
780 let mut reduction_dim: Option<ValueId> = None;
781 let mut output: Option<ValueId> = None;
782
783 let is_reduction_group = group.kind.is_reduction();
784 for node_id in &group.nodes {
785 let Some(node) = graph.node(*node_id) else {
786 continue;
787 };
788 for input in &node.inputs {
789 let binding = graph.var_binding(*input);
790 let (external, is_variable, maybe_constant) = match graph.value(*input) {
791 Some(info) => match &info.origin {
792 ValueOrigin::NodeOutput { node: origin, .. }
793 if node_set.contains(origin) =>
794 {
795 (false, false, None)
796 }
797 ValueOrigin::Variable { .. } => (true, true, None),
798 ValueOrigin::NodeOutput { .. } if binding.is_some() => (true, true, None),
799 ValueOrigin::Constant => (true, false, info.constant.clone()),
800 _ => (true, false, None),
801 },
802 None => (true, false, None),
803 };
804 if external {
805 if is_reduction_group {
808 if let Some(constant) = maybe_constant.clone() {
809 let key = constants.len() + 1000;
811 constants.insert(key, constant);
812 continue;
813 }
814 if let Some(data_id) = reduction_data {
816 if *input != data_id {
817 continue;
819 }
820 }
821 }
822
823 let mut newly_added = false;
824 let input_idx = if let Some(idx) = seen_inputs.get(input) {
825 *idx
826 } else {
827 let idx = inputs.len();
828 inputs.push(*input);
829 seen_inputs.insert(*input, idx);
830 newly_added = true;
831 idx
832 };
833
834 if fusion_debug_enabled() {
835 let origin = graph.value(*input).map(|v| v.origin.clone());
836 log::debug!(
837 "fusion plan #{:?} consider input vid={} origin={:?} binding={:?} newly_added={} is_variable={} stack_candidate={}",
838 index,
839 input,
840 origin,
841 binding,
842 newly_added,
843 is_variable,
844 !is_variable && newly_added
845 );
846 }
847 if let Some(constant) = maybe_constant.clone() {
848 constants.insert(input_idx, constant);
849 } else if !is_variable && newly_added {
850 let allow_stack = match graph.value(*input) {
851 Some(info) => match info.origin {
852 ValueOrigin::NodeOutput { node, .. } => graph
853 .node(node)
854 .map(|n| n.span.start <= group.span.start)
855 .unwrap_or(false),
856 _ => true,
857 },
858 None => true,
859 };
860 if allow_stack {
861 stack_pattern.push(input_idx);
862 } else if fusion_debug_enabled() {
863 log::debug!(
864 "fusion plan {} skipping stack candidate vid={} origin_after_span",
865 index,
866 input
867 );
868 }
869 } else if !is_variable
870 && !newly_added
871 && matches!(
872 graph.value(*input).map(|v| &v.origin),
873 Some(ValueOrigin::Constant)
874 )
875 {
876 }
877 }
878 }
879
880 let op = match &node.label {
881 AccelNodeLabel::Primitive(p) => FusionOp::Primitive {
882 op: *p,
883 inputs: node.inputs.clone(),
884 output: node.outputs.first().copied(),
885 },
886 AccelNodeLabel::Builtin { name } => FusionOp::Builtin {
887 name: name.clone(),
888 inputs: node.inputs.clone(),
889 output: node.outputs.first().copied(),
890 },
891 AccelNodeLabel::Unknown => FusionOp::Primitive {
892 op: PrimitiveOp::UPlus,
893 inputs: node.inputs.clone(),
894 output: node.outputs.first().copied(),
895 },
896 };
897 operations.push(op);
898
899 if let Some(out) = node.outputs.first().copied() {
900 output = Some(out);
901 }
902 if node.is_reduction() {
904 if let Some(sig) = detect_reduction_signature(graph, node) {
905 reduction_data = Some(sig.data_input);
906 reduction_dim = sig.dim_arg;
907 reduction_flavor = Some(match sig.behavior {
908 ReductionBehavior::MeanLike => ReductionFlavor::Mean,
909 _ => ReductionFlavor::Sum,
910 });
911 reduction_axes = Some(sig.axes.clone());
912 }
913 }
914 }
915
916 let kind = group.kind.clone();
917 let pattern = group.pattern.clone();
918 let mut plan = Self {
919 index,
920 group,
921 operations,
922 stack_pattern,
923 constants,
924 const_values,
925 inputs,
926 output,
927 kernel: FusionKernelSpec::new(kind, true),
928 reduction_data,
929 reduction_dim,
930 reduction_flavor,
931 reduction_axes,
932 pattern,
933 };
934
935 log_plan_stack_pattern("initial", &plan, graph);
936
937 for node_id in &plan.group.nodes {
939 if let Some(node) = graph.node(*node_id) {
940 for &inp in &node.inputs {
941 if let Some(info) = graph.value(inp) {
942 if let Some(cv) = info.constant.clone() {
943 plan.const_values.insert(inp, cv);
944 }
945 }
946 }
947 }
948 }
949
950 if plan.group.kind.is_reduction() {
952 if let Some(data_vid) = plan.reduction_data {
953 let original_inputs = plan.inputs.clone();
954 let original_stack_pattern = plan.stack_pattern.clone();
955 let mut prod: HashMap<ValueId, Vec<ValueId>> = HashMap::new();
958 for op in &plan.operations {
959 match op {
960 FusionOp::Primitive {
961 inputs,
962 output,
963 op: _,
964 } => {
965 if let Some(out) = output {
966 prod.insert(*out, inputs.clone());
967 }
968 }
969 FusionOp::Builtin {
970 name: _,
971 inputs,
972 output,
973 } => {
974 if let Some(out) = output {
975 prod.insert(*out, inputs.clone());
976 }
977 }
978 }
979 }
980 let mut deps: Vec<ValueId> = Vec::new();
981 let mut visited: HashSet<ValueId> = HashSet::new();
982 let mut stack: Vec<ValueId> = vec![data_vid];
983 let mut extra_ops: Vec<FusionOp> = Vec::new();
985 let mut added_nodes: HashSet<ValueId> = HashSet::new();
986 while let Some(cur) = stack.pop() {
987 if !visited.insert(cur) {
988 continue;
989 }
990 if graph.var_binding(cur).is_some() {
991 if !deps.contains(&cur) {
992 deps.push(cur);
993 }
994 continue;
995 }
996 if let Some(info) = graph.value(cur) {
997 if matches!(info.origin, ValueOrigin::Variable { .. }) {
998 if !deps.contains(&cur) {
999 deps.push(cur);
1000 }
1001 continue;
1002 }
1003 }
1004 if original_inputs.contains(&cur) && cur != data_vid {
1006 if !deps.contains(&cur) {
1007 deps.push(cur);
1008 }
1009 continue;
1010 }
1011 if let Some(parents) = prod.get(&cur) {
1012 for p in parents {
1013 stack.push(*p);
1014 }
1015 continue;
1016 }
1017 if let Some((_, node)) = node_from_value(graph, cur) {
1019 match &node.label {
1021 AccelNodeLabel::Primitive(PrimitiveOp::Mul)
1022 | AccelNodeLabel::Primitive(PrimitiveOp::ElemMul)
1023 | AccelNodeLabel::Primitive(PrimitiveOp::Div)
1024 | AccelNodeLabel::Primitive(PrimitiveOp::ElemDiv)
1025 | AccelNodeLabel::Primitive(PrimitiveOp::ElemLeftDiv)
1026 | AccelNodeLabel::Primitive(PrimitiveOp::Add)
1027 | AccelNodeLabel::Primitive(PrimitiveOp::Sub) => {
1028 if added_nodes.insert(cur) {
1030 extra_ops.push(FusionOp::Primitive {
1031 op: match node.label {
1032 AccelNodeLabel::Primitive(op) => op,
1033 _ => PrimitiveOp::UPlus,
1034 },
1035 inputs: node.inputs.clone(),
1036 output: node.outputs.first().copied(),
1037 });
1038 }
1039 for &p in &node.inputs {
1040 stack.push(p);
1041 }
1042 continue;
1043 }
1044 AccelNodeLabel::Primitive(PrimitiveOp::ElemPow) => {
1045 if node.inputs.len() == 2 {
1047 if let Some(exp) = value_constant_f64(graph, node.inputs[1]) {
1048 if exp.is_finite() {
1049 if added_nodes.insert(cur) {
1050 extra_ops.push(FusionOp::Primitive {
1051 op: PrimitiveOp::ElemPow,
1052 inputs: node.inputs.clone(),
1053 output: node.outputs.first().copied(),
1054 });
1055 }
1056 stack.push(node.inputs[0]);
1057 stack.push(node.inputs[1]);
1059 continue;
1060 }
1061 }
1062 }
1063 }
1065 AccelNodeLabel::Builtin { name } => {
1066 if (name.eq_ignore_ascii_case("single")
1068 || name.eq_ignore_ascii_case("double"))
1069 && node.inputs.len() == 1
1070 {
1071 stack.push(node.inputs[0]);
1072 continue;
1073 }
1074 }
1076 _ => {
1077 }
1079 }
1080 }
1081 }
1082 if let Some(parents) = prod.get(&data_vid) {
1084 for &p in parents {
1085 if !deps.contains(&p) {
1086 let is_const = plan.const_values.contains_key(&p)
1088 || graph.value(p).and_then(|vi| vi.constant.as_ref()).is_some();
1089 if !is_const {
1090 deps.push(p);
1091 }
1092 }
1093 }
1094 }
1095 if !extra_ops.is_empty() {
1098 let mut new_ops = Vec::with_capacity(extra_ops.len() + plan.operations.len());
1100 new_ops.extend(extra_ops);
1101 new_ops.append(&mut plan.operations);
1102 plan.operations = new_ops;
1103 }
1104 plan.inputs = deps;
1105 for op in &plan.operations {
1107 let inputs = match op {
1108 FusionOp::Primitive { inputs, .. } => inputs,
1109 FusionOp::Builtin { inputs, .. } => inputs,
1110 };
1111 for vid in inputs {
1112 if plan.const_values.contains_key(vid) {
1113 continue;
1114 }
1115 if let Some(info) = graph.value(*vid) {
1116 if let Some(cv) = info.constant.clone() {
1117 plan.const_values.insert(*vid, cv);
1118 }
1119 }
1120 }
1121 }
1122
1123 let mut new_stack_pattern: Vec<usize> = Vec::new();
1126 for (new_idx, vid) in plan.inputs.iter().enumerate() {
1127 if let Some(old_idx) = original_inputs.iter().position(|v| v == vid) {
1128 if original_stack_pattern.contains(&old_idx) {
1129 new_stack_pattern.push(new_idx);
1130 }
1131 }
1132 }
1133
1134 let mut new_constants: HashMap<usize, Value> = HashMap::new();
1136 for (idx, vid) in plan.inputs.iter().enumerate() {
1137 if let Some(value) = plan.const_values.get(vid) {
1138 new_constants.insert(idx, value.clone());
1139 } else if let Some(info) = graph.value(*vid) {
1140 if let Some(cv) = info.constant.clone() {
1141 new_constants.insert(idx, cv);
1142 }
1143 }
1144 }
1145 plan.constants = new_constants;
1146
1147 if new_stack_pattern.is_empty() {
1148 for (idx, vid) in plan.inputs.iter().enumerate() {
1149 if plan.constants.contains_key(&idx) {
1150 continue;
1151 }
1152 if let Some(info) = graph.value(*vid) {
1153 if matches!(
1154 info.origin,
1155 ValueOrigin::Variable { .. } | ValueOrigin::Constant
1156 ) {
1157 continue;
1158 }
1159 }
1160 new_stack_pattern.push(idx);
1161 }
1162 }
1163 plan.stack_pattern = new_stack_pattern;
1164 }
1165 }
1166
1167 if plan.group.kind.is_reduction() {
1169 let original_inputs = plan.inputs.clone();
1170 plan.inputs.retain(|vid| {
1171 if let Some(info) = graph.value(*vid) {
1172 !matches!(info.origin, ValueOrigin::Constant)
1173 && !plan.const_values.contains_key(vid)
1174 } else {
1175 true
1176 }
1177 });
1178 if plan.inputs.len() != original_inputs.len() {
1179 let mut new_stack: Vec<usize> = Vec::new();
1180 for old_idx in &plan.stack_pattern {
1181 if *old_idx < original_inputs.len() {
1182 let vid = original_inputs[*old_idx];
1183 if let Some(new_idx) = plan.inputs.iter().position(|v| *v == vid) {
1184 new_stack.push(new_idx);
1185 }
1186 }
1187 }
1188 plan.stack_pattern = new_stack;
1189 }
1190 }
1191
1192 let supported = if plan.kernel.kind.is_elementwise() {
1197 plan.generate_wgsl("f32").is_some()
1198 } else if plan.kernel.kind.is_reduction() {
1199 plan.generate_reduction_wgsl("f32").is_some()
1200 } else {
1201 true
1202 };
1203 plan.kernel.supported = plan.kernel.supported && supported;
1204 if !plan.kernel.supported && fusion_debug_enabled() {
1205 let const_ids: Vec<ValueId> = plan.const_values.keys().copied().collect();
1206 log::debug!(
1207 "fusion plan {} unsupported: kind={:?} group_kind={:?} inputs={:?} reduction_data={:?} reduction_dim={:?} const_ids={:?}",
1208 plan.index,
1209 plan.kernel.kind,
1210 plan.group.kind,
1211 plan.inputs,
1212 plan.reduction_data,
1213 plan.reduction_dim,
1214 const_ids
1215 );
1216 if plan.kernel.kind.is_reduction() {
1217 let mut seen: HashSet<ValueId> = HashSet::new();
1218 let mut value_info: Vec<String> = Vec::new();
1219 for op in &plan.operations {
1220 let inputs = match op {
1221 FusionOp::Primitive { inputs, .. } => inputs,
1222 FusionOp::Builtin { inputs, .. } => inputs,
1223 };
1224 for vid in inputs {
1225 if seen.insert(*vid) {
1226 if let Some(info) = graph.value(*vid) {
1227 value_info.push(format!(
1228 "vid={} origin={:?} constant={}",
1229 vid,
1230 info.origin,
1231 info.constant.is_some()
1232 ));
1233 } else {
1234 value_info.push(format!("vid={} origin=<missing>", vid));
1235 }
1236 }
1237 }
1238 }
1239 log::debug!(
1240 "fusion reduction plan {} value summary: [{}]",
1241 plan.index,
1242 value_info.join(", ")
1243 );
1244 }
1245 }
1246
1247 if matches!(plan.group.kind, FusionKind::CenteredGram) && plan.stack_pattern.is_empty() {
1248 let mut centered_stack_idxs: Vec<usize> = Vec::new();
1249 for (idx, vid) in plan.inputs.iter().enumerate() {
1250 if plan.constants.contains_key(&idx) {
1251 continue;
1252 }
1253 if let Some(info) = graph.value(*vid) {
1254 if matches!(info.origin, ValueOrigin::NodeOutput { .. }) {
1255 centered_stack_idxs.push(idx);
1256 continue;
1257 }
1258 if matches!(info.origin, ValueOrigin::Variable { .. }) {
1259 continue;
1260 }
1261 }
1262 centered_stack_idxs.push(idx);
1263 }
1264 if centered_stack_idxs.is_empty() && !plan.inputs.is_empty() {
1265 centered_stack_idxs.push(0);
1266 }
1267 plan.stack_pattern = centered_stack_idxs;
1268 }
1269
1270 log_plan_stack_pattern("final", &plan, graph);
1271
1272 plan
1275 }
1276
1277 pub fn reduction_data_shape(&self, graph: &AccelGraph) -> Option<Vec<usize>> {
1278 let vid = self.reduction_data?;
1279 let info = graph.value(vid)?;
1280 match &info.shape {
1281 ShapeInfo::Tensor(dims) if !dims.is_empty() && dims.iter().all(|d| d.is_some()) => {
1282 Some(dims.iter().map(|d| d.unwrap()).collect())
1283 }
1284 _ => None,
1285 }
1286 }
1287
1288 pub fn element_count(&self) -> Option<usize> {
1289 self.group.element_count()
1290 }
1291
1292 pub fn constant_shape(&self, len: usize) -> Vec<usize> {
1293 match &self.group.shape {
1294 ShapeInfo::Tensor(dims) if !dims.is_empty() && dims.iter().all(|dim| dim.is_some()) => {
1295 dims.iter().map(|dim| dim.unwrap()).collect()
1296 }
1297 _ => vec![len],
1298 }
1299 }
1300
1301 pub fn generate_wgsl(&self, scalar_ty: &str) -> Option<String> {
1302 if !self.kernel.kind.is_elementwise() {
1303 return None;
1304 }
1305 if !self.kernel.supported {
1306 return None;
1307 }
1308 let output_id = self.output?;
1309 let mut exprs: HashMap<ValueId, String> = HashMap::new();
1310 for (idx, input_id) in self.inputs.iter().enumerate() {
1311 exprs.insert(*input_id, format!("input{idx}.data[i{idx}]"));
1313 }
1314
1315 let mut body = String::new();
1316 for (node_idx, op) in self.operations.iter().enumerate() {
1317 let tmp_name = format!("tmp{node_idx}");
1318 match op {
1319 FusionOp::Primitive { op, inputs, output } => {
1320 let expr = primitive_expr(*op, inputs, &exprs)?;
1321 body.push_str(&format!(" let {tmp_name}: {scalar_ty} = {expr};\n"));
1322 if let Some(out) = output {
1323 exprs.insert(*out, tmp_name.clone());
1324 }
1325 }
1326 FusionOp::Builtin {
1327 name,
1328 inputs,
1329 output,
1330 } => {
1331 let expr = builtin_expr(name, inputs, &exprs, scalar_ty)?;
1332 body.push_str(&format!(" let {tmp_name}: {scalar_ty} = {expr};\n"));
1333 if let Some(out) = output {
1334 exprs.insert(*out, tmp_name.clone());
1335 }
1336 }
1337 }
1338 }
1339
1340 let final_expr = exprs.get(&output_id)?.clone();
1341
1342 let mut shader = String::new();
1343 shader.push_str("const MAX_RANK: u32 = 128u;\n");
1344 shader.push_str("struct PackedValue { value: u32, _pad0: u32, _pad1: u32, _pad2: u32 };\n");
1345 shader.push_str("alias PackedArray = array<PackedValue, MAX_RANK>;\n\n");
1346 shader.push_str(&format!("struct Tensor {{ data: array<{scalar_ty}>, }};\n"));
1347 shader.push_str("struct Params {\n len: u32,\n offset: u32,\n rank: u32,\n _pad: u32,\n out_shape: PackedArray,\n");
1349 for idx in 0..self.inputs.len() {
1350 shader.push_str(&format!(" in{}_shape: PackedArray,\n", idx));
1351 shader.push_str(&format!(" in{}_stride: PackedArray,\n", idx));
1352 }
1353 shader.push_str("}\n\n");
1354 if scalar_ty == "f32" {
1356 shader.push_str("fn isNan(x: f32) -> bool { return x != x; }\n");
1357 shader.push_str("fn isFinite(x: f32) -> bool { return (x == x) && (abs(x) < 3.4028234663852886e38); }\n");
1358 shader.push_str("fn isInf(x: f32) -> bool { return (x == x) && !(abs(x) < 3.4028234663852886e38); }\n\n");
1359 } else {
1360 shader.push_str("fn isNan(x: f64) -> bool { return x != x; }\n");
1361 shader.push_str("fn isFinite(x: f64) -> bool { return (x == x) && (abs(x) < f64(1.7976931348623157e308)); }\n");
1362 shader.push_str("fn isInf(x: f64) -> bool { return (x == x) && !(abs(x) < f64(1.7976931348623157e308)); }\n\n");
1363 }
1364 for (idx, _) in self.inputs.iter().enumerate() {
1365 shader.push_str(&format!(
1366 "@group(0) @binding({}) var<storage, read> input{}: Tensor;\n",
1367 idx, idx
1368 ));
1369 }
1370 shader.push_str(&format!(
1371 "@group(0) @binding({}) var<storage, read_write> output: Tensor;\n",
1372 self.inputs.len()
1373 ));
1374 shader.push_str(&format!(
1375 "@group(0) @binding({}) var<uniform> params: Params;\n\n",
1376 self.inputs.len() + 1
1377 ));
1378 shader.push_str("@compute @workgroup_size(@WG@)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n");
1379 shader.push_str(" let idx = gid.x;\n if (idx >= params.len) { return; }\n");
1380 shader.push_str(" let g = idx + params.offset;\n");
1381 shader.push_str(" // Compute N-D coordinates from global index (with chunk offset)\n var coord: array<u32, MAX_RANK>;\n var tmp: u32 = g;\n var d: u32 = 0u;\n loop { if d >= params.rank { break; } let dim = params.out_shape[d].value; if dim == 0u { coord[d] = 0u; } else { coord[d] = tmp % dim; tmp = tmp / dim; } d = d + 1u; }\n");
1382 for (idx, _) in self.inputs.iter().enumerate() {
1384 shader.push_str(&format!(
1385 " var i{}: u32 = 0u; d = 0u; loop {{ if d >= params.rank {{ break; }} let sd = params.in{}_shape[d].value; let st = params.in{}_stride[d].value; let c = select(coord[d], 0u, sd == 1u); i{} = i{} + c * st; d = d + 1u; }}\n",
1386 idx, idx, idx, idx, idx
1387 ));
1388 }
1389 shader.push_str(&body);
1390 shader.push_str(&format!(" output.data[g] = {final_expr};\n}}\n"));
1391 Some(shader)
1392 }
1393
1394 pub fn generate_reduction_wgsl(&self, scalar_ty: &str) -> Option<String> {
1395 if !self.kernel.kind.is_reduction() {
1396 return None;
1397 }
1398 if self.inputs.is_empty() {
1401 return None;
1402 }
1403 let mut axis = 0usize;
1406 let reduce_all = self
1408 .constants
1409 .values()
1410 .any(|v| matches!(v, Value::String(s) if s.eq_ignore_ascii_case("all")))
1411 || self
1412 .const_values
1413 .values()
1414 .any(|v| matches!(v, Value::String(s) if s.eq_ignore_ascii_case("all")));
1415 if reduce_all {
1416 axis = 0;
1418 } else if let Some(dim_vid) = self.reduction_dim {
1419 if let Some(v) = self.const_values.get(&dim_vid) {
1420 match v {
1421 Value::Num(n) if *n >= 1.0 => {
1422 axis = (*n as usize).saturating_sub(1);
1423 }
1424 Value::Int(i) => {
1425 let val = i.to_f64();
1426 if val >= 1.0 {
1427 axis = (val as usize).saturating_sub(1);
1428 }
1429 }
1430 _ => {}
1431 }
1432 }
1433 } else {
1434 for v in self.constants.values() {
1436 match v {
1437 Value::Num(n) if *n >= 1.0 => {
1438 axis = (*n as usize).saturating_sub(1);
1439 break;
1440 }
1441 Value::Int(i) => {
1442 let val = i.to_f64();
1443 if val >= 1.0 {
1444 axis = (val as usize).saturating_sub(1);
1445 break;
1446 }
1447 }
1448 _ => {}
1449 }
1450 }
1451 }
1452
1453 let omitnan = self.constants.values().any(|v| match v {
1455 Value::String(s) => s.eq_ignore_ascii_case("omitnan"),
1456 _ => false,
1457 });
1458
1459 let data_vid = self.reduction_data?;
1461 let ext_input = self.inputs[0];
1462 let mut exprs: HashMap<ValueId, String> = HashMap::new();
1463 exprs.insert(ext_input, "v".to_string());
1464 for (idx, &vid) in self.inputs.iter().enumerate().skip(1) {
1466 exprs.insert(vid, format!("v{idx}"));
1467 }
1468 for (vid, val) in &self.const_values {
1469 let lit = match val {
1470 Value::Num(n) => {
1471 if scalar_ty == "f64" {
1472 format!("f64({})", n)
1473 } else {
1474 format!("{:?}", *n as f32)
1475 }
1476 }
1477 Value::Int(i) => {
1478 let f = i.to_f64();
1479 if scalar_ty == "f64" {
1480 format!("f64({})", f)
1481 } else {
1482 format!("{:?}", f as f32)
1483 }
1484 }
1485 Value::Tensor(t) if t.data.len() == 1 => {
1486 let scalar = t.data[0];
1487 if scalar_ty == "f64" {
1488 format!("f64({})", scalar)
1489 } else {
1490 format!("{:?}", scalar as f32)
1491 }
1492 }
1493 _ => {
1494 if scalar_ty == "f64" {
1495 "f64(0.0)".to_string()
1496 } else {
1497 "0.0".to_string()
1498 }
1499 }
1500 };
1501 exprs.insert(*vid, lit);
1502 }
1503 let mut progressed = true;
1504 while progressed {
1505 progressed = false;
1506 for op in &self.operations {
1507 match op {
1508 FusionOp::Primitive { op, inputs, output } => {
1509 if let Some(out) = output {
1510 if exprs.contains_key(out) {
1511 continue;
1512 }
1513 if let Some(code) = primitive_expr(*op, inputs, &exprs) {
1514 exprs.insert(*out, code);
1515 progressed = true;
1516 }
1517 }
1518 }
1519 FusionOp::Builtin {
1520 name,
1521 inputs,
1522 output,
1523 } => {
1524 if let Some(out) = output {
1525 if exprs.contains_key(out) {
1526 continue;
1527 }
1528 if let Some(code) = builtin_expr(name, inputs, &exprs, scalar_ty) {
1529 exprs.insert(*out, code);
1530 progressed = true;
1531 }
1532 }
1533 }
1534 }
1535 }
1536 if exprs.contains_key(&data_vid) {
1537 break;
1538 }
1539 }
1540 let val_expr = match exprs.get(&data_vid) {
1542 Some(s) => s.clone(),
1543 None => {
1544 if fusion_debug_enabled() {
1545 let expr_keys: Vec<ValueId> = exprs.keys().copied().collect();
1546 log::debug!(
1547 "fusion reduction WGSL: missing expression for data {:?}; inputs={:?} expr_keys={:?} ops={:?}",
1548 data_vid,
1549 self.inputs,
1550 expr_keys,
1551 self.operations
1552 );
1553 }
1554 return None;
1555 }
1556 };
1557
1558 let mut shader = String::new();
1559 shader.push_str(&format!("struct Tensor {{ data: array<{scalar_ty}>, }};\n"));
1560 shader.push_str("struct MParams { nrows: u32, ncols: u32, ld: u32, flags: u32 }\n\n");
1561 for (idx, _) in self.inputs.iter().enumerate() {
1563 shader.push_str(&format!(
1564 "@group(0) @binding({}) var<storage, read> input{}: Tensor;\n",
1565 idx, idx
1566 ));
1567 }
1568 shader.push_str(&format!(
1569 "@group(0) @binding({}) var<storage, read_write> output: Tensor;\n",
1570 self.inputs.len()
1571 ));
1572 shader.push_str(&format!(
1573 "@group(0) @binding({}) var<uniform> params: MParams;\n\n",
1574 self.inputs.len() + 1
1575 ));
1576 shader.push_str(&format!(
1578 "var<workgroup> tile: array<{scalar_ty}, @WG@u>;\n\n"
1579 ));
1580 shader.push_str(&format!(
1581 "const OMITNAN: bool = {};\n\n",
1582 if omitnan { "true" } else { "false" }
1583 ));
1584 let is_mean = matches!(self.reduction_flavor, Some(ReductionFlavor::Mean));
1586 let post_scale = if is_mean {
1587 let dim = if axis == 0 {
1588 "params.nrows"
1589 } else {
1590 "params.ncols"
1591 };
1592 if scalar_ty == "f64" {
1593 format!("(1.0 / f64(f32({dim})))")
1594 } else {
1595 format!("(1.0 / f32({dim}))")
1596 }
1597 } else if scalar_ty == "f64" {
1598 "f64(1.0)".to_string()
1599 } else {
1600 "1.0".to_string()
1601 };
1602 shader.push_str(&format!(
1604 "fn isNanF(x: {scalar}) -> bool {{ return x != x; }}\n\n",
1605 scalar = scalar_ty
1606 ));
1607 shader.push_str("@compute @workgroup_size(@WG@)\n");
1608 if axis == 0 {
1609 shader.push_str(
1611 "fn main(@builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {\n",
1612 );
1613 shader.push_str(" let col = wid.x;\n if (col >= params.ncols) { return; }\n");
1614 shader.push_str(&format!(
1615 " var acc: {scalar_ty} = {}0.0;\n",
1616 if scalar_ty == "f64" { "f64(" } else { "" }
1617 ));
1618 if scalar_ty == "f64" {
1619 shader.push_str(" // close cast for f64 literal\n");
1620 }
1621 shader.push_str(" var saw_nan: bool = false;\n var r = lid.x;\n");
1623 {
1625 let mut loop_body = String::new();
1627 loop_body.push_str(" let v = input0.data[ (col * params.nrows) + r ];\n");
1629 for (idx, _) in self.inputs.iter().enumerate().skip(1) {
1631 loop_body.push_str(&format!(
1632 " let v{idx} = input{idx}.data[ (col * params.nrows) + r ];\n"
1633 ));
1634 }
1635 loop_body.push_str(&format!(
1637 " let val: {scalar} = {val};\n if (OMITNAN) {{ if (!isNanF(val)) {{ acc = acc + val; }} }} else {{ if (isNanF(val)) {{ saw_nan = true; }} else {{ acc = acc + val; }} }}\n",
1638 scalar = scalar_ty,
1639 val = val_expr
1640 ));
1641 shader.push_str(" while (r < params.nrows) {\n");
1642 shader.push_str(&loop_body);
1643 shader.push_str(" r += @WG@u;\n }\n");
1644 }
1645 if scalar_ty == "f64" {
1646 shader.push_str(
1647 " if (!OMITNAN && saw_nan) { acc = bitcast<f64>(0x7ff8000000000000u); }\n",
1648 );
1649 } else {
1650 shader
1651 .push_str(" if (!OMITNAN && saw_nan) { acc = bitcast<f32>(0x7fc00000u); }\n");
1652 }
1653 shader.push_str(" tile[lid.x] = acc;\n workgroupBarrier();\n");
1654 shader.push_str(
1655 " var off = (@WG@u) / 2u;\n loop { if (off == 0u) { break; } if (lid.x < off) {\n let a = tile[lid.x]; let b = tile[lid.x + off];\n tile[lid.x] = a + b;\n } workgroupBarrier(); off = off / 2u; }\n",
1656 );
1657 shader.push_str(&format!(
1659 " if (lid.x == 0u) {{ output.data[col] = tile[0u] * {}; }}\n}}\n",
1660 post_scale
1661 ));
1662 } else {
1663 shader.push_str(
1665 "fn main(@builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {\n",
1666 );
1667 shader.push_str(" let row = wid.x;\n // For axis=1, number of output slices equals rows (params.ncols)\n if (row >= params.ncols) { return; }\n");
1668 shader.push_str(&format!(
1669 " var acc: {scalar_ty} = {}0.0;\n",
1670 if scalar_ty == "f64" { "f64(" } else { "" }
1671 ));
1672 if scalar_ty == "f64" {
1673 shader.push_str(" // close cast for f64 literal\n");
1674 }
1675 shader.push_str(" var saw_nan: bool = false;\n var c = lid.x;\n");
1677 {
1678 let mut loop_body = String::new();
1679 loop_body.push_str(" let v = input0.data[ row + (c * params.ncols) ];\n");
1681 for (idx, _) in self.inputs.iter().enumerate().skip(1) {
1683 loop_body.push_str(&format!(
1684 " let v{idx} = input{idx}.data[ row + (c * params.ncols) ];\n"
1685 ));
1686 }
1687 loop_body.push_str(&format!(
1688 " let val: {scalar} = {val};\n if (OMITNAN) {{ if (!isNanF(val)) {{ acc = acc + val; }} }} else {{ if (isNanF(val)) {{ saw_nan = true; }} else {{ acc = acc + val; }} }}\n",
1689 scalar = scalar_ty,
1690 val = val_expr
1691 ));
1692 shader.push_str(" while (c < params.nrows) {\n");
1694 shader.push_str(&loop_body);
1695 shader.push_str(" c += @WG@u;\n }\n");
1696 }
1697 if scalar_ty == "f64" {
1698 shader.push_str(
1699 " if (!OMITNAN && saw_nan) { acc = bitcast<f64>(0x7ff8000000000000u); }\n",
1700 );
1701 } else {
1702 shader
1703 .push_str(" if (!OMITNAN && saw_nan) { acc = bitcast<f32>(0x7fc00000u); }\n");
1704 }
1705 shader.push_str(" tile[lid.x] = acc;\n workgroupBarrier();\n");
1706 shader.push_str(
1707 " var off = (@WG@u) / 2u;\n loop { if (off == 0u) { break; } if (lid.x < off) {\n let a = tile[lid.x]; let b = tile[lid.x + off];\n tile[lid.x] = a + b;\n } workgroupBarrier(); off = off / 2u; }\n",
1708 );
1709 shader.push_str(&format!(
1710 " if (lid.x == 0u) {{ output.data[row] = tile[0u] * {}; }}\n}}\n",
1711 post_scale
1712 ));
1713 }
1714 Some(shader)
1715 }
1716}
1717
1718impl FusionGroup {
1719 pub fn element_count(&self) -> Option<usize> {
1720 match &self.shape {
1721 ShapeInfo::Scalar => Some(1),
1722 ShapeInfo::Tensor(dims) => dims
1723 .iter()
1724 .try_fold(1usize, |acc, dim| dim.and_then(|d| acc.checked_mul(d))),
1725 ShapeInfo::Unknown => None,
1726 }
1727 }
1728}
1729
1730impl FusionKind {
1731 pub fn is_elementwise(&self) -> bool {
1732 matches!(self, FusionKind::ElementwiseChain)
1733 }
1734
1735 pub fn is_reduction(&self) -> bool {
1736 matches!(self, FusionKind::Reduction)
1737 }
1738}
1739
1740fn detect_centered_gram(
1741 graph: &AccelGraph,
1742 assigned: &mut HashSet<NodeId>,
1743 groups: &mut Vec<FusionGroup>,
1744 next_group_id: &mut usize,
1745) {
1746 for div_node in &graph.nodes {
1747 if assigned.contains(&div_node.id) {
1748 continue;
1749 }
1750 let div_op = match div_node.label {
1751 AccelNodeLabel::Primitive(op) => op,
1752 _ => continue,
1753 };
1754 if div_op != PrimitiveOp::Div && div_op != PrimitiveOp::ElemDiv {
1755 continue;
1756 }
1757 if div_node.inputs.len() != 2 {
1758 continue;
1759 }
1760 let (numerator_id, denom_id) = (div_node.inputs[0], div_node.inputs[1]);
1761 let denom_info = match graph.value(denom_id) {
1762 Some(info) => info,
1763 None => continue,
1764 };
1765 let denom_const = match &denom_info.constant {
1766 Some(Value::Num(v)) => Some(*v),
1767 Some(Value::Int(i)) => Some(i.to_f64()),
1768 _ => None,
1769 };
1770 if denom_const.is_some_and(|v| v == 0.0) {
1771 continue;
1772 }
1773
1774 let mul_node_id = match graph
1775 .value(numerator_id)
1776 .and_then(|info| match &info.origin {
1777 ValueOrigin::NodeOutput { node, .. } => Some(*node),
1778 _ => None,
1779 }) {
1780 Some(id) => id,
1781 None => continue,
1782 };
1783 if assigned.contains(&mul_node_id) {
1784 continue;
1785 }
1786 let mul_node = match graph.node(mul_node_id) {
1787 Some(node) => node,
1788 None => continue,
1789 };
1790 let mul_op = match mul_node.label {
1791 AccelNodeLabel::Primitive(op) => op,
1792 _ => continue,
1793 };
1794 if mul_op != PrimitiveOp::Mul && mul_op != PrimitiveOp::ElemMul {
1795 continue;
1796 }
1797 if mul_node.inputs.len() != 2 {
1798 continue;
1799 }
1800
1801 let mut transpose_node_id: Option<NodeId> = None;
1802 let mut centered_val_id: Option<ValueId> = None;
1803 for input_vid in &mul_node.inputs {
1804 let candidate_node_id =
1805 match graph.value(*input_vid).and_then(|info| match &info.origin {
1806 ValueOrigin::NodeOutput { node, .. } => Some(*node),
1807 _ => None,
1808 }) {
1809 Some(id) => id,
1810 None => continue,
1811 };
1812 if let Some(trans_node) = graph.node(candidate_node_id) {
1813 if matches!(
1814 trans_node.label,
1815 AccelNodeLabel::Primitive(PrimitiveOp::Transpose)
1816 ) {
1817 if let Some(centered) = trans_node.inputs.first().copied() {
1818 transpose_node_id = Some(candidate_node_id);
1819 centered_val_id = Some(centered);
1820 break;
1821 }
1822 }
1823 }
1824 }
1825
1826 let transpose_node_id = match transpose_node_id {
1827 Some(id) if !assigned.contains(&id) => id,
1828 _ => continue,
1829 };
1830 let centered_val_id = match centered_val_id {
1831 Some(id) => id,
1832 None => continue,
1833 };
1834
1835 if assigned.contains(&transpose_node_id) {
1836 continue;
1837 }
1838 if graph.node(transpose_node_id).is_none() {
1839 continue;
1840 }
1841
1842 let centered_node_id =
1843 match graph
1844 .value(centered_val_id)
1845 .and_then(|info| match &info.origin {
1846 ValueOrigin::NodeOutput { node, .. } => Some(*node),
1847 _ => None,
1848 }) {
1849 Some(id) => id,
1850 None => continue,
1851 };
1852 if assigned.contains(¢ered_node_id) {
1853 continue;
1854 }
1855 let centered_node = match graph.node(centered_node_id) {
1856 Some(node) => node,
1857 None => continue,
1858 };
1859 if !matches!(
1860 centered_node.label,
1861 AccelNodeLabel::Primitive(PrimitiveOp::Sub)
1862 ) {
1863 continue;
1864 }
1865 if centered_node.inputs.len() != 2 {
1866 continue;
1867 }
1868 let matrix_val_id = centered_node.inputs[0];
1869 let mean_val_id = centered_node.inputs[1];
1870
1871 let mean_node_id = match graph
1872 .value(mean_val_id)
1873 .and_then(|info| match &info.origin {
1874 ValueOrigin::NodeOutput { node, .. } => Some(*node),
1875 _ => None,
1876 }) {
1877 Some(id) => id,
1878 None => continue,
1879 };
1880 if assigned.contains(&mean_node_id) {
1881 continue;
1882 }
1883 let mean_node = match graph.node(mean_node_id) {
1884 Some(node) => node,
1885 None => continue,
1886 };
1887 match &mean_node.label {
1888 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mean") => {}
1889 _ => continue,
1890 }
1891 if mean_node.inputs.is_empty() || mean_node.inputs[0] != matrix_val_id {
1892 continue;
1893 }
1894
1895 let matrix_info = match graph.value(matrix_val_id) {
1896 Some(info) => info,
1897 None => continue,
1898 };
1899 let matrix_rows = match &matrix_info.shape {
1900 ShapeInfo::Tensor(dims) if !dims.is_empty() => dims[0].unwrap_or(0),
1901 _ => 0,
1902 };
1903 let normalization = if matrix_rows > 1 {
1904 if let Some(value) = denom_const {
1905 let unbiased = (matrix_rows as f64 - 1.0).max(1.0);
1906 let biased = matrix_rows as f64;
1907 if approx_eq(value, unbiased) {
1908 CovNormalization::Unbiased
1909 } else if approx_eq(value, biased) {
1910 CovNormalization::Biased
1911 } else {
1912 CovNormalization::Unbiased
1913 }
1914 } else {
1915 CovNormalization::Unbiased
1916 }
1917 } else {
1918 CovNormalization::Unbiased
1919 };
1920
1921 let mut nodes = vec![
1922 mean_node_id,
1923 centered_node_id,
1924 transpose_node_id,
1925 mul_node_id,
1926 div_node.id,
1927 ];
1928 nodes.sort_by_key(|node_id| {
1929 graph
1930 .node(*node_id)
1931 .map(|node| node.span.start)
1932 .unwrap_or(usize::MAX)
1933 });
1934 let span = group_span(graph, &nodes);
1935 let shape = node_output_shape(graph, div_node);
1936
1937 groups.push(FusionGroup {
1938 id: *next_group_id,
1939 kind: FusionKind::CenteredGram,
1940 nodes: nodes.clone(),
1941 shape,
1942 span,
1943 pattern: Some(FusionPattern::CenteredGram {
1944 matrix: matrix_val_id,
1945 normalization,
1946 }),
1947 });
1948 *next_group_id += 1;
1949 for id in nodes {
1950 assigned.insert(id);
1951 }
1952 }
1953}
1954
1955fn detect_image_normalize(
1956 graph: &AccelGraph,
1957 assigned: &mut HashSet<NodeId>,
1958 groups: &mut Vec<FusionGroup>,
1959 next_group_id: &mut usize,
1960) {
1961 for pow_node in &graph.nodes {
1962 if assigned.contains(&pow_node.id) {
1963 continue;
1964 }
1965 let Some(match_info) = analyze_image_normalize(graph, pow_node.id, assigned) else {
1966 continue;
1967 };
1968
1969 let pow_node_ref = match graph.node(pow_node.id) {
1970 Some(node) => node,
1971 None => continue,
1972 };
1973
1974 let shape = node_output_shape(graph, pow_node_ref);
1975 let span = group_span(graph, &match_info.nodes);
1976
1977 let pattern = ImageNormalizePattern {
1978 input: match_info.input,
1979 epsilon: match_info.epsilon.clone(),
1980 gain: match_info.gain.clone(),
1981 bias: match_info.bias.clone(),
1982 gamma: match_info.gamma.clone(),
1983 };
1984
1985 groups.push(FusionGroup {
1986 id: *next_group_id,
1987 kind: FusionKind::ImageNormalize,
1988 nodes: match_info.nodes.clone(),
1989 shape,
1990 span: span.clone(),
1991 pattern: Some(FusionPattern::ImageNormalize(pattern)),
1992 });
1993 if fusion_debug_enabled() {
1994 log::debug!(
1995 "fusion: detected image normalize group id={} span={:?} nodes={:?}",
1996 next_group_id,
1997 span,
1998 match_info.nodes
1999 );
2000 }
2001 *next_group_id += 1;
2002 for node_id in match_info.nodes {
2003 assigned.insert(node_id);
2004 }
2005 }
2006}
2007
2008fn approx_eq(a: f64, b: f64) -> bool {
2009 let scale = a.abs().max(b.abs()).max(1.0);
2010 (a - b).abs() <= scale * 1e-6
2011}
2012
2013fn detect_power_step_normalize(
2014 graph: &AccelGraph,
2015 assigned: &mut HashSet<NodeId>,
2016 groups: &mut Vec<FusionGroup>,
2017 next_group_id: &mut usize,
2018) {
2019 'outer: for div_node in &graph.nodes {
2020 if assigned.contains(&div_node.id) {
2021 continue;
2022 }
2023 let div_op = match div_node.label {
2024 AccelNodeLabel::Primitive(op) => op,
2025 _ => continue,
2026 };
2027 if div_op != PrimitiveOp::Div && div_op != PrimitiveOp::ElemDiv {
2028 continue;
2029 }
2030 if div_node.inputs.len() != 2 {
2031 continue;
2032 }
2033 let numerator_vid = div_node.inputs[0];
2034 let denom_vid = div_node.inputs[1];
2035
2036 let (matmul_id, matmul_node) = match node_from_value(graph, numerator_vid) {
2037 Some((id, node)) => (id, node),
2038 None => continue,
2039 };
2040 if assigned.contains(&matmul_id) {
2041 continue;
2042 }
2043 match &matmul_node.label {
2044 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mtimes") => {}
2045 _ => continue,
2046 }
2047 if matmul_node.inputs.len() != 2 {
2048 continue;
2049 }
2050
2051 let Some(denom_info) = analyze_power_step_denominator(graph, denom_vid, numerator_vid)
2052 else {
2053 continue;
2054 };
2055 if assigned.contains(&denom_info.sqrt_node) {
2056 continue;
2057 }
2058 if assigned.contains(&denom_info.sum_node) {
2059 continue;
2060 }
2061 if assigned.contains(&denom_info.pow_node) {
2062 continue;
2063 }
2064 if let Some(add_id) = denom_info.add_node {
2065 if assigned.contains(&add_id) {
2066 continue;
2067 }
2068 }
2069 if denom_info.pow_input != numerator_vid {
2070 continue;
2071 }
2072
2073 let mut nodes = vec![matmul_id, denom_info.pow_node, denom_info.sum_node];
2074 if let Some(add_id) = denom_info.add_node {
2075 nodes.push(add_id);
2076 }
2077 nodes.push(denom_info.sqrt_node);
2078 nodes.push(div_node.id);
2079
2080 for node_id in &nodes {
2081 if assigned.contains(node_id) {
2082 continue 'outer;
2083 }
2084 }
2085
2086 nodes.sort_by_key(|node_id| {
2087 graph
2088 .node(*node_id)
2089 .map(|node| node.span.start)
2090 .unwrap_or(usize::MAX)
2091 });
2092
2093 let span = group_span(graph, &nodes);
2094 let shape = node_output_shape(graph, div_node);
2095
2096 groups.push(FusionGroup {
2097 id: *next_group_id,
2098 kind: FusionKind::PowerStepNormalize,
2099 nodes: nodes.clone(),
2100 shape,
2101 span,
2102 pattern: Some(FusionPattern::PowerStepNormalize {
2103 lhs: matmul_node.inputs[0],
2104 rhs: matmul_node.inputs[1],
2105 epsilon: denom_info.epsilon,
2106 }),
2107 });
2108 *next_group_id += 1;
2109 for id in nodes {
2110 assigned.insert(id);
2111 }
2112 }
2113}
2114
2115fn detect_explained_variance(
2116 graph: &AccelGraph,
2117 assigned: &mut HashSet<NodeId>,
2118 groups: &mut Vec<FusionGroup>,
2119 next_group_id: &mut usize,
2120) {
2121 for diag_node in &graph.nodes {
2122 if assigned.contains(&diag_node.id) {
2123 continue;
2124 }
2125 match &diag_node.label {
2126 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("diag") => {}
2127 _ => continue,
2128 }
2129 if diag_node.inputs.len() != 1 {
2130 continue;
2131 }
2132 let matmul2_vid = diag_node.inputs[0];
2133 let (matmul2_id, matmul2_node) = match node_from_value(graph, matmul2_vid) {
2134 Some(pair) => pair,
2135 None => continue,
2136 };
2137 if assigned.contains(&matmul2_id) {
2138 continue;
2139 }
2140 match &matmul2_node.label {
2141 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mtimes") => {}
2142 _ => continue,
2143 }
2144 if matmul2_node.inputs.len() != 2 {
2145 continue;
2146 }
2147
2148 let (matmul1_id, matmul1_node, q_vid) = if let Some((mm_id, mm_node)) =
2149 node_from_value(graph, matmul2_node.inputs[0])
2150 {
2151 if matches!(mm_node.label, AccelNodeLabel::Builtin { ref name } if name.eq_ignore_ascii_case("mtimes"))
2152 {
2153 (mm_id, mm_node, matmul2_node.inputs[1])
2154 } else {
2155 continue;
2156 }
2157 } else if let Some((mm_id, mm_node)) = node_from_value(graph, matmul2_node.inputs[1]) {
2158 if matches!(mm_node.label, AccelNodeLabel::Builtin { ref name } if name.eq_ignore_ascii_case("mtimes"))
2159 {
2160 (mm_id, mm_node, matmul2_node.inputs[0])
2161 } else {
2162 continue;
2163 }
2164 } else {
2165 continue;
2166 };
2167
2168 if assigned.contains(&matmul1_id) {
2169 continue;
2170 }
2171
2172 if matmul1_node.inputs.len() != 2 {
2173 continue;
2174 }
2175
2176 let (transpose_id, transpose_input_vid, g_vid) =
2177 if let Some((t_id, src_vid)) = is_transpose_node(graph, matmul1_node.inputs[0]) {
2178 (t_id, src_vid, matmul1_node.inputs[1])
2179 } else if let Some((t_id, src_vid)) = is_transpose_node(graph, matmul1_node.inputs[1]) {
2180 (t_id, src_vid, matmul1_node.inputs[0])
2181 } else {
2182 continue;
2183 };
2184
2185 if assigned.contains(&transpose_id) {
2186 continue;
2187 }
2188
2189 if transpose_input_vid != q_vid {
2190 continue;
2191 }
2192
2193 let mut nodes = vec![diag_node.id, matmul2_id, matmul1_id, transpose_id];
2194 nodes.sort_by_key(|node_id| {
2195 graph
2196 .node(*node_id)
2197 .map(|node| node.span.start)
2198 .unwrap_or(usize::MAX)
2199 });
2200 let span = group_span(graph, &nodes);
2201 let shape = node_output_shape(graph, diag_node);
2202 groups.push(FusionGroup {
2203 id: *next_group_id,
2204 kind: FusionKind::ExplainedVariance,
2205 nodes: nodes.clone(),
2206 shape,
2207 span,
2208 pattern: Some(FusionPattern::ExplainedVariance { q: q_vid, g: g_vid }),
2209 });
2210 *next_group_id += 1;
2211 for id in nodes {
2212 assigned.insert(id);
2213 }
2214 }
2215}
2216
2217struct PowerStepDenominatorInfo {
2218 sqrt_node: NodeId,
2219 add_node: Option<NodeId>,
2220 sum_node: NodeId,
2221 pow_node: NodeId,
2222 pow_input: ValueId,
2223 epsilon: f64,
2224}
2225
2226fn analyze_power_step_denominator(
2227 graph: &AccelGraph,
2228 denom_vid: ValueId,
2229 expected_source_vid: ValueId,
2230) -> Option<PowerStepDenominatorInfo> {
2231 let (sqrt_node_id, sqrt_input_vid, add_node_opt, epsilon_from_outer) =
2232 if let Some((sqrt_id, sqrt_in)) = is_sqrt_node(graph, denom_vid) {
2233 if let Some((add_node, sum_vid, epsilon_inner)) =
2234 extract_add_with_constant(graph, sqrt_in)
2235 {
2236 (sqrt_id, sum_vid, Some(add_node), epsilon_inner)
2237 } else {
2238 (sqrt_id, sqrt_in, None, 0.0)
2239 }
2240 } else if let Some((add_node, other_vid, epsilon_inner)) =
2241 extract_add_with_constant(graph, denom_vid)
2242 {
2243 let (sqrt_id, sqrt_in) = is_sqrt_node(graph, other_vid)?;
2244 (sqrt_id, sqrt_in, Some(add_node), epsilon_inner)
2245 } else {
2246 return None;
2247 };
2248
2249 let (sum_node_id, sum_node) = node_from_value(graph, sqrt_input_vid)?;
2250 match &sum_node.label {
2251 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("sum") => {}
2252 _ => return None,
2253 }
2254 if sum_node.inputs.is_empty() {
2255 return None;
2256 }
2257 let pow_vid = sum_node.inputs[0];
2258 let (pow_node_id, pow_node) = node_from_value(graph, pow_vid)?;
2259 let pow_input = match pow_node.label {
2260 AccelNodeLabel::Primitive(PrimitiveOp::ElemPow) => {
2261 if pow_node.inputs.len() != 2 {
2262 return None;
2263 }
2264 let base = pow_node.inputs[0];
2265 let exponent_vid = pow_node.inputs[1];
2266 let exponent = value_constant_f64(graph, exponent_vid)?;
2267 if !approx_eq(exponent, 2.0) {
2268 return None;
2269 }
2270 base
2271 }
2272 _ => return None,
2273 };
2274
2275 if pow_input != expected_source_vid {
2276 return None;
2277 }
2278
2279 let epsilon = epsilon_from_outer;
2280 Some(PowerStepDenominatorInfo {
2281 sqrt_node: sqrt_node_id,
2282 add_node: add_node_opt,
2283 sum_node: sum_node_id,
2284 pow_node: pow_node_id,
2285 pow_input,
2286 epsilon,
2287 })
2288}
2289
2290fn node_from_value(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, &AccelNode)> {
2291 let info = graph.value(vid)?;
2292 match info.origin {
2293 ValueOrigin::NodeOutput { node, .. } => graph.node(node).map(|n| (node, n)),
2294 _ => None,
2295 }
2296}
2297
2298fn is_sqrt_node(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId)> {
2299 let (node_id, node) = node_from_value(graph, vid)?;
2300 match &node.label {
2301 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("sqrt") => {
2302 let input = node.inputs.first().copied()?;
2303 Some((node_id, input))
2304 }
2305 _ => None,
2306 }
2307}
2308
2309fn is_transpose_node(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId)> {
2310 let (node_id, node) = node_from_value(graph, vid)?;
2311 match &node.label {
2312 AccelNodeLabel::Primitive(PrimitiveOp::Transpose) => {
2313 let input = node.inputs.first().copied()?;
2314 Some((node_id, input))
2315 }
2316 _ => None,
2317 }
2318}
2319
2320fn extract_add_with_constant(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId, f64)> {
2321 let (node_id, node) = node_from_value(graph, vid)?;
2322 match node.label {
2323 AccelNodeLabel::Primitive(PrimitiveOp::Add) => {
2324 if node.inputs.len() != 2 {
2325 return None;
2326 }
2327 let lhs = node.inputs[0];
2328 let rhs = node.inputs[1];
2329 if let Some(eps) = value_constant_f64(graph, rhs) {
2330 return Some((node_id, lhs, eps));
2331 }
2332 if let Some(eps) = value_constant_f64(graph, lhs) {
2333 return Some((node_id, rhs, eps));
2334 }
2335 None
2336 }
2337 AccelNodeLabel::Primitive(PrimitiveOp::Sub) => {
2338 if node.inputs.len() != 2 {
2339 return None;
2340 }
2341 let lhs = node.inputs[0];
2342 let rhs = node.inputs[1];
2343 if let Some(eps) = value_constant_f64(graph, rhs) {
2344 return Some((node_id, lhs, -eps));
2345 }
2346 if let Some(eps) = value_constant_f64(graph, lhs) {
2347 return Some((node_id, rhs, eps));
2348 }
2349 None
2350 }
2351 _ => None,
2352 }
2353}
2354
2355struct ConstantTrace {
2356 value: f64,
2357 nodes: Vec<NodeId>,
2358}
2359
2360fn collect_scalar_constant(graph: &AccelGraph, vid: ValueId) -> Option<ConstantTrace> {
2361 let mut current = vid;
2362 let mut nodes: Vec<NodeId> = Vec::new();
2363 let mut sign = 1.0f64;
2364 let mut visited: HashSet<NodeId> = HashSet::new();
2365
2366 loop {
2367 let info = graph.value(current)?;
2368 match &info.origin {
2369 ValueOrigin::Constant => {
2370 let base = value_info_scalar(info)?;
2371 return Some(ConstantTrace {
2372 value: sign * base,
2373 nodes,
2374 });
2375 }
2376 ValueOrigin::NodeOutput { node, .. } => {
2377 if !visited.insert(*node) {
2378 return None;
2379 }
2380 let node_ref = graph.node(*node)?;
2381 match &node_ref.label {
2382 AccelNodeLabel::Builtin { name }
2383 if name.eq_ignore_ascii_case("single")
2384 || name.eq_ignore_ascii_case("double")
2385 || name.eq_ignore_ascii_case("gpuarray") =>
2386 {
2387 if node_ref.inputs.len() != 1 {
2388 return None;
2389 }
2390 nodes.push(*node);
2391 current = node_ref.inputs[0];
2392 }
2393 AccelNodeLabel::Primitive(PrimitiveOp::Neg) => {
2394 if node_ref.inputs.len() != 1 {
2395 return None;
2396 }
2397 nodes.push(*node);
2398 sign = -sign;
2399 current = node_ref.inputs[0];
2400 }
2401 AccelNodeLabel::Primitive(PrimitiveOp::UPlus) => {
2402 if node_ref.inputs.len() != 1 {
2403 return None;
2404 }
2405 nodes.push(*node);
2406 current = node_ref.inputs[0];
2407 }
2408 _ => return None,
2409 }
2410 }
2411 _ => return None,
2412 }
2413 }
2414}
2415
2416fn scalar_shape_known_one(shape: &ShapeInfo) -> bool {
2417 match shape {
2418 ShapeInfo::Scalar => true,
2419 ShapeInfo::Tensor(dims) => {
2420 if dims.is_empty() {
2421 return true;
2422 }
2423 dims.iter().all(|dim| matches!(dim, Some(1)))
2424 }
2425 ShapeInfo::Unknown => false,
2426 }
2427}
2428
2429fn capture_image_scalar(
2430 graph: &AccelGraph,
2431 vid: ValueId,
2432 assigned: &HashSet<NodeId>,
2433 _nodes: &mut Vec<NodeId>,
2434) -> Option<ImageScalar> {
2435 if let Some(trace) = collect_scalar_constant(graph, vid) {
2436 if trace.nodes.iter().any(|id| assigned.contains(id)) {
2437 return None;
2438 }
2439 return Some(ImageScalar::Constant(trace.value));
2440 }
2441 let info = graph.value(vid)?;
2442 if scalar_shape_known_one(&info.shape) {
2443 return Some(ImageScalar::Value(vid));
2444 }
2445 if log::log_enabled!(log::Level::Debug) {
2446 log::debug!(
2447 "capture_image_scalar: reject vid={vid:?} shape={:?} origin={:?}",
2448 info.shape,
2449 info.origin
2450 );
2451 }
2452 None
2453}
2454
2455fn peel_numeric_casts(
2456 graph: &AccelGraph,
2457 mut vid: ValueId,
2458 assigned: &HashSet<NodeId>,
2459 _nodes: &mut Vec<NodeId>,
2460) -> Option<ValueId> {
2461 loop {
2462 let info = graph.value(vid)?;
2463 match &info.origin {
2464 ValueOrigin::NodeOutput { node, .. } => {
2465 if assigned.contains(node) {
2466 return None;
2467 }
2468 let node_ref = graph.node(*node)?;
2469 if let AccelNodeLabel::Builtin { name } = &node_ref.label {
2470 if name.eq_ignore_ascii_case("single")
2471 || name.eq_ignore_ascii_case("double")
2472 || name.eq_ignore_ascii_case("gpuarray")
2473 {
2474 if node_ref.inputs.len() != 1 {
2475 return None;
2476 }
2477 vid = node_ref.inputs[0];
2478 continue;
2479 }
2480 }
2481 return Some(vid);
2482 }
2483 _ => return Some(vid),
2484 }
2485 }
2486}
2487
2488fn resolve_scalar_constant(graph: &AccelGraph, vid: ValueId) -> Option<f64> {
2489 collect_scalar_constant(graph, vid).map(|trace| trace.value)
2490}
2491
2492fn value_info_scalar(info: &ValueInfo) -> Option<f64> {
2493 match &info.constant {
2494 Some(Value::Num(v)) => Some(*v),
2495 Some(Value::Int(i)) => Some(i.to_f64()),
2496 Some(Value::Tensor(t)) if t.data.len() == 1 => Some(t.data[0]),
2497 Some(Value::LogicalArray(arr)) if arr.data.len() == 1 => Some(arr.data[0] as f64),
2498 Some(Value::Bool(flag)) => Some(if *flag { 1.0 } else { 0.0 }),
2499 _ => None,
2500 }
2501}
2502
2503fn value_constant_f64(graph: &AccelGraph, vid: ValueId) -> Option<f64> {
2504 resolve_scalar_constant(graph, vid)
2505}
2506
2507fn primitive_expr(
2508 op: PrimitiveOp,
2509 inputs: &[ValueId],
2510 exprs: &HashMap<ValueId, String>,
2511) -> Option<String> {
2512 let binary = |exprs: &HashMap<ValueId, String>| -> Option<(String, String)> {
2513 let lhs = exprs.get(inputs.first()?).cloned()?;
2514 let rhs = exprs.get(inputs.get(1)?).cloned()?;
2515 Some((lhs, rhs))
2516 };
2517 match op {
2518 PrimitiveOp::Add => {
2519 let (lhs, rhs) = binary(exprs)?;
2520 Some(format!("({lhs} + {rhs})"))
2521 }
2522 PrimitiveOp::Sub => {
2523 let (lhs, rhs) = binary(exprs)?;
2524 Some(format!("({lhs} - {rhs})"))
2525 }
2526 PrimitiveOp::Mul | PrimitiveOp::ElemMul => {
2527 let (lhs, rhs) = binary(exprs)?;
2528 Some(format!("({lhs} * {rhs})"))
2529 }
2530 PrimitiveOp::Div | PrimitiveOp::ElemDiv | PrimitiveOp::ElemLeftDiv => {
2531 let (lhs, rhs) = binary(exprs)?;
2532 Some(format!("({lhs} / {rhs})"))
2533 }
2534 PrimitiveOp::Pow | PrimitiveOp::ElemPow => {
2535 let (lhs, rhs) = binary(exprs)?;
2536 Some(format!("pow({lhs}, {rhs})"))
2537 }
2538 PrimitiveOp::Neg => {
2539 let arg = exprs.get(inputs.first()?).cloned()?;
2540 Some(format!("(-{arg})"))
2541 }
2542 PrimitiveOp::UPlus => {
2543 let arg = exprs.get(inputs.first()?).cloned()?;
2544 Some(format!("(+{arg})"))
2545 }
2546 _ => None,
2547 }
2548}
2549
2550fn builtin_expr(
2551 name: &str,
2552 inputs: &[ValueId],
2553 exprs: &HashMap<ValueId, String>,
2554 scalar_ty: &str,
2555) -> Option<String> {
2556 let func = match name.to_ascii_lowercase().as_str() {
2557 "isfinite" => return builtin_unary_call("isFinite", inputs, exprs),
2558 "isinf" => return builtin_unary_call("isInf", inputs, exprs),
2559 "isnan" => return builtin_unary_call("isNan", inputs, exprs),
2560 "single" | "double" | "gpuarray" => return builtin_identity(inputs, exprs),
2561 "sin" => "sin",
2562 "cos" => "cos",
2563 "tan" => "tan",
2564 "asin" => "asin",
2565 "acos" => "acos",
2566 "atan" => "atan",
2567 "atan2" => return builtin_binary("atan2", inputs, exprs),
2568 "sinh" => "sinh",
2569 "cosh" => "cosh",
2570 "tanh" => "tanh",
2571 "exp" => "exp",
2572 "log" => "log",
2573 "log2" => "log2",
2574 "sqrt" => "sqrt",
2575 "abs" => "abs",
2576 "exp2" => "exp2",
2577 "floor" => "floor",
2578 "ceil" => "ceil",
2579 "round" => "round",
2580 "trunc" => "trunc",
2581 "max" => return builtin_binary("max", inputs, exprs),
2582 "min" => return builtin_binary("min", inputs, exprs),
2583 _ => {
2584 return match name.to_ascii_lowercase().as_str() {
2585 "log10" => {
2586 let arg = exprs.get(inputs.first()?).cloned()?;
2587 let constant = cast_literal(scalar_ty, "0.4342944819032518");
2588 Some(format!("(log({arg}) * {constant})"))
2589 }
2590 "log1p" => {
2591 let arg = exprs.get(inputs.first()?).cloned()?;
2592 let one = cast_literal(scalar_ty, "1.0");
2593 Some(format!("log({arg} + {one})"))
2594 }
2595 "expm1" => {
2596 let arg = exprs.get(inputs.first()?).cloned()?;
2597 let one = cast_literal(scalar_ty, "1.0");
2598 Some(format!("(exp({arg}) - {one})"))
2599 }
2600 _ => None,
2601 }
2602 }
2603 };
2604 let arg = exprs.get(inputs.first()?).cloned()?;
2605 Some(format!("{func}({arg})"))
2606}
2607
2608fn builtin_binary(
2609 func: &str,
2610 inputs: &[ValueId],
2611 exprs: &HashMap<ValueId, String>,
2612) -> Option<String> {
2613 let lhs = exprs.get(inputs.first()?).cloned()?;
2614 let rhs = exprs.get(inputs.get(1)?).cloned()?;
2615 Some(format!("{func}({lhs}, {rhs})"))
2616}
2617
2618fn builtin_unary_call(
2619 func: &str,
2620 inputs: &[ValueId],
2621 exprs: &HashMap<ValueId, String>,
2622) -> Option<String> {
2623 let arg = exprs.get(inputs.first()?).cloned()?;
2624 Some(format!("{func}({arg})"))
2625}
2626
2627fn builtin_identity(inputs: &[ValueId], exprs: &HashMap<ValueId, String>) -> Option<String> {
2628 exprs.get(inputs.first()?).cloned()
2629}
2630
2631fn cast_literal(scalar_ty: &str, literal: &str) -> String {
2632 if scalar_ty == "f64" {
2633 format!("{scalar_ty}({literal})")
2634 } else {
2635 literal.to_string()
2636 }
2637}
2638
2639fn split_add_with_scalar(
2640 graph: &AccelGraph,
2641 vid: ValueId,
2642 assigned: &HashSet<NodeId>,
2643 nodes: &mut Vec<NodeId>,
2644) -> Option<(NodeId, ValueId, ImageScalar)> {
2645 let (node_id, node) = node_from_value(graph, vid)?;
2646 match node.label {
2647 AccelNodeLabel::Primitive(PrimitiveOp::Add) => {
2648 if node.inputs.len() != 2 {
2649 return None;
2650 }
2651 let lhs = node.inputs[0];
2652 let rhs = node.inputs[1];
2653 if let Some(scalar) = capture_image_scalar(graph, rhs, assigned, nodes) {
2654 return Some((node_id, lhs, scalar));
2655 }
2656 if let Some(scalar) = capture_image_scalar(graph, lhs, assigned, nodes) {
2657 return Some((node_id, rhs, scalar));
2658 }
2659 None
2660 }
2661 AccelNodeLabel::Primitive(PrimitiveOp::Sub) => {
2662 if node.inputs.len() != 2 {
2663 return None;
2664 }
2665 let lhs = node.inputs[0];
2666 let rhs = node.inputs[1];
2667 if let Some(ImageScalar::Constant(value)) =
2668 capture_image_scalar(graph, rhs, assigned, nodes)
2669 {
2670 return Some((node_id, lhs, ImageScalar::Constant(-value)));
2671 }
2672 None
2673 }
2674 _ => None,
2675 }
2676}
2677
2678fn split_mul_with_scalar(
2679 graph: &AccelGraph,
2680 vid: ValueId,
2681 assigned: &HashSet<NodeId>,
2682 nodes: &mut Vec<NodeId>,
2683) -> Option<(NodeId, ValueId, ImageScalar)> {
2684 let (node_id, node) = node_from_value(graph, vid)?;
2685 match node.label {
2686 AccelNodeLabel::Primitive(PrimitiveOp::Mul)
2687 | AccelNodeLabel::Primitive(PrimitiveOp::ElemMul) => {
2688 if node.inputs.len() != 2 {
2689 return None;
2690 }
2691 let lhs = node.inputs[0];
2692 let rhs = node.inputs[1];
2693 if let Some(scalar) = capture_image_scalar(graph, rhs, assigned, nodes) {
2694 return Some((node_id, lhs, scalar));
2695 }
2696 if let Some(scalar) = capture_image_scalar(graph, lhs, assigned, nodes) {
2697 return Some((node_id, rhs, scalar));
2698 }
2699 None
2700 }
2701 _ => None,
2702 }
2703}
2704
2705fn split_max_with_zero_scalar(
2706 graph: &AccelGraph,
2707 vid: ValueId,
2708 assigned: &HashSet<NodeId>,
2709 nodes: &mut Vec<NodeId>,
2710) -> Option<(NodeId, ValueId)> {
2711 let (node_id, node) = node_from_value(graph, vid)?;
2712 match &node.label {
2713 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("max") => {
2714 if node.inputs.len() != 2 {
2715 if log::log_enabled!(log::Level::Debug) {
2716 log::debug!(
2717 "split_max_with_zero_scalar: node {node_id:?} has {} inputs",
2718 node.inputs.len()
2719 );
2720 }
2721 return None;
2722 }
2723 let lhs = node.inputs[0];
2724 let rhs = node.inputs[1];
2725 if let Some(ImageScalar::Constant(value)) =
2726 capture_image_scalar(graph, rhs, assigned, nodes)
2727 {
2728 if approx_eq(value, 0.0) {
2729 if log::log_enabled!(log::Level::Debug) {
2730 log::debug!(
2731 "split_max_with_zero_scalar: rhs zero constant for node {node_id:?}"
2732 );
2733 }
2734 return Some((node_id, lhs));
2735 }
2736 }
2737 if let Some(ImageScalar::Constant(value)) =
2738 capture_image_scalar(graph, lhs, assigned, nodes)
2739 {
2740 if approx_eq(value, 0.0) {
2741 if log::log_enabled!(log::Level::Debug) {
2742 log::debug!(
2743 "split_max_with_zero_scalar: lhs zero constant for node {node_id:?}"
2744 );
2745 }
2746 return Some((node_id, rhs));
2747 }
2748 }
2749 if log::log_enabled!(log::Level::Debug) {
2750 log::debug!(
2751 "split_max_with_zero_scalar: node {node_id:?} inputs not zero constants"
2752 );
2753 }
2754 None
2755 }
2756 _ => None,
2757 }
2758}
2759
2760fn resolve_numeric_vector_constant(graph: &AccelGraph, vid: ValueId) -> Option<Vec<f64>> {
2761 if let Some(scalar) = resolve_scalar_constant(graph, vid) {
2762 return Some(vec![scalar]);
2763 }
2764 let info = graph.value(vid)?;
2765 match &info.constant {
2766 Some(Value::Tensor(tensor)) if !tensor.data.is_empty() => Some(tensor.data.clone()),
2767 Some(Value::LogicalArray(arr)) if !arr.data.is_empty() => Some(
2768 arr.data
2769 .iter()
2770 .map(|v| if *v == 0 { 0.0 } else { 1.0 })
2771 .collect(),
2772 ),
2773 Some(Value::Bool(flag)) => Some(vec![if *flag { 1.0 } else { 0.0 }]),
2774 Some(Value::Int(iv)) => Some(vec![iv.to_f64()]),
2775 Some(Value::Num(num)) => Some(vec![*num]),
2776 _ => None,
2777 }
2778}
2779
2780fn match_mean_axes(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId, Vec<f64>)> {
2781 let (node_id, node) = node_from_value(graph, vid)?;
2782 match &node.label {
2783 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mean") => {}
2784 _ => return None,
2785 }
2786 if node.inputs.len() < 2 {
2787 return None;
2788 }
2789 let data_vid = node.inputs[0];
2790 let dims_vid = node.inputs[1];
2791 let dims = resolve_numeric_vector_constant(graph, dims_vid)?;
2792 Some((node_id, data_vid, dims))
2793}
2794
2795fn dims_match_unordered(found: &[f64], expected: &[f64]) -> bool {
2796 if found.len() != expected.len() {
2797 return false;
2798 }
2799 let mut a: Vec<i64> = found.iter().map(|d| d.round() as i64).collect();
2800 let mut b: Vec<i64> = expected.iter().map(|d| d.round() as i64).collect();
2801 a.sort_unstable();
2802 b.sort_unstable();
2803 a == b
2804}
2805
2806fn peel_mean_dims(
2807 graph: &AccelGraph,
2808 vid: ValueId,
2809 expected_dims: &[f64],
2810 assigned: &HashSet<NodeId>,
2811 nodes: &mut Vec<NodeId>,
2812) -> Option<ValueId> {
2813 if expected_dims.is_empty() {
2814 return Some(vid);
2815 }
2816 let (node_id, data_vid, dims) = match_mean_axes(graph, vid)?;
2817 if assigned.contains(&node_id) {
2818 return None;
2819 }
2820 if dims.len() == expected_dims.len() && dims_match_unordered(&dims, expected_dims) {
2821 nodes.push(node_id);
2822 return Some(data_vid);
2823 }
2824 if dims.len() == 1 && approx_eq(dims[0], expected_dims[0]) {
2825 nodes.push(node_id);
2826 return peel_mean_dims(graph, data_vid, &expected_dims[1..], assigned, nodes);
2827 }
2828 None
2829}
2830
2831struct ImageNormalizeMatch {
2832 nodes: Vec<NodeId>,
2833 input: ValueId,
2834 epsilon: ImageScalar,
2835 gain: Option<ImageScalar>,
2836 bias: Option<ImageScalar>,
2837 gamma: Option<ImageScalar>,
2838}
2839
2840fn analyze_image_normalize(
2841 graph: &AccelGraph,
2842 pow_node_id: NodeId,
2843 assigned: &HashSet<NodeId>,
2844) -> Option<ImageNormalizeMatch> {
2845 let pow_node = graph.node(pow_node_id)?;
2846 if log::log_enabled!(log::Level::Debug) {
2847 log::debug!(
2848 "image_normalize: inspect pow candidate node={pow_node_id:?} label={:?}",
2849 pow_node.label
2850 );
2851 }
2852 macro_rules! img_norm_fail {
2853 ($reason:expr) => {{
2854 if log::log_enabled!(log::Level::Debug) {
2855 log::debug!(
2856 "image_normalize: reject node {pow_node_id:?} reason={}",
2857 $reason
2858 );
2859 }
2860 return None;
2861 }};
2862 }
2863 if !matches!(
2864 pow_node.label,
2865 AccelNodeLabel::Primitive(PrimitiveOp::ElemPow)
2866 ) {
2867 img_norm_fail!("not elem pow");
2868 }
2869 if pow_node.inputs.len() != 2 || pow_node.outputs.len() != 1 {
2870 img_norm_fail!("unexpected pow arity");
2871 }
2872
2873 let mut nodes: Vec<NodeId> = vec![pow_node_id];
2874
2875 let gamma_scalar = capture_image_scalar(graph, pow_node.inputs[1], assigned, &mut nodes)?;
2876 if log::log_enabled!(log::Level::Debug) {
2877 log::debug!("image_normalize: node {pow_node_id:?} gamma scalar={gamma_scalar:?}");
2878 }
2879 let gamma_opt = match &gamma_scalar {
2880 ImageScalar::Constant(value) if approx_eq(*value, 1.0) => None,
2881 _ => Some(gamma_scalar),
2882 };
2883
2884 let (clamp_node_id, clamp_input_vid) =
2885 split_max_with_zero_scalar(graph, pow_node.inputs[0], assigned, &mut nodes)?;
2886 if assigned.contains(&clamp_node_id) {
2887 img_norm_fail!("clamp node already assigned");
2888 }
2889 nodes.push(clamp_node_id);
2890
2891 let pre_bias_vid = peel_numeric_casts(graph, clamp_input_vid, assigned, &mut nodes)?;
2892 let (pre_gain_vid, bias_opt) = if let Some((add_node_id, base_vid, bias_scalar)) =
2893 split_add_with_scalar(graph, pre_bias_vid, assigned, &mut nodes)
2894 {
2895 if assigned.contains(&add_node_id) {
2896 img_norm_fail!("bias add already assigned");
2897 }
2898 nodes.push(add_node_id);
2899 let bias = match &bias_scalar {
2900 ImageScalar::Constant(value) if approx_eq(*value, 0.0) => None,
2901 _ => Some(bias_scalar),
2902 };
2903 let base_vid = peel_numeric_casts(graph, base_vid, assigned, &mut nodes)?;
2904 (base_vid, bias)
2905 } else {
2906 (pre_bias_vid, None)
2907 };
2908
2909 let (mut norm_vid, gain_opt) = if let Some((mul_node_id, base_vid, gain_scalar)) =
2910 split_mul_with_scalar(graph, pre_gain_vid, assigned, &mut nodes)
2911 {
2912 if assigned.contains(&mul_node_id) {
2913 img_norm_fail!("gain mul already assigned");
2914 }
2915 nodes.push(mul_node_id);
2916 let gain = match &gain_scalar {
2917 ImageScalar::Constant(value) if approx_eq(*value, 1.0) => None,
2918 _ => Some(gain_scalar),
2919 };
2920 let base_vid = peel_numeric_casts(graph, base_vid, assigned, &mut nodes)?;
2921 (base_vid, gain)
2922 } else {
2923 (pre_gain_vid, None)
2924 };
2925
2926 norm_vid = peel_numeric_casts(graph, norm_vid, assigned, &mut nodes)?;
2927
2928 let (div_node_id, div_node) = node_from_value(graph, norm_vid)?;
2929 if assigned.contains(&div_node_id) {
2930 img_norm_fail!("div node already assigned");
2931 }
2932 match div_node.label {
2933 AccelNodeLabel::Primitive(PrimitiveOp::ElemDiv)
2934 | AccelNodeLabel::Primitive(PrimitiveOp::Div) => {}
2935 _ => img_norm_fail!("not div primitive"),
2936 }
2937 if div_node.inputs.len() != 2 {
2938 img_norm_fail!("div arity");
2939 }
2940
2941 let diff_vid = div_node.inputs[0];
2942 let sigma_vid = peel_numeric_casts(graph, div_node.inputs[1], assigned, &mut nodes)?;
2943 let (sigma_node_id, sigma_input_vid) = match is_sqrt_node(graph, sigma_vid) {
2944 Some(pair) => pair,
2945 None => img_norm_fail!("sigma not sqrt"),
2946 };
2947 if assigned.contains(&sigma_node_id) {
2948 img_norm_fail!("sqrt node already assigned");
2949 }
2950 nodes.push(div_node_id);
2951 nodes.push(sigma_node_id);
2952
2953 let (add_node_id, mean_sq_vid, epsilon_scalar) =
2954 split_add_with_scalar(graph, sigma_input_vid, assigned, &mut nodes)?;
2955 if assigned.contains(&add_node_id) {
2956 img_norm_fail!("epsilon add already assigned");
2957 }
2958 nodes.push(add_node_id);
2959 let epsilon = epsilon_scalar;
2960 let mean_sq_vid = peel_numeric_casts(graph, mean_sq_vid, assigned, &mut nodes)?;
2961
2962 let squared_diff_vid = peel_mean_dims(graph, mean_sq_vid, &[3.0, 2.0], assigned, &mut nodes)?;
2963
2964 let (square_pow_node_id, square_pow_node) = node_from_value(graph, squared_diff_vid)?;
2965 if assigned.contains(&square_pow_node_id) {
2966 img_norm_fail!("square pow already assigned");
2967 }
2968 if !matches!(
2969 square_pow_node.label,
2970 AccelNodeLabel::Primitive(PrimitiveOp::ElemPow)
2971 ) {
2972 img_norm_fail!("variance pow not elem pow");
2973 }
2974 if square_pow_node.inputs.len() != 2 {
2975 img_norm_fail!("variance pow arity");
2976 }
2977 let exponent_trace = collect_scalar_constant(graph, square_pow_node.inputs[1])?;
2978 if !approx_eq(exponent_trace.value, 2.0) {
2979 img_norm_fail!("variance exponent != 2");
2980 }
2981 if exponent_trace.nodes.iter().any(|id| assigned.contains(id)) {
2982 img_norm_fail!("variance exponent nodes already assigned");
2983 }
2984 nodes.push(square_pow_node_id);
2985 nodes.extend(exponent_trace.nodes.iter().copied());
2986
2987 let diff_var_vid = square_pow_node.inputs[0];
2988 let (diff_var_node_id, diff_var_node) = node_from_value(graph, diff_var_vid)?;
2989 if assigned.contains(&diff_var_node_id) {
2990 img_norm_fail!("diff variance node already assigned");
2991 }
2992 if !matches!(
2993 diff_var_node.label,
2994 AccelNodeLabel::Primitive(PrimitiveOp::Sub)
2995 ) {
2996 img_norm_fail!("diff variance node not sub");
2997 }
2998 if diff_var_node.inputs.len() != 2 {
2999 img_norm_fail!("diff variance arity");
3000 }
3001 let imgs_vid = diff_var_node.inputs[0];
3002 let mu_vid = peel_numeric_casts(graph, diff_var_node.inputs[1], assigned, &mut nodes)?;
3003 nodes.push(diff_var_node_id);
3004
3005 let (diff_node_id, diff_node) = node_from_value(graph, diff_vid)?;
3006 if assigned.contains(&diff_node_id) {
3007 img_norm_fail!("diff node already assigned");
3008 }
3009 if !matches!(diff_node.label, AccelNodeLabel::Primitive(PrimitiveOp::Sub)) {
3010 img_norm_fail!("diff node not sub");
3011 }
3012 if diff_node.inputs.len() != 2 {
3013 img_norm_fail!("diff node arity");
3014 }
3015 let diff_mu_vid = peel_numeric_casts(graph, diff_node.inputs[1], assigned, &mut nodes)?;
3016 if diff_node.inputs[0] != imgs_vid || diff_mu_vid != mu_vid {
3017 img_norm_fail!("diff inputs mismatch with variance pair");
3018 }
3019 nodes.push(diff_node_id);
3020
3021 let mean_mu_input_vid = peel_mean_dims(graph, mu_vid, &[3.0, 2.0], assigned, &mut nodes)?;
3022 if mean_mu_input_vid != imgs_vid {
3023 img_norm_fail!("mean mu input mismatch");
3024 }
3025
3026 let input_info = graph.value(imgs_vid)?;
3027 match &input_info.shape {
3028 ShapeInfo::Tensor(dims) if dims.len() >= 2 => {}
3029 ShapeInfo::Unknown => {}
3030 other => {
3031 if log::log_enabled!(log::Level::Debug) {
3032 log::debug!(
3033 "image_normalize: node {pow_node_id:?} input shape {:?}",
3034 other
3035 );
3036 }
3037 img_norm_fail!("input not 3-d tensor");
3038 }
3039 }
3040
3041 nodes.sort_unstable();
3042 nodes.dedup();
3043
3044 Some(ImageNormalizeMatch {
3045 nodes,
3046 input: imgs_vid,
3047 epsilon,
3048 gain: gain_opt,
3049 bias: bias_opt,
3050 gamma: gamma_opt,
3051 })
3052}
3053
3054#[cfg(test)]
3055mod tests {
3056 use super::*;
3057 use crate::graph::{
3058 AccelGraph, AccelGraphTag, AccelNode, AccelNodeLabel, AccelOpCategory, InstrSpan,
3059 PrimitiveOp, ValueId, ValueInfo, ValueOrigin, VarKind,
3060 };
3061 use runmat_builtins::{Type, Value};
3062 use std::collections::HashMap as StdHashMap;
3063
3064 fn simple_elementwise_graph() -> AccelGraph {
3065 let values = vec![
3066 ValueInfo {
3068 id: 0,
3069 origin: ValueOrigin::Variable {
3070 kind: VarKind::Global,
3071 index: 0,
3072 },
3073 ty: Type::tensor(),
3074 shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3075 constant: None,
3076 },
3077 ValueInfo {
3079 id: 1,
3080 origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3081 ty: Type::tensor(),
3082 shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3083 constant: None,
3084 },
3085 ValueInfo {
3087 id: 2,
3088 origin: ValueOrigin::NodeOutput { node: 1, output: 0 },
3089 ty: Type::tensor(),
3090 shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3091 constant: None,
3092 },
3093 ];
3094
3095 let node0 = AccelNode {
3096 id: 0,
3097 label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3098 category: AccelOpCategory::Elementwise,
3099 inputs: vec![0, 0],
3100 outputs: vec![1],
3101 span: InstrSpan { start: 10, end: 10 },
3102 tags: vec![AccelGraphTag::Elementwise],
3103 };
3104 let node1 = AccelNode {
3105 id: 1,
3106 label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3107 category: AccelOpCategory::Elementwise,
3108 inputs: vec![1, 0],
3109 outputs: vec![2],
3110 span: InstrSpan { start: 11, end: 11 },
3111 tags: vec![AccelGraphTag::Elementwise],
3112 };
3113
3114 AccelGraph {
3115 nodes: vec![node0, node1],
3116 values,
3117 var_bindings: StdHashMap::new(),
3118 node_bindings: StdHashMap::new(),
3119 }
3120 }
3121
3122 #[test]
3123 fn detects_chain() {
3124 let graph = simple_elementwise_graph();
3125 let groups = detect_fusion_groups(&graph);
3126 assert_eq!(groups.len(), 1);
3127 let group = &groups[0];
3128 assert_eq!(group.nodes, vec![0, 1]);
3129 assert_eq!(group.kind, FusionKind::ElementwiseChain);
3130 }
3131
3132 #[test]
3133 fn builds_plan_and_template() {
3134 let graph = simple_elementwise_graph();
3135 let groups = detect_fusion_groups(&graph);
3136 let plan = FusionPlan::from_graph(&graph, &groups);
3137 assert_eq!(plan.groups.len(), 1);
3138 let group_plan = &plan.groups[0];
3139 assert!(group_plan.kernel.supported);
3140 let wgsl = group_plan.generate_wgsl("f32").expect("wgsl");
3141 assert!(wgsl.contains("@compute"));
3142 assert!(group_plan.group.element_count().is_some());
3143 }
3144
3145 #[test]
3146 fn stack_pattern_tracks_repeated_constants() {
3147 let values = vec![
3148 ValueInfo {
3149 id: 0,
3150 origin: ValueOrigin::Variable {
3151 kind: VarKind::Global,
3152 index: 0,
3153 },
3154 ty: Type::tensor(),
3155 shape: ShapeInfo::Tensor(vec![Some(4)]),
3156 constant: None,
3157 },
3158 ValueInfo {
3159 id: 1,
3160 origin: ValueOrigin::Constant,
3161 ty: Type::tensor(),
3162 shape: ShapeInfo::Tensor(vec![Some(4)]),
3163 constant: Some(Value::Num(1.0)),
3164 },
3165 ValueInfo {
3166 id: 2,
3167 origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3168 ty: Type::tensor(),
3169 shape: ShapeInfo::Tensor(vec![Some(4)]),
3170 constant: None,
3171 },
3172 ValueInfo {
3173 id: 3,
3174 origin: ValueOrigin::NodeOutput { node: 1, output: 0 },
3175 ty: Type::tensor(),
3176 shape: ShapeInfo::Tensor(vec![Some(4)]),
3177 constant: None,
3178 },
3179 ];
3180
3181 let node0 = AccelNode {
3182 id: 0,
3183 label: AccelNodeLabel::Primitive(PrimitiveOp::Add),
3184 category: AccelOpCategory::Elementwise,
3185 inputs: vec![0, 1],
3186 outputs: vec![2],
3187 span: InstrSpan { start: 5, end: 5 },
3188 tags: vec![AccelGraphTag::Elementwise],
3189 };
3190 let node1 = AccelNode {
3191 id: 1,
3192 label: AccelNodeLabel::Primitive(PrimitiveOp::Add),
3193 category: AccelOpCategory::Elementwise,
3194 inputs: vec![2, 1],
3195 outputs: vec![3],
3196 span: InstrSpan { start: 6, end: 6 },
3197 tags: vec![AccelGraphTag::Elementwise],
3198 };
3199
3200 let graph = AccelGraph {
3201 nodes: vec![node0, node1],
3202 values,
3203 var_bindings: StdHashMap::new(),
3204 node_bindings: StdHashMap::new(),
3205 };
3206
3207 let groups = detect_fusion_groups(&graph);
3208 assert_eq!(groups.len(), 1);
3209 let plan = FusionPlan::from_graph(&graph, &groups);
3210 let group_plan = &plan.groups[0];
3211 assert_eq!(group_plan.inputs.len(), 2);
3212 assert!(group_plan.stack_pattern.is_empty());
3213 assert!(group_plan.constants.contains_key(&1));
3214 assert!(group_plan.const_values.contains_key(&1));
3215 }
3216
3217 #[test]
3218 fn builtin_expr_supports_extended_set() {
3219 let mut exprs: StdHashMap<ValueId, String> = StdHashMap::new();
3220 exprs.insert(0, "v0".to_string());
3221 exprs.insert(1, "v1".to_string());
3222
3223 let log1p = super::builtin_expr("log1p", &[0], &exprs, "f32");
3224 assert!(log1p.is_some());
3225
3226 let log10 = super::builtin_expr("log10", &[0], &exprs, "f64");
3227 assert!(log10.unwrap().contains("log"));
3228
3229 let expm1 = super::builtin_expr("expm1", &[0], &exprs, "f32");
3230 assert!(expm1.unwrap().contains("exp"));
3231
3232 let floor = super::builtin_expr("floor", &[0], &exprs, "f32");
3233 assert_eq!(floor.unwrap(), "floor(v0)");
3234
3235 let atan2 = super::builtin_expr("atan2", &[0, 1], &exprs, "f32");
3236 assert_eq!(atan2.unwrap(), "atan2(v0, v1)");
3237
3238 let single = super::builtin_expr("single", &[0], &exprs, "f32");
3239 assert_eq!(single.unwrap(), "v0");
3240
3241 let double = super::builtin_expr("double", &[0], &exprs, "f64");
3242 assert_eq!(double.unwrap(), "v0");
3243 }
3244
3245 #[test]
3246 fn fanout_chain_with_casts_supported() {
3247 let values = vec![
3248 ValueInfo {
3250 id: 0,
3251 origin: ValueOrigin::Variable {
3252 kind: VarKind::Global,
3253 index: 0,
3254 },
3255 ty: Type::tensor(),
3256 shape: ShapeInfo::Tensor(vec![Some(8)]),
3257 constant: None,
3258 },
3259 ValueInfo {
3261 id: 1,
3262 origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3263 ty: Type::tensor(),
3264 shape: ShapeInfo::Tensor(vec![Some(8)]),
3265 constant: None,
3266 },
3267 ValueInfo {
3269 id: 2,
3270 origin: ValueOrigin::Constant,
3271 ty: Type::Num,
3272 shape: ShapeInfo::Scalar,
3273 constant: Some(Value::Num(0.1)),
3274 },
3275 ValueInfo {
3277 id: 3,
3278 origin: ValueOrigin::NodeOutput { node: 1, output: 0 },
3279 ty: Type::Num,
3280 shape: ShapeInfo::Scalar,
3281 constant: None,
3282 },
3283 ValueInfo {
3285 id: 4,
3286 origin: ValueOrigin::NodeOutput { node: 2, output: 0 },
3287 ty: Type::tensor(),
3288 shape: ShapeInfo::Tensor(vec![Some(8)]),
3289 constant: None,
3290 },
3291 ValueInfo {
3293 id: 5,
3294 origin: ValueOrigin::NodeOutput { node: 3, output: 0 },
3295 ty: Type::tensor(),
3296 shape: ShapeInfo::Tensor(vec![Some(8)]),
3297 constant: None,
3298 },
3299 ];
3300
3301 let tanh_node = AccelNode {
3302 id: 0,
3303 label: AccelNodeLabel::Builtin {
3304 name: "tanh".to_string(),
3305 },
3306 category: AccelOpCategory::Elementwise,
3307 inputs: vec![0],
3308 outputs: vec![1],
3309 span: InstrSpan { start: 10, end: 10 },
3310 tags: vec![AccelGraphTag::Elementwise],
3311 };
3312 let single_node = AccelNode {
3313 id: 1,
3314 label: AccelNodeLabel::Builtin {
3315 name: "single".to_string(),
3316 },
3317 category: AccelOpCategory::Elementwise,
3318 inputs: vec![2],
3319 outputs: vec![3],
3320 span: InstrSpan { start: 11, end: 11 },
3321 tags: vec![AccelGraphTag::Elementwise],
3322 };
3323 let mul_node = AccelNode {
3324 id: 2,
3325 label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3326 category: AccelOpCategory::Elementwise,
3327 inputs: vec![3, 0],
3328 outputs: vec![4],
3329 span: InstrSpan { start: 12, end: 12 },
3330 tags: vec![AccelGraphTag::Elementwise],
3331 };
3332 let add_node = AccelNode {
3333 id: 3,
3334 label: AccelNodeLabel::Primitive(PrimitiveOp::Add),
3335 category: AccelOpCategory::Elementwise,
3336 inputs: vec![1, 4],
3337 outputs: vec![5],
3338 span: InstrSpan { start: 13, end: 13 },
3339 tags: vec![AccelGraphTag::Elementwise],
3340 };
3341
3342 let graph = AccelGraph {
3343 nodes: vec![tanh_node, single_node, mul_node, add_node],
3344 values,
3345 var_bindings: StdHashMap::new(),
3346 node_bindings: StdHashMap::new(),
3347 };
3348
3349 let groups = detect_fusion_groups(&graph);
3350 assert_eq!(groups.len(), 1);
3351
3352 let plan = FusionPlan::from_graph(&graph, &groups);
3353 let group_plan = &plan.groups[0];
3354 assert!(group_plan.kernel.supported);
3355 let shader = group_plan.generate_wgsl("f32");
3356 assert!(shader
3357 .as_ref()
3358 .map(|wgsl| wgsl.contains("tanh") && wgsl.contains("output.data"))
3359 .unwrap_or(false));
3360 }
3361}