1use crate::accel::residency as accel_residency;
2use crate::bytecode::program::ExecutionContext;
3use crate::bytecode::Instr;
4use crate::interpreter::engine as interp_engine;
5use crate::interpreter::errors::mex;
6use crate::runtime::workspace::refresh_workspace_state;
7use runmat_accelerate::fusion::FusionStoreMaterialization;
8use runmat_accelerate::fusion_exec::{
9 execute_centered_gram, execute_elementwise, execute_explained_variance,
10 execute_image_normalize, execute_matmul_epilogue, execute_power_step_normalize,
11 execute_reduction, FusionExecutionRequest,
12};
13use runmat_accelerate::InstrSpan;
14use runmat_accelerate::{value_is_all_keyword, FusionKind, ShapeInfo, ValueOrigin, VarKind};
15use runmat_builtins::Value;
16use runmat_runtime::builtins::common::shape::is_scalar_shape;
17use runmat_runtime::RuntimeError;
18use std::collections::HashMap;
19
20#[inline]
21pub fn value_kind(value: &Value) -> &'static str {
22 match value {
23 Value::Int(_) => "Int",
24 Value::Num(_) => "Num",
25 Value::Complex(_, _) => "Complex",
26 Value::Bool(_) => "Bool",
27 Value::LogicalArray(_) => "LogicalArray",
28 Value::String(_) => "String",
29 Value::StringArray(_) => "StringArray",
30 Value::CharArray(_) => "CharArray",
31 Value::Tensor(_) => "Tensor",
32 Value::SparseTensor(_) => "SparseTensor",
33 Value::ComplexTensor(_) => "ComplexTensor",
34 Value::Cell(_) => "Cell",
35 Value::Struct(_) => "Struct",
36 Value::GpuTensor(_) => "GpuTensor",
37 Value::Object(_) => "Object",
38 Value::HandleObject(_) => "HandleObject",
39 Value::Listener(_) => "Listener",
40 Value::FunctionHandle(_)
41 | Value::ExternalFunctionHandle(_)
42 | Value::MethodFunctionHandle(_) => "FunctionHandle",
43 Value::BoundFunctionHandle { .. } => "FunctionHandle",
44 Value::Closure(_) => "Closure",
45 Value::ClassRef(_) => "ClassRef",
46 Value::MException(_) => "MException",
47 Value::OutputList(_) => "OutputList",
48 }
49}
50
51#[inline]
52pub fn summarize_value(i: usize, v: &Value) -> String {
53 match v {
54 Value::GpuTensor(h) => format!("in#{i}:GpuTensor shape={:?}", h.shape),
55 Value::Tensor(t) => format!("in#{i}:Tensor shape={:?}", t.shape),
56 Value::Num(n) => format!("in#{i}:Num({n:.6})"),
57 Value::Int(n) => format!("in#{i}:Int({})", n.to_i64()),
58 Value::Bool(b) => format!("in#{i}:Bool({})", if *b { 1 } else { 0 }),
59 Value::String(s) => format!("in#{i}:String({})", s),
60 _ => format!("in#{i}:{}", value_kind(v)),
61 }
62}
63
64#[inline]
65fn is_scalarish_runtime_value(value: &Value) -> bool {
66 match value {
67 Value::Num(_) | Value::Int(_) | Value::Bool(_) | Value::Complex(_, _) => true,
68 Value::Tensor(tensor) => is_scalar_shape(&tensor.shape),
69 Value::ComplexTensor(tensor) => is_scalar_shape(&tensor.shape),
70 Value::LogicalArray(array) => is_scalar_shape(&array.shape),
71 Value::GpuTensor(handle) => is_scalar_shape(&handle.shape),
72 Value::CharArray(array) => array.rows * array.cols == 1,
73 _ => false,
74 }
75}
76
77pub fn fusion_span_live_result_count(instructions: &[Instr], span: &InstrSpan) -> Option<usize> {
78 if span.start > span.end || span.end >= instructions.len() {
79 return None;
80 }
81 let mut current_depth = 0usize;
82 for instr in &instructions[span.start..=span.end] {
83 let effect = instr.stack_effect()?;
84 if current_depth < effect.pops {
85 current_depth = effect.pops;
86 }
87 current_depth = current_depth - effect.pops + effect.pushes;
88 }
89 Some(current_depth)
90}
91
92pub fn fusion_span_has_vm_barrier(instructions: &[Instr], span: &InstrSpan) -> bool {
93 if span.start > span.end || span.end >= instructions.len() {
94 return true;
95 }
96 for instr in &instructions[span.start..=span.end] {
97 if matches!(
98 instr,
99 Instr::StoreIndex(_)
100 | Instr::StoreIndexDelete(_)
101 | Instr::StoreSlice(_, _, _, _)
102 | Instr::StoreSliceDelete(_, _, _, _)
103 | Instr::StoreSliceExpr { .. }
104 | Instr::StoreSliceExprDelete { .. }
105 | Instr::StoreIndexCell { .. }
106 | Instr::StoreIndexCellDelete { .. }
107 | Instr::StoreMember(_)
108 | Instr::StoreMemberOrInit(_)
109 | Instr::StoreMemberDynamic
110 | Instr::StoreMemberDynamicOrInit
111 ) {
112 return true;
113 }
114 }
115 fusion_span_live_result_count(instructions, span) != Some(1)
116}
117
118pub struct StackSliceGuard<'a> {
119 stack: *mut Vec<Value>,
120 slice: Option<Vec<Value>>,
121 _marker: std::marker::PhantomData<&'a mut Vec<Value>>,
122}
123
124impl<'a> StackSliceGuard<'a> {
125 pub fn new(stack: &'a mut Vec<Value>, slice_start: usize) -> Self {
126 let slice = stack.split_off(slice_start);
127 Self {
128 stack,
129 slice: Some(slice),
130 _marker: std::marker::PhantomData,
131 }
132 }
133
134 pub fn slice(&self) -> &[Value] {
135 self.slice.as_ref().expect("stack slice missing").as_slice()
136 }
137
138 pub fn commit(mut self) {
139 self.slice = None;
140 }
141}
142
143impl Drop for StackSliceGuard<'_> {
144 fn drop(&mut self) {
145 if let Some(slice) = self.slice.take() {
146 unsafe { (&mut *self.stack).extend(slice) }
147 }
148 }
149}
150
151pub fn gather_fusion_inputs<'a>(
152 plan: &'a runmat_accelerate::FusionGroupPlan,
153 graph: &runmat_accelerate::AccelGraph,
154 stack: &'a mut Vec<Value>,
155 vars: &mut [Value],
156 context: &mut ExecutionContext,
157) -> Result<
158 (
159 StackSliceGuard<'a>,
160 FusionExecutionRequest<'a>,
161 Vec<Option<Value>>,
162 ),
163 RuntimeError,
164> {
165 if plan.group.stack_layout.is_none() && !plan.stack_pattern.is_empty() {
166 return Err(mex(
167 "FusionMissingStackLayout",
168 "fusion: missing compile-time stack layout metadata",
169 ));
170 }
171 let required_stack_operands = plan
172 .group
173 .stack_layout
174 .as_ref()
175 .map(|layout| layout.required_stack_operands)
176 .unwrap_or_else(|| plan.stack_pattern.len());
177 let mut inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
178
179 for (idx, value) in &plan.constants {
180 if let Some(slot) = inputs.get_mut(*idx) {
181 if slot.is_none() {
182 *slot = Some(value.clone());
183 }
184 }
185 }
186
187 for (idx, value_id) in plan.inputs.iter().enumerate() {
188 let info = graph
189 .value(*value_id)
190 .ok_or_else(|| format!("fusion: missing value metadata for id {value_id}"))?;
191 match &info.origin {
192 ValueOrigin::Variable { kind, index } => {
193 let value =
194 match kind {
195 VarKind::Global => vars
196 .get(*index)
197 .cloned()
198 .ok_or_else(|| format!("fusion: global var {index} out of range"))?,
199 VarKind::Local => {
200 if let Some(frame) = context.call_stack.last() {
201 let absolute = frame.locals_start + index;
202 context.locals.get(absolute).cloned().ok_or_else(|| {
203 format!("fusion: local var {index} unavailable")
204 })?
205 } else {
206 vars.get(*index).cloned().ok_or_else(|| {
207 format!("fusion: local var {index} unavailable")
208 })?
209 }
210 }
211 };
212 debug_assert!(
213 inputs[idx].is_none(),
214 "fusion: duplicate input slot {} for plan {}",
215 idx,
216 plan.index
217 );
218 inputs[idx] = Some(value);
219 }
220 ValueOrigin::Constant | ValueOrigin::NodeOutput { .. } | ValueOrigin::Unknown => {}
221 }
222 }
223
224 if log::log_enabled!(log::Level::Debug) && interp_engine::fusion_debug_enabled() {
225 let stack_needed_preview = required_stack_operands;
226 let stack_snapshot: Vec<&Value> = stack.iter().rev().take(stack_needed_preview).collect();
227 let stack_kinds: Vec<&'static str> =
228 stack_snapshot.iter().rev().map(|v| value_kind(v)).collect();
229 let input_meta: Vec<String> = plan
230 .inputs
231 .iter()
232 .enumerate()
233 .map(|(i, value_id)| {
234 if let Some(info) = graph.value(*value_id) {
235 format!("#{i}:id={} origin={:?}", value_id, info.origin)
236 } else {
237 format!("#{i}:id={} origin=<missing>", value_id)
238 }
239 })
240 .collect();
241 log::debug!(
242 "fusion group {} gather: stack_depth={} stack_needed={} stack_kinds={:?} pattern={:?} inputs={:?}",
243 plan.index, stack.len(), stack_needed_preview, stack_kinds, &plan.stack_pattern, input_meta
244 );
245 }
246
247 if stack.len() < required_stack_operands {
248 if interp_engine::fusion_debug_enabled() {
249 log::debug!(
250 "fusion stack underflow: plan={} needed={} available={} pattern={:?}",
251 plan.index,
252 required_stack_operands,
253 stack.len(),
254 plan.stack_pattern
255 );
256 }
257 return Err(mex(
258 "FusionStackUnderflow",
259 "fusion: stack underflow gathering inputs",
260 ));
261 }
262 let available = required_stack_operands;
263 let slice_start = stack.len() - available;
264 let stack_guard = StackSliceGuard::new(stack, slice_start);
265 let slice = stack_guard.slice().to_vec();
266 let mut consumed_inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
267 let input_positions: HashMap<runmat_accelerate::graph::ValueId, usize> = plan
268 .inputs
269 .iter()
270 .enumerate()
271 .map(|(idx, value_id)| (*value_id, idx))
272 .collect();
273
274 let allow_stack_value = |val: &Value| {
275 if plan.group.kind.is_reduction() {
276 matches!(val, Value::GpuTensor(_) | Value::Tensor(_))
277 } else {
278 true
279 }
280 };
281
282 if let Some(layout) = plan.group.stack_layout.as_ref() {
283 for binding in &layout.bindings {
284 let Some(input_idx) = input_positions.get(&binding.value_id).copied() else {
285 continue;
286 };
287 let Some(val) = slice.get(binding.stack_offset).cloned() else {
288 continue;
289 };
290 consumed_inputs[input_idx] = Some(val.clone());
291 if inputs[input_idx].is_none() && allow_stack_value(&val) {
292 inputs[input_idx] = Some(val);
293 }
294 }
295 } else {
296 for (offset, input_idx) in plan.stack_pattern.iter().enumerate() {
297 let Some(val) = slice.get(offset).cloned() else {
298 continue;
299 };
300 consumed_inputs[*input_idx] = Some(val.clone());
301 if inputs[*input_idx].is_none() && allow_stack_value(&val) {
302 inputs[*input_idx] = Some(val);
303 }
304 }
305 }
306
307 for (idx, slot) in inputs.iter_mut().enumerate() {
308 if slot.is_some() {
309 continue;
310 }
311 let vid = plan.inputs[idx];
312 let info = graph.value(vid);
313 if let Some(info) = info {
314 match &info.origin {
315 ValueOrigin::Variable { kind, index } => {
316 let value_opt = match kind {
317 VarKind::Global => vars.get(*index).cloned(),
318 VarKind::Local => {
319 if let Some(frame) = context.call_stack.last() {
320 let absolute = frame.locals_start + index;
321 context.locals.get(absolute).cloned()
322 } else {
323 vars.get(*index).cloned()
324 }
325 }
326 };
327 if let Some(value) = value_opt {
328 *slot = Some(value);
329 continue;
330 }
331 }
332 ValueOrigin::Constant => {
333 if let Some(value) = plan.const_values.get(&vid) {
334 *slot = Some(value.clone());
335 continue;
336 }
337 }
338 _ => {}
339 }
340 }
341 if slot.is_none() {
342 if let Some(binding) = graph.var_binding(vid) {
343 let value_opt = match binding.kind {
344 VarKind::Global => vars.get(binding.index).cloned(),
345 VarKind::Local => {
346 if let Some(frame) = context.call_stack.last() {
347 let absolute = frame.locals_start + binding.index;
348 context.locals.get(absolute).cloned()
349 } else {
350 vars.get(binding.index).cloned()
351 }
352 }
353 };
354 if let Some(value) = value_opt {
355 *slot = Some(value);
356 continue;
357 }
358 }
359 }
360 if slot.is_none() {
361 if let Some(info) = info {
362 if let ValueOrigin::NodeOutput { node, .. } = info.origin {
363 if let Some(binding) = graph.node_binding(node) {
364 let value_opt = match binding.kind {
365 VarKind::Global => vars.get(binding.index).cloned(),
366 VarKind::Local => {
367 if let Some(frame) = context.call_stack.last() {
368 let absolute = frame.locals_start + binding.index;
369 context.locals.get(absolute).cloned()
370 } else {
371 vars.get(binding.index).cloned()
372 }
373 }
374 };
375 if let Some(value) = value_opt {
376 *slot = Some(value);
377 continue;
378 }
379 }
380 }
381 }
382 }
383 if slot.is_none() {
384 if let Some(value) = plan.const_values.get(&vid) {
385 *slot = Some(value.clone());
386 }
387 }
388 }
389
390 let inputs: Vec<Value> = inputs
391 .into_iter()
392 .map(|opt| opt.ok_or_else(|| mex("FusionMissingInput", "fusion: missing input value")))
393 .collect::<Result<_, _>>()?;
394
395 if log::log_enabled!(log::Level::Debug) {
396 let summaries: Vec<String> = inputs
397 .iter()
398 .enumerate()
399 .map(|(i, v)| summarize_value(i, v))
400 .collect();
401 log::debug!("fusion inputs runtime: [{}]", summaries.join(", "));
402 }
403
404 Ok((
405 stack_guard,
406 FusionExecutionRequest { plan, inputs },
407 consumed_inputs,
408 ))
409}
410
411pub fn write_elementwise_materialized_stores(
412 materialized_stores: Vec<(FusionStoreMaterialization, Value)>,
413 vars: &mut Vec<Value>,
414 context: &mut ExecutionContext,
415) {
416 for (store, value) in materialized_stores {
417 match store.binding.kind {
418 VarKind::Global => {
419 let i = store.binding.index;
420 if i < vars.len() {
421 accel_residency::clear_value_excluding(&vars[i], &value);
422 }
423 if i >= vars.len() {
424 vars.resize(i + 1, Value::Num(0.0));
425 refresh_workspace_state(vars);
426 }
427 vars[i] = value;
428 }
429 VarKind::Local => {
430 if let Some(frame) = context.call_stack.last() {
431 let absolute = frame.locals_start + store.binding.index;
432 while context.locals.len() <= absolute {
433 context.locals.push(Value::Num(0.0));
434 }
435 accel_residency::clear_value_excluding(&context.locals[absolute], &value);
436 context.locals[absolute] = value;
437 } else {
438 let i = store.binding.index;
439 if i < vars.len() {
440 accel_residency::clear_value_excluding(&vars[i], &value);
441 }
442 if i >= vars.len() {
443 vars.resize(i + 1, Value::Num(0.0));
444 refresh_workspace_state(vars);
445 }
446 vars[i] = value;
447 }
448 }
449 }
450 }
451}
452
453pub fn execute_fusion_elementwise(
454 request: FusionExecutionRequest<'_>,
455 stack_guard: StackSliceGuard<'_>,
456 vars: &mut Vec<Value>,
457 context: &mut ExecutionContext,
458) -> Result<Value, RuntimeError> {
459 match execute_elementwise(request) {
460 Ok(result) => {
461 write_elementwise_materialized_stores(result.materialized_stores, vars, context);
462 stack_guard.commit();
463 Ok(result.final_value)
464 }
465 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
466 }
467}
468
469pub async fn execute_fusion_special_kind(
470 kind: FusionKind,
471 plan_inputs: &[runmat_accelerate::graph::ValueId],
472 request: FusionExecutionRequest<'_>,
473 stack_guard: StackSliceGuard<'_>,
474) -> Result<Value, RuntimeError> {
475 match kind {
476 FusionKind::CenteredGram => match execute_centered_gram(request).await {
477 Ok(result) => {
478 stack_guard.commit();
479 Ok(result)
480 }
481 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
482 },
483 FusionKind::PowerStepNormalize => match execute_power_step_normalize(request).await {
484 Ok(result) => {
485 stack_guard.commit();
486 Ok(result)
487 }
488 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
489 },
490 FusionKind::ExplainedVariance => {
491 log::debug!("explained variance plan inputs {:?}", plan_inputs);
492 match execute_explained_variance(request).await {
493 Ok(result) => {
494 stack_guard.commit();
495 Ok(result)
496 }
497 Err(err) => {
498 log::debug!("explained variance fusion fallback: {}", err);
499 Err(mex("FusionExecutionFailed", &err.to_string()))
500 }
501 }
502 }
503 FusionKind::MatmulEpilogue => match execute_matmul_epilogue(request).await {
504 Ok(result) => {
505 stack_guard.commit();
506 Ok(result)
507 }
508 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
509 },
510 FusionKind::ImageNormalize => match execute_image_normalize(request).await {
511 Ok(result) => {
512 stack_guard.commit();
513 Ok(result)
514 }
515 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
516 },
517 _ => Err(mex(
518 "FusionUnsupportedKind",
519 "fusion: unsupported fusion kind",
520 )),
521 }
522}
523
524pub struct ReductionGeometry {
525 pub axis: usize,
526 pub reduce_len: usize,
527 pub num_slices: usize,
528}
529
530pub fn resolve_reduction_geometry(
531 plan: &runmat_accelerate::FusionGroupPlan,
532 graph: &runmat_accelerate::AccelGraph,
533 request: &FusionExecutionRequest<'_>,
534 consumed_inputs: &[Option<Value>],
535 vars: &[Value],
536 context: &ExecutionContext,
537) -> Result<ReductionGeometry, RuntimeError> {
538 fn detect_reduce_all(
539 plan: &runmat_accelerate::FusionGroupPlan,
540 graph: &runmat_accelerate::AccelGraph,
541 ) -> bool {
542 let mut reduce_all = matches!(
543 plan.reduction_axes,
544 Some(runmat_accelerate::ReductionAxes::All)
545 );
546 let has_all = reduce_all
547 || plan.constants.values().any(value_is_all_keyword)
548 || plan.const_values.values().any(value_is_all_keyword);
549 if has_all {
550 return true;
551 }
552 for node_id in &plan.group.nodes {
553 if let Some(node) = graph.node(*node_id) {
554 if let runmat_accelerate::graph::AccelNodeLabel::Builtin { name } = &node.label {
555 if name.eq_ignore_ascii_case("mean") {
556 for input_vid in &node.inputs {
557 if let Some(info) = graph.value(*input_vid) {
558 if let Some(constant) = &info.constant {
559 if value_is_all_keyword(constant) {
560 reduce_all = true;
561 break;
562 }
563 }
564 }
565 }
566 }
567 }
568 }
569 if reduce_all {
570 break;
571 }
572 }
573 reduce_all
574 }
575
576 fn resolve_reduction_axis(plan: &runmat_accelerate::FusionGroupPlan) -> (usize, bool) {
577 let mut axis = 0usize;
578 let mut axis_explicit = false;
579 if let Some(runmat_accelerate::ReductionAxes::Explicit(dims)) = &plan.reduction_axes {
580 if let Some(first) = dims.first().copied() {
581 axis = first.saturating_sub(1);
582 axis_explicit = true;
583 }
584 }
585 if let Some(dim_vid) = plan.reduction_dim {
586 if let Some(cv) = plan.const_values.get(&dim_vid) {
587 axis = match cv {
588 Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
589 Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
590 _ => axis,
591 };
592 axis_explicit = true;
593 } else if let Some(input_idx) = plan.inputs.iter().position(|v| *v == dim_vid) {
594 if let Some(cv) = plan.constants.get(&input_idx) {
595 axis = match cv {
596 Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
597 Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
598 _ => axis,
599 };
600 axis_explicit = true;
601 }
602 }
603 } else if let Some(dim_const) = plan.constants.get(&1) {
604 axis = match dim_const {
605 Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
606 Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
607 _ => axis,
608 };
609 axis_explicit = true;
610 }
611 (axis, axis_explicit)
612 }
613
614 fn derive_rows_cols(
615 plan: &runmat_accelerate::FusionGroupPlan,
616 graph: &runmat_accelerate::AccelGraph,
617 request: &FusionExecutionRequest<'_>,
618 consumed_inputs: &[Option<Value>],
619 vars: &[Value],
620 context: &ExecutionContext,
621 ) -> Option<(usize, usize)> {
622 let shape_of = |value: &Value| -> Option<(usize, usize)> {
623 match value {
624 Value::GpuTensor(h) => Some((
625 h.shape.first().copied().unwrap_or(1).max(1),
626 h.shape.get(1).copied().unwrap_or(1).max(1),
627 )),
628 Value::Tensor(t) => Some((
629 t.shape.first().copied().unwrap_or(1).max(1),
630 t.shape.get(1).copied().unwrap_or(1).max(1),
631 )),
632 _ => None,
633 }
634 };
635
636 if let Some(shape) = plan.reduction_data_shape(graph) {
637 if shape.len() >= 2 {
638 return Some((shape[0].max(1), shape[1].max(1)));
639 }
640 if shape.len() == 1 {
641 return Some((shape[0].max(1), 1));
642 }
643 }
644
645 for &vid in &plan.inputs {
646 if let Some(binding) = graph.var_binding(vid) {
647 let value_opt = match binding.kind {
648 VarKind::Global => vars.get(binding.index).cloned(),
649 VarKind::Local => {
650 if let Some(frame) = context.call_stack.last() {
651 let absolute = frame.locals_start + binding.index;
652 context.locals.get(absolute).cloned()
653 } else {
654 vars.get(binding.index).cloned()
655 }
656 }
657 };
658 if let Some(value) = value_opt {
659 if let Some(shape) = shape_of(&value) {
660 return Some(shape);
661 }
662 }
663 }
664 }
665
666 for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
667 if let Some(shape) = shape_of(v) {
668 return Some(shape);
669 }
670 }
671
672 if let Some(data_id) = plan.reduction_data {
673 if let Some(input_index) = plan.inputs.iter().position(|vid| *vid == data_id) {
674 if let Some(val) = consumed_inputs.get(input_index).and_then(|v| v.as_ref()) {
675 if let Some(shape) = shape_of(val) {
676 return Some(shape);
677 }
678 }
679 if let Some(val) = request.inputs.get(input_index) {
680 if let Some(shape) = shape_of(val) {
681 return Some(shape);
682 }
683 }
684 }
685 if let Some(info) = graph.value(data_id) {
686 if let ValueOrigin::Variable { kind, index } = &info.origin {
687 let val = match kind {
688 VarKind::Global => vars.get(*index).cloned(),
689 VarKind::Local => {
690 if let Some(frame) = context.call_stack.last() {
691 let absolute = frame.locals_start + index;
692 context.locals.get(absolute).cloned()
693 } else {
694 vars.get(*index).cloned()
695 }
696 }
697 };
698 if let Some(v) = val {
699 if let Some(shape) = shape_of(&v) {
700 return Some(shape);
701 }
702 }
703 }
704 if let ShapeInfo::Tensor(dims) = &info.shape {
705 if !dims.is_empty() {
706 let r = dims.first().and_then(|d| *d).unwrap_or(1);
707 let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
708 return Some((r.max(1), c.max(1)));
709 }
710 }
711 }
712 }
713
714 for v in &request.inputs {
715 if let Some(shape) = shape_of(v) {
716 return Some(shape);
717 }
718 }
719
720 if let ShapeInfo::Tensor(dims) = &plan.group.shape {
721 if !dims.is_empty() {
722 let r = dims.first().and_then(|d| *d).unwrap_or(1);
723 let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
724 return Some((r.max(1), c.max(1)));
725 }
726 }
727 None
728 }
729
730 if log::log_enabled!(log::Level::Debug) {
731 let meta: Vec<String> = plan
732 .inputs
733 .iter()
734 .map(|vid| {
735 if let Some(info) = graph.value(*vid) {
736 format!(
737 "vid={} origin={:?} shape={:?}",
738 vid, info.origin, info.shape
739 )
740 } else {
741 format!("vid={} origin=<missing>", vid)
742 }
743 })
744 .collect();
745 log::debug!("reduction gather meta: [{}]", meta.join(", "));
746 }
747
748 let reduce_all = detect_reduce_all(plan, graph);
749 let (mut axis, axis_explicit) = if reduce_all {
750 (0usize, false)
751 } else {
752 resolve_reduction_axis(plan)
753 };
754 if reduce_all && interp_engine::fusion_debug_enabled() {
755 log::debug!(
756 "fusion reduction (all) meta: data_vid={:?} inputs={:?} stack_pattern={:?}",
757 plan.reduction_data,
758 plan.inputs,
759 plan.stack_pattern
760 );
761 }
762
763 let (r, c) =
764 derive_rows_cols(plan, graph, request, consumed_inputs, vars, context).unwrap_or((1, 1));
765 let (reduce_len, num_slices) = if reduce_all {
766 let total_from_runtime = consumed_inputs
767 .iter()
768 .filter_map(|v| v.as_ref())
769 .chain(request.inputs.iter())
770 .find_map(|value| match value {
771 Value::GpuTensor(handle) => Some(if handle.shape.is_empty() {
772 1
773 } else {
774 handle
775 .shape
776 .iter()
777 .copied()
778 .map(|d| d.max(1))
779 .product::<usize>()
780 }),
781 Value::Tensor(tensor) => Some(if tensor.shape.is_empty() {
782 1
783 } else {
784 tensor
785 .shape
786 .iter()
787 .copied()
788 .map(|d| d.max(1))
789 .product::<usize>()
790 }),
791 _ => None,
792 });
793 let total = plan
794 .reduction_data_shape(graph)
795 .map(|shape| shape.into_iter().map(|d| d.max(1)).product::<usize>())
796 .or(total_from_runtime)
797 .or_else(|| plan.element_count())
798 .filter(|v| *v > 0)
799 .ok_or_else(|| {
800 mex(
801 "FusionReductionExtentUnknown",
802 "fusion: reduction all extent unknown",
803 )
804 })?;
805 if interp_engine::fusion_debug_enabled() {
806 log::debug!(
807 "fusion reduction (all): total_elems={} fallback_rows={} fallback_cols={}",
808 total,
809 r,
810 c
811 );
812 }
813 (total, 1usize)
814 } else {
815 if !axis_explicit {
816 axis = if r == 1 && c > 1 {
817 1
818 } else if r > 1 {
819 0
820 } else {
821 axis
822 };
823 }
824 if interp_engine::fusion_debug_enabled() {
825 if r == 1 && c == 1 {
826 log::debug!(
827 "fusion reduction: unresolved shape (defaulted to 1x1); axis={}, constants={:?}",
828 axis,
829 plan.constants
830 );
831 } else {
832 log::debug!(
833 "fusion reduction: resolved shape rows={} cols={} axis={} constants={:?}",
834 r,
835 c,
836 axis,
837 plan.constants
838 );
839 }
840 }
841 if axis == 0 {
842 (r, c)
843 } else {
844 (c, r)
845 }
846 };
847
848 if interp_engine::fusion_debug_enabled() {
849 log::debug!(
850 "fusion reduction: axis={} reduce_len={} num_slices={} constants={:?}",
851 axis,
852 reduce_len,
853 num_slices,
854 plan.constants
855 );
856 }
857
858 let looks_wrong = reduce_len == 1 && num_slices == 1 && {
859 let mut big = false;
860 let mut check_val = |v: &Value| match v {
861 Value::GpuTensor(h) => {
862 let prod = h.shape.iter().copied().product::<usize>();
863 if prod > 1 {
864 big = true;
865 }
866 }
867 Value::Tensor(t) => {
868 let prod = t.shape.iter().copied().product::<usize>();
869 if prod > 1 {
870 big = true;
871 }
872 }
873 _ => {}
874 };
875 for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
876 check_val(v);
877 }
878 for v in &request.inputs {
879 check_val(v);
880 }
881 big
882 };
883 if looks_wrong {
884 log::debug!("fusion reduction: skipping fusion due to unresolved shape; falling back to provider path");
885 return Err(mex(
886 "FusionReductionShapeUnresolved",
887 "fusion: reduction shape unresolved",
888 ));
889 }
890 if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION")
891 .ok()
892 .as_deref()
893 == Some("1")
894 {
895 return Err(mex(
896 "FusionReductionDisabled",
897 "fusion: fused reductions disabled",
898 ));
899 }
900
901 Ok(ReductionGeometry {
902 axis,
903 reduce_len,
904 num_slices,
905 })
906}
907
908pub fn execute_fusion_reduction(
909 plan: &runmat_accelerate::FusionGroupPlan,
910 graph: &runmat_accelerate::AccelGraph,
911 request: FusionExecutionRequest<'_>,
912 consumed_inputs: &[Option<Value>],
913 stack_guard: StackSliceGuard<'_>,
914 vars: &[Value],
915 context: &ExecutionContext,
916) -> Result<Value, RuntimeError> {
917 let geom = resolve_reduction_geometry(plan, graph, &request, consumed_inputs, vars, context)?;
918 match execute_reduction(request, geom.reduce_len, geom.num_slices, 256u32) {
919 Ok(result) => {
920 stack_guard.commit();
921 Ok(result)
922 }
923 Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
924 }
925}
926
927pub async fn try_execute_fusion_group(
928 plan: &runmat_accelerate::FusionGroupPlan,
929 graph: &runmat_accelerate::AccelGraph,
930 stack: &mut Vec<Value>,
931 vars: &mut Vec<Value>,
932 context: &mut ExecutionContext,
933) -> Result<Value, RuntimeError> {
934 let (stack_guard, request, consumed_inputs) =
935 gather_fusion_inputs(plan, graph, stack, vars, context)?;
936 if plan.group.kind.is_elementwise()
937 && !request.inputs.is_empty()
938 && request.inputs.iter().all(is_scalarish_runtime_value)
939 {
940 return Err(mex(
941 "FusionScalarBypass",
942 "fusion: bypass scalar-only elementwise group",
943 ));
944 }
945 log::debug!(
946 "dispatch fusion kind {:?}, supported {}",
947 plan.group.kind,
948 plan.kernel.supported
949 );
950 if plan.group.kind.is_elementwise() {
951 execute_fusion_elementwise(request, stack_guard, vars, context)
952 } else if plan.group.kind.is_reduction() {
953 execute_fusion_reduction(
954 plan,
955 graph,
956 request,
957 &consumed_inputs,
958 stack_guard,
959 vars,
960 context,
961 )
962 } else {
963 execute_fusion_special_kind(plan.group.kind.clone(), &plan.inputs, request, stack_guard)
964 .await
965 }
966}
967
968#[cfg(all(test, feature = "native-accel"))]
969mod tests {
970 use super::write_elementwise_materialized_stores;
971 use crate::bytecode::program::ExecutionContext;
972 use runmat_accelerate::fusion::FusionStoreMaterialization;
973 use runmat_accelerate::fusion_residency;
974 use runmat_accelerate::graph::VarBinding;
975 use runmat_accelerate::VarKind;
976 use runmat_accelerate_api::GpuTensorHandle;
977 use runmat_builtins::Value;
978
979 #[test]
980 fn fusion_writeback_preserves_shared_gpu_handles() {
981 let shared = GpuTensorHandle {
982 shape: vec![1],
983 device_id: 17,
984 buffer_id: 17001,
985 };
986 let old_only = GpuTensorHandle {
987 shape: vec![1],
988 device_id: 17,
989 buffer_id: 17002,
990 };
991 fusion_residency::mark(&shared);
992 fusion_residency::mark(&old_only);
993 assert!(fusion_residency::is_resident(&shared));
994 assert!(fusion_residency::is_resident(&old_only));
995
996 let mut vars = vec![Value::OutputList(vec![
997 Value::GpuTensor(shared.clone()),
998 Value::GpuTensor(old_only.clone()),
999 ])];
1000 let mut context = ExecutionContext {
1001 call_stack: Vec::new(),
1002 locals: Vec::new(),
1003 instruction_pointer: 0,
1004 spawned_task_ids: std::collections::HashSet::new(),
1005 next_spawn_task_id: 0,
1006 };
1007 write_elementwise_materialized_stores(
1008 vec![(
1009 FusionStoreMaterialization {
1010 value_id: 1,
1011 binding: VarBinding {
1012 kind: VarKind::Global,
1013 index: 0,
1014 },
1015 },
1016 Value::GpuTensor(shared.clone()),
1017 )],
1018 &mut vars,
1019 &mut context,
1020 );
1021
1022 assert!(fusion_residency::is_resident(&shared));
1023 assert!(!fusion_residency::is_resident(&old_only));
1024 fusion_residency::clear(&shared);
1025 }
1026}