1use crate::accel::residency as accel_residency;
2use crate::bytecode::ExecutionContext;
3use crate::bytecode::Instr;
4use crate::interpreter::engine as interp_engine;
5use crate::interpreter::errors::mex;
6use crate::runtime::workspace::refresh_workspace_state;
7use runmat_accelerate::fusion::FusionStoreMaterialization;
8use runmat_accelerate::fusion_exec::{
9 execute_centered_gram, execute_elementwise, execute_explained_variance,
10 execute_image_normalize, execute_matmul_epilogue, execute_power_step_normalize,
11 execute_reduction, FusionExecutionRequest,
12};
13use runmat_accelerate::InstrSpan;
14use runmat_accelerate::{value_is_all_keyword, FusionKind, ShapeInfo, ValueOrigin, VarKind};
15use runmat_builtins::Value;
16use runmat_runtime::RuntimeError;
17use std::collections::HashMap;
18
19#[inline]
20pub fn value_kind(value: &Value) -> &'static str {
21 match value {
22 Value::Int(_) => "Int",
23 Value::Num(_) => "Num",
24 Value::Complex(_, _) => "Complex",
25 Value::Bool(_) => "Bool",
26 Value::LogicalArray(_) => "LogicalArray",
27 Value::String(_) => "String",
28 Value::StringArray(_) => "StringArray",
29 Value::CharArray(_) => "CharArray",
30 Value::Tensor(_) => "Tensor",
31 Value::ComplexTensor(_) => "ComplexTensor",
32 Value::Cell(_) => "Cell",
33 Value::Struct(_) => "Struct",
34 Value::GpuTensor(_) => "GpuTensor",
35 Value::Object(_) => "Object",
36 Value::HandleObject(_) => "HandleObject",
37 Value::Listener(_) => "Listener",
38 Value::FunctionHandle(_) => "FunctionHandle",
39 Value::Closure(_) => "Closure",
40 Value::ClassRef(_) => "ClassRef",
41 Value::MException(_) => "MException",
42 Value::OutputList(_) => "OutputList",
43 }
44}
45
46#[inline]
47pub fn summarize_value(i: usize, v: &Value) -> String {
48 match v {
49 Value::GpuTensor(h) => format!("in#{i}:GpuTensor shape={:?}", h.shape),
50 Value::Tensor(t) => format!("in#{i}:Tensor shape={:?}", t.shape),
51 Value::Num(n) => format!("in#{i}:Num({n:.6})"),
52 Value::Int(n) => format!("in#{i}:Int({})", n.to_i64()),
53 Value::Bool(b) => format!("in#{i}:Bool({})", if *b { 1 } else { 0 }),
54 Value::String(s) => format!("in#{i}:String({})", s),
55 _ => format!("in#{i}:{}", value_kind(v)),
56 }
57}
58
59pub fn fusion_span_live_result_count(instructions: &[Instr], span: &InstrSpan) -> Option<usize> {
60 if span.start > span.end || span.end >= instructions.len() {
61 return None;
62 }
63 let mut current_depth = 0usize;
64 for instr in &instructions[span.start..=span.end] {
65 let effect = instr.stack_effect()?;
66 if current_depth < effect.pops {
67 current_depth = effect.pops;
68 }
69 current_depth = current_depth - effect.pops + effect.pushes;
70 }
71 Some(current_depth)
72}
73
74pub fn fusion_span_has_vm_barrier(instructions: &[Instr], span: &InstrSpan) -> bool {
75 if span.start > span.end || span.end >= instructions.len() {
76 return true;
77 }
78 for instr in &instructions[span.start..=span.end] {
79 if matches!(
80 instr,
81 Instr::StoreIndex(_)
82 | Instr::StoreSlice(_, _, _, _)
83 | Instr::StoreSliceExpr { .. }
84 | Instr::StoreIndexCell(_)
85 | Instr::StoreMember(_)
86 | Instr::StoreMemberOrInit(_)
87 | Instr::StoreMemberDynamic
88 | Instr::StoreMemberDynamicOrInit
89 ) {
90 return true;
91 }
92 }
93 fusion_span_live_result_count(instructions, span) != Some(1)
94}
95
96pub struct StackSliceGuard<'a> {
97 stack: *mut Vec<Value>,
98 slice: Option<Vec<Value>>,
99 _marker: std::marker::PhantomData<&'a mut Vec<Value>>,
100}
101
102impl<'a> StackSliceGuard<'a> {
103 pub fn new(stack: &'a mut Vec<Value>, slice_start: usize) -> Self {
104 let slice = stack.split_off(slice_start);
105 Self {
106 stack,
107 slice: Some(slice),
108 _marker: std::marker::PhantomData,
109 }
110 }
111
112 pub fn slice(&self) -> &[Value] {
113 self.slice.as_ref().expect("stack slice missing").as_slice()
114 }
115
116 pub fn commit(mut self) {
117 self.slice = None;
118 }
119}
120
121impl Drop for StackSliceGuard<'_> {
122 fn drop(&mut self) {
123 if let Some(slice) = self.slice.take() {
124 unsafe { (&mut *self.stack).extend(slice) }
125 }
126 }
127}
128
129pub fn gather_fusion_inputs<'a>(
130 plan: &'a runmat_accelerate::FusionGroupPlan,
131 graph: &runmat_accelerate::AccelGraph,
132 stack: &'a mut Vec<Value>,
133 vars: &mut [Value],
134 context: &mut ExecutionContext,
135) -> Result<
136 (
137 StackSliceGuard<'a>,
138 FusionExecutionRequest<'a>,
139 Vec<Option<Value>>,
140 ),
141 RuntimeError,
142> {
143 if plan.group.stack_layout.is_none() && !plan.stack_pattern.is_empty() {
144 return Err(mex(
145 "FusionMissingStackLayout",
146 "fusion: missing compile-time stack layout metadata",
147 ));
148 }
149 let required_stack_operands = plan
150 .group
151 .stack_layout
152 .as_ref()
153 .map(|layout| layout.required_stack_operands)
154 .unwrap_or_else(|| plan.stack_pattern.len());
155 let mut inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
156
157 for (idx, value) in &plan.constants {
158 if let Some(slot) = inputs.get_mut(*idx) {
159 if slot.is_none() {
160 *slot = Some(value.clone());
161 }
162 }
163 }
164
165 for (idx, value_id) in plan.inputs.iter().enumerate() {
166 let info = graph
167 .value(*value_id)
168 .ok_or_else(|| format!("fusion: missing value metadata for id {value_id}"))?;
169 match &info.origin {
170 ValueOrigin::Variable { kind, index } => {
171 let value =
172 match kind {
173 VarKind::Global => vars
174 .get(*index)
175 .cloned()
176 .ok_or_else(|| format!("fusion: global var {index} out of range"))?,
177 VarKind::Local => {
178 if let Some(frame) = context.call_stack.last() {
179 let absolute = frame.locals_start + index;
180 context.locals.get(absolute).cloned().ok_or_else(|| {
181 format!("fusion: local var {index} unavailable")
182 })?
183 } else {
184 vars.get(*index).cloned().ok_or_else(|| {
185 format!("fusion: local var {index} unavailable")
186 })?
187 }
188 }
189 };
190 debug_assert!(
191 inputs[idx].is_none(),
192 "fusion: duplicate input slot {} for plan {}",
193 idx,
194 plan.index
195 );
196 inputs[idx] = Some(value);
197 }
198 ValueOrigin::Constant | ValueOrigin::NodeOutput { .. } | ValueOrigin::Unknown => {}
199 }
200 }
201
202 if log::log_enabled!(log::Level::Debug) && interp_engine::fusion_debug_enabled() {
203 let stack_needed_preview = required_stack_operands;
204 let stack_snapshot: Vec<&Value> = stack.iter().rev().take(stack_needed_preview).collect();
205 let stack_kinds: Vec<&'static str> =
206 stack_snapshot.iter().rev().map(|v| value_kind(v)).collect();
207 let input_meta: Vec<String> = plan
208 .inputs
209 .iter()
210 .enumerate()
211 .map(|(i, value_id)| {
212 if let Some(info) = graph.value(*value_id) {
213 format!("#{i}:id={} origin={:?}", value_id, info.origin)
214 } else {
215 format!("#{i}:id={} origin=<missing>", value_id)
216 }
217 })
218 .collect();
219 log::debug!(
220 "fusion group {} gather: stack_depth={} stack_needed={} stack_kinds={:?} pattern={:?} inputs={:?}",
221 plan.index, stack.len(), stack_needed_preview, stack_kinds, &plan.stack_pattern, input_meta
222 );
223 }
224
225 if stack.len() < required_stack_operands {
226 if interp_engine::fusion_debug_enabled() {
227 log::debug!(
228 "fusion stack underflow: plan={} needed={} available={} pattern={:?}",
229 plan.index,
230 required_stack_operands,
231 stack.len(),
232 plan.stack_pattern
233 );
234 }
235 return Err(mex(
236 "FusionStackUnderflow",
237 "fusion: stack underflow gathering inputs",
238 ));
239 }
240 let available = required_stack_operands;
241 let slice_start = stack.len() - available;
242 let stack_guard = StackSliceGuard::new(stack, slice_start);
243 let slice = stack_guard.slice().to_vec();
244 let mut consumed_inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
245 let input_positions: HashMap<runmat_accelerate::graph::ValueId, usize> = plan
246 .inputs
247 .iter()
248 .enumerate()
249 .map(|(idx, value_id)| (*value_id, idx))
250 .collect();
251
252 let allow_stack_value = |val: &Value| {
253 if plan.group.kind.is_reduction() {
254 matches!(val, Value::GpuTensor(_) | Value::Tensor(_))
255 } else {
256 true
257 }
258 };
259
260 if let Some(layout) = plan.group.stack_layout.as_ref() {
261 for binding in &layout.bindings {
262 let Some(input_idx) = input_positions.get(&binding.value_id).copied() else {
263 continue;
264 };
265 let Some(val) = slice.get(binding.stack_offset).cloned() else {
266 continue;
267 };
268 consumed_inputs[input_idx] = Some(val.clone());
269 if inputs[input_idx].is_none() && allow_stack_value(&val) {
270 inputs[input_idx] = Some(val);
271 }
272 }
273 } else {
274 for (offset, input_idx) in plan.stack_pattern.iter().enumerate() {
275 let Some(val) = slice.get(offset).cloned() else {
276 continue;
277 };
278 consumed_inputs[*input_idx] = Some(val.clone());
279 if inputs[*input_idx].is_none() && allow_stack_value(&val) {
280 inputs[*input_idx] = Some(val);
281 }
282 }
283 }
284
285 for (idx, slot) in inputs.iter_mut().enumerate() {
286 if slot.is_some() {
287 continue;
288 }
289 let vid = plan.inputs[idx];
290 let info = graph.value(vid);
291 if let Some(info) = info {
292 match &info.origin {
293 ValueOrigin::Variable { kind, index } => {
294 let value_opt = match kind {
295 VarKind::Global => vars.get(*index).cloned(),
296 VarKind::Local => {
297 if let Some(frame) = context.call_stack.last() {
298 let absolute = frame.locals_start + index;
299 context.locals.get(absolute).cloned()
300 } else {
301 vars.get(*index).cloned()
302 }
303 }
304 };
305 if let Some(value) = value_opt {
306 *slot = Some(value);
307 continue;
308 }
309 }
310 ValueOrigin::Constant => {
311 if let Some(value) = plan.const_values.get(&vid) {
312 *slot = Some(value.clone());
313 continue;
314 }
315 }
316 _ => {}
317 }
318 }
319 if slot.is_none() {
320 if let Some(binding) = graph.var_binding(vid) {
321 let value_opt = match binding.kind {
322 VarKind::Global => vars.get(binding.index).cloned(),
323 VarKind::Local => {
324 if let Some(frame) = context.call_stack.last() {
325 let absolute = frame.locals_start + binding.index;
326 context.locals.get(absolute).cloned()
327 } else {
328 vars.get(binding.index).cloned()
329 }
330 }
331 };
332 if let Some(value) = value_opt {
333 *slot = Some(value);
334 continue;
335 }
336 }
337 }
338 if slot.is_none() {
339 if let Some(info) = info {
340 if let ValueOrigin::NodeOutput { node, .. } = info.origin {
341 if let Some(binding) = graph.node_binding(node) {
342 let value_opt = match binding.kind {
343 VarKind::Global => vars.get(binding.index).cloned(),
344 VarKind::Local => {
345 if let Some(frame) = context.call_stack.last() {
346 let absolute = frame.locals_start + binding.index;
347 context.locals.get(absolute).cloned()
348 } else {
349 vars.get(binding.index).cloned()
350 }
351 }
352 };
353 if let Some(value) = value_opt {
354 *slot = Some(value);
355 continue;
356 }
357 }
358 }
359 }
360 }
361 if slot.is_none() {
362 if let Some(value) = plan.const_values.get(&vid) {
363 *slot = Some(value.clone());
364 }
365 }
366 }
367
368 let inputs: Vec<Value> = inputs
369 .into_iter()
370 .map(|opt| opt.ok_or_else(|| mex("FusionMissingInput", "fusion: missing input value")))
371 .collect::<Result<_, _>>()?;
372
373 if log::log_enabled!(log::Level::Debug) {
374 let summaries: Vec<String> = inputs
375 .iter()
376 .enumerate()
377 .map(|(i, v)| summarize_value(i, v))
378 .collect();
379 log::debug!("fusion inputs runtime: [{}]", summaries.join(", "));
380 }
381
382 Ok((
383 stack_guard,
384 FusionExecutionRequest { plan, inputs },
385 consumed_inputs,
386 ))
387}
388
389pub fn write_elementwise_materialized_stores(
390 materialized_stores: Vec<(FusionStoreMaterialization, Value)>,
391 vars: &mut Vec<Value>,
392 context: &mut ExecutionContext,
393) {
394 for (store, value) in materialized_stores {
395 match store.binding.kind {
396 VarKind::Global => {
397 let i = store.binding.index;
398 if i < vars.len() && !accel_residency::same_gpu_handle(&vars[i], &value) {
399 accel_residency::clear_value(&vars[i]);
400 }
401 if i >= vars.len() {
402 vars.resize(i + 1, Value::Num(0.0));
403 refresh_workspace_state(vars);
404 }
405 vars[i] = value;
406 }
407 VarKind::Local => {
408 if let Some(frame) = context.call_stack.last() {
409 let absolute = frame.locals_start + store.binding.index;
410 while context.locals.len() <= absolute {
411 context.locals.push(Value::Num(0.0));
412 }
413 if !accel_residency::same_gpu_handle(&context.locals[absolute], &value) {
414 accel_residency::clear_value(&context.locals[absolute]);
415 }
416 context.locals[absolute] = value;
417 } else {
418 let i = store.binding.index;
419 if i < vars.len() && !accel_residency::same_gpu_handle(&vars[i], &value) {
420 accel_residency::clear_value(&vars[i]);
421 }
422 if i >= vars.len() {
423 vars.resize(i + 1, Value::Num(0.0));
424 refresh_workspace_state(vars);
425 }
426 vars[i] = value;
427 }
428 }
429 }
430 }
431}
432
433pub fn execute_fusion_elementwise(
434 request: FusionExecutionRequest<'_>,
435 stack_guard: StackSliceGuard<'_>,
436 vars: &mut Vec<Value>,
437 context: &mut ExecutionContext,
438) -> Result<Value, RuntimeError> {
439 match execute_elementwise(request) {
440 Ok(result) => {
441 write_elementwise_materialized_stores(result.materialized_stores, vars, context);
442 stack_guard.commit();
443 Ok(result.final_value)
444 }
445 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
446 }
447}
448
449pub async fn execute_fusion_special_kind(
450 kind: FusionKind,
451 plan_inputs: &[runmat_accelerate::graph::ValueId],
452 request: FusionExecutionRequest<'_>,
453 stack_guard: StackSliceGuard<'_>,
454) -> Result<Value, RuntimeError> {
455 match kind {
456 FusionKind::CenteredGram => match execute_centered_gram(request).await {
457 Ok(result) => {
458 stack_guard.commit();
459 Ok(result)
460 }
461 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
462 },
463 FusionKind::PowerStepNormalize => match execute_power_step_normalize(request).await {
464 Ok(result) => {
465 stack_guard.commit();
466 Ok(result)
467 }
468 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
469 },
470 FusionKind::ExplainedVariance => {
471 log::debug!("explained variance plan inputs {:?}", plan_inputs);
472 match execute_explained_variance(request).await {
473 Ok(result) => {
474 stack_guard.commit();
475 Ok(result)
476 }
477 Err(err) => {
478 log::debug!("explained variance fusion fallback: {}", err);
479 Err(mex("FusionExecutionFailed", &err.to_string()))
480 }
481 }
482 }
483 FusionKind::MatmulEpilogue => match execute_matmul_epilogue(request).await {
484 Ok(result) => {
485 stack_guard.commit();
486 Ok(result)
487 }
488 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
489 },
490 FusionKind::ImageNormalize => match execute_image_normalize(request).await {
491 Ok(result) => {
492 stack_guard.commit();
493 Ok(result)
494 }
495 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
496 },
497 _ => Err(mex(
498 "FusionUnsupportedKind",
499 "fusion: unsupported fusion kind",
500 )),
501 }
502}
503
504pub struct ReductionGeometry {
505 pub axis: usize,
506 pub reduce_len: usize,
507 pub num_slices: usize,
508}
509
510pub fn resolve_reduction_geometry(
511 plan: &runmat_accelerate::FusionGroupPlan,
512 graph: &runmat_accelerate::AccelGraph,
513 request: &FusionExecutionRequest<'_>,
514 consumed_inputs: &[Option<Value>],
515 vars: &[Value],
516 context: &ExecutionContext,
517) -> Result<ReductionGeometry, RuntimeError> {
518 fn detect_reduce_all(
519 plan: &runmat_accelerate::FusionGroupPlan,
520 graph: &runmat_accelerate::AccelGraph,
521 ) -> bool {
522 let mut reduce_all = matches!(
523 plan.reduction_axes,
524 Some(runmat_accelerate::ReductionAxes::All)
525 );
526 let has_all = reduce_all
527 || plan.constants.values().any(value_is_all_keyword)
528 || plan.const_values.values().any(value_is_all_keyword);
529 if has_all {
530 return true;
531 }
532 for node_id in &plan.group.nodes {
533 if let Some(node) = graph.node(*node_id) {
534 if let runmat_accelerate::graph::AccelNodeLabel::Builtin { name } = &node.label {
535 if name.eq_ignore_ascii_case("mean") {
536 for input_vid in &node.inputs {
537 if let Some(info) = graph.value(*input_vid) {
538 if let Some(constant) = &info.constant {
539 if value_is_all_keyword(constant) {
540 reduce_all = true;
541 break;
542 }
543 }
544 }
545 }
546 }
547 }
548 }
549 if reduce_all {
550 break;
551 }
552 }
553 reduce_all
554 }
555
556 fn resolve_reduction_axis(plan: &runmat_accelerate::FusionGroupPlan) -> (usize, bool) {
557 let mut axis = 0usize;
558 let mut axis_explicit = false;
559 if let Some(runmat_accelerate::ReductionAxes::Explicit(dims)) = &plan.reduction_axes {
560 if let Some(first) = dims.first().copied() {
561 axis = first.saturating_sub(1);
562 axis_explicit = true;
563 }
564 }
565 if let Some(dim_vid) = plan.reduction_dim {
566 if let Some(cv) = plan.const_values.get(&dim_vid) {
567 axis = match cv {
568 Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
569 Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
570 _ => axis,
571 };
572 axis_explicit = true;
573 } else if let Some(input_idx) = plan.inputs.iter().position(|v| *v == dim_vid) {
574 if let Some(cv) = plan.constants.get(&input_idx) {
575 axis = match cv {
576 Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
577 Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
578 _ => axis,
579 };
580 axis_explicit = true;
581 }
582 }
583 } else if let Some(dim_const) = plan.constants.get(&1) {
584 axis = match dim_const {
585 Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
586 Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
587 _ => axis,
588 };
589 axis_explicit = true;
590 }
591 (axis, axis_explicit)
592 }
593
594 fn derive_rows_cols(
595 plan: &runmat_accelerate::FusionGroupPlan,
596 graph: &runmat_accelerate::AccelGraph,
597 request: &FusionExecutionRequest<'_>,
598 consumed_inputs: &[Option<Value>],
599 vars: &[Value],
600 context: &ExecutionContext,
601 ) -> Option<(usize, usize)> {
602 let shape_of = |value: &Value| -> Option<(usize, usize)> {
603 match value {
604 Value::GpuTensor(h) => Some((
605 h.shape.first().copied().unwrap_or(1).max(1),
606 h.shape.get(1).copied().unwrap_or(1).max(1),
607 )),
608 Value::Tensor(t) => Some((
609 t.shape.first().copied().unwrap_or(1).max(1),
610 t.shape.get(1).copied().unwrap_or(1).max(1),
611 )),
612 _ => None,
613 }
614 };
615
616 if let Some(shape) = plan.reduction_data_shape(graph) {
617 if shape.len() >= 2 {
618 return Some((shape[0].max(1), shape[1].max(1)));
619 }
620 if shape.len() == 1 {
621 return Some((shape[0].max(1), 1));
622 }
623 }
624
625 for &vid in &plan.inputs {
626 if let Some(binding) = graph.var_binding(vid) {
627 let value_opt = match binding.kind {
628 VarKind::Global => vars.get(binding.index).cloned(),
629 VarKind::Local => {
630 if let Some(frame) = context.call_stack.last() {
631 let absolute = frame.locals_start + binding.index;
632 context.locals.get(absolute).cloned()
633 } else {
634 vars.get(binding.index).cloned()
635 }
636 }
637 };
638 if let Some(value) = value_opt {
639 if let Some(shape) = shape_of(&value) {
640 return Some(shape);
641 }
642 }
643 }
644 }
645
646 for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
647 if let Some(shape) = shape_of(v) {
648 return Some(shape);
649 }
650 }
651
652 if let Some(data_id) = plan.reduction_data {
653 if let Some(input_index) = plan.inputs.iter().position(|vid| *vid == data_id) {
654 if let Some(val) = consumed_inputs.get(input_index).and_then(|v| v.as_ref()) {
655 if let Some(shape) = shape_of(val) {
656 return Some(shape);
657 }
658 }
659 if let Some(val) = request.inputs.get(input_index) {
660 if let Some(shape) = shape_of(val) {
661 return Some(shape);
662 }
663 }
664 }
665 if let Some(info) = graph.value(data_id) {
666 if let ValueOrigin::Variable { kind, index } = &info.origin {
667 let val = match kind {
668 VarKind::Global => vars.get(*index).cloned(),
669 VarKind::Local => {
670 if let Some(frame) = context.call_stack.last() {
671 let absolute = frame.locals_start + index;
672 context.locals.get(absolute).cloned()
673 } else {
674 vars.get(*index).cloned()
675 }
676 }
677 };
678 if let Some(v) = val {
679 if let Some(shape) = shape_of(&v) {
680 return Some(shape);
681 }
682 }
683 }
684 if let ShapeInfo::Tensor(dims) = &info.shape {
685 if !dims.is_empty() {
686 let r = dims.first().and_then(|d| *d).unwrap_or(1);
687 let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
688 return Some((r.max(1), c.max(1)));
689 }
690 }
691 }
692 }
693
694 for v in &request.inputs {
695 if let Some(shape) = shape_of(v) {
696 return Some(shape);
697 }
698 }
699
700 if let ShapeInfo::Tensor(dims) = &plan.group.shape {
701 if !dims.is_empty() {
702 let r = dims.first().and_then(|d| *d).unwrap_or(1);
703 let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
704 return Some((r.max(1), c.max(1)));
705 }
706 }
707 None
708 }
709
710 if log::log_enabled!(log::Level::Debug) {
711 let meta: Vec<String> = plan
712 .inputs
713 .iter()
714 .map(|vid| {
715 if let Some(info) = graph.value(*vid) {
716 format!(
717 "vid={} origin={:?} shape={:?}",
718 vid, info.origin, info.shape
719 )
720 } else {
721 format!("vid={} origin=<missing>", vid)
722 }
723 })
724 .collect();
725 log::debug!("reduction gather meta: [{}]", meta.join(", "));
726 }
727
728 let reduce_all = detect_reduce_all(plan, graph);
729 let (mut axis, axis_explicit) = if reduce_all {
730 (0usize, false)
731 } else {
732 resolve_reduction_axis(plan)
733 };
734 if reduce_all && interp_engine::fusion_debug_enabled() {
735 log::debug!(
736 "fusion reduction (all) meta: data_vid={:?} inputs={:?} stack_pattern={:?}",
737 plan.reduction_data,
738 plan.inputs,
739 plan.stack_pattern
740 );
741 }
742
743 let (r, c) =
744 derive_rows_cols(plan, graph, request, consumed_inputs, vars, context).unwrap_or((1, 1));
745 let (reduce_len, num_slices) = if reduce_all {
746 let total_from_runtime = consumed_inputs
747 .iter()
748 .filter_map(|v| v.as_ref())
749 .chain(request.inputs.iter())
750 .find_map(|value| match value {
751 Value::GpuTensor(handle) => Some(if handle.shape.is_empty() {
752 1
753 } else {
754 handle
755 .shape
756 .iter()
757 .copied()
758 .map(|d| d.max(1))
759 .product::<usize>()
760 }),
761 Value::Tensor(tensor) => Some(if tensor.shape.is_empty() {
762 1
763 } else {
764 tensor
765 .shape
766 .iter()
767 .copied()
768 .map(|d| d.max(1))
769 .product::<usize>()
770 }),
771 _ => None,
772 });
773 let total = plan
774 .reduction_data_shape(graph)
775 .map(|shape| shape.into_iter().map(|d| d.max(1)).product::<usize>())
776 .or(total_from_runtime)
777 .or_else(|| plan.element_count())
778 .filter(|v| *v > 0)
779 .ok_or_else(|| {
780 mex(
781 "FusionReductionExtentUnknown",
782 "fusion: reduction all extent unknown",
783 )
784 })?;
785 if interp_engine::fusion_debug_enabled() {
786 log::debug!(
787 "fusion reduction (all): total_elems={} fallback_rows={} fallback_cols={}",
788 total,
789 r,
790 c
791 );
792 }
793 (total, 1usize)
794 } else {
795 if !axis_explicit {
796 axis = if r == 1 && c > 1 {
797 1
798 } else if r > 1 {
799 0
800 } else {
801 axis
802 };
803 }
804 if interp_engine::fusion_debug_enabled() {
805 if r == 1 && c == 1 {
806 log::debug!(
807 "fusion reduction: unresolved shape (defaulted to 1x1); axis={}, constants={:?}",
808 axis,
809 plan.constants
810 );
811 } else {
812 log::debug!(
813 "fusion reduction: resolved shape rows={} cols={} axis={} constants={:?}",
814 r,
815 c,
816 axis,
817 plan.constants
818 );
819 }
820 }
821 if axis == 0 {
822 (r, c)
823 } else {
824 (c, r)
825 }
826 };
827
828 if interp_engine::fusion_debug_enabled() {
829 log::debug!(
830 "fusion reduction: axis={} reduce_len={} num_slices={} constants={:?}",
831 axis,
832 reduce_len,
833 num_slices,
834 plan.constants
835 );
836 }
837
838 let looks_wrong = reduce_len == 1 && num_slices == 1 && {
839 let mut big = false;
840 let mut check_val = |v: &Value| match v {
841 Value::GpuTensor(h) => {
842 let prod = h.shape.iter().copied().product::<usize>();
843 if prod > 1 {
844 big = true;
845 }
846 }
847 Value::Tensor(t) => {
848 let prod = t.shape.iter().copied().product::<usize>();
849 if prod > 1 {
850 big = true;
851 }
852 }
853 _ => {}
854 };
855 for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
856 check_val(v);
857 }
858 for v in &request.inputs {
859 check_val(v);
860 }
861 big
862 };
863 if looks_wrong {
864 log::debug!("fusion reduction: skipping fusion due to unresolved shape; falling back to provider path");
865 return Err(mex(
866 "FusionReductionShapeUnresolved",
867 "fusion: reduction shape unresolved",
868 ));
869 }
870 if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION")
871 .ok()
872 .as_deref()
873 == Some("1")
874 {
875 return Err(mex(
876 "FusionReductionDisabled",
877 "fusion: fused reductions disabled",
878 ));
879 }
880
881 Ok(ReductionGeometry {
882 axis,
883 reduce_len,
884 num_slices,
885 })
886}
887
888pub fn execute_fusion_reduction(
889 plan: &runmat_accelerate::FusionGroupPlan,
890 graph: &runmat_accelerate::AccelGraph,
891 request: FusionExecutionRequest<'_>,
892 consumed_inputs: &[Option<Value>],
893 stack_guard: StackSliceGuard<'_>,
894 vars: &[Value],
895 context: &ExecutionContext,
896) -> Result<Value, RuntimeError> {
897 let geom = resolve_reduction_geometry(plan, graph, &request, consumed_inputs, vars, context)?;
898 match execute_reduction(request, geom.reduce_len, geom.num_slices, 256u32) {
899 Ok(result) => {
900 stack_guard.commit();
901 Ok(result)
902 }
903 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
904 }
905}
906
907pub async fn try_execute_fusion_group(
908 plan: &runmat_accelerate::FusionGroupPlan,
909 graph: &runmat_accelerate::AccelGraph,
910 stack: &mut Vec<Value>,
911 vars: &mut Vec<Value>,
912 context: &mut ExecutionContext,
913) -> Result<Value, RuntimeError> {
914 let (stack_guard, request, consumed_inputs) =
915 gather_fusion_inputs(plan, graph, stack, vars, context)?;
916 log::debug!(
917 "dispatch fusion kind {:?}, supported {}",
918 plan.group.kind,
919 plan.kernel.supported
920 );
921 if plan.group.kind.is_elementwise() {
922 execute_fusion_elementwise(request, stack_guard, vars, context)
923 } else if plan.group.kind.is_reduction() {
924 execute_fusion_reduction(
925 plan,
926 graph,
927 request,
928 &consumed_inputs,
929 stack_guard,
930 vars,
931 context,
932 )
933 } else {
934 execute_fusion_special_kind(plan.group.kind.clone(), &plan.inputs, request, stack_guard)
935 .await
936 }
937}