1#[cfg(not(target_arch = "wasm32"))]
2use std::cell::RefCell;
3use std::collections::{HashMap, HashSet};
4#[cfg(target_arch = "wasm32")]
5use std::sync::Mutex;
6use std::sync::{Arc, OnceLock, RwLock, Weak};
7
8use once_cell::sync::Lazy;
9use runmat_accelerate_api::ReductionFlavor;
10use runmat_builtins::Value;
11use serde::{Deserialize, Serialize};
12
13use crate::graph::{
14 AccelGraph, AccelNode, AccelNodeLabel, AccelOpCategory, InstrSpan, NodeId, PrimitiveOp,
15 ShapeInfo, ValueId, ValueInfo, ValueOrigin, VarBinding,
16};
17use crate::reduction_meta::{detect_reduction_signature, ReductionAxes, ReductionBehavior};
18use runmat_accelerate_api::CovNormalization;
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
21pub enum FusionKind {
22 ElementwiseChain,
23 Reduction,
24 MatmulEpilogue,
25 CenteredGram,
26 ImageNormalize,
27 PowerStepNormalize,
28 ExplainedVariance,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct FusionGroup {
33 pub id: usize,
34 pub kind: FusionKind,
35 pub nodes: Vec<NodeId>,
36 pub shape: ShapeInfo,
37 pub span: InstrSpan,
38 pub pattern: Option<FusionPattern>,
39 #[serde(default)]
40 pub stack_layout: Option<FusionStackLayout>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
44pub struct FusionStackLayout {
45 pub required_stack_operands: usize,
46 pub bindings: Vec<FusionStackValueBinding>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50pub struct FusionStackValueBinding {
51 pub value_id: ValueId,
52 pub stack_offset: usize,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub enum FusionPattern {
57 CenteredGram {
58 matrix: ValueId,
59 normalization: CovNormalization,
60 },
61 ImageNormalize(ImageNormalizePattern),
62 PowerStepNormalize {
63 lhs: ValueId,
64 rhs: ValueId,
65 epsilon: f64,
66 },
67 ExplainedVariance {
68 q: ValueId,
69 g: ValueId,
70 },
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ImageNormalizePattern {
75 pub input: ValueId,
76 pub epsilon: ImageScalar,
77 pub gain: Option<ImageScalar>,
78 pub bias: Option<ImageScalar>,
79 pub gamma: Option<ImageScalar>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub enum ImageScalar {
84 Constant(f64),
85 Value(ValueId),
86}
87
88pub fn detect_fusion_groups(graph: &AccelGraph) -> Vec<FusionGroup> {
89 if graph.nodes.is_empty() {
90 return Vec::new();
91 }
92
93 let consumer_map = build_consumer_map(graph);
94 let mut assigned: HashSet<NodeId> = HashSet::new();
95 let mut groups = Vec::new();
96 let mut group_id = 0usize;
97
98 detect_image_normalize(graph, &mut assigned, &mut groups, &mut group_id);
99 detect_explained_variance(graph, &mut assigned, &mut groups, &mut group_id);
100 detect_power_step_normalize(graph, &mut assigned, &mut groups, &mut group_id);
101 detect_centered_gram(graph, &mut assigned, &mut groups, &mut group_id);
102
103 for node in &graph.nodes {
104 if assigned.contains(&node.id) {
106 continue;
107 }
108 let elementwise_like = node.is_elementwise() || is_elementwise_max_min(graph, node);
109 if !elementwise_like {
110 continue;
111 }
112 if node.outputs.is_empty() {
113 continue;
114 }
115 let mut current_shape = node_output_shape(graph, node);
116 if matches!(current_shape, ShapeInfo::Unknown | ShapeInfo::Scalar) {
117 continue;
118 }
119 let mut chain: Vec<NodeId> = Vec::new();
120 let mut frontier = node.id;
121 let mut local_seen: HashSet<NodeId> = HashSet::new();
122
123 loop {
124 if !local_seen.insert(frontier) {
125 break;
126 }
127 chain.push(frontier);
128 let next = find_next_elementwise(
129 graph,
130 frontier,
131 &assigned,
132 &local_seen,
133 &consumer_map,
134 ¤t_shape,
135 );
136 match next {
137 Some((next_id, next_shape)) => {
138 frontier = next_id;
139 current_shape = next_shape;
140 }
141 None => break,
142 }
143 }
144
145 if chain.len() > 1 {
146 expand_group_with_fanout(graph, &mut chain, &assigned, &consumer_map);
147 chain.sort_unstable_by_key(|id| {
148 graph
149 .node(*id)
150 .map(|node| node.span.start)
151 .unwrap_or_default()
152 });
153 chain.dedup();
154 for id in &chain {
155 assigned.insert(*id);
156 }
157 let span = group_span(graph, &chain);
158 groups.push(FusionGroup {
159 id: group_id,
160 kind: FusionKind::ElementwiseChain,
161 nodes: chain,
162 shape: current_shape.clone(),
163 span,
164 pattern: None,
165 stack_layout: None,
166 });
167 group_id += 1;
168 }
169 }
170
171 for node in &graph.nodes {
173 if assigned.contains(&node.id) {
174 continue;
175 }
176 if !node.is_reduction() || is_elementwise_max_min(graph, node) {
177 continue;
178 }
179 let span = InstrSpan {
180 start: node.span.start,
181 end: node.span.end,
182 };
183 groups.push(FusionGroup {
184 id: group_id,
185 kind: FusionKind::Reduction,
186 nodes: vec![node.id],
187 shape: node_output_shape(graph, node),
188 span,
189 pattern: None,
190 stack_layout: None,
191 });
192 group_id += 1;
193 }
194
195 for node in &graph.nodes {
197 if node.category != AccelOpCategory::MatMul || assigned.contains(&node.id) {
198 continue;
199 }
200 if node.outputs.is_empty() {
201 continue;
202 }
203 let mut chain: Vec<NodeId> = vec![node.id];
205 let mut frontier = node.id;
206 let mut ok = false;
207 loop {
208 let mut next_id_opt: Option<NodeId> = None;
210 for &out in &graph.node(frontier).unwrap().outputs {
211 if let Some(cons) = consumer_map.get(&out) {
212 if cons.len() == 1 {
213 next_id_opt = cons.iter().copied().next();
214 } else {
215 next_id_opt = None;
216 }
217 }
218 }
219 let Some(next_id) = next_id_opt else { break };
220 let next = graph.node(next_id).unwrap();
221 if !next.is_elementwise() {
222 break;
223 }
224 let allowed = matches!(
226 next.label,
227 AccelNodeLabel::Primitive(PrimitiveOp::Add)
228 | AccelNodeLabel::Primitive(PrimitiveOp::Sub)
229 | AccelNodeLabel::Primitive(PrimitiveOp::Mul)
230 | AccelNodeLabel::Primitive(PrimitiveOp::ElemMul)
231 | AccelNodeLabel::Primitive(PrimitiveOp::ElemDiv)
232 );
233 if !allowed {
234 break;
235 }
236 chain.push(next_id);
237 frontier = next_id;
238 ok = true;
239 }
240 if ok {
241 for id in &chain {
242 assigned.insert(*id);
243 }
244 let span = group_span(graph, &chain);
245 groups.push(FusionGroup {
246 id: group_id,
247 kind: FusionKind::MatmulEpilogue,
248 nodes: chain,
249 shape: node_output_shape(graph, node),
250 span,
251 pattern: None,
252 stack_layout: None,
253 });
254 group_id += 1;
255 }
256 }
257
258 merge_downstream_fanout(graph, &mut groups, &consumer_map);
259 groups
260}
261
262fn expand_group_with_fanout(
263 graph: &AccelGraph,
264 chain: &mut Vec<NodeId>,
265 assigned: &HashSet<NodeId>,
266 consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
267) {
268 let base_start = chain
269 .iter()
270 .filter_map(|id| graph.node(*id).map(|node| node.span.start))
271 .min()
272 .unwrap_or(0);
273 let mut node_set: HashSet<NodeId> = chain.iter().copied().collect();
274 let mut changed = true;
275 while changed {
276 changed = false;
277 for node in &graph.nodes {
278 if node_set.contains(&node.id) {
279 continue;
280 }
281 if node.span.start < base_start {
282 continue;
283 }
284 if assigned.contains(&node.id) {
285 continue;
286 }
287 if !(node.is_elementwise() || is_elementwise_max_min(graph, node)) {
288 continue;
289 }
290 if node.outputs.is_empty() {
291 continue;
292 }
293 let mut feeds_group = false;
294 let mut all_consumers_ok = true;
295 for &out in &node.outputs {
296 if let Some(consumers) = consumer_map.get(&out) {
297 let mut consumer_in_group = false;
298 for consumer in consumers {
299 if node_set.contains(consumer) {
300 consumer_in_group = true;
301 } else {
302 all_consumers_ok = false;
303 break;
304 }
305 }
306 if !all_consumers_ok {
307 break;
308 }
309 if consumer_in_group {
310 feeds_group = true;
311 }
312 } else {
313 all_consumers_ok = false;
314 break;
315 }
316 }
317 if !feeds_group || !all_consumers_ok {
318 continue;
319 }
320 let mut inputs_ok = true;
321 for &input in &node.inputs {
322 if let Some(info) = graph.value(input) {
323 if let ValueOrigin::NodeOutput { node: producer, .. } = info.origin {
324 if !node_set.contains(&producer) {
325 if let Some(prod_node) = graph.node(producer) {
326 if prod_node.span.start >= base_start {
327 inputs_ok = false;
328 break;
329 }
330 } else {
331 inputs_ok = false;
332 break;
333 }
334 }
335 }
336 }
337 }
338 if inputs_ok {
339 node_set.insert(node.id);
340 chain.push(node.id);
341 changed = true;
342 }
343 }
344 }
345}
346
347fn build_consumer_map(graph: &AccelGraph) -> HashMap<ValueId, HashSet<NodeId>> {
348 let mut map: HashMap<ValueId, HashSet<NodeId>> = HashMap::new();
349 for node in &graph.nodes {
350 for &input in &node.inputs {
351 if let Some(value) = graph.value(input) {
352 if matches!(value.origin, crate::graph::ValueOrigin::NodeOutput { .. }) {
353 map.entry(input).or_default().insert(node.id);
354 }
355 }
356 }
357 }
358 map
359}
360
361fn merge_downstream_fanout(
362 graph: &AccelGraph,
363 groups: &mut Vec<FusionGroup>,
364 consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
365) {
366 let mut changed = true;
367 while changed {
368 changed = false;
369 let mut node_group: HashMap<NodeId, usize> = HashMap::new();
370 for (idx, group) in groups.iter().enumerate() {
371 if group.kind.is_elementwise() {
372 for &node in &group.nodes {
373 node_group.insert(node, idx);
374 }
375 }
376 }
377 'outer: for target_idx in 0..groups.len() {
378 if !groups[target_idx].kind.is_elementwise() {
379 continue;
380 }
381 let base_start = groups[target_idx].span.start;
382 let mut merge_indices: Vec<usize> = Vec::new();
383 for &node_id in &groups[target_idx].nodes {
384 let Some(node) = graph.node(node_id) else {
385 continue;
386 };
387 for &input in &node.inputs {
388 if let Some(info) = graph.value(input) {
389 if let ValueOrigin::NodeOutput { node: producer, .. } = info.origin {
390 if let Some(&source_idx) = node_group.get(&producer) {
391 if source_idx == target_idx {
392 continue;
393 }
394 let source_group = &groups[source_idx];
395 if !source_group.kind.is_elementwise() {
396 continue;
397 }
398 if source_group.span.start < base_start {
399 continue;
400 }
401 if !group_consumers_subset(
402 source_group,
403 target_idx,
404 groups,
405 consumer_map,
406 graph,
407 ) {
408 continue;
409 }
410 merge_indices.push(source_idx);
411 }
412 }
413 }
414 }
415 }
416 if merge_indices.is_empty() {
417 continue;
418 }
419 merge_indices.sort_unstable();
420 merge_indices.dedup();
421 for idx in &merge_indices {
422 let nodes = groups[*idx].nodes.clone();
423 groups[target_idx].nodes.extend(nodes);
424 groups[*idx].nodes.clear();
425 }
426 groups[target_idx]
427 .nodes
428 .sort_unstable_by_key(|id| graph.node(*id).map(|n| n.span.start).unwrap_or(0));
429 groups[target_idx].nodes.dedup();
430 groups[target_idx].span = group_span(graph, &groups[target_idx].nodes);
431 changed = true;
432 break 'outer;
433 }
434 if changed {
435 groups.retain(|group| !group.nodes.is_empty());
436 }
437 }
438}
439
440fn group_consumers_subset(
441 source_group: &FusionGroup,
442 target_idx: usize,
443 groups: &[FusionGroup],
444 consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
445 graph: &AccelGraph,
446) -> bool {
447 let target_nodes: HashSet<NodeId> = groups[target_idx].nodes.iter().copied().collect();
448 let source_nodes: HashSet<NodeId> = source_group.nodes.iter().copied().collect();
449 for &node_id in &source_group.nodes {
450 let Some(node) = graph.node(node_id) else {
451 continue;
452 };
453 for &out in &node.outputs {
454 if let Some(consumers) = consumer_map.get(&out) {
455 for consumer in consumers {
456 if !source_nodes.contains(consumer) && !target_nodes.contains(consumer) {
457 return false;
458 }
459 }
460 }
461 }
462 }
463 true
464}
465
466fn node_output_shape(graph: &AccelGraph, node: &AccelNode) -> ShapeInfo {
467 let mut shape = ShapeInfo::Scalar;
468 for &output in &node.outputs {
469 if let Some(info) = graph.value(output) {
470 shape = shape.unify(&info.shape);
471 }
472 }
473 shape
474}
475
476fn find_next_elementwise(
477 graph: &AccelGraph,
478 node_id: NodeId,
479 assigned: &HashSet<NodeId>,
480 local_seen: &HashSet<NodeId>,
481 consumer_map: &HashMap<ValueId, HashSet<NodeId>>,
482 current_shape: &ShapeInfo,
483) -> Option<(NodeId, ShapeInfo)> {
484 let node = graph.node(node_id)?;
485 let mut candidate: Option<(NodeId, ShapeInfo)> = None;
486
487 for &output in &node.outputs {
488 let consumers = consumer_map.get(&output)?;
489 if consumers.len() != 1 {
490 return None;
491 }
492 let next_id = *consumers.iter().next()?;
493 if next_id <= node_id || assigned.contains(&next_id) || local_seen.contains(&next_id) {
494 return None;
495 }
496 let next_node = graph.node(next_id)?;
497 if !(next_node.is_elementwise() || is_elementwise_max_min(graph, next_node)) {
498 return None;
499 }
500 if !next_node.inputs.contains(&output) {
502 continue;
503 }
504 let next_shape = node_output_shape(graph, next_node);
505 if matches!(next_shape, ShapeInfo::Unknown) {
506 return None;
507 }
508 let unified = current_shape.unify(&next_shape);
509 if matches!(unified, ShapeInfo::Unknown) {
510 return None;
511 }
512 candidate = Some((next_id, unified));
513 break;
514 }
515
516 candidate
517}
518
519fn is_elementwise_max_min(graph: &AccelGraph, node: &AccelNode) -> bool {
520 match &node.label {
521 AccelNodeLabel::Builtin { name }
522 if name.eq_ignore_ascii_case("max") || name.eq_ignore_ascii_case("min") =>
523 {
524 if node.inputs.len() < 2 {
525 return false;
526 }
527 !value_is_placeholder(graph, node.inputs[1])
528 }
529 _ => false,
530 }
531}
532
533fn value_is_placeholder(graph: &AccelGraph, vid: ValueId) -> bool {
534 let Some(info) = graph.value(vid) else {
535 return false;
536 };
537 let Some(constant) = &info.constant else {
538 return false;
539 };
540 match constant {
541 Value::Tensor(t) => t.data.is_empty(),
542 Value::LogicalArray(l) => l.data.is_empty(),
543 Value::StringArray(sa) => sa.data.is_empty(),
544 Value::CharArray(ca) => ca.data.is_empty(),
545 Value::Cell(cell) => cell.data.is_empty(),
546 Value::String(s) => s.is_empty(),
547 _ => false,
548 }
549}
550
551fn group_span(graph: &AccelGraph, nodes: &[NodeId]) -> InstrSpan {
552 let mut start = usize::MAX;
553 let mut end = 0usize;
554 for &id in nodes {
555 if let Some(node) = graph.node(id) {
556 start = start.min(node.span.start);
557 end = end.max(node.span.end);
558 }
559 }
560 if start == usize::MAX {
561 start = 0;
562 }
563 InstrSpan { start, end }
564}
565
566fn merge_stack_layout_with_stack_pattern(
567 existing: Option<&FusionStackLayout>,
568 inputs: &[ValueId],
569 stack_pattern: &[usize],
570) -> Option<FusionStackLayout> {
571 if existing.is_none() && stack_pattern.is_empty() {
572 return None;
573 }
574
575 let mut bindings = existing
576 .map(|layout| layout.bindings.clone())
577 .unwrap_or_default();
578 for (stack_offset, &input_idx) in stack_pattern.iter().enumerate() {
579 let &value_id = inputs.get(input_idx)?;
580 if bindings.iter().any(|binding| binding.value_id == value_id) {
581 continue;
582 }
583 bindings.push(FusionStackValueBinding {
584 value_id,
585 stack_offset,
586 });
587 }
588
589 let required_stack_operands = existing
590 .map(|layout| layout.required_stack_operands)
591 .unwrap_or(0)
592 .max(stack_pattern.len());
593
594 Some(FusionStackLayout {
595 required_stack_operands,
596 bindings,
597 })
598}
599
600#[derive(Debug, Clone)]
601pub struct FusionPlan {
602 pub groups: Vec<FusionGroupPlan>,
603}
604
605#[derive(Debug, Clone)]
606pub struct FusionGroupPlan {
607 pub index: usize,
608 pub group: FusionGroup,
609 pub operations: Vec<FusionOp>,
610 pub inputs: Vec<ValueId>,
611 pub stack_pattern: Vec<usize>,
612 pub constants: HashMap<usize, Value>,
613 pub const_values: HashMap<ValueId, Value>,
614 pub materialized_stores: Vec<FusionStoreMaterialization>,
615 pub output: Option<ValueId>,
616 pub kernel: FusionKernelSpec,
617 pub reduction_data: Option<ValueId>,
619 pub reduction_dim: Option<ValueId>,
621 pub reduction_flavor: Option<ReductionFlavor>,
623 pub reduction_axes: Option<ReductionAxes>,
625 pub pattern: Option<FusionPattern>,
626}
627
628#[derive(Debug, Clone)]
629pub struct FusionStoreMaterialization {
630 pub value_id: ValueId,
631 pub binding: VarBinding,
632}
633
634#[derive(Debug, Clone)]
635pub enum FusionOp {
636 Primitive {
637 op: PrimitiveOp,
638 inputs: Vec<ValueId>,
639 output: Option<ValueId>,
640 },
641 Builtin {
642 name: String,
643 inputs: Vec<ValueId>,
644 output: Option<ValueId>,
645 },
646}
647
648#[derive(Debug, Clone)]
649pub struct FusionKernelSpec {
650 pub kind: FusionKind,
651 pub supported: bool,
652}
653
654impl FusionKernelSpec {
655 fn new(kind: FusionKind, supported: bool) -> Self {
656 Self { kind, supported }
657 }
658}
659
660#[derive(Clone, Debug)]
661pub struct ActiveFusion {
662 pub kind: FusionKind,
663 pub span: InstrSpan,
664 pub element_count: Option<usize>,
665 pub supported: bool,
666}
667
668struct ActiveContext {
669 plan: Arc<FusionPlan>,
670 active_group: Option<usize>,
671}
672
673static PLAN_CACHE: Lazy<RwLock<HashMap<usize, Weak<FusionPlan>>>> =
674 Lazy::new(|| RwLock::new(HashMap::new()));
675
676#[cfg(not(target_arch = "wasm32"))]
677thread_local! {
678 static ACTIVE_PLAN: RefCell<Option<ActiveContext>> = const { RefCell::new(None) };
679}
680#[cfg(target_arch = "wasm32")]
681static ACTIVE_PLAN: Lazy<Mutex<Option<ActiveContext>>> = Lazy::new(|| Mutex::new(None));
682
683#[cfg(not(target_arch = "wasm32"))]
684fn with_active_context<R>(f: impl FnOnce(&mut Option<ActiveContext>) -> R) -> R {
685 ACTIVE_PLAN.with(|ctx| {
686 let mut slot = ctx.borrow_mut();
687 f(&mut slot)
688 })
689}
690
691#[cfg(target_arch = "wasm32")]
692fn with_active_context<R>(f: impl FnOnce(&mut Option<ActiveContext>) -> R) -> R {
693 let mut slot = ACTIVE_PLAN.lock().expect("active plan mutex poisoned");
694 f(&mut slot)
695}
696
697fn fusion_debug_enabled() -> bool {
698 static FLAG: OnceLock<bool> = OnceLock::new();
699 *FLAG.get_or_init(|| match std::env::var("RUNMAT_DEBUG_FUSION") {
700 Ok(v) => v == "1" || v.eq_ignore_ascii_case("true") || v.eq_ignore_ascii_case("yes"),
701 Err(_) => false,
702 })
703}
704
705pub fn prepare_fusion_plan(
706 graph: Option<&AccelGraph>,
707 groups: &[FusionGroup],
708 candidate_group_count: usize,
709) -> Option<Arc<FusionPlan>> {
710 let graph = graph?;
711 if candidate_group_count == 0 {
712 if !groups.is_empty() && fusion_debug_enabled() {
713 log::debug!(
714 "fusion plan preparation: executable bytecode fusion groups present ({}) but semantic candidate groups are absent",
715 groups.len()
716 );
717 }
718 return None;
719 }
720 if groups.is_empty() {
721 if candidate_group_count > 0 && fusion_debug_enabled() {
722 log::debug!(
723 "fusion plan preparation: semantic candidate groups present ({}) but executable bytecode fusion groups are empty",
724 candidate_group_count
725 );
726 }
727 return None;
728 }
729 let groups = sanitize_runtime_groups(graph, groups);
730 if groups.is_empty() {
731 if fusion_debug_enabled() {
732 log::debug!(
733 "fusion plan preparation: semantic-gated bytecode groups could not be reconciled against runtime accel graph nodes"
734 );
735 }
736 return None;
737 }
738 let key = graph as *const AccelGraph as usize;
739 if let Some(plan) = PLAN_CACHE
740 .read()
741 .ok()
742 .and_then(|guard| guard.get(&key).and_then(|weak| weak.upgrade()))
743 {
744 return Some(plan);
745 }
746
747 let plan = FusionPlan::from_graph(graph, &groups);
748 let plan = Arc::new(plan);
749 if let Ok(mut guard) = PLAN_CACHE.write() {
750 guard.insert(key, Arc::downgrade(&plan));
751 }
752 Some(plan)
753}
754
755fn sanitize_runtime_groups(graph: &AccelGraph, groups: &[FusionGroup]) -> Vec<FusionGroup> {
756 groups
757 .iter()
758 .filter_map(|group| {
759 let had_explicit_mapped_nodes = !group.nodes.is_empty();
760 let mut sanitized = group.clone();
761 sanitized.nodes.retain(|id| {
762 graph
763 .node(*id)
764 .map(|node| {
765 node_matches_runtime_group_kind(graph, node, &sanitized.kind)
766 && node_within_group_span(node, &sanitized.span)
767 })
768 .unwrap_or(false)
769 });
770 if sanitized.nodes.is_empty() && !had_explicit_mapped_nodes {
771 sanitized.nodes = graph
772 .nodes
773 .iter()
774 .filter(|node| {
775 node_matches_runtime_group_kind(graph, node, &sanitized.kind)
776 && node_within_group_span(node, &sanitized.span)
777 })
778 .map(|node| node.id)
779 .collect();
780 }
781 sanitized.nodes.sort_unstable_by_key(|node_id| {
782 graph
783 .node(*node_id)
784 .map(|node| (node.span.start, node.span.end, node.id))
785 .unwrap_or((usize::MAX, usize::MAX, *node_id))
786 });
787 sanitized.nodes.dedup();
788 if sanitized.nodes.is_empty() {
789 None
790 } else {
791 Some(sanitized)
792 }
793 })
794 .collect()
795}
796
797fn node_matches_runtime_group_kind(
798 graph: &AccelGraph,
799 node: &AccelNode,
800 kind: &FusionKind,
801) -> bool {
802 match kind {
803 FusionKind::ElementwiseChain => {
804 node.is_elementwise()
805 || node.category == AccelOpCategory::Transpose
806 || is_elementwise_max_min(graph, node)
807 }
808 FusionKind::Reduction => node.is_reduction(),
809 FusionKind::MatmulEpilogue => {
810 node.category == AccelOpCategory::MatMul
811 || node.is_elementwise()
812 || node.category == AccelOpCategory::Transpose
813 }
814 FusionKind::CenteredGram
815 | FusionKind::ImageNormalize
816 | FusionKind::PowerStepNormalize
817 | FusionKind::ExplainedVariance => true,
818 }
819}
820
821fn node_within_group_span(node: &AccelNode, span: &InstrSpan) -> bool {
822 node.span.start >= span.start && node.span.end <= span.end
823}
824
825pub fn activate_fusion_plan(plan: Option<Arc<FusionPlan>>) {
826 with_active_context(|slot| {
827 *slot = plan.map(|plan| ActiveContext {
828 plan,
829 active_group: None,
830 });
831 });
832}
833
834pub fn deactivate_fusion_plan() {
835 with_active_context(|slot| {
836 slot.take();
837 });
838}
839
840pub fn set_current_pc(pc: usize) {
841 with_active_context(|slot| {
842 if let Some(context) = slot.as_mut() {
843 context.active_group = context.plan.group_for_pc(pc);
844 }
845 });
846}
847
848pub fn active_fusion() -> Option<ActiveFusion> {
849 with_active_context(|slot| {
850 slot.as_ref()
851 .and_then(|context| {
852 context
853 .active_group
854 .and_then(|idx| context.plan.groups.get(idx))
855 })
856 .map(|plan| ActiveFusion {
857 kind: plan.group.kind.clone(),
858 span: plan.group.span.clone(),
859 element_count: plan.element_count(),
860 supported: plan.kernel.supported,
861 })
862 })
863}
864
865pub fn active_group_plan_clone() -> Option<FusionGroupPlan> {
866 with_active_context(|slot| {
867 slot.as_ref().and_then(|context| {
868 context
869 .active_group
870 .and_then(|idx| context.plan.groups.get(idx).cloned())
871 })
872 })
873}
874
875impl FusionPlan {
876 pub fn from_graph(graph: &AccelGraph, groups: &[FusionGroup]) -> Self {
877 let plans = groups
878 .iter()
879 .enumerate()
880 .map(|(idx, group)| FusionGroupPlan::new(idx, group.clone(), graph))
881 .collect();
882 Self { groups: plans }
883 }
884
885 fn group_for_pc(&self, pc: usize) -> Option<usize> {
886 self.groups
887 .iter()
888 .find(|plan| pc >= plan.group.span.start && pc <= plan.group.span.end)
889 .map(|plan| plan.index)
890 }
891}
892
893impl From<Vec<FusionGroupPlan>> for FusionPlan {
894 fn from(groups: Vec<FusionGroupPlan>) -> Self {
895 Self { groups }
896 }
897}
898
899fn log_plan_stack_pattern(stage: &str, plan: &FusionGroupPlan, graph: &AccelGraph) {
900 if !fusion_debug_enabled() || plan.stack_pattern.is_empty() {
901 return;
902 }
903 let mut pattern_meta: Vec<String> = Vec::with_capacity(plan.stack_pattern.len());
904 for (pos, input_idx) in plan.stack_pattern.iter().enumerate() {
905 let value_id = plan.inputs.get(*input_idx).copied();
906 if let Some(vid) = value_id {
907 if let Some(info) = graph.value(vid) {
908 let node_label = match info.origin {
909 ValueOrigin::NodeOutput { node, .. } => graph
910 .node(node)
911 .map(|n| format!("{:?}", n.label))
912 .unwrap_or_else(|| "<missing-node>".to_string()),
913 _ => String::new(),
914 };
915 pattern_meta.push(format!(
916 "#{}:input_idx={} vid={} origin={:?} label={}",
917 pos, input_idx, vid, info.origin, node_label
918 ));
919 } else {
920 pattern_meta.push(format!(
921 "#{}:input_idx={} vid={} origin=<missing>",
922 pos, input_idx, vid
923 ));
924 }
925 } else {
926 pattern_meta.push(format!("#{}:input_idx={} vid=<missing>", pos, input_idx));
927 }
928 }
929 log::trace!(
930 "fusion plan {} {} stack_pattern={:?} meta={:?}",
931 plan.index,
932 stage,
933 plan.stack_pattern,
934 pattern_meta
935 );
936}
937
938impl FusionGroupPlan {
939 fn new(index: usize, group: FusionGroup, graph: &AccelGraph) -> Self {
940 let node_set: HashSet<NodeId> = group.nodes.iter().copied().collect();
941 let mut seen_inputs: HashMap<ValueId, usize> = HashMap::new();
942 let mut inputs: Vec<ValueId> = Vec::new();
943 let mut stack_pattern: Vec<usize> = Vec::new();
944 let mut constants: HashMap<usize, Value> = HashMap::new();
945 let const_values: HashMap<ValueId, Value> = HashMap::new();
946 let mut operations = Vec::new();
947 let mut reduction_flavor: Option<ReductionFlavor> = None;
948 let mut reduction_axes: Option<ReductionAxes> = None;
949 let mut reduction_data: Option<ValueId> = None;
950 let mut reduction_dim: Option<ValueId> = None;
951 let mut output: Option<ValueId> = None;
952
953 let is_reduction_group = group.kind.is_reduction();
954 for node_id in &group.nodes {
955 let Some(node) = graph.node(*node_id) else {
956 continue;
957 };
958 for input in &node.inputs {
959 let binding = graph.var_binding(*input);
960 let (external, is_variable, maybe_constant) = match graph.value(*input) {
961 Some(info) => match &info.origin {
962 ValueOrigin::NodeOutput { node: origin, .. }
963 if node_set.contains(origin) =>
964 {
965 (false, false, None)
966 }
967 ValueOrigin::Variable { .. } => (true, true, None),
968 ValueOrigin::NodeOutput { .. } if binding.is_some() => (true, true, None),
969 ValueOrigin::Constant => (true, false, info.constant.clone()),
970 _ => (true, false, None),
971 },
972 None => (true, false, None),
973 };
974 if external {
975 if is_reduction_group {
978 if let Some(constant) = maybe_constant.clone() {
979 let key = constants.len() + 1000;
981 constants.insert(key, constant);
982 continue;
983 }
984 if let Some(data_id) = reduction_data {
986 if *input != data_id {
987 continue;
989 }
990 }
991 }
992
993 let mut newly_added = false;
994 let input_idx = if let Some(idx) = seen_inputs.get(input) {
995 *idx
996 } else {
997 let idx = inputs.len();
998 inputs.push(*input);
999 seen_inputs.insert(*input, idx);
1000 newly_added = true;
1001 idx
1002 };
1003
1004 if fusion_debug_enabled() {
1005 let origin = graph.value(*input).map(|v| v.origin.clone());
1006 log::trace!(
1007 "fusion plan #{:?} consider input vid={} origin={:?} binding={:?} newly_added={} is_variable={} stack_candidate={}",
1008 index,
1009 input,
1010 origin,
1011 binding,
1012 newly_added,
1013 is_variable,
1014 !is_variable && newly_added
1015 );
1016 }
1017 if let Some(constant) = maybe_constant.clone() {
1018 constants.insert(input_idx, constant);
1019 } else if !is_variable && newly_added {
1020 let allow_stack = match graph.value(*input) {
1021 Some(info) => match info.origin {
1022 ValueOrigin::NodeOutput { node, .. } => graph
1023 .node(node)
1024 .map(|n| n.span.start <= group.span.start)
1025 .unwrap_or(false),
1026 _ => true,
1027 },
1028 None => true,
1029 };
1030 if allow_stack {
1031 stack_pattern.push(input_idx);
1032 } else if fusion_debug_enabled() {
1033 log::trace!(
1034 "fusion plan {} skipping stack candidate vid={} origin_after_span",
1035 index,
1036 input
1037 );
1038 }
1039 } else if !is_variable
1040 && !newly_added
1041 && matches!(
1042 graph.value(*input).map(|v| &v.origin),
1043 Some(ValueOrigin::Constant)
1044 )
1045 {
1046 }
1047 }
1048 }
1049
1050 let op = match &node.label {
1051 AccelNodeLabel::Primitive(p) => FusionOp::Primitive {
1052 op: *p,
1053 inputs: node.inputs.clone(),
1054 output: node.outputs.first().copied(),
1055 },
1056 AccelNodeLabel::Builtin { name } => FusionOp::Builtin {
1057 name: name.clone(),
1058 inputs: node.inputs.clone(),
1059 output: node.outputs.first().copied(),
1060 },
1061 AccelNodeLabel::Unknown => FusionOp::Primitive {
1062 op: PrimitiveOp::UPlus,
1063 inputs: node.inputs.clone(),
1064 output: node.outputs.first().copied(),
1065 },
1066 };
1067 operations.push(op);
1068
1069 if let Some(out) = node.outputs.first().copied() {
1070 output = Some(out);
1071 }
1072 if node.is_reduction() {
1074 if let Some(sig) = detect_reduction_signature(graph, node) {
1075 reduction_data = Some(sig.data_input);
1076 reduction_dim = sig.dim_arg;
1077 reduction_flavor = Some(match sig.behavior {
1078 ReductionBehavior::MeanLike => ReductionFlavor::Mean,
1079 _ => ReductionFlavor::Sum,
1080 });
1081 reduction_axes = Some(sig.axes.clone());
1082 }
1083 }
1084 }
1085
1086 let kind = group.kind.clone();
1087 let pattern = group.pattern.clone();
1088 let mut plan = Self {
1089 index,
1090 group,
1091 operations,
1092 stack_pattern,
1093 constants,
1094 const_values,
1095 materialized_stores: Vec::new(),
1096 inputs,
1097 output,
1098 kernel: FusionKernelSpec::new(kind, true),
1099 reduction_data,
1100 reduction_dim,
1101 reduction_flavor,
1102 reduction_axes,
1103 pattern,
1104 };
1105
1106 log_plan_stack_pattern("initial", &plan, graph);
1107
1108 for node_id in &plan.group.nodes {
1110 if let Some(node) = graph.node(*node_id) {
1111 for &inp in &node.inputs {
1112 if let Some(info) = graph.value(inp) {
1113 if let Some(cv) = info.constant.clone() {
1114 plan.const_values.insert(inp, cv);
1115 }
1116 }
1117 }
1118 }
1119 }
1120
1121 if plan.group.kind.is_reduction() {
1123 if let Some(data_vid) = plan.reduction_data {
1124 let original_inputs = plan.inputs.clone();
1125 let original_stack_pattern = plan.stack_pattern.clone();
1126 let mut prod: HashMap<ValueId, Vec<ValueId>> = HashMap::new();
1129 for op in &plan.operations {
1130 match op {
1131 FusionOp::Primitive {
1132 inputs,
1133 output,
1134 op: _,
1135 } => {
1136 if let Some(out) = output {
1137 prod.insert(*out, inputs.clone());
1138 }
1139 }
1140 FusionOp::Builtin {
1141 name: _,
1142 inputs,
1143 output,
1144 } => {
1145 if let Some(out) = output {
1146 prod.insert(*out, inputs.clone());
1147 }
1148 }
1149 }
1150 }
1151 let mut deps: Vec<ValueId> = Vec::new();
1152 let mut visited: HashSet<ValueId> = HashSet::new();
1153 let mut stack: Vec<ValueId> = vec![data_vid];
1154 let mut extra_ops: Vec<FusionOp> = Vec::new();
1156 let mut added_nodes: HashSet<ValueId> = HashSet::new();
1157 while let Some(cur) = stack.pop() {
1158 if !visited.insert(cur) {
1159 continue;
1160 }
1161 if graph.var_binding(cur).is_some() {
1162 if !deps.contains(&cur) {
1163 deps.push(cur);
1164 }
1165 continue;
1166 }
1167 if let Some(info) = graph.value(cur) {
1168 if matches!(info.origin, ValueOrigin::Variable { .. }) {
1169 if !deps.contains(&cur) {
1170 deps.push(cur);
1171 }
1172 continue;
1173 }
1174 }
1175 if original_inputs.contains(&cur) && cur != data_vid {
1177 if !deps.contains(&cur) {
1178 deps.push(cur);
1179 }
1180 continue;
1181 }
1182 if let Some(parents) = prod.get(&cur) {
1183 for p in parents {
1184 stack.push(*p);
1185 }
1186 continue;
1187 }
1188 if let Some((_, node)) = node_from_value(graph, cur) {
1190 match &node.label {
1192 AccelNodeLabel::Primitive(PrimitiveOp::Mul)
1193 | AccelNodeLabel::Primitive(PrimitiveOp::ElemMul)
1194 | AccelNodeLabel::Primitive(PrimitiveOp::ElemDiv)
1195 | AccelNodeLabel::Primitive(PrimitiveOp::ElemLeftDiv)
1196 | AccelNodeLabel::Primitive(PrimitiveOp::Add)
1197 | AccelNodeLabel::Primitive(PrimitiveOp::Sub) => {
1198 if added_nodes.insert(cur) {
1200 extra_ops.push(FusionOp::Primitive {
1201 op: match node.label {
1202 AccelNodeLabel::Primitive(op) => op,
1203 _ => PrimitiveOp::UPlus,
1204 },
1205 inputs: node.inputs.clone(),
1206 output: node.outputs.first().copied(),
1207 });
1208 }
1209 for &p in &node.inputs {
1210 stack.push(p);
1211 }
1212 continue;
1213 }
1214 AccelNodeLabel::Primitive(PrimitiveOp::ElemPow) => {
1215 if node.inputs.len() == 2 {
1217 if let Some(exp) = value_constant_f64(graph, node.inputs[1]) {
1218 if exp.is_finite() {
1219 if added_nodes.insert(cur) {
1220 extra_ops.push(FusionOp::Primitive {
1221 op: PrimitiveOp::ElemPow,
1222 inputs: node.inputs.clone(),
1223 output: node.outputs.first().copied(),
1224 });
1225 }
1226 stack.push(node.inputs[0]);
1227 stack.push(node.inputs[1]);
1229 continue;
1230 }
1231 }
1232 }
1233 }
1235 AccelNodeLabel::Builtin { name } => {
1236 if (name.eq_ignore_ascii_case("single")
1238 || name.eq_ignore_ascii_case("double"))
1239 && node.inputs.len() == 1
1240 {
1241 stack.push(node.inputs[0]);
1242 continue;
1243 }
1244 }
1246 _ => {
1247 }
1249 }
1250 }
1251 }
1252 if let Some(parents) = prod.get(&data_vid) {
1254 for &p in parents {
1255 if !deps.contains(&p) {
1256 let is_const = plan.const_values.contains_key(&p)
1258 || graph.value(p).and_then(|vi| vi.constant.as_ref()).is_some();
1259 if !is_const {
1260 deps.push(p);
1261 }
1262 }
1263 }
1264 }
1265 if !extra_ops.is_empty() {
1268 let mut new_ops = Vec::with_capacity(extra_ops.len() + plan.operations.len());
1270 new_ops.extend(extra_ops);
1271 new_ops.append(&mut plan.operations);
1272 plan.operations = new_ops;
1273 }
1274 plan.inputs = deps;
1275 for op in &plan.operations {
1277 let inputs = match op {
1278 FusionOp::Primitive { inputs, .. } => inputs,
1279 FusionOp::Builtin { inputs, .. } => inputs,
1280 };
1281 for vid in inputs {
1282 if plan.const_values.contains_key(vid) {
1283 continue;
1284 }
1285 if let Some(info) = graph.value(*vid) {
1286 if let Some(cv) = info.constant.clone() {
1287 plan.const_values.insert(*vid, cv);
1288 }
1289 }
1290 }
1291 }
1292
1293 let mut new_stack_pattern: Vec<usize> = Vec::new();
1296 for (new_idx, vid) in plan.inputs.iter().enumerate() {
1297 if let Some(old_idx) = original_inputs.iter().position(|v| v == vid) {
1298 if original_stack_pattern.contains(&old_idx) {
1299 new_stack_pattern.push(new_idx);
1300 }
1301 }
1302 }
1303
1304 let mut new_constants: HashMap<usize, Value> = HashMap::new();
1306 for (idx, vid) in plan.inputs.iter().enumerate() {
1307 if let Some(value) = plan.const_values.get(vid) {
1308 new_constants.insert(idx, value.clone());
1309 } else if let Some(info) = graph.value(*vid) {
1310 if let Some(cv) = info.constant.clone() {
1311 new_constants.insert(idx, cv);
1312 }
1313 }
1314 }
1315 plan.constants = new_constants;
1316
1317 if new_stack_pattern.is_empty() {
1318 for (idx, vid) in plan.inputs.iter().enumerate() {
1319 if plan.constants.contains_key(&idx) {
1320 continue;
1321 }
1322 if let Some(info) = graph.value(*vid) {
1323 if matches!(
1324 info.origin,
1325 ValueOrigin::Variable { .. } | ValueOrigin::Constant
1326 ) {
1327 continue;
1328 }
1329 }
1330 new_stack_pattern.push(idx);
1331 }
1332 }
1333 plan.stack_pattern = new_stack_pattern;
1334 }
1335 }
1336
1337 if plan.group.kind.is_reduction() {
1339 let original_inputs = plan.inputs.clone();
1340 plan.inputs.retain(|vid| {
1341 if let Some(info) = graph.value(*vid) {
1342 !matches!(info.origin, ValueOrigin::Constant)
1343 && !plan.const_values.contains_key(vid)
1344 } else {
1345 true
1346 }
1347 });
1348 if plan.inputs.len() != original_inputs.len() {
1349 let mut new_stack: Vec<usize> = Vec::new();
1350 for old_idx in &plan.stack_pattern {
1351 if *old_idx < original_inputs.len() {
1352 let vid = original_inputs[*old_idx];
1353 if let Some(new_idx) = plan.inputs.iter().position(|v| *v == vid) {
1354 new_stack.push(new_idx);
1355 }
1356 }
1357 }
1358 plan.stack_pattern = new_stack;
1359 }
1360 }
1361
1362 let supported = if plan.kernel.kind.is_elementwise() {
1367 if scalar_shape_known_one(&plan.group.shape) {
1371 false
1372 } else {
1373 plan.generate_wgsl("f32").is_some()
1374 }
1375 } else if plan.kernel.kind.is_reduction() {
1376 plan.generate_reduction_wgsl("f32").is_some()
1377 } else {
1378 true
1379 };
1380 plan.kernel.supported = plan.kernel.supported && supported;
1381 if !plan.kernel.supported && fusion_debug_enabled() {
1382 let const_ids: Vec<ValueId> = plan.const_values.keys().copied().collect();
1383 log::debug!(
1384 "fusion plan {} unsupported: kind={:?} group_kind={:?} inputs={:?} reduction_data={:?} reduction_dim={:?} const_ids={:?}",
1385 plan.index,
1386 plan.kernel.kind,
1387 plan.group.kind,
1388 plan.inputs,
1389 plan.reduction_data,
1390 plan.reduction_dim,
1391 const_ids
1392 );
1393 if plan.kernel.kind.is_reduction() {
1394 let mut seen: HashSet<ValueId> = HashSet::new();
1395 let mut value_info: Vec<String> = Vec::new();
1396 for op in &plan.operations {
1397 let inputs = match op {
1398 FusionOp::Primitive { inputs, .. } => inputs,
1399 FusionOp::Builtin { inputs, .. } => inputs,
1400 };
1401 for vid in inputs {
1402 if seen.insert(*vid) {
1403 if let Some(info) = graph.value(*vid) {
1404 value_info.push(format!(
1405 "vid={} origin={:?} constant={}",
1406 vid,
1407 info.origin,
1408 info.constant.is_some()
1409 ));
1410 } else {
1411 value_info.push(format!("vid={} origin=<missing>", vid));
1412 }
1413 }
1414 }
1415 }
1416 log::debug!(
1417 "fusion reduction plan {} value summary: [{}]",
1418 plan.index,
1419 value_info.join(", ")
1420 );
1421 }
1422 }
1423
1424 if matches!(plan.group.kind, FusionKind::CenteredGram) && plan.stack_pattern.is_empty() {
1425 let mut centered_stack_idxs: Vec<usize> = Vec::new();
1426 for (idx, vid) in plan.inputs.iter().enumerate() {
1427 if plan.constants.contains_key(&idx) {
1428 continue;
1429 }
1430 if let Some(info) = graph.value(*vid) {
1431 if matches!(info.origin, ValueOrigin::NodeOutput { .. }) {
1432 centered_stack_idxs.push(idx);
1433 continue;
1434 }
1435 if matches!(info.origin, ValueOrigin::Variable { .. }) {
1436 continue;
1437 }
1438 }
1439 centered_stack_idxs.push(idx);
1440 }
1441 if centered_stack_idxs.is_empty() && !plan.inputs.is_empty() {
1442 centered_stack_idxs.push(0);
1443 }
1444 plan.stack_pattern = centered_stack_idxs;
1445 }
1446
1447 if !plan.stack_pattern.is_empty() || plan.group.stack_layout.is_some() {
1448 plan.group.stack_layout = merge_stack_layout_with_stack_pattern(
1449 plan.group.stack_layout.as_ref(),
1450 &plan.inputs,
1451 &plan.stack_pattern,
1452 );
1453 }
1454
1455 if plan.group.kind.is_elementwise() {
1456 let mut stores = Vec::new();
1457 for op in &plan.operations {
1458 let output = match op {
1459 FusionOp::Primitive { output, .. } => *output,
1460 FusionOp::Builtin { output, .. } => *output,
1461 };
1462 let Some(value_id) = output else {
1463 continue;
1464 };
1465 let Some(binding) = graph.var_binding(value_id).cloned() else {
1466 continue;
1467 };
1468 stores.push(FusionStoreMaterialization { value_id, binding });
1469 }
1470 plan.materialized_stores = stores;
1471 }
1472
1473 log_plan_stack_pattern("final", &plan, graph);
1474
1475 plan
1478 }
1479
1480 pub fn reduction_data_shape(&self, graph: &AccelGraph) -> Option<Vec<usize>> {
1481 let vid = self.reduction_data?;
1482 let info = graph.value(vid)?;
1483 match &info.shape {
1484 ShapeInfo::Tensor(dims) if !dims.is_empty() && dims.iter().all(|d| d.is_some()) => {
1485 Some(dims.iter().map(|d| d.unwrap()).collect())
1486 }
1487 _ => None,
1488 }
1489 }
1490
1491 pub fn element_count(&self) -> Option<usize> {
1492 self.group.element_count()
1493 }
1494
1495 pub fn constant_shape(&self, len: usize) -> Vec<usize> {
1496 match &self.group.shape {
1497 ShapeInfo::Tensor(dims) if !dims.is_empty() && dims.iter().all(|dim| dim.is_some()) => {
1498 dims.iter().map(|dim| dim.unwrap()).collect()
1499 }
1500 _ => vec![len],
1501 }
1502 }
1503
1504 pub fn generate_wgsl(&self, scalar_ty: &str) -> Option<String> {
1505 self.generate_wgsl_for_output(self.output?, scalar_ty)
1506 }
1507
1508 fn build_wgsl_shader(
1517 &self,
1518 scalar_ty: &str,
1519 output_bindings: &str,
1520 params_binding_idx: usize,
1521 body: &str,
1522 final_writes: &str,
1523 ) -> String {
1524 let mut shader = String::new();
1525
1526 shader.push_str("const MAX_RANK: u32 = 128u;\n");
1528 shader.push_str("struct PackedValue { value: u32, _pad0: u32, _pad1: u32, _pad2: u32 };\n");
1529 shader.push_str("alias PackedArray = array<PackedValue, MAX_RANK>;\n\n");
1530 shader.push_str(&format!("struct Tensor {{ data: array<{scalar_ty}>, }};\n"));
1531
1532 shader.push_str(
1534 "struct Params {\n len: u32,\n offset: u32,\n rank: u32,\n _pad: u32,\n out_shape: PackedArray,\n",
1535 );
1536 for idx in 0..self.inputs.len() {
1537 shader.push_str(&format!(" in{}_shape: PackedArray,\n", idx));
1538 shader.push_str(&format!(" in{}_stride: PackedArray,\n", idx));
1539 }
1540 shader.push_str("}\n\n");
1541
1542 if scalar_ty == "f32" {
1550 shader.push_str("fn isNan(x: f32) -> bool { return x != x; }\n");
1551 shader.push_str("fn isFinite(x: f32) -> bool { return (x == x) && (abs(x) < 3.4028234663852886e38); }\n");
1552 shader.push_str("fn isInf(x: f32) -> bool { return (x == x) && !(abs(x) < 3.4028234663852886e38); }\n");
1553 shader.push_str(concat!(
1554 "fn hypot(a: f32, b: f32) -> f32 {\n",
1555 " let lo = min(abs(a), abs(b));\n",
1556 " let hi = max(abs(a), abs(b));\n",
1557 " if hi == 0.0 { return 0.0; }\n",
1558 " if isInf(hi) { return hi; }\n",
1559 " let r = lo / hi;\n",
1560 " return hi * sqrt(1.0 + r * r);\n",
1561 "}\n\n",
1562 ));
1563 } else {
1564 shader.push_str("fn isNan(x: f64) -> bool { return x != x; }\n");
1565 shader.push_str("fn isFinite(x: f64) -> bool { return (x == x) && (abs(x) < f64(1.7976931348623157e308)); }\n");
1566 shader.push_str("fn isInf(x: f64) -> bool { return (x == x) && !(abs(x) < f64(1.7976931348623157e308)); }\n");
1567 shader.push_str(concat!(
1568 "fn hypot(a: f64, b: f64) -> f64 {\n",
1569 " let lo = min(abs(a), abs(b));\n",
1570 " let hi = max(abs(a), abs(b));\n",
1571 " if hi == f64(0.0) { return f64(0.0); }\n",
1572 " if isInf(hi) { return hi; }\n",
1573 " let r = lo / hi;\n",
1574 " return hi * sqrt(f64(1.0) + r * r);\n",
1575 "}\n\n",
1576 ));
1577 }
1578
1579 for (idx, _) in self.inputs.iter().enumerate() {
1581 shader.push_str(&format!(
1582 "@group(0) @binding({idx}) var<storage, read> input{idx}: Tensor;\n",
1583 ));
1584 }
1585 shader.push_str(output_bindings);
1586 shader.push_str(&format!(
1587 "@group(0) @binding({params_binding_idx}) var<uniform> params: Params;\n\n",
1588 ));
1589
1590 shader.push_str(
1592 "@compute @workgroup_size(@WG@)\nfn main(@builtin(global_invocation_id) gid: vec3<u32>) {\n",
1593 );
1594 shader.push_str(" let idx = gid.x;\n if (idx >= params.len) { return; }\n");
1595 shader.push_str(" let g = idx + params.offset;\n");
1596
1597 shader.push_str(
1599 " var coord: array<u32, MAX_RANK>;\n var tmp: u32 = g;\n var d: u32 = 0u;\n loop { if d >= params.rank { break; } let dim = params.out_shape[d].value; if dim == 0u { coord[d] = 0u; } else { coord[d] = tmp % dim; tmp = tmp / dim; } d = d + 1u; }\n",
1600 );
1601
1602 for (idx, _) in self.inputs.iter().enumerate() {
1604 shader.push_str(&format!(
1605 " var i{idx}: u32 = 0u; d = 0u; loop {{ if d >= params.rank {{ break; }} let sd = params.in{idx}_shape[d].value; let st = params.in{idx}_stride[d].value; let c = select(coord[d], 0u, sd == 1u); i{idx} = i{idx} + c * st; d = d + 1u; }}\n",
1606 ));
1607 }
1608
1609 shader.push_str(body);
1610 shader.push_str(final_writes);
1611 shader.push_str("}\n");
1612 shader
1613 }
1614
1615 pub fn generate_wgsl_for_outputs(
1624 &self,
1625 output_ids: &[ValueId],
1626 scalar_ty: &str,
1627 ) -> Option<String> {
1628 if output_ids.is_empty() {
1629 return None;
1630 }
1631 if output_ids.len() == 1 {
1632 return self.generate_wgsl_for_output(output_ids[0], scalar_ty);
1633 }
1634 if !self.kernel.kind.is_elementwise() {
1635 return None;
1636 }
1637 if !self.kernel.supported {
1638 return None;
1639 }
1640
1641 let mut exprs: HashMap<ValueId, String> = HashMap::new();
1642 for (idx, input_id) in self.inputs.iter().enumerate() {
1643 exprs.insert(*input_id, format!("input{idx}.data[i{idx}]"));
1644 }
1645
1646 let mut body = String::new();
1647 for (node_idx, op) in self.operations.iter().enumerate() {
1648 let tmp_name = format!("tmp{node_idx}");
1649 match op {
1650 FusionOp::Primitive { op, inputs, output } => {
1651 let expr = primitive_expr(*op, inputs, &exprs)?;
1652 body.push_str(&format!(" let {tmp_name}: {scalar_ty} = {expr};\n"));
1653 if let Some(out) = output {
1654 exprs.insert(*out, tmp_name.clone());
1655 }
1656 }
1657 FusionOp::Builtin {
1658 name,
1659 inputs,
1660 output,
1661 } => {
1662 let expr = builtin_expr(name, inputs, &exprs, scalar_ty)?;
1663 body.push_str(&format!(" let {tmp_name}: {scalar_ty} = {expr};\n"));
1664 if let Some(out) = output {
1665 exprs.insert(*out, tmp_name.clone());
1666 }
1667 }
1668 }
1669 }
1670
1671 let mut final_exprs = Vec::with_capacity(output_ids.len());
1672 for output_id in output_ids {
1673 final_exprs.push(exprs.get(output_id)?.clone());
1674 }
1675
1676 let num_outputs = output_ids.len();
1677 let n_inputs = self.inputs.len();
1678
1679 let mut output_bindings = String::new();
1680 for k in 0..num_outputs {
1681 output_bindings.push_str(&format!(
1682 "@group(0) @binding({}) var<storage, read_write> output{k}: Tensor;\n",
1683 n_inputs + k,
1684 ));
1685 }
1686
1687 let mut final_writes = String::new();
1688 for (k, expr) in final_exprs.iter().enumerate() {
1689 final_writes.push_str(&format!(" output{k}.data[g] = {expr};\n"));
1690 }
1691
1692 Some(self.build_wgsl_shader(
1693 scalar_ty,
1694 &output_bindings,
1695 n_inputs + num_outputs,
1696 &body,
1697 &final_writes,
1698 ))
1699 }
1700
1701 pub fn generate_wgsl_for_output(&self, output_id: ValueId, scalar_ty: &str) -> Option<String> {
1702 if !self.kernel.kind.is_elementwise() {
1703 return None;
1704 }
1705 if !self.kernel.supported {
1706 return None;
1707 }
1708
1709 let mut exprs: HashMap<ValueId, String> = HashMap::new();
1710 for (idx, input_id) in self.inputs.iter().enumerate() {
1711 exprs.insert(*input_id, format!("input{idx}.data[i{idx}]"));
1713 }
1714
1715 let mut body = String::new();
1716 for (node_idx, op) in self.operations.iter().enumerate() {
1717 let tmp_name = format!("tmp{node_idx}");
1718 match op {
1719 FusionOp::Primitive { op, inputs, output } => {
1720 let expr = primitive_expr(*op, inputs, &exprs)?;
1721 body.push_str(&format!(" let {tmp_name}: {scalar_ty} = {expr};\n"));
1722 if let Some(out) = output {
1723 exprs.insert(*out, tmp_name.clone());
1724 }
1725 }
1726 FusionOp::Builtin {
1727 name,
1728 inputs,
1729 output,
1730 } => {
1731 let expr = builtin_expr(name, inputs, &exprs, scalar_ty)?;
1732 body.push_str(&format!(" let {tmp_name}: {scalar_ty} = {expr};\n"));
1733 if let Some(out) = output {
1734 exprs.insert(*out, tmp_name.clone());
1735 }
1736 }
1737 }
1738 }
1739
1740 let final_expr = exprs.get(&output_id)?.clone();
1741 let n_inputs = self.inputs.len();
1742
1743 let output_bindings =
1744 format!("@group(0) @binding({n_inputs}) var<storage, read_write> output: Tensor;\n",);
1745 let final_writes = format!(" output.data[g] = {final_expr};\n");
1746
1747 Some(self.build_wgsl_shader(
1748 scalar_ty,
1749 &output_bindings,
1750 n_inputs + 1,
1751 &body,
1752 &final_writes,
1753 ))
1754 }
1755
1756 pub fn generate_reduction_wgsl(&self, scalar_ty: &str) -> Option<String> {
1757 if !self.kernel.kind.is_reduction() {
1758 return None;
1759 }
1760 if self.inputs.is_empty() {
1763 return None;
1764 }
1765 let mut axis = 0usize;
1768 let reduce_all = self
1770 .constants
1771 .values()
1772 .any(|v| matches!(v, Value::String(s) if s.eq_ignore_ascii_case("all")))
1773 || self
1774 .const_values
1775 .values()
1776 .any(|v| matches!(v, Value::String(s) if s.eq_ignore_ascii_case("all")));
1777 if reduce_all {
1778 axis = 0;
1780 } else if let Some(dim_vid) = self.reduction_dim {
1781 if let Some(v) = self.const_values.get(&dim_vid) {
1782 match v {
1783 Value::Num(n) if *n >= 1.0 => {
1784 axis = (*n as usize).saturating_sub(1);
1785 }
1786 Value::Int(i) => {
1787 let val = i.to_f64();
1788 if val >= 1.0 {
1789 axis = (val as usize).saturating_sub(1);
1790 }
1791 }
1792 _ => {}
1793 }
1794 }
1795 } else {
1796 for v in self.constants.values() {
1798 match v {
1799 Value::Num(n) if *n >= 1.0 => {
1800 axis = (*n as usize).saturating_sub(1);
1801 break;
1802 }
1803 Value::Int(i) => {
1804 let val = i.to_f64();
1805 if val >= 1.0 {
1806 axis = (val as usize).saturating_sub(1);
1807 break;
1808 }
1809 }
1810 _ => {}
1811 }
1812 }
1813 }
1814
1815 let omitnan = self.constants.values().any(|v| match v {
1817 Value::String(s) => s.eq_ignore_ascii_case("omitnan"),
1818 _ => false,
1819 });
1820
1821 let data_vid = self.reduction_data?;
1823 let ext_input = self.inputs[0];
1824 let mut exprs: HashMap<ValueId, String> = HashMap::new();
1825 exprs.insert(ext_input, "v".to_string());
1826 for (idx, &vid) in self.inputs.iter().enumerate().skip(1) {
1828 exprs.insert(vid, format!("v{idx}"));
1829 }
1830 for (vid, val) in &self.const_values {
1831 let lit = match val {
1832 Value::Num(n) => {
1833 if scalar_ty == "f64" {
1834 format!("f64({})", n)
1835 } else {
1836 format!("{:?}", *n as f32)
1837 }
1838 }
1839 Value::Int(i) => {
1840 let f = i.to_f64();
1841 if scalar_ty == "f64" {
1842 format!("f64({})", f)
1843 } else {
1844 format!("{:?}", f as f32)
1845 }
1846 }
1847 Value::Tensor(t) if t.data.len() == 1 => {
1848 let scalar = t.data[0];
1849 if scalar_ty == "f64" {
1850 format!("f64({})", scalar)
1851 } else {
1852 format!("{:?}", scalar as f32)
1853 }
1854 }
1855 _ => {
1856 if scalar_ty == "f64" {
1857 "f64(0.0)".to_string()
1858 } else {
1859 "0.0".to_string()
1860 }
1861 }
1862 };
1863 exprs.insert(*vid, lit);
1864 }
1865 let mut progressed = true;
1866 while progressed {
1867 progressed = false;
1868 for op in &self.operations {
1869 match op {
1870 FusionOp::Primitive { op, inputs, output } => {
1871 if let Some(out) = output {
1872 if exprs.contains_key(out) {
1873 continue;
1874 }
1875 if let Some(code) = primitive_expr(*op, inputs, &exprs) {
1876 exprs.insert(*out, code);
1877 progressed = true;
1878 }
1879 }
1880 }
1881 FusionOp::Builtin {
1882 name,
1883 inputs,
1884 output,
1885 } => {
1886 if let Some(out) = output {
1887 if exprs.contains_key(out) {
1888 continue;
1889 }
1890 if let Some(code) = builtin_expr(name, inputs, &exprs, scalar_ty) {
1891 exprs.insert(*out, code);
1892 progressed = true;
1893 }
1894 }
1895 }
1896 }
1897 }
1898 if exprs.contains_key(&data_vid) {
1899 break;
1900 }
1901 }
1902 let val_expr = match exprs.get(&data_vid) {
1904 Some(s) => s.clone(),
1905 None => {
1906 if fusion_debug_enabled() {
1907 let expr_keys: Vec<ValueId> = exprs.keys().copied().collect();
1908 log::debug!(
1909 "fusion reduction WGSL: missing expression for data {:?}; inputs={:?} expr_keys={:?} ops={:?}",
1910 data_vid,
1911 self.inputs,
1912 expr_keys,
1913 self.operations
1914 );
1915 }
1916 return None;
1917 }
1918 };
1919
1920 let mut shader = String::new();
1921 shader.push_str(&format!("struct Tensor {{ data: array<{scalar_ty}>, }};\n"));
1922 shader.push_str("struct MParams { nrows: u32, ncols: u32, ld: u32, flags: u32 }\n\n");
1923 for (idx, _) in self.inputs.iter().enumerate() {
1925 shader.push_str(&format!(
1926 "@group(0) @binding({}) var<storage, read> input{}: Tensor;\n",
1927 idx, idx
1928 ));
1929 }
1930 shader.push_str(&format!(
1931 "@group(0) @binding({}) var<storage, read_write> output: Tensor;\n",
1932 self.inputs.len()
1933 ));
1934 shader.push_str(&format!(
1935 "@group(0) @binding({}) var<uniform> params: MParams;\n\n",
1936 self.inputs.len() + 1
1937 ));
1938 shader.push_str(&format!(
1940 "var<workgroup> tile: array<{scalar_ty}, @WG@u>;\n\n"
1941 ));
1942 shader.push_str(&format!(
1943 "const OMITNAN: bool = {};\n\n",
1944 if omitnan { "true" } else { "false" }
1945 ));
1946 let is_mean = matches!(self.reduction_flavor, Some(ReductionFlavor::Mean));
1948 let post_scale = if is_mean {
1949 let dim = if axis == 0 {
1950 "params.nrows"
1951 } else {
1952 "params.ncols"
1953 };
1954 if scalar_ty == "f64" {
1955 format!("(1.0 / f64(f32({dim})))")
1956 } else {
1957 format!("(1.0 / f32({dim}))")
1958 }
1959 } else if scalar_ty == "f64" {
1960 "f64(1.0)".to_string()
1961 } else {
1962 "1.0".to_string()
1963 };
1964 shader.push_str(&format!(
1966 "fn isNanF(x: {scalar}) -> bool {{ return x != x; }}\n",
1967 scalar = scalar_ty
1968 ));
1969 if scalar_ty == "f64" {
1970 shader.push_str("fn canonicalNan() -> f64 {\n var bits: u64 = 0x7ff8000000000000u;\n return bitcast<f64>(bits);\n}\n\n");
1971 } else {
1972 shader.push_str("fn canonicalNan() -> f32 {\n var bits: u32 = 0x7fc00000u;\n return bitcast<f32>(bits);\n}\n\n");
1973 }
1974 shader.push_str("@compute @workgroup_size(@WG@)\n");
1975 if axis == 0 {
1976 shader.push_str(
1978 "fn main(@builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {\n",
1979 );
1980 shader.push_str(" let col = wid.x;\n if (col >= params.ncols) { return; }\n");
1981 shader.push_str(&format!(
1982 " var acc: {scalar_ty} = {}0.0;\n",
1983 if scalar_ty == "f64" { "f64(" } else { "" }
1984 ));
1985 if scalar_ty == "f64" {
1986 shader.push_str(" // close cast for f64 literal\n");
1987 }
1988 shader.push_str(" var saw_nan: bool = false;\n var r = lid.x;\n");
1990 {
1992 let mut loop_body = String::new();
1994 loop_body.push_str(" let v = input0.data[ (col * params.nrows) + r ];\n");
1996 for (idx, _) in self.inputs.iter().enumerate().skip(1) {
1998 loop_body.push_str(&format!(
1999 " let v{idx} = input{idx}.data[ (col * params.nrows) + r ];\n"
2000 ));
2001 }
2002 loop_body.push_str(&format!(
2004 " let val: {scalar} = {val};\n if (OMITNAN) {{ if (!isNanF(val)) {{ acc = acc + val; }} }} else {{ if (isNanF(val)) {{ saw_nan = true; }} else {{ acc = acc + val; }} }}\n",
2005 scalar = scalar_ty,
2006 val = val_expr
2007 ));
2008 shader.push_str(" while (r < params.nrows) {\n");
2009 shader.push_str(&loop_body);
2010 shader.push_str(" r += @WG@u;\n }\n");
2011 }
2012 shader.push_str(" if (!OMITNAN && saw_nan) { acc = canonicalNan(); }\n");
2013 shader.push_str(" tile[lid.x] = acc;\n workgroupBarrier();\n");
2014 shader.push_str(
2015 " var off = (@WG@u) / 2u;\n loop { if (off == 0u) { break; } if (lid.x < off) {\n let a = tile[lid.x]; let b = tile[lid.x + off];\n tile[lid.x] = a + b;\n } workgroupBarrier(); off = off / 2u; }\n",
2016 );
2017 shader.push_str(&format!(
2019 " if (lid.x == 0u) {{ output.data[col] = tile[0u] * {}; }}\n}}\n",
2020 post_scale
2021 ));
2022 } else {
2023 shader.push_str(
2025 "fn main(@builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {\n",
2026 );
2027 shader.push_str(" let row = wid.x;\n // For axis=1, number of output slices equals rows (params.ncols)\n if (row >= params.ncols) { return; }\n");
2028 shader.push_str(&format!(
2029 " var acc: {scalar_ty} = {}0.0;\n",
2030 if scalar_ty == "f64" { "f64(" } else { "" }
2031 ));
2032 if scalar_ty == "f64" {
2033 shader.push_str(" // close cast for f64 literal\n");
2034 }
2035 shader.push_str(" var saw_nan: bool = false;\n var c = lid.x;\n");
2037 {
2038 let mut loop_body = String::new();
2039 loop_body.push_str(" let v = input0.data[ row + (c * params.ncols) ];\n");
2041 for (idx, _) in self.inputs.iter().enumerate().skip(1) {
2043 loop_body.push_str(&format!(
2044 " let v{idx} = input{idx}.data[ row + (c * params.ncols) ];\n"
2045 ));
2046 }
2047 loop_body.push_str(&format!(
2048 " let val: {scalar} = {val};\n if (OMITNAN) {{ if (!isNanF(val)) {{ acc = acc + val; }} }} else {{ if (isNanF(val)) {{ saw_nan = true; }} else {{ acc = acc + val; }} }}\n",
2049 scalar = scalar_ty,
2050 val = val_expr
2051 ));
2052 shader.push_str(" while (c < params.nrows) {\n");
2054 shader.push_str(&loop_body);
2055 shader.push_str(" c += @WG@u;\n }\n");
2056 }
2057 shader.push_str(" if (!OMITNAN && saw_nan) { acc = canonicalNan(); }\n");
2058 shader.push_str(" tile[lid.x] = acc;\n workgroupBarrier();\n");
2059 shader.push_str(
2060 " var off = (@WG@u) / 2u;\n loop { if (off == 0u) { break; } if (lid.x < off) {\n let a = tile[lid.x]; let b = tile[lid.x + off];\n tile[lid.x] = a + b;\n } workgroupBarrier(); off = off / 2u; }\n",
2061 );
2062 shader.push_str(&format!(
2063 " if (lid.x == 0u) {{ output.data[row] = tile[0u] * {}; }}\n}}\n",
2064 post_scale
2065 ));
2066 }
2067 Some(shader)
2068 }
2069}
2070
2071impl FusionGroup {
2072 pub fn element_count(&self) -> Option<usize> {
2073 match &self.shape {
2074 ShapeInfo::Scalar => Some(1),
2075 ShapeInfo::Tensor(dims) => dims
2076 .iter()
2077 .try_fold(1usize, |acc, dim| dim.and_then(|d| acc.checked_mul(d))),
2078 ShapeInfo::Unknown => None,
2079 }
2080 }
2081}
2082
2083impl FusionKind {
2084 pub fn is_elementwise(&self) -> bool {
2085 matches!(self, FusionKind::ElementwiseChain)
2086 }
2087
2088 pub fn is_reduction(&self) -> bool {
2089 matches!(self, FusionKind::Reduction)
2090 }
2091}
2092
2093fn detect_centered_gram(
2094 graph: &AccelGraph,
2095 assigned: &mut HashSet<NodeId>,
2096 groups: &mut Vec<FusionGroup>,
2097 next_group_id: &mut usize,
2098) {
2099 for div_node in &graph.nodes {
2100 if assigned.contains(&div_node.id) {
2101 continue;
2102 }
2103 let div_op = match div_node.label {
2104 AccelNodeLabel::Primitive(op) => op,
2105 _ => continue,
2106 };
2107 if div_op != PrimitiveOp::ElemDiv {
2108 continue;
2109 }
2110 if div_node.inputs.len() != 2 {
2111 continue;
2112 }
2113 let (numerator_id, denom_id) = (div_node.inputs[0], div_node.inputs[1]);
2114 let denom_info = match graph.value(denom_id) {
2115 Some(info) => info,
2116 None => continue,
2117 };
2118 let denom_const = match &denom_info.constant {
2119 Some(Value::Num(v)) => Some(*v),
2120 Some(Value::Int(i)) => Some(i.to_f64()),
2121 _ => None,
2122 };
2123 if denom_const.is_some_and(|v| v == 0.0) {
2124 continue;
2125 }
2126
2127 let mul_node_id = match graph
2128 .value(numerator_id)
2129 .and_then(|info| match &info.origin {
2130 ValueOrigin::NodeOutput { node, .. } => Some(*node),
2131 _ => None,
2132 }) {
2133 Some(id) => id,
2134 None => continue,
2135 };
2136 if assigned.contains(&mul_node_id) {
2137 continue;
2138 }
2139 let mul_node = match graph.node(mul_node_id) {
2140 Some(node) => node,
2141 None => continue,
2142 };
2143 let mul_op = match mul_node.label {
2144 AccelNodeLabel::Primitive(op) => op,
2145 _ => continue,
2146 };
2147 if mul_op != PrimitiveOp::Mul && mul_op != PrimitiveOp::ElemMul {
2148 continue;
2149 }
2150 if mul_node.inputs.len() != 2 {
2151 continue;
2152 }
2153
2154 let mut transpose_node_id: Option<NodeId> = None;
2155 let mut centered_val_id: Option<ValueId> = None;
2156 for input_vid in &mul_node.inputs {
2157 let candidate_node_id =
2158 match graph.value(*input_vid).and_then(|info| match &info.origin {
2159 ValueOrigin::NodeOutput { node, .. } => Some(*node),
2160 _ => None,
2161 }) {
2162 Some(id) => id,
2163 None => continue,
2164 };
2165 if let Some(trans_node) = graph.node(candidate_node_id) {
2166 if matches!(
2167 trans_node.label,
2168 AccelNodeLabel::Primitive(PrimitiveOp::Transpose)
2169 ) {
2170 if let Some(centered) = trans_node.inputs.first().copied() {
2171 transpose_node_id = Some(candidate_node_id);
2172 centered_val_id = Some(centered);
2173 break;
2174 }
2175 }
2176 }
2177 }
2178
2179 let transpose_node_id = match transpose_node_id {
2180 Some(id) if !assigned.contains(&id) => id,
2181 _ => continue,
2182 };
2183 let centered_val_id = match centered_val_id {
2184 Some(id) => id,
2185 None => continue,
2186 };
2187
2188 if assigned.contains(&transpose_node_id) {
2189 continue;
2190 }
2191 if graph.node(transpose_node_id).is_none() {
2192 continue;
2193 }
2194
2195 let centered_node_id =
2196 match graph
2197 .value(centered_val_id)
2198 .and_then(|info| match &info.origin {
2199 ValueOrigin::NodeOutput { node, .. } => Some(*node),
2200 _ => None,
2201 }) {
2202 Some(id) => id,
2203 None => continue,
2204 };
2205 if assigned.contains(¢ered_node_id) {
2206 continue;
2207 }
2208 let centered_node = match graph.node(centered_node_id) {
2209 Some(node) => node,
2210 None => continue,
2211 };
2212 if !matches!(
2213 centered_node.label,
2214 AccelNodeLabel::Primitive(PrimitiveOp::Sub)
2215 ) {
2216 continue;
2217 }
2218 if centered_node.inputs.len() != 2 {
2219 continue;
2220 }
2221 let matrix_val_id = centered_node.inputs[0];
2222 let mean_val_id = centered_node.inputs[1];
2223
2224 let mean_node_id = match graph
2225 .value(mean_val_id)
2226 .and_then(|info| match &info.origin {
2227 ValueOrigin::NodeOutput { node, .. } => Some(*node),
2228 _ => None,
2229 }) {
2230 Some(id) => id,
2231 None => continue,
2232 };
2233 if assigned.contains(&mean_node_id) {
2234 continue;
2235 }
2236 let mean_node = match graph.node(mean_node_id) {
2237 Some(node) => node,
2238 None => continue,
2239 };
2240 match &mean_node.label {
2241 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mean") => {}
2242 _ => continue,
2243 }
2244 if mean_node.inputs.is_empty() || mean_node.inputs[0] != matrix_val_id {
2245 continue;
2246 }
2247
2248 let matrix_info = match graph.value(matrix_val_id) {
2249 Some(info) => info,
2250 None => continue,
2251 };
2252 let matrix_rows = match &matrix_info.shape {
2253 ShapeInfo::Tensor(dims) if !dims.is_empty() => dims[0].unwrap_or(0),
2254 _ => 0,
2255 };
2256 let normalization = if matrix_rows > 1 {
2257 if let Some(value) = denom_const {
2258 let unbiased = (matrix_rows as f64 - 1.0).max(1.0);
2259 let biased = matrix_rows as f64;
2260 if approx_eq(value, unbiased) {
2261 CovNormalization::Unbiased
2262 } else if approx_eq(value, biased) {
2263 CovNormalization::Biased
2264 } else {
2265 CovNormalization::Unbiased
2266 }
2267 } else {
2268 CovNormalization::Unbiased
2269 }
2270 } else {
2271 CovNormalization::Unbiased
2272 };
2273
2274 let mut nodes = vec![
2275 mean_node_id,
2276 centered_node_id,
2277 transpose_node_id,
2278 mul_node_id,
2279 div_node.id,
2280 ];
2281 nodes.sort_by_key(|node_id| {
2282 graph
2283 .node(*node_id)
2284 .map(|node| node.span.start)
2285 .unwrap_or(usize::MAX)
2286 });
2287 let span = group_span(graph, &nodes);
2288 let shape = node_output_shape(graph, div_node);
2289
2290 groups.push(FusionGroup {
2291 id: *next_group_id,
2292 kind: FusionKind::CenteredGram,
2293 nodes: nodes.clone(),
2294 shape,
2295 span,
2296 pattern: Some(FusionPattern::CenteredGram {
2297 matrix: matrix_val_id,
2298 normalization,
2299 }),
2300 stack_layout: None,
2301 });
2302 *next_group_id += 1;
2303 for id in nodes {
2304 assigned.insert(id);
2305 }
2306 }
2307}
2308
2309fn detect_image_normalize(
2310 graph: &AccelGraph,
2311 assigned: &mut HashSet<NodeId>,
2312 groups: &mut Vec<FusionGroup>,
2313 next_group_id: &mut usize,
2314) {
2315 for pow_node in &graph.nodes {
2316 if assigned.contains(&pow_node.id) {
2317 continue;
2318 }
2319 let Some(match_info) = analyze_image_normalize(graph, pow_node.id, assigned) else {
2320 continue;
2321 };
2322
2323 let pow_node_ref = match graph.node(pow_node.id) {
2324 Some(node) => node,
2325 None => continue,
2326 };
2327
2328 let shape = node_output_shape(graph, pow_node_ref);
2329 let span = group_span(graph, &match_info.nodes);
2330
2331 let pattern = ImageNormalizePattern {
2332 input: match_info.input,
2333 epsilon: match_info.epsilon.clone(),
2334 gain: match_info.gain.clone(),
2335 bias: match_info.bias.clone(),
2336 gamma: match_info.gamma.clone(),
2337 };
2338
2339 groups.push(FusionGroup {
2340 id: *next_group_id,
2341 kind: FusionKind::ImageNormalize,
2342 nodes: match_info.nodes.clone(),
2343 shape,
2344 span: span.clone(),
2345 pattern: Some(FusionPattern::ImageNormalize(pattern)),
2346 stack_layout: None,
2347 });
2348 if fusion_debug_enabled() {
2349 log::debug!(
2350 "fusion: detected image normalize group id={} span={:?} nodes={:?}",
2351 next_group_id,
2352 span,
2353 match_info.nodes
2354 );
2355 }
2356 *next_group_id += 1;
2357 for node_id in match_info.nodes {
2358 assigned.insert(node_id);
2359 }
2360 }
2361}
2362
2363fn approx_eq(a: f64, b: f64) -> bool {
2364 let scale = a.abs().max(b.abs()).max(1.0);
2365 (a - b).abs() <= scale * 1e-6
2366}
2367
2368fn detect_power_step_normalize(
2369 graph: &AccelGraph,
2370 assigned: &mut HashSet<NodeId>,
2371 groups: &mut Vec<FusionGroup>,
2372 next_group_id: &mut usize,
2373) {
2374 'outer: for div_node in &graph.nodes {
2375 if assigned.contains(&div_node.id) {
2376 continue;
2377 }
2378 let div_op = match div_node.label {
2379 AccelNodeLabel::Primitive(op) => op,
2380 _ => continue,
2381 };
2382 if div_op != PrimitiveOp::ElemDiv {
2383 continue;
2384 }
2385 if div_node.inputs.len() != 2 {
2386 continue;
2387 }
2388 let numerator_vid = div_node.inputs[0];
2389 let denom_vid = div_node.inputs[1];
2390
2391 let (matmul_id, matmul_node) = match node_from_value(graph, numerator_vid) {
2392 Some((id, node)) => (id, node),
2393 None => continue,
2394 };
2395 if assigned.contains(&matmul_id) {
2396 continue;
2397 }
2398 match &matmul_node.label {
2399 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mtimes") => {}
2400 _ => continue,
2401 }
2402 if matmul_node.inputs.len() != 2 {
2403 continue;
2404 }
2405
2406 let Some(denom_info) = analyze_power_step_denominator(graph, denom_vid, numerator_vid)
2407 else {
2408 continue;
2409 };
2410 if assigned.contains(&denom_info.sqrt_node) {
2411 continue;
2412 }
2413 if assigned.contains(&denom_info.sum_node) {
2414 continue;
2415 }
2416 if assigned.contains(&denom_info.pow_node) {
2417 continue;
2418 }
2419 if let Some(add_id) = denom_info.add_node {
2420 if assigned.contains(&add_id) {
2421 continue;
2422 }
2423 }
2424 if denom_info.pow_input != numerator_vid {
2425 continue;
2426 }
2427
2428 let mut nodes = vec![matmul_id, denom_info.pow_node, denom_info.sum_node];
2429 if let Some(add_id) = denom_info.add_node {
2430 nodes.push(add_id);
2431 }
2432 nodes.push(denom_info.sqrt_node);
2433 nodes.push(div_node.id);
2434
2435 for node_id in &nodes {
2436 if assigned.contains(node_id) {
2437 continue 'outer;
2438 }
2439 }
2440
2441 nodes.sort_by_key(|node_id| {
2442 graph
2443 .node(*node_id)
2444 .map(|node| node.span.start)
2445 .unwrap_or(usize::MAX)
2446 });
2447
2448 let span = group_span(graph, &nodes);
2449 let shape = node_output_shape(graph, div_node);
2450
2451 groups.push(FusionGroup {
2452 id: *next_group_id,
2453 kind: FusionKind::PowerStepNormalize,
2454 nodes: nodes.clone(),
2455 shape,
2456 span,
2457 pattern: Some(FusionPattern::PowerStepNormalize {
2458 lhs: matmul_node.inputs[0],
2459 rhs: matmul_node.inputs[1],
2460 epsilon: denom_info.epsilon,
2461 }),
2462 stack_layout: None,
2463 });
2464 *next_group_id += 1;
2465 for id in nodes {
2466 assigned.insert(id);
2467 }
2468 }
2469}
2470
2471fn detect_explained_variance(
2472 graph: &AccelGraph,
2473 assigned: &mut HashSet<NodeId>,
2474 groups: &mut Vec<FusionGroup>,
2475 next_group_id: &mut usize,
2476) {
2477 for diag_node in &graph.nodes {
2478 if assigned.contains(&diag_node.id) {
2479 continue;
2480 }
2481 match &diag_node.label {
2482 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("diag") => {}
2483 _ => continue,
2484 }
2485 if diag_node.inputs.len() != 1 {
2486 continue;
2487 }
2488 let matmul2_vid = diag_node.inputs[0];
2489 let (matmul2_id, matmul2_node) = match node_from_value(graph, matmul2_vid) {
2490 Some(pair) => pair,
2491 None => continue,
2492 };
2493 if assigned.contains(&matmul2_id) {
2494 continue;
2495 }
2496 match &matmul2_node.label {
2497 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mtimes") => {}
2498 _ => continue,
2499 }
2500 if matmul2_node.inputs.len() != 2 {
2501 continue;
2502 }
2503
2504 let (matmul1_id, matmul1_node, q_vid) = if let Some((mm_id, mm_node)) =
2505 node_from_value(graph, matmul2_node.inputs[0])
2506 {
2507 if matches!(mm_node.label, AccelNodeLabel::Builtin { ref name } if name.eq_ignore_ascii_case("mtimes"))
2508 {
2509 (mm_id, mm_node, matmul2_node.inputs[1])
2510 } else {
2511 continue;
2512 }
2513 } else if let Some((mm_id, mm_node)) = node_from_value(graph, matmul2_node.inputs[1]) {
2514 if matches!(mm_node.label, AccelNodeLabel::Builtin { ref name } if name.eq_ignore_ascii_case("mtimes"))
2515 {
2516 (mm_id, mm_node, matmul2_node.inputs[0])
2517 } else {
2518 continue;
2519 }
2520 } else {
2521 continue;
2522 };
2523
2524 if assigned.contains(&matmul1_id) {
2525 continue;
2526 }
2527
2528 if matmul1_node.inputs.len() != 2 {
2529 continue;
2530 }
2531
2532 let (transpose_id, transpose_input_vid, g_vid) =
2533 if let Some((t_id, src_vid)) = is_transpose_node(graph, matmul1_node.inputs[0]) {
2534 (t_id, src_vid, matmul1_node.inputs[1])
2535 } else if let Some((t_id, src_vid)) = is_transpose_node(graph, matmul1_node.inputs[1]) {
2536 (t_id, src_vid, matmul1_node.inputs[0])
2537 } else {
2538 continue;
2539 };
2540
2541 if assigned.contains(&transpose_id) {
2542 continue;
2543 }
2544
2545 if transpose_input_vid != q_vid {
2546 continue;
2547 }
2548
2549 let mut nodes = vec![diag_node.id, matmul2_id, matmul1_id, transpose_id];
2550 nodes.sort_by_key(|node_id| {
2551 graph
2552 .node(*node_id)
2553 .map(|node| node.span.start)
2554 .unwrap_or(usize::MAX)
2555 });
2556 let span = group_span(graph, &nodes);
2557 let shape = node_output_shape(graph, diag_node);
2558 groups.push(FusionGroup {
2559 id: *next_group_id,
2560 kind: FusionKind::ExplainedVariance,
2561 nodes: nodes.clone(),
2562 shape,
2563 span,
2564 pattern: Some(FusionPattern::ExplainedVariance { q: q_vid, g: g_vid }),
2565 stack_layout: None,
2566 });
2567 *next_group_id += 1;
2568 for id in nodes {
2569 assigned.insert(id);
2570 }
2571 }
2572}
2573
2574struct PowerStepDenominatorInfo {
2575 sqrt_node: NodeId,
2576 add_node: Option<NodeId>,
2577 sum_node: NodeId,
2578 pow_node: NodeId,
2579 pow_input: ValueId,
2580 epsilon: f64,
2581}
2582
2583fn analyze_power_step_denominator(
2584 graph: &AccelGraph,
2585 denom_vid: ValueId,
2586 expected_source_vid: ValueId,
2587) -> Option<PowerStepDenominatorInfo> {
2588 let (sqrt_node_id, sqrt_input_vid, add_node_opt, epsilon_from_outer) =
2589 if let Some((sqrt_id, sqrt_in)) = is_sqrt_node(graph, denom_vid) {
2590 if let Some((add_node, sum_vid, epsilon_inner)) =
2591 extract_add_with_constant(graph, sqrt_in)
2592 {
2593 (sqrt_id, sum_vid, Some(add_node), epsilon_inner)
2594 } else {
2595 (sqrt_id, sqrt_in, None, 0.0)
2596 }
2597 } else if let Some((add_node, other_vid, epsilon_inner)) =
2598 extract_add_with_constant(graph, denom_vid)
2599 {
2600 let (sqrt_id, sqrt_in) = is_sqrt_node(graph, other_vid)?;
2601 (sqrt_id, sqrt_in, Some(add_node), epsilon_inner)
2602 } else {
2603 return None;
2604 };
2605
2606 let (sum_node_id, sum_node) = node_from_value(graph, sqrt_input_vid)?;
2607 match &sum_node.label {
2608 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("sum") => {}
2609 _ => return None,
2610 }
2611 if sum_node.inputs.is_empty() {
2612 return None;
2613 }
2614 let pow_vid = sum_node.inputs[0];
2615 let (pow_node_id, pow_node) = node_from_value(graph, pow_vid)?;
2616 let pow_input = match pow_node.label {
2617 AccelNodeLabel::Primitive(PrimitiveOp::ElemPow) => {
2618 if pow_node.inputs.len() != 2 {
2619 return None;
2620 }
2621 let base = pow_node.inputs[0];
2622 let exponent_vid = pow_node.inputs[1];
2623 let exponent = value_constant_f64(graph, exponent_vid)?;
2624 if !approx_eq(exponent, 2.0) {
2625 return None;
2626 }
2627 base
2628 }
2629 _ => return None,
2630 };
2631
2632 if pow_input != expected_source_vid {
2633 return None;
2634 }
2635
2636 let epsilon = epsilon_from_outer;
2637 Some(PowerStepDenominatorInfo {
2638 sqrt_node: sqrt_node_id,
2639 add_node: add_node_opt,
2640 sum_node: sum_node_id,
2641 pow_node: pow_node_id,
2642 pow_input,
2643 epsilon,
2644 })
2645}
2646
2647fn node_from_value(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, &AccelNode)> {
2648 let info = graph.value(vid)?;
2649 match info.origin {
2650 ValueOrigin::NodeOutput { node, .. } => graph.node(node).map(|n| (node, n)),
2651 _ => None,
2652 }
2653}
2654
2655fn is_sqrt_node(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId)> {
2656 let (node_id, node) = node_from_value(graph, vid)?;
2657 match &node.label {
2658 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("sqrt") => {
2659 let input = node.inputs.first().copied()?;
2660 Some((node_id, input))
2661 }
2662 _ => None,
2663 }
2664}
2665
2666fn is_transpose_node(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId)> {
2667 let (node_id, node) = node_from_value(graph, vid)?;
2668 match &node.label {
2669 AccelNodeLabel::Primitive(PrimitiveOp::Transpose) => {
2670 let input = node.inputs.first().copied()?;
2671 Some((node_id, input))
2672 }
2673 _ => None,
2674 }
2675}
2676
2677fn extract_add_with_constant(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId, f64)> {
2678 let (node_id, node) = node_from_value(graph, vid)?;
2679 match node.label {
2680 AccelNodeLabel::Primitive(PrimitiveOp::Add) => {
2681 if node.inputs.len() != 2 {
2682 return None;
2683 }
2684 let lhs = node.inputs[0];
2685 let rhs = node.inputs[1];
2686 if let Some(eps) = value_constant_f64(graph, rhs) {
2687 return Some((node_id, lhs, eps));
2688 }
2689 if let Some(eps) = value_constant_f64(graph, lhs) {
2690 return Some((node_id, rhs, eps));
2691 }
2692 None
2693 }
2694 AccelNodeLabel::Primitive(PrimitiveOp::Sub) => {
2695 if node.inputs.len() != 2 {
2696 return None;
2697 }
2698 let lhs = node.inputs[0];
2699 let rhs = node.inputs[1];
2700 if let Some(eps) = value_constant_f64(graph, rhs) {
2701 return Some((node_id, lhs, -eps));
2702 }
2703 if let Some(eps) = value_constant_f64(graph, lhs) {
2704 return Some((node_id, rhs, eps));
2705 }
2706 None
2707 }
2708 _ => None,
2709 }
2710}
2711
2712struct ConstantTrace {
2713 value: f64,
2714 nodes: Vec<NodeId>,
2715}
2716
2717fn collect_scalar_constant(graph: &AccelGraph, vid: ValueId) -> Option<ConstantTrace> {
2718 let mut current = vid;
2719 let mut nodes: Vec<NodeId> = Vec::new();
2720 let mut sign = 1.0f64;
2721 let mut visited: HashSet<NodeId> = HashSet::new();
2722
2723 loop {
2724 let info = graph.value(current)?;
2725 match &info.origin {
2726 ValueOrigin::Constant => {
2727 let base = value_info_scalar(info)?;
2728 return Some(ConstantTrace {
2729 value: sign * base,
2730 nodes,
2731 });
2732 }
2733 ValueOrigin::NodeOutput { node, .. } => {
2734 if !visited.insert(*node) {
2735 return None;
2736 }
2737 let node_ref = graph.node(*node)?;
2738 match &node_ref.label {
2739 AccelNodeLabel::Builtin { name }
2740 if name.eq_ignore_ascii_case("single")
2741 || name.eq_ignore_ascii_case("double")
2742 || name.eq_ignore_ascii_case("gpuarray") =>
2743 {
2744 if node_ref.inputs.len() != 1 {
2745 return None;
2746 }
2747 nodes.push(*node);
2748 current = node_ref.inputs[0];
2749 }
2750 AccelNodeLabel::Primitive(PrimitiveOp::Neg) => {
2751 if node_ref.inputs.len() != 1 {
2752 return None;
2753 }
2754 nodes.push(*node);
2755 sign = -sign;
2756 current = node_ref.inputs[0];
2757 }
2758 AccelNodeLabel::Primitive(PrimitiveOp::UPlus) => {
2759 if node_ref.inputs.len() != 1 {
2760 return None;
2761 }
2762 nodes.push(*node);
2763 current = node_ref.inputs[0];
2764 }
2765 _ => return None,
2766 }
2767 }
2768 _ => return None,
2769 }
2770 }
2771}
2772
2773fn scalar_shape_known_one(shape: &ShapeInfo) -> bool {
2774 match shape {
2775 ShapeInfo::Scalar => true,
2776 ShapeInfo::Tensor(dims) => {
2777 if dims.is_empty() {
2778 return true;
2779 }
2780 dims.iter().all(|dim| matches!(dim, Some(1)))
2781 }
2782 ShapeInfo::Unknown => false,
2783 }
2784}
2785
2786fn capture_image_scalar(
2787 graph: &AccelGraph,
2788 vid: ValueId,
2789 assigned: &HashSet<NodeId>,
2790 _nodes: &mut Vec<NodeId>,
2791) -> Option<ImageScalar> {
2792 if let Some(trace) = collect_scalar_constant(graph, vid) {
2793 if trace.nodes.iter().any(|id| assigned.contains(id)) {
2794 return None;
2795 }
2796 return Some(ImageScalar::Constant(trace.value));
2797 }
2798 let info = graph.value(vid)?;
2799 if scalar_shape_known_one(&info.shape) {
2800 return Some(ImageScalar::Value(vid));
2801 }
2802 if log::log_enabled!(log::Level::Debug) {
2803 log::debug!(
2804 "capture_image_scalar: reject vid={vid:?} shape={:?} origin={:?}",
2805 info.shape,
2806 info.origin
2807 );
2808 }
2809 None
2810}
2811
2812fn peel_numeric_casts(
2813 graph: &AccelGraph,
2814 mut vid: ValueId,
2815 assigned: &HashSet<NodeId>,
2816 _nodes: &mut Vec<NodeId>,
2817) -> Option<ValueId> {
2818 loop {
2819 let info = graph.value(vid)?;
2820 match &info.origin {
2821 ValueOrigin::NodeOutput { node, .. } => {
2822 if assigned.contains(node) {
2823 return None;
2824 }
2825 let node_ref = graph.node(*node)?;
2826 if let AccelNodeLabel::Builtin { name } = &node_ref.label {
2827 if name.eq_ignore_ascii_case("single")
2828 || name.eq_ignore_ascii_case("double")
2829 || name.eq_ignore_ascii_case("gpuarray")
2830 {
2831 if node_ref.inputs.len() != 1 {
2832 return None;
2833 }
2834 vid = node_ref.inputs[0];
2835 continue;
2836 }
2837 }
2838 return Some(vid);
2839 }
2840 _ => return Some(vid),
2841 }
2842 }
2843}
2844
2845fn resolve_scalar_constant(graph: &AccelGraph, vid: ValueId) -> Option<f64> {
2846 collect_scalar_constant(graph, vid).map(|trace| trace.value)
2847}
2848
2849fn value_info_scalar(info: &ValueInfo) -> Option<f64> {
2850 match &info.constant {
2851 Some(Value::Num(v)) => Some(*v),
2852 Some(Value::Int(i)) => Some(i.to_f64()),
2853 Some(Value::Tensor(t)) if t.data.len() == 1 => Some(t.data[0]),
2854 Some(Value::LogicalArray(arr)) if arr.data.len() == 1 => Some(arr.data[0] as f64),
2855 Some(Value::Bool(flag)) => Some(if *flag { 1.0 } else { 0.0 }),
2856 _ => None,
2857 }
2858}
2859
2860fn value_constant_f64(graph: &AccelGraph, vid: ValueId) -> Option<f64> {
2861 resolve_scalar_constant(graph, vid)
2862}
2863
2864fn primitive_expr(
2865 op: PrimitiveOp,
2866 inputs: &[ValueId],
2867 exprs: &HashMap<ValueId, String>,
2868) -> Option<String> {
2869 let binary = |exprs: &HashMap<ValueId, String>| -> Option<(String, String)> {
2870 let lhs = exprs.get(inputs.first()?).cloned()?;
2871 let rhs = exprs.get(inputs.get(1)?).cloned()?;
2872 Some((lhs, rhs))
2873 };
2874 match op {
2875 PrimitiveOp::Add => {
2876 let (lhs, rhs) = binary(exprs)?;
2877 Some(format!("({lhs} + {rhs})"))
2878 }
2879 PrimitiveOp::Sub => {
2880 let (lhs, rhs) = binary(exprs)?;
2881 Some(format!("({lhs} - {rhs})"))
2882 }
2883 PrimitiveOp::Mul | PrimitiveOp::ElemMul => {
2884 let (lhs, rhs) = binary(exprs)?;
2885 Some(format!("({lhs} * {rhs})"))
2886 }
2887 PrimitiveOp::ElemDiv | PrimitiveOp::ElemLeftDiv => {
2888 let (lhs, rhs) = binary(exprs)?;
2889 Some(format!("({lhs} / {rhs})"))
2890 }
2891 PrimitiveOp::Pow | PrimitiveOp::ElemPow => {
2892 let (lhs, rhs) = binary(exprs)?;
2893 Some(format!("pow({lhs}, {rhs})"))
2894 }
2895 PrimitiveOp::Neg => {
2896 let arg = exprs.get(inputs.first()?).cloned()?;
2897 Some(format!("(-{arg})"))
2898 }
2899 PrimitiveOp::UPlus => {
2900 let arg = exprs.get(inputs.first()?).cloned()?;
2901 Some(format!("(+{arg})"))
2902 }
2903 _ => None,
2904 }
2905}
2906
2907fn builtin_expr(
2908 name: &str,
2909 inputs: &[ValueId],
2910 exprs: &HashMap<ValueId, String>,
2911 scalar_ty: &str,
2912) -> Option<String> {
2913 let func = match name.to_ascii_lowercase().as_str() {
2914 "isfinite" => return builtin_unary_call("isFinite", inputs, exprs),
2915 "isinf" => return builtin_unary_call("isInf", inputs, exprs),
2916 "isnan" => return builtin_unary_call("isNan", inputs, exprs),
2917 "single" | "double" | "gpuarray" => return builtin_identity(inputs, exprs),
2918 "fix" => return builtin_unary_call("trunc", inputs, exprs),
2919 "sign" => return builtin_unary_call("sign", inputs, exprs),
2920 "mod" => {
2921 let lhs = exprs.get(inputs.first()?).cloned()?;
2922 let rhs = exprs.get(inputs.get(1)?).cloned()?;
2923 return Some(format!(
2927 "select(({lhs} - {rhs} * floor({lhs} / {rhs})), select({rhs}, {lhs}, ({lhs} == 0.0 || sign({lhs}) == sign({rhs}))), (isInf({rhs}) && isFinite({lhs})))"
2928 ));
2929 }
2930 "rem" => {
2931 let lhs = exprs.get(inputs.first()?).cloned()?;
2932 let rhs = exprs.get(inputs.get(1)?).cloned()?;
2933 return Some(format!(
2934 "select(({lhs} - {rhs} * trunc({lhs} / {rhs})), {lhs}, (isInf({rhs}) && isFinite({lhs})))"
2935 ));
2936 }
2937 "sin" => "sin",
2938 "cos" => "cos",
2939 "tan" => "tan",
2940 "asin" => "asin",
2941 "acos" => "acos",
2942 "atan" => "atan",
2943 "atan2" => return builtin_binary("atan2", inputs, exprs),
2944 "hypot" => return builtin_binary("hypot", inputs, exprs),
2945 "pow2" => {
2946 if inputs.len() == 1 {
2947 return builtin_unary_call("exp2", inputs, exprs);
2948 }
2949 return None;
2950 }
2951 "sinh" => "sinh",
2952 "cosh" => "cosh",
2953 "tanh" => "tanh",
2954 "exp" => "exp",
2955 "log" => "log",
2956 "log2" => "log2",
2957 "sqrt" => "sqrt",
2958 "abs" => "abs",
2959 "exp2" => "exp2",
2960 "floor" => "floor",
2961 "ceil" => "ceil",
2962 "round" => "round",
2963 "trunc" => "trunc",
2964 "asinh" => return builtin_unary_call("asinh", inputs, exprs),
2965 "acosh" => return builtin_unary_call("acosh", inputs, exprs),
2966 "atanh" => return builtin_unary_call("atanh", inputs, exprs),
2967 "max" => return builtin_binary("max", inputs, exprs),
2968 "min" => return builtin_binary("min", inputs, exprs),
2969 _ => {
2970 return match name.to_ascii_lowercase().as_str() {
2971 "log10" => {
2972 let arg = exprs.get(inputs.first()?).cloned()?;
2973 let constant = cast_literal(scalar_ty, "0.4342944819032518");
2974 Some(format!("(log({arg}) * {constant})"))
2975 }
2976 "log1p" => {
2977 let arg = exprs.get(inputs.first()?).cloned()?;
2978 let one = cast_literal(scalar_ty, "1.0");
2979 Some(format!("log({arg} + {one})"))
2980 }
2981 "expm1" => {
2982 let arg = exprs.get(inputs.first()?).cloned()?;
2983 let one = cast_literal(scalar_ty, "1.0");
2984 Some(format!("(exp({arg}) - {one})"))
2985 }
2986 _ => None,
2987 }
2988 }
2989 };
2990 let arg = exprs.get(inputs.first()?).cloned()?;
2991 Some(format!("{func}({arg})"))
2992}
2993
2994fn builtin_binary(
2995 func: &str,
2996 inputs: &[ValueId],
2997 exprs: &HashMap<ValueId, String>,
2998) -> Option<String> {
2999 let lhs = exprs.get(inputs.first()?).cloned()?;
3000 let rhs = exprs.get(inputs.get(1)?).cloned()?;
3001 Some(format!("{func}({lhs}, {rhs})"))
3002}
3003
3004fn builtin_unary_call(
3005 func: &str,
3006 inputs: &[ValueId],
3007 exprs: &HashMap<ValueId, String>,
3008) -> Option<String> {
3009 let arg = exprs.get(inputs.first()?).cloned()?;
3010 Some(format!("{func}({arg})"))
3011}
3012
3013fn builtin_identity(inputs: &[ValueId], exprs: &HashMap<ValueId, String>) -> Option<String> {
3014 exprs.get(inputs.first()?).cloned()
3015}
3016
3017fn cast_literal(scalar_ty: &str, literal: &str) -> String {
3018 if scalar_ty == "f64" {
3019 format!("{scalar_ty}({literal})")
3020 } else {
3021 literal.to_string()
3022 }
3023}
3024
3025fn split_add_with_scalar(
3026 graph: &AccelGraph,
3027 vid: ValueId,
3028 assigned: &HashSet<NodeId>,
3029 nodes: &mut Vec<NodeId>,
3030) -> Option<(NodeId, ValueId, ImageScalar)> {
3031 let (node_id, node) = node_from_value(graph, vid)?;
3032 match node.label {
3033 AccelNodeLabel::Primitive(PrimitiveOp::Add) => {
3034 if node.inputs.len() != 2 {
3035 return None;
3036 }
3037 let lhs = node.inputs[0];
3038 let rhs = node.inputs[1];
3039 if let Some(scalar) = capture_image_scalar(graph, rhs, assigned, nodes) {
3040 return Some((node_id, lhs, scalar));
3041 }
3042 if let Some(scalar) = capture_image_scalar(graph, lhs, assigned, nodes) {
3043 return Some((node_id, rhs, scalar));
3044 }
3045 None
3046 }
3047 AccelNodeLabel::Primitive(PrimitiveOp::Sub) => {
3048 if node.inputs.len() != 2 {
3049 return None;
3050 }
3051 let lhs = node.inputs[0];
3052 let rhs = node.inputs[1];
3053 if let Some(ImageScalar::Constant(value)) =
3054 capture_image_scalar(graph, rhs, assigned, nodes)
3055 {
3056 return Some((node_id, lhs, ImageScalar::Constant(-value)));
3057 }
3058 None
3059 }
3060 _ => None,
3061 }
3062}
3063
3064fn split_mul_with_scalar(
3065 graph: &AccelGraph,
3066 vid: ValueId,
3067 assigned: &HashSet<NodeId>,
3068 nodes: &mut Vec<NodeId>,
3069) -> Option<(NodeId, ValueId, ImageScalar)> {
3070 let (node_id, node) = node_from_value(graph, vid)?;
3071 match node.label {
3072 AccelNodeLabel::Primitive(PrimitiveOp::Mul)
3073 | AccelNodeLabel::Primitive(PrimitiveOp::ElemMul) => {
3074 if node.inputs.len() != 2 {
3075 return None;
3076 }
3077 let lhs = node.inputs[0];
3078 let rhs = node.inputs[1];
3079 if let Some(scalar) = capture_image_scalar(graph, rhs, assigned, nodes) {
3080 return Some((node_id, lhs, scalar));
3081 }
3082 if let Some(scalar) = capture_image_scalar(graph, lhs, assigned, nodes) {
3083 return Some((node_id, rhs, scalar));
3084 }
3085 None
3086 }
3087 _ => None,
3088 }
3089}
3090
3091fn split_max_with_zero_scalar(
3092 graph: &AccelGraph,
3093 vid: ValueId,
3094 assigned: &HashSet<NodeId>,
3095 nodes: &mut Vec<NodeId>,
3096) -> Option<(NodeId, ValueId)> {
3097 let (node_id, node) = node_from_value(graph, vid)?;
3098 match &node.label {
3099 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("max") => {
3100 if node.inputs.len() != 2 {
3101 if log::log_enabled!(log::Level::Debug) {
3102 log::debug!(
3103 "split_max_with_zero_scalar: node {node_id:?} has {} inputs",
3104 node.inputs.len()
3105 );
3106 }
3107 return None;
3108 }
3109 let lhs = node.inputs[0];
3110 let rhs = node.inputs[1];
3111 if let Some(ImageScalar::Constant(value)) =
3112 capture_image_scalar(graph, rhs, assigned, nodes)
3113 {
3114 if approx_eq(value, 0.0) {
3115 if log::log_enabled!(log::Level::Debug) {
3116 log::debug!(
3117 "split_max_with_zero_scalar: rhs zero constant for node {node_id:?}"
3118 );
3119 }
3120 return Some((node_id, lhs));
3121 }
3122 }
3123 if let Some(ImageScalar::Constant(value)) =
3124 capture_image_scalar(graph, lhs, assigned, nodes)
3125 {
3126 if approx_eq(value, 0.0) {
3127 if log::log_enabled!(log::Level::Debug) {
3128 log::debug!(
3129 "split_max_with_zero_scalar: lhs zero constant for node {node_id:?}"
3130 );
3131 }
3132 return Some((node_id, rhs));
3133 }
3134 }
3135 if log::log_enabled!(log::Level::Debug) {
3136 log::debug!(
3137 "split_max_with_zero_scalar: node {node_id:?} inputs not zero constants"
3138 );
3139 }
3140 None
3141 }
3142 _ => None,
3143 }
3144}
3145
3146fn resolve_numeric_vector_constant(graph: &AccelGraph, vid: ValueId) -> Option<Vec<f64>> {
3147 if let Some(scalar) = resolve_scalar_constant(graph, vid) {
3148 return Some(vec![scalar]);
3149 }
3150 let info = graph.value(vid)?;
3151 match &info.constant {
3152 Some(Value::Tensor(tensor)) if !tensor.data.is_empty() => Some(tensor.data.clone()),
3153 Some(Value::LogicalArray(arr)) if !arr.data.is_empty() => Some(
3154 arr.data
3155 .iter()
3156 .map(|v| if *v == 0 { 0.0 } else { 1.0 })
3157 .collect(),
3158 ),
3159 Some(Value::Bool(flag)) => Some(vec![if *flag { 1.0 } else { 0.0 }]),
3160 Some(Value::Int(iv)) => Some(vec![iv.to_f64()]),
3161 Some(Value::Num(num)) => Some(vec![*num]),
3162 _ => None,
3163 }
3164}
3165
3166fn match_mean_axes(graph: &AccelGraph, vid: ValueId) -> Option<(NodeId, ValueId, Vec<f64>)> {
3167 let (node_id, node) = node_from_value(graph, vid)?;
3168 match &node.label {
3169 AccelNodeLabel::Builtin { name } if name.eq_ignore_ascii_case("mean") => {}
3170 _ => return None,
3171 }
3172 if node.inputs.len() < 2 {
3173 return None;
3174 }
3175 let data_vid = node.inputs[0];
3176 let dims_vid = node.inputs[1];
3177 let dims = resolve_numeric_vector_constant(graph, dims_vid)?;
3178 Some((node_id, data_vid, dims))
3179}
3180
3181fn dims_match_unordered(found: &[f64], expected: &[f64]) -> bool {
3182 if found.len() != expected.len() {
3183 return false;
3184 }
3185 let mut a: Vec<i64> = found.iter().map(|d| d.round() as i64).collect();
3186 let mut b: Vec<i64> = expected.iter().map(|d| d.round() as i64).collect();
3187 a.sort_unstable();
3188 b.sort_unstable();
3189 a == b
3190}
3191
3192fn peel_mean_dims(
3193 graph: &AccelGraph,
3194 vid: ValueId,
3195 expected_dims: &[f64],
3196 assigned: &HashSet<NodeId>,
3197 nodes: &mut Vec<NodeId>,
3198) -> Option<ValueId> {
3199 if expected_dims.is_empty() {
3200 return Some(vid);
3201 }
3202 let (node_id, data_vid, dims) = match_mean_axes(graph, vid)?;
3203 if assigned.contains(&node_id) {
3204 return None;
3205 }
3206 if dims.len() == expected_dims.len() && dims_match_unordered(&dims, expected_dims) {
3207 nodes.push(node_id);
3208 return Some(data_vid);
3209 }
3210 if dims.len() == 1 && approx_eq(dims[0], expected_dims[0]) {
3211 nodes.push(node_id);
3212 return peel_mean_dims(graph, data_vid, &expected_dims[1..], assigned, nodes);
3213 }
3214 None
3215}
3216
3217struct ImageNormalizeMatch {
3218 nodes: Vec<NodeId>,
3219 input: ValueId,
3220 epsilon: ImageScalar,
3221 gain: Option<ImageScalar>,
3222 bias: Option<ImageScalar>,
3223 gamma: Option<ImageScalar>,
3224}
3225
3226fn analyze_image_normalize(
3227 graph: &AccelGraph,
3228 pow_node_id: NodeId,
3229 assigned: &HashSet<NodeId>,
3230) -> Option<ImageNormalizeMatch> {
3231 let pow_node = graph.node(pow_node_id)?;
3232 if log::log_enabled!(log::Level::Trace) {
3233 log::trace!(
3234 "image_normalize: inspect pow candidate node={pow_node_id:?} label={:?}",
3235 pow_node.label
3236 );
3237 }
3238 macro_rules! img_norm_fail {
3239 ($reason:expr) => {{
3240 if log::log_enabled!(log::Level::Trace) {
3241 log::trace!(
3242 "image_normalize: reject node {pow_node_id:?} reason={}",
3243 $reason
3244 );
3245 }
3246 return None;
3247 }};
3248 }
3249 if !matches!(
3250 pow_node.label,
3251 AccelNodeLabel::Primitive(PrimitiveOp::ElemPow)
3252 ) {
3253 img_norm_fail!("not elem pow");
3254 }
3255 if pow_node.inputs.len() != 2 || pow_node.outputs.len() != 1 {
3256 img_norm_fail!("unexpected pow arity");
3257 }
3258
3259 let mut nodes: Vec<NodeId> = vec![pow_node_id];
3260
3261 let gamma_scalar = capture_image_scalar(graph, pow_node.inputs[1], assigned, &mut nodes)?;
3262 if log::log_enabled!(log::Level::Trace) {
3263 log::trace!("image_normalize: node {pow_node_id:?} gamma scalar={gamma_scalar:?}");
3264 }
3265 let gamma_opt = match &gamma_scalar {
3266 ImageScalar::Constant(value) if approx_eq(*value, 1.0) => None,
3267 _ => Some(gamma_scalar),
3268 };
3269
3270 let (clamp_node_id, clamp_input_vid) =
3271 split_max_with_zero_scalar(graph, pow_node.inputs[0], assigned, &mut nodes)?;
3272 if assigned.contains(&clamp_node_id) {
3273 img_norm_fail!("clamp node already assigned");
3274 }
3275 nodes.push(clamp_node_id);
3276
3277 let pre_bias_vid = peel_numeric_casts(graph, clamp_input_vid, assigned, &mut nodes)?;
3278 let (pre_gain_vid, bias_opt) = if let Some((add_node_id, base_vid, bias_scalar)) =
3279 split_add_with_scalar(graph, pre_bias_vid, assigned, &mut nodes)
3280 {
3281 if assigned.contains(&add_node_id) {
3282 img_norm_fail!("bias add already assigned");
3283 }
3284 nodes.push(add_node_id);
3285 let bias = match &bias_scalar {
3286 ImageScalar::Constant(value) if approx_eq(*value, 0.0) => None,
3287 _ => Some(bias_scalar),
3288 };
3289 let base_vid = peel_numeric_casts(graph, base_vid, assigned, &mut nodes)?;
3290 (base_vid, bias)
3291 } else {
3292 (pre_bias_vid, None)
3293 };
3294
3295 let (mut norm_vid, gain_opt) = if let Some((mul_node_id, base_vid, gain_scalar)) =
3296 split_mul_with_scalar(graph, pre_gain_vid, assigned, &mut nodes)
3297 {
3298 if assigned.contains(&mul_node_id) {
3299 img_norm_fail!("gain mul already assigned");
3300 }
3301 nodes.push(mul_node_id);
3302 let gain = match &gain_scalar {
3303 ImageScalar::Constant(value) if approx_eq(*value, 1.0) => None,
3304 _ => Some(gain_scalar),
3305 };
3306 let base_vid = peel_numeric_casts(graph, base_vid, assigned, &mut nodes)?;
3307 (base_vid, gain)
3308 } else {
3309 (pre_gain_vid, None)
3310 };
3311
3312 norm_vid = peel_numeric_casts(graph, norm_vid, assigned, &mut nodes)?;
3313
3314 let (div_node_id, div_node) = node_from_value(graph, norm_vid)?;
3315 if assigned.contains(&div_node_id) {
3316 img_norm_fail!("div node already assigned");
3317 }
3318 match div_node.label {
3319 AccelNodeLabel::Primitive(PrimitiveOp::ElemDiv) => {}
3320 _ => img_norm_fail!("not div primitive"),
3321 }
3322 if div_node.inputs.len() != 2 {
3323 img_norm_fail!("div arity");
3324 }
3325
3326 let diff_vid = div_node.inputs[0];
3327 let sigma_vid = peel_numeric_casts(graph, div_node.inputs[1], assigned, &mut nodes)?;
3328 let (sigma_node_id, sigma_input_vid) = match is_sqrt_node(graph, sigma_vid) {
3329 Some(pair) => pair,
3330 None => img_norm_fail!("sigma not sqrt"),
3331 };
3332 if assigned.contains(&sigma_node_id) {
3333 img_norm_fail!("sqrt node already assigned");
3334 }
3335 nodes.push(div_node_id);
3336 nodes.push(sigma_node_id);
3337
3338 let (add_node_id, mean_sq_vid, epsilon_scalar) =
3339 split_add_with_scalar(graph, sigma_input_vid, assigned, &mut nodes)?;
3340 if assigned.contains(&add_node_id) {
3341 img_norm_fail!("epsilon add already assigned");
3342 }
3343 nodes.push(add_node_id);
3344 let epsilon = epsilon_scalar;
3345 let mean_sq_vid = peel_numeric_casts(graph, mean_sq_vid, assigned, &mut nodes)?;
3346
3347 let squared_diff_vid = peel_mean_dims(graph, mean_sq_vid, &[3.0, 2.0], assigned, &mut nodes)?;
3348
3349 let (square_pow_node_id, square_pow_node) = node_from_value(graph, squared_diff_vid)?;
3350 if assigned.contains(&square_pow_node_id) {
3351 img_norm_fail!("square pow already assigned");
3352 }
3353 if !matches!(
3354 square_pow_node.label,
3355 AccelNodeLabel::Primitive(PrimitiveOp::ElemPow)
3356 ) {
3357 img_norm_fail!("variance pow not elem pow");
3358 }
3359 if square_pow_node.inputs.len() != 2 {
3360 img_norm_fail!("variance pow arity");
3361 }
3362 let exponent_trace = collect_scalar_constant(graph, square_pow_node.inputs[1])?;
3363 if !approx_eq(exponent_trace.value, 2.0) {
3364 img_norm_fail!("variance exponent != 2");
3365 }
3366 if exponent_trace.nodes.iter().any(|id| assigned.contains(id)) {
3367 img_norm_fail!("variance exponent nodes already assigned");
3368 }
3369 nodes.push(square_pow_node_id);
3370 nodes.extend(exponent_trace.nodes.iter().copied());
3371
3372 let diff_var_vid = square_pow_node.inputs[0];
3373 let (diff_var_node_id, diff_var_node) = node_from_value(graph, diff_var_vid)?;
3374 if assigned.contains(&diff_var_node_id) {
3375 img_norm_fail!("diff variance node already assigned");
3376 }
3377 if !matches!(
3378 diff_var_node.label,
3379 AccelNodeLabel::Primitive(PrimitiveOp::Sub)
3380 ) {
3381 img_norm_fail!("diff variance node not sub");
3382 }
3383 if diff_var_node.inputs.len() != 2 {
3384 img_norm_fail!("diff variance arity");
3385 }
3386 let imgs_vid = diff_var_node.inputs[0];
3387 let mu_vid = peel_numeric_casts(graph, diff_var_node.inputs[1], assigned, &mut nodes)?;
3388 nodes.push(diff_var_node_id);
3389
3390 let (diff_node_id, diff_node) = node_from_value(graph, diff_vid)?;
3391 if assigned.contains(&diff_node_id) {
3392 img_norm_fail!("diff node already assigned");
3393 }
3394 if !matches!(diff_node.label, AccelNodeLabel::Primitive(PrimitiveOp::Sub)) {
3395 img_norm_fail!("diff node not sub");
3396 }
3397 if diff_node.inputs.len() != 2 {
3398 img_norm_fail!("diff node arity");
3399 }
3400 let diff_mu_vid = peel_numeric_casts(graph, diff_node.inputs[1], assigned, &mut nodes)?;
3401 if diff_node.inputs[0] != imgs_vid || diff_mu_vid != mu_vid {
3402 img_norm_fail!("diff inputs mismatch with variance pair");
3403 }
3404 nodes.push(diff_node_id);
3405
3406 let mean_mu_input_vid = peel_mean_dims(graph, mu_vid, &[3.0, 2.0], assigned, &mut nodes)?;
3407 if mean_mu_input_vid != imgs_vid {
3408 img_norm_fail!("mean mu input mismatch");
3409 }
3410
3411 let input_info = graph.value(imgs_vid)?;
3412 match &input_info.shape {
3413 ShapeInfo::Tensor(dims) if dims.len() >= 2 => {}
3414 ShapeInfo::Unknown => {}
3415 other => {
3416 if log::log_enabled!(log::Level::Debug) {
3417 log::debug!(
3418 "image_normalize: node {pow_node_id:?} input shape {:?}",
3419 other
3420 );
3421 }
3422 img_norm_fail!("input not 3-d tensor");
3423 }
3424 }
3425
3426 nodes.sort_unstable();
3427 nodes.dedup();
3428
3429 Some(ImageNormalizeMatch {
3430 nodes,
3431 input: imgs_vid,
3432 epsilon,
3433 gain: gain_opt,
3434 bias: bias_opt,
3435 gamma: gamma_opt,
3436 })
3437}
3438
3439#[cfg(test)]
3440mod tests {
3441 use super::*;
3442 use crate::graph::{
3443 AccelGraph, AccelGraphTag, AccelNode, AccelNodeLabel, AccelOpCategory, InstrSpan,
3444 PrimitiveOp, ValueId, ValueInfo, ValueOrigin, VarKind,
3445 };
3446 use runmat_builtins::{Type, Value};
3447 use std::collections::HashMap as StdHashMap;
3448
3449 fn simple_elementwise_graph() -> AccelGraph {
3450 let values = vec![
3451 ValueInfo {
3453 id: 0,
3454 origin: ValueOrigin::Variable {
3455 kind: VarKind::Global,
3456 index: 0,
3457 },
3458 ty: Type::tensor(),
3459 shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3460 constant: None,
3461 },
3462 ValueInfo {
3464 id: 1,
3465 origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3466 ty: Type::tensor(),
3467 shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3468 constant: None,
3469 },
3470 ValueInfo {
3472 id: 2,
3473 origin: ValueOrigin::NodeOutput { node: 1, output: 0 },
3474 ty: Type::tensor(),
3475 shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3476 constant: None,
3477 },
3478 ];
3479
3480 let node0 = AccelNode {
3481 id: 0,
3482 label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3483 category: AccelOpCategory::Elementwise,
3484 inputs: vec![0, 0],
3485 outputs: vec![1],
3486 span: InstrSpan { start: 10, end: 10 },
3487 tags: vec![AccelGraphTag::Elementwise],
3488 };
3489 let node1 = AccelNode {
3490 id: 1,
3491 label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3492 category: AccelOpCategory::Elementwise,
3493 inputs: vec![1, 0],
3494 outputs: vec![2],
3495 span: InstrSpan { start: 11, end: 11 },
3496 tags: vec![AccelGraphTag::Elementwise],
3497 };
3498
3499 AccelGraph {
3500 nodes: vec![node0, node1],
3501 values,
3502 var_bindings: StdHashMap::new(),
3503 node_bindings: StdHashMap::new(),
3504 }
3505 }
3506
3507 #[test]
3508 fn detects_chain() {
3509 let graph = simple_elementwise_graph();
3510 let groups = detect_fusion_groups(&graph);
3511 assert_eq!(groups.len(), 1);
3512 let group = &groups[0];
3513 assert_eq!(group.nodes, vec![0, 1]);
3514 assert_eq!(group.kind, FusionKind::ElementwiseChain);
3515 }
3516
3517 #[test]
3518 fn prepare_fusion_plan_requires_semantic_candidate_groups() {
3519 let graph = simple_elementwise_graph();
3520 let groups = detect_fusion_groups(&graph);
3521 assert_eq!(groups.len(), 1);
3522
3523 let plan = prepare_fusion_plan(Some(&graph), &groups, 0);
3524 assert!(
3525 plan.is_none(),
3526 "bytecode groups alone should not produce an executable fusion plan without semantic candidate evidence"
3527 );
3528 }
3529
3530 #[test]
3531 fn prepare_fusion_plan_allows_semantic_gated_groups() {
3532 let graph = simple_elementwise_graph();
3533 let groups = detect_fusion_groups(&graph);
3534 assert_eq!(groups.len(), 1);
3535
3536 let plan = prepare_fusion_plan(Some(&graph), &groups, 1);
3537 assert!(
3538 plan.is_some(),
3539 "semantic candidate evidence should allow executable fusion plan preparation"
3540 );
3541 }
3542
3543 #[test]
3544 fn prepare_fusion_plan_recovers_empty_group_nodes_from_contained_runtime_span() {
3545 let graph = simple_elementwise_graph();
3546 let groups = vec![FusionGroup {
3547 id: 0,
3548 kind: FusionKind::ElementwiseChain,
3549 nodes: Vec::new(),
3550 shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3551 span: InstrSpan { start: 10, end: 10 },
3552 pattern: None,
3553 stack_layout: None,
3554 }];
3555
3556 let plan = prepare_fusion_plan(Some(&graph), &groups, 1)
3557 .expect("runtime group sanitization should recover contained elementwise nodes");
3558 assert_eq!(plan.groups.len(), 1);
3559 assert_eq!(
3560 plan.groups[0].group.nodes,
3561 vec![0],
3562 "runtime sanitization should recover a compatible contained node for empty group mapping"
3563 );
3564 }
3565
3566 #[test]
3567 fn prepare_fusion_plan_rejects_empty_group_nodes_when_runtime_graph_is_too_far() {
3568 let graph = simple_elementwise_graph();
3569 let groups = vec![FusionGroup {
3570 id: 0,
3571 kind: FusionKind::ElementwiseChain,
3572 nodes: Vec::new(),
3573 shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3574 span: InstrSpan { start: 20, end: 20 },
3575 pattern: None,
3576 stack_layout: None,
3577 }];
3578
3579 let plan = prepare_fusion_plan(Some(&graph), &groups, 1);
3580 assert!(
3581 plan.is_none(),
3582 "runtime sanitization should reject empty group mapping when no compatible nearby nodes exist"
3583 );
3584 }
3585
3586 #[test]
3587 fn prepare_fusion_plan_rejects_empty_group_nodes_when_runtime_node_covers_group_span() {
3588 let values = vec![
3589 ValueInfo {
3590 id: 0,
3591 origin: ValueOrigin::Variable {
3592 kind: VarKind::Global,
3593 index: 0,
3594 },
3595 ty: Type::tensor(),
3596 shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3597 constant: None,
3598 },
3599 ValueInfo {
3600 id: 1,
3601 origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3602 ty: Type::tensor(),
3603 shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3604 constant: None,
3605 },
3606 ];
3607 let graph = AccelGraph {
3608 nodes: vec![AccelNode {
3609 id: 0,
3610 label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3611 category: AccelOpCategory::Elementwise,
3612 inputs: vec![0, 0],
3613 outputs: vec![1],
3614 span: InstrSpan { start: 10, end: 12 },
3615 tags: vec![AccelGraphTag::Elementwise],
3616 }],
3617 values,
3618 var_bindings: StdHashMap::new(),
3619 node_bindings: StdHashMap::new(),
3620 };
3621 let groups = vec![FusionGroup {
3622 id: 0,
3623 kind: FusionKind::ElementwiseChain,
3624 nodes: Vec::new(),
3625 shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3626 span: InstrSpan { start: 11, end: 11 },
3627 pattern: None,
3628 stack_layout: None,
3629 }];
3630
3631 let plan = prepare_fusion_plan(Some(&graph), &groups, 1);
3632 assert!(
3633 plan.is_none(),
3634 "runtime sanitization should reject covering runtime-node spans when semantic group spans are narrower"
3635 );
3636 }
3637
3638 #[test]
3639 fn prepare_fusion_plan_rejects_stale_mapped_nodes_without_runtime_remap() {
3640 let graph = simple_elementwise_graph();
3641 let groups = vec![FusionGroup {
3642 id: 0,
3643 kind: FusionKind::ElementwiseChain,
3644 nodes: vec![1],
3645 shape: ShapeInfo::Tensor(vec![Some(4), Some(4)]),
3646 span: InstrSpan { start: 10, end: 10 },
3647 pattern: None,
3648 stack_layout: None,
3649 }];
3650
3651 let plan = prepare_fusion_plan(Some(&graph), &groups, 1);
3652 assert!(
3653 plan.is_none(),
3654 "runtime sanitization should reject stale mapped nodes instead of remapping from runtime graph scan"
3655 );
3656 }
3657
3658 #[test]
3659 fn builds_plan_and_template() {
3660 let graph = simple_elementwise_graph();
3661 let groups = detect_fusion_groups(&graph);
3662 let plan = FusionPlan::from_graph(&graph, &groups);
3663 assert_eq!(plan.groups.len(), 1);
3664 let group_plan = &plan.groups[0];
3665 assert!(group_plan.kernel.supported);
3666 let wgsl = group_plan.generate_wgsl("f32").expect("wgsl");
3667 assert!(wgsl.contains("@compute"));
3668 assert!(group_plan.group.element_count().is_some());
3669 }
3670
3671 #[test]
3672 fn stack_pattern_tracks_repeated_constants() {
3673 let values = vec![
3674 ValueInfo {
3675 id: 0,
3676 origin: ValueOrigin::Variable {
3677 kind: VarKind::Global,
3678 index: 0,
3679 },
3680 ty: Type::tensor(),
3681 shape: ShapeInfo::Tensor(vec![Some(4)]),
3682 constant: None,
3683 },
3684 ValueInfo {
3685 id: 1,
3686 origin: ValueOrigin::Constant,
3687 ty: Type::tensor(),
3688 shape: ShapeInfo::Tensor(vec![Some(4)]),
3689 constant: Some(Value::Num(1.0)),
3690 },
3691 ValueInfo {
3692 id: 2,
3693 origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3694 ty: Type::tensor(),
3695 shape: ShapeInfo::Tensor(vec![Some(4)]),
3696 constant: None,
3697 },
3698 ValueInfo {
3699 id: 3,
3700 origin: ValueOrigin::NodeOutput { node: 1, output: 0 },
3701 ty: Type::tensor(),
3702 shape: ShapeInfo::Tensor(vec![Some(4)]),
3703 constant: None,
3704 },
3705 ];
3706
3707 let node0 = AccelNode {
3708 id: 0,
3709 label: AccelNodeLabel::Primitive(PrimitiveOp::Add),
3710 category: AccelOpCategory::Elementwise,
3711 inputs: vec![0, 1],
3712 outputs: vec![2],
3713 span: InstrSpan { start: 5, end: 5 },
3714 tags: vec![AccelGraphTag::Elementwise],
3715 };
3716 let node1 = AccelNode {
3717 id: 1,
3718 label: AccelNodeLabel::Primitive(PrimitiveOp::Add),
3719 category: AccelOpCategory::Elementwise,
3720 inputs: vec![2, 1],
3721 outputs: vec![3],
3722 span: InstrSpan { start: 6, end: 6 },
3723 tags: vec![AccelGraphTag::Elementwise],
3724 };
3725
3726 let graph = AccelGraph {
3727 nodes: vec![node0, node1],
3728 values,
3729 var_bindings: StdHashMap::new(),
3730 node_bindings: StdHashMap::new(),
3731 };
3732
3733 let groups = detect_fusion_groups(&graph);
3734 assert_eq!(groups.len(), 1);
3735 let plan = FusionPlan::from_graph(&graph, &groups);
3736 let group_plan = &plan.groups[0];
3737 assert_eq!(group_plan.inputs.len(), 2);
3738 assert!(group_plan.stack_pattern.is_empty());
3739 assert!(group_plan.constants.contains_key(&1));
3740 assert!(group_plan.const_values.contains_key(&1));
3741 }
3742
3743 #[test]
3744 fn builtin_expr_supports_extended_set() {
3745 let mut exprs: StdHashMap<ValueId, String> = StdHashMap::new();
3746 exprs.insert(0, "v0".to_string());
3747 exprs.insert(1, "v1".to_string());
3748
3749 let log1p = super::builtin_expr("log1p", &[0], &exprs, "f32");
3750 assert!(log1p.is_some());
3751
3752 let log10 = super::builtin_expr("log10", &[0], &exprs, "f64");
3753 assert!(log10.unwrap().contains("log"));
3754
3755 let expm1 = super::builtin_expr("expm1", &[0], &exprs, "f32");
3756 assert!(expm1.unwrap().contains("exp"));
3757
3758 let floor = super::builtin_expr("floor", &[0], &exprs, "f32");
3759 assert_eq!(floor.unwrap(), "floor(v0)");
3760
3761 let atan2 = super::builtin_expr("atan2", &[0, 1], &exprs, "f32");
3762 assert_eq!(atan2.unwrap(), "atan2(v0, v1)");
3763
3764 let asinh = super::builtin_expr("asinh", &[0], &exprs, "f32");
3765 assert_eq!(asinh.unwrap(), "asinh(v0)");
3766
3767 let acosh = super::builtin_expr("acosh", &[0], &exprs, "f32");
3768 assert_eq!(acosh.unwrap(), "acosh(v0)");
3769
3770 let atanh = super::builtin_expr("atanh", &[0], &exprs, "f32");
3771 assert_eq!(atanh.unwrap(), "atanh(v0)");
3772
3773 let hypot = super::builtin_expr("hypot", &[0, 1], &exprs, "f32");
3774 assert_eq!(hypot.unwrap(), "hypot(v0, v1)");
3775
3776 let sign = super::builtin_expr("sign", &[0], &exprs, "f32");
3777 assert_eq!(sign.unwrap(), "sign(v0)");
3778
3779 let fix = super::builtin_expr("fix", &[0], &exprs, "f32");
3780 assert_eq!(fix.unwrap(), "trunc(v0)");
3781
3782 let modulo = super::builtin_expr("mod", &[0, 1], &exprs, "f32");
3783 let modulo = modulo.unwrap();
3784 assert!(modulo.contains("floor"));
3785 assert!(modulo.contains("isInf"));
3786
3787 let rem = super::builtin_expr("rem", &[0, 1], &exprs, "f32");
3788 let rem = rem.unwrap();
3789 assert!(rem.contains("trunc"));
3790 assert!(rem.contains("isInf"));
3791
3792 let pow2 = super::builtin_expr("pow2", &[0], &exprs, "f32");
3793 assert_eq!(pow2.unwrap(), "exp2(v0)");
3794
3795 let single = super::builtin_expr("single", &[0], &exprs, "f32");
3796 assert_eq!(single.unwrap(), "v0");
3797
3798 let double = super::builtin_expr("double", &[0], &exprs, "f64");
3799 assert_eq!(double.unwrap(), "v0");
3800 }
3801
3802 #[test]
3803 fn fanout_chain_with_casts_supported() {
3804 let values = vec![
3805 ValueInfo {
3807 id: 0,
3808 origin: ValueOrigin::Variable {
3809 kind: VarKind::Global,
3810 index: 0,
3811 },
3812 ty: Type::tensor(),
3813 shape: ShapeInfo::Tensor(vec![Some(8)]),
3814 constant: None,
3815 },
3816 ValueInfo {
3818 id: 1,
3819 origin: ValueOrigin::NodeOutput { node: 0, output: 0 },
3820 ty: Type::tensor(),
3821 shape: ShapeInfo::Tensor(vec![Some(8)]),
3822 constant: None,
3823 },
3824 ValueInfo {
3826 id: 2,
3827 origin: ValueOrigin::Constant,
3828 ty: Type::Num,
3829 shape: ShapeInfo::Scalar,
3830 constant: Some(Value::Num(0.1)),
3831 },
3832 ValueInfo {
3834 id: 3,
3835 origin: ValueOrigin::NodeOutput { node: 1, output: 0 },
3836 ty: Type::Num,
3837 shape: ShapeInfo::Scalar,
3838 constant: None,
3839 },
3840 ValueInfo {
3842 id: 4,
3843 origin: ValueOrigin::NodeOutput { node: 2, output: 0 },
3844 ty: Type::tensor(),
3845 shape: ShapeInfo::Tensor(vec![Some(8)]),
3846 constant: None,
3847 },
3848 ValueInfo {
3850 id: 5,
3851 origin: ValueOrigin::NodeOutput { node: 3, output: 0 },
3852 ty: Type::tensor(),
3853 shape: ShapeInfo::Tensor(vec![Some(8)]),
3854 constant: None,
3855 },
3856 ];
3857
3858 let tanh_node = AccelNode {
3859 id: 0,
3860 label: AccelNodeLabel::Builtin {
3861 name: "tanh".to_string(),
3862 },
3863 category: AccelOpCategory::Elementwise,
3864 inputs: vec![0],
3865 outputs: vec![1],
3866 span: InstrSpan { start: 10, end: 10 },
3867 tags: vec![AccelGraphTag::Elementwise],
3868 };
3869 let single_node = AccelNode {
3870 id: 1,
3871 label: AccelNodeLabel::Builtin {
3872 name: "single".to_string(),
3873 },
3874 category: AccelOpCategory::Elementwise,
3875 inputs: vec![2],
3876 outputs: vec![3],
3877 span: InstrSpan { start: 11, end: 11 },
3878 tags: vec![AccelGraphTag::Elementwise],
3879 };
3880 let mul_node = AccelNode {
3881 id: 2,
3882 label: AccelNodeLabel::Primitive(PrimitiveOp::ElemMul),
3883 category: AccelOpCategory::Elementwise,
3884 inputs: vec![3, 0],
3885 outputs: vec![4],
3886 span: InstrSpan { start: 12, end: 12 },
3887 tags: vec![AccelGraphTag::Elementwise],
3888 };
3889 let add_node = AccelNode {
3890 id: 3,
3891 label: AccelNodeLabel::Primitive(PrimitiveOp::Add),
3892 category: AccelOpCategory::Elementwise,
3893 inputs: vec![1, 4],
3894 outputs: vec![5],
3895 span: InstrSpan { start: 13, end: 13 },
3896 tags: vec![AccelGraphTag::Elementwise],
3897 };
3898
3899 let graph = AccelGraph {
3900 nodes: vec![tanh_node, single_node, mul_node, add_node],
3901 values,
3902 var_bindings: StdHashMap::new(),
3903 node_bindings: StdHashMap::new(),
3904 };
3905
3906 let groups = detect_fusion_groups(&graph);
3907 assert_eq!(groups.len(), 1);
3908
3909 let plan = FusionPlan::from_graph(&graph, &groups);
3910 let group_plan = &plan.groups[0];
3911 assert!(group_plan.kernel.supported);
3912 let shader = group_plan.generate_wgsl("f32");
3913 assert!(shader
3914 .as_ref()
3915 .map(|wgsl| wgsl.contains("tanh") && wgsl.contains("output.data"))
3916 .unwrap_or(false));
3917 }
3918}