Skip to main content

runmat_vm/accel/
fusion.rs

1use crate::accel::residency as accel_residency;
2use crate::bytecode::program::ExecutionContext;
3use crate::bytecode::Instr;
4use crate::interpreter::engine as interp_engine;
5use crate::interpreter::errors::mex;
6use crate::runtime::workspace::refresh_workspace_state;
7use runmat_accelerate::fusion::FusionStoreMaterialization;
8use runmat_accelerate::fusion_exec::{
9    execute_centered_gram, execute_elementwise, execute_explained_variance,
10    execute_image_normalize, execute_matmul_epilogue, execute_power_step_normalize,
11    execute_reduction, FusionExecutionRequest,
12};
13use runmat_accelerate::InstrSpan;
14use runmat_accelerate::{value_is_all_keyword, FusionKind, ShapeInfo, ValueOrigin, VarKind};
15use runmat_builtins::Value;
16use runmat_runtime::builtins::common::shape::is_scalar_shape;
17use runmat_runtime::RuntimeError;
18use std::collections::HashMap;
19
20#[inline]
21pub fn value_kind(value: &Value) -> &'static str {
22    match value {
23        Value::Int(_) => "Int",
24        Value::Num(_) => "Num",
25        Value::Complex(_, _) => "Complex",
26        Value::Bool(_) => "Bool",
27        Value::LogicalArray(_) => "LogicalArray",
28        Value::String(_) => "String",
29        Value::StringArray(_) => "StringArray",
30        Value::CharArray(_) => "CharArray",
31        Value::Tensor(_) => "Tensor",
32        Value::ComplexTensor(_) => "ComplexTensor",
33        Value::Cell(_) => "Cell",
34        Value::Struct(_) => "Struct",
35        Value::GpuTensor(_) => "GpuTensor",
36        Value::Object(_) => "Object",
37        Value::HandleObject(_) => "HandleObject",
38        Value::Listener(_) => "Listener",
39        Value::FunctionHandle(_)
40        | Value::ExternalFunctionHandle(_)
41        | Value::MethodFunctionHandle(_) => "FunctionHandle",
42        Value::BoundFunctionHandle { .. } => "FunctionHandle",
43        Value::Closure(_) => "Closure",
44        Value::ClassRef(_) => "ClassRef",
45        Value::MException(_) => "MException",
46        Value::OutputList(_) => "OutputList",
47    }
48}
49
50#[inline]
51pub fn summarize_value(i: usize, v: &Value) -> String {
52    match v {
53        Value::GpuTensor(h) => format!("in#{i}:GpuTensor shape={:?}", h.shape),
54        Value::Tensor(t) => format!("in#{i}:Tensor shape={:?}", t.shape),
55        Value::Num(n) => format!("in#{i}:Num({n:.6})"),
56        Value::Int(n) => format!("in#{i}:Int({})", n.to_i64()),
57        Value::Bool(b) => format!("in#{i}:Bool({})", if *b { 1 } else { 0 }),
58        Value::String(s) => format!("in#{i}:String({})", s),
59        _ => format!("in#{i}:{}", value_kind(v)),
60    }
61}
62
63#[inline]
64fn is_scalarish_runtime_value(value: &Value) -> bool {
65    match value {
66        Value::Num(_) | Value::Int(_) | Value::Bool(_) | Value::Complex(_, _) => true,
67        Value::Tensor(tensor) => is_scalar_shape(&tensor.shape),
68        Value::ComplexTensor(tensor) => is_scalar_shape(&tensor.shape),
69        Value::LogicalArray(array) => is_scalar_shape(&array.shape),
70        Value::GpuTensor(handle) => is_scalar_shape(&handle.shape),
71        Value::CharArray(array) => array.rows * array.cols == 1,
72        _ => false,
73    }
74}
75
76pub fn fusion_span_live_result_count(instructions: &[Instr], span: &InstrSpan) -> Option<usize> {
77    if span.start > span.end || span.end >= instructions.len() {
78        return None;
79    }
80    let mut current_depth = 0usize;
81    for instr in &instructions[span.start..=span.end] {
82        let effect = instr.stack_effect()?;
83        if current_depth < effect.pops {
84            current_depth = effect.pops;
85        }
86        current_depth = current_depth - effect.pops + effect.pushes;
87    }
88    Some(current_depth)
89}
90
91pub fn fusion_span_has_vm_barrier(instructions: &[Instr], span: &InstrSpan) -> bool {
92    if span.start > span.end || span.end >= instructions.len() {
93        return true;
94    }
95    for instr in &instructions[span.start..=span.end] {
96        if matches!(
97            instr,
98            Instr::StoreIndex(_)
99                | Instr::StoreIndexDelete(_)
100                | Instr::StoreSlice(_, _, _, _)
101                | Instr::StoreSliceDelete(_, _, _, _)
102                | Instr::StoreSliceExpr { .. }
103                | Instr::StoreSliceExprDelete { .. }
104                | Instr::StoreIndexCell { .. }
105                | Instr::StoreIndexCellDelete { .. }
106                | Instr::StoreMember(_)
107                | Instr::StoreMemberOrInit(_)
108                | Instr::StoreMemberDynamic
109                | Instr::StoreMemberDynamicOrInit
110        ) {
111            return true;
112        }
113    }
114    fusion_span_live_result_count(instructions, span) != Some(1)
115}
116
117pub struct StackSliceGuard<'a> {
118    stack: *mut Vec<Value>,
119    slice: Option<Vec<Value>>,
120    _marker: std::marker::PhantomData<&'a mut Vec<Value>>,
121}
122
123impl<'a> StackSliceGuard<'a> {
124    pub fn new(stack: &'a mut Vec<Value>, slice_start: usize) -> Self {
125        let slice = stack.split_off(slice_start);
126        Self {
127            stack,
128            slice: Some(slice),
129            _marker: std::marker::PhantomData,
130        }
131    }
132
133    pub fn slice(&self) -> &[Value] {
134        self.slice.as_ref().expect("stack slice missing").as_slice()
135    }
136
137    pub fn commit(mut self) {
138        self.slice = None;
139    }
140}
141
142impl Drop for StackSliceGuard<'_> {
143    fn drop(&mut self) {
144        if let Some(slice) = self.slice.take() {
145            unsafe { (&mut *self.stack).extend(slice) }
146        }
147    }
148}
149
150pub fn gather_fusion_inputs<'a>(
151    plan: &'a runmat_accelerate::FusionGroupPlan,
152    graph: &runmat_accelerate::AccelGraph,
153    stack: &'a mut Vec<Value>,
154    vars: &mut [Value],
155    context: &mut ExecutionContext,
156) -> Result<
157    (
158        StackSliceGuard<'a>,
159        FusionExecutionRequest<'a>,
160        Vec<Option<Value>>,
161    ),
162    RuntimeError,
163> {
164    if plan.group.stack_layout.is_none() && !plan.stack_pattern.is_empty() {
165        return Err(mex(
166            "FusionMissingStackLayout",
167            "fusion: missing compile-time stack layout metadata",
168        ));
169    }
170    let required_stack_operands = plan
171        .group
172        .stack_layout
173        .as_ref()
174        .map(|layout| layout.required_stack_operands)
175        .unwrap_or_else(|| plan.stack_pattern.len());
176    let mut inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
177
178    for (idx, value) in &plan.constants {
179        if let Some(slot) = inputs.get_mut(*idx) {
180            if slot.is_none() {
181                *slot = Some(value.clone());
182            }
183        }
184    }
185
186    for (idx, value_id) in plan.inputs.iter().enumerate() {
187        let info = graph
188            .value(*value_id)
189            .ok_or_else(|| format!("fusion: missing value metadata for id {value_id}"))?;
190        match &info.origin {
191            ValueOrigin::Variable { kind, index } => {
192                let value =
193                    match kind {
194                        VarKind::Global => vars
195                            .get(*index)
196                            .cloned()
197                            .ok_or_else(|| format!("fusion: global var {index} out of range"))?,
198                        VarKind::Local => {
199                            if let Some(frame) = context.call_stack.last() {
200                                let absolute = frame.locals_start + index;
201                                context.locals.get(absolute).cloned().ok_or_else(|| {
202                                    format!("fusion: local var {index} unavailable")
203                                })?
204                            } else {
205                                vars.get(*index).cloned().ok_or_else(|| {
206                                    format!("fusion: local var {index} unavailable")
207                                })?
208                            }
209                        }
210                    };
211                debug_assert!(
212                    inputs[idx].is_none(),
213                    "fusion: duplicate input slot {} for plan {}",
214                    idx,
215                    plan.index
216                );
217                inputs[idx] = Some(value);
218            }
219            ValueOrigin::Constant | ValueOrigin::NodeOutput { .. } | ValueOrigin::Unknown => {}
220        }
221    }
222
223    if log::log_enabled!(log::Level::Debug) && interp_engine::fusion_debug_enabled() {
224        let stack_needed_preview = required_stack_operands;
225        let stack_snapshot: Vec<&Value> = stack.iter().rev().take(stack_needed_preview).collect();
226        let stack_kinds: Vec<&'static str> =
227            stack_snapshot.iter().rev().map(|v| value_kind(v)).collect();
228        let input_meta: Vec<String> = plan
229            .inputs
230            .iter()
231            .enumerate()
232            .map(|(i, value_id)| {
233                if let Some(info) = graph.value(*value_id) {
234                    format!("#{i}:id={} origin={:?}", value_id, info.origin)
235                } else {
236                    format!("#{i}:id={} origin=<missing>", value_id)
237                }
238            })
239            .collect();
240        log::debug!(
241            "fusion group {} gather: stack_depth={} stack_needed={} stack_kinds={:?} pattern={:?} inputs={:?}",
242            plan.index, stack.len(), stack_needed_preview, stack_kinds, &plan.stack_pattern, input_meta
243        );
244    }
245
246    if stack.len() < required_stack_operands {
247        if interp_engine::fusion_debug_enabled() {
248            log::debug!(
249                "fusion stack underflow: plan={} needed={} available={} pattern={:?}",
250                plan.index,
251                required_stack_operands,
252                stack.len(),
253                plan.stack_pattern
254            );
255        }
256        return Err(mex(
257            "FusionStackUnderflow",
258            "fusion: stack underflow gathering inputs",
259        ));
260    }
261    let available = required_stack_operands;
262    let slice_start = stack.len() - available;
263    let stack_guard = StackSliceGuard::new(stack, slice_start);
264    let slice = stack_guard.slice().to_vec();
265    let mut consumed_inputs: Vec<Option<Value>> = vec![None; plan.inputs.len()];
266    let input_positions: HashMap<runmat_accelerate::graph::ValueId, usize> = plan
267        .inputs
268        .iter()
269        .enumerate()
270        .map(|(idx, value_id)| (*value_id, idx))
271        .collect();
272
273    let allow_stack_value = |val: &Value| {
274        if plan.group.kind.is_reduction() {
275            matches!(val, Value::GpuTensor(_) | Value::Tensor(_))
276        } else {
277            true
278        }
279    };
280
281    if let Some(layout) = plan.group.stack_layout.as_ref() {
282        for binding in &layout.bindings {
283            let Some(input_idx) = input_positions.get(&binding.value_id).copied() else {
284                continue;
285            };
286            let Some(val) = slice.get(binding.stack_offset).cloned() else {
287                continue;
288            };
289            consumed_inputs[input_idx] = Some(val.clone());
290            if inputs[input_idx].is_none() && allow_stack_value(&val) {
291                inputs[input_idx] = Some(val);
292            }
293        }
294    } else {
295        for (offset, input_idx) in plan.stack_pattern.iter().enumerate() {
296            let Some(val) = slice.get(offset).cloned() else {
297                continue;
298            };
299            consumed_inputs[*input_idx] = Some(val.clone());
300            if inputs[*input_idx].is_none() && allow_stack_value(&val) {
301                inputs[*input_idx] = Some(val);
302            }
303        }
304    }
305
306    for (idx, slot) in inputs.iter_mut().enumerate() {
307        if slot.is_some() {
308            continue;
309        }
310        let vid = plan.inputs[idx];
311        let info = graph.value(vid);
312        if let Some(info) = info {
313            match &info.origin {
314                ValueOrigin::Variable { kind, index } => {
315                    let value_opt = match kind {
316                        VarKind::Global => vars.get(*index).cloned(),
317                        VarKind::Local => {
318                            if let Some(frame) = context.call_stack.last() {
319                                let absolute = frame.locals_start + index;
320                                context.locals.get(absolute).cloned()
321                            } else {
322                                vars.get(*index).cloned()
323                            }
324                        }
325                    };
326                    if let Some(value) = value_opt {
327                        *slot = Some(value);
328                        continue;
329                    }
330                }
331                ValueOrigin::Constant => {
332                    if let Some(value) = plan.const_values.get(&vid) {
333                        *slot = Some(value.clone());
334                        continue;
335                    }
336                }
337                _ => {}
338            }
339        }
340        if slot.is_none() {
341            if let Some(binding) = graph.var_binding(vid) {
342                let value_opt = match binding.kind {
343                    VarKind::Global => vars.get(binding.index).cloned(),
344                    VarKind::Local => {
345                        if let Some(frame) = context.call_stack.last() {
346                            let absolute = frame.locals_start + binding.index;
347                            context.locals.get(absolute).cloned()
348                        } else {
349                            vars.get(binding.index).cloned()
350                        }
351                    }
352                };
353                if let Some(value) = value_opt {
354                    *slot = Some(value);
355                    continue;
356                }
357            }
358        }
359        if slot.is_none() {
360            if let Some(info) = info {
361                if let ValueOrigin::NodeOutput { node, .. } = info.origin {
362                    if let Some(binding) = graph.node_binding(node) {
363                        let value_opt = match binding.kind {
364                            VarKind::Global => vars.get(binding.index).cloned(),
365                            VarKind::Local => {
366                                if let Some(frame) = context.call_stack.last() {
367                                    let absolute = frame.locals_start + binding.index;
368                                    context.locals.get(absolute).cloned()
369                                } else {
370                                    vars.get(binding.index).cloned()
371                                }
372                            }
373                        };
374                        if let Some(value) = value_opt {
375                            *slot = Some(value);
376                            continue;
377                        }
378                    }
379                }
380            }
381        }
382        if slot.is_none() {
383            if let Some(value) = plan.const_values.get(&vid) {
384                *slot = Some(value.clone());
385            }
386        }
387    }
388
389    let inputs: Vec<Value> = inputs
390        .into_iter()
391        .map(|opt| opt.ok_or_else(|| mex("FusionMissingInput", "fusion: missing input value")))
392        .collect::<Result<_, _>>()?;
393
394    if log::log_enabled!(log::Level::Debug) {
395        let summaries: Vec<String> = inputs
396            .iter()
397            .enumerate()
398            .map(|(i, v)| summarize_value(i, v))
399            .collect();
400        log::debug!("fusion inputs runtime: [{}]", summaries.join(", "));
401    }
402
403    Ok((
404        stack_guard,
405        FusionExecutionRequest { plan, inputs },
406        consumed_inputs,
407    ))
408}
409
410pub fn write_elementwise_materialized_stores(
411    materialized_stores: Vec<(FusionStoreMaterialization, Value)>,
412    vars: &mut Vec<Value>,
413    context: &mut ExecutionContext,
414) {
415    for (store, value) in materialized_stores {
416        match store.binding.kind {
417            VarKind::Global => {
418                let i = store.binding.index;
419                if i < vars.len() {
420                    accel_residency::clear_value_excluding(&vars[i], &value);
421                }
422                if i >= vars.len() {
423                    vars.resize(i + 1, Value::Num(0.0));
424                    refresh_workspace_state(vars);
425                }
426                vars[i] = value;
427            }
428            VarKind::Local => {
429                if let Some(frame) = context.call_stack.last() {
430                    let absolute = frame.locals_start + store.binding.index;
431                    while context.locals.len() <= absolute {
432                        context.locals.push(Value::Num(0.0));
433                    }
434                    accel_residency::clear_value_excluding(&context.locals[absolute], &value);
435                    context.locals[absolute] = value;
436                } else {
437                    let i = store.binding.index;
438                    if i < vars.len() {
439                        accel_residency::clear_value_excluding(&vars[i], &value);
440                    }
441                    if i >= vars.len() {
442                        vars.resize(i + 1, Value::Num(0.0));
443                        refresh_workspace_state(vars);
444                    }
445                    vars[i] = value;
446                }
447            }
448        }
449    }
450}
451
452pub fn execute_fusion_elementwise(
453    request: FusionExecutionRequest<'_>,
454    stack_guard: StackSliceGuard<'_>,
455    vars: &mut Vec<Value>,
456    context: &mut ExecutionContext,
457) -> Result<Value, RuntimeError> {
458    match execute_elementwise(request) {
459        Ok(result) => {
460            write_elementwise_materialized_stores(result.materialized_stores, vars, context);
461            stack_guard.commit();
462            Ok(result.final_value)
463        }
464        Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
465    }
466}
467
468pub async fn execute_fusion_special_kind(
469    kind: FusionKind,
470    plan_inputs: &[runmat_accelerate::graph::ValueId],
471    request: FusionExecutionRequest<'_>,
472    stack_guard: StackSliceGuard<'_>,
473) -> Result<Value, RuntimeError> {
474    match kind {
475        FusionKind::CenteredGram => match execute_centered_gram(request).await {
476            Ok(result) => {
477                stack_guard.commit();
478                Ok(result)
479            }
480            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
481        },
482        FusionKind::PowerStepNormalize => match execute_power_step_normalize(request).await {
483            Ok(result) => {
484                stack_guard.commit();
485                Ok(result)
486            }
487            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
488        },
489        FusionKind::ExplainedVariance => {
490            log::debug!("explained variance plan inputs {:?}", plan_inputs);
491            match execute_explained_variance(request).await {
492                Ok(result) => {
493                    stack_guard.commit();
494                    Ok(result)
495                }
496                Err(err) => {
497                    log::debug!("explained variance fusion fallback: {}", err);
498                    Err(mex("FusionExecutionFailed", &err.to_string()))
499                }
500            }
501        }
502        FusionKind::MatmulEpilogue => match execute_matmul_epilogue(request).await {
503            Ok(result) => {
504                stack_guard.commit();
505                Ok(result)
506            }
507            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
508        },
509        FusionKind::ImageNormalize => match execute_image_normalize(request).await {
510            Ok(result) => {
511                stack_guard.commit();
512                Ok(result)
513            }
514            Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
515        },
516        _ => Err(mex(
517            "FusionUnsupportedKind",
518            "fusion: unsupported fusion kind",
519        )),
520    }
521}
522
523pub struct ReductionGeometry {
524    pub axis: usize,
525    pub reduce_len: usize,
526    pub num_slices: usize,
527}
528
529pub fn resolve_reduction_geometry(
530    plan: &runmat_accelerate::FusionGroupPlan,
531    graph: &runmat_accelerate::AccelGraph,
532    request: &FusionExecutionRequest<'_>,
533    consumed_inputs: &[Option<Value>],
534    vars: &[Value],
535    context: &ExecutionContext,
536) -> Result<ReductionGeometry, RuntimeError> {
537    fn detect_reduce_all(
538        plan: &runmat_accelerate::FusionGroupPlan,
539        graph: &runmat_accelerate::AccelGraph,
540    ) -> bool {
541        let mut reduce_all = matches!(
542            plan.reduction_axes,
543            Some(runmat_accelerate::ReductionAxes::All)
544        );
545        let has_all = reduce_all
546            || plan.constants.values().any(value_is_all_keyword)
547            || plan.const_values.values().any(value_is_all_keyword);
548        if has_all {
549            return true;
550        }
551        for node_id in &plan.group.nodes {
552            if let Some(node) = graph.node(*node_id) {
553                if let runmat_accelerate::graph::AccelNodeLabel::Builtin { name } = &node.label {
554                    if name.eq_ignore_ascii_case("mean") {
555                        for input_vid in &node.inputs {
556                            if let Some(info) = graph.value(*input_vid) {
557                                if let Some(constant) = &info.constant {
558                                    if value_is_all_keyword(constant) {
559                                        reduce_all = true;
560                                        break;
561                                    }
562                                }
563                            }
564                        }
565                    }
566                }
567            }
568            if reduce_all {
569                break;
570            }
571        }
572        reduce_all
573    }
574
575    fn resolve_reduction_axis(plan: &runmat_accelerate::FusionGroupPlan) -> (usize, bool) {
576        let mut axis = 0usize;
577        let mut axis_explicit = false;
578        if let Some(runmat_accelerate::ReductionAxes::Explicit(dims)) = &plan.reduction_axes {
579            if let Some(first) = dims.first().copied() {
580                axis = first.saturating_sub(1);
581                axis_explicit = true;
582            }
583        }
584        if let Some(dim_vid) = plan.reduction_dim {
585            if let Some(cv) = plan.const_values.get(&dim_vid) {
586                axis = match cv {
587                    Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
588                    Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
589                    _ => axis,
590                };
591                axis_explicit = true;
592            } else if let Some(input_idx) = plan.inputs.iter().position(|v| *v == dim_vid) {
593                if let Some(cv) = plan.constants.get(&input_idx) {
594                    axis = match cv {
595                        Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
596                        Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
597                        _ => axis,
598                    };
599                    axis_explicit = true;
600                }
601            }
602        } else if let Some(dim_const) = plan.constants.get(&1) {
603            axis = match dim_const {
604                Value::Num(n) if *n >= 1.0 => (*n as usize).saturating_sub(1),
605                Value::Int(i) => (i.to_f64() as usize).saturating_sub(1),
606                _ => axis,
607            };
608            axis_explicit = true;
609        }
610        (axis, axis_explicit)
611    }
612
613    fn derive_rows_cols(
614        plan: &runmat_accelerate::FusionGroupPlan,
615        graph: &runmat_accelerate::AccelGraph,
616        request: &FusionExecutionRequest<'_>,
617        consumed_inputs: &[Option<Value>],
618        vars: &[Value],
619        context: &ExecutionContext,
620    ) -> Option<(usize, usize)> {
621        let shape_of = |value: &Value| -> Option<(usize, usize)> {
622            match value {
623                Value::GpuTensor(h) => Some((
624                    h.shape.first().copied().unwrap_or(1).max(1),
625                    h.shape.get(1).copied().unwrap_or(1).max(1),
626                )),
627                Value::Tensor(t) => Some((
628                    t.shape.first().copied().unwrap_or(1).max(1),
629                    t.shape.get(1).copied().unwrap_or(1).max(1),
630                )),
631                _ => None,
632            }
633        };
634
635        if let Some(shape) = plan.reduction_data_shape(graph) {
636            if shape.len() >= 2 {
637                return Some((shape[0].max(1), shape[1].max(1)));
638            }
639            if shape.len() == 1 {
640                return Some((shape[0].max(1), 1));
641            }
642        }
643
644        for &vid in &plan.inputs {
645            if let Some(binding) = graph.var_binding(vid) {
646                let value_opt = match binding.kind {
647                    VarKind::Global => vars.get(binding.index).cloned(),
648                    VarKind::Local => {
649                        if let Some(frame) = context.call_stack.last() {
650                            let absolute = frame.locals_start + binding.index;
651                            context.locals.get(absolute).cloned()
652                        } else {
653                            vars.get(binding.index).cloned()
654                        }
655                    }
656                };
657                if let Some(value) = value_opt {
658                    if let Some(shape) = shape_of(&value) {
659                        return Some(shape);
660                    }
661                }
662            }
663        }
664
665        for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
666            if let Some(shape) = shape_of(v) {
667                return Some(shape);
668            }
669        }
670
671        if let Some(data_id) = plan.reduction_data {
672            if let Some(input_index) = plan.inputs.iter().position(|vid| *vid == data_id) {
673                if let Some(val) = consumed_inputs.get(input_index).and_then(|v| v.as_ref()) {
674                    if let Some(shape) = shape_of(val) {
675                        return Some(shape);
676                    }
677                }
678                if let Some(val) = request.inputs.get(input_index) {
679                    if let Some(shape) = shape_of(val) {
680                        return Some(shape);
681                    }
682                }
683            }
684            if let Some(info) = graph.value(data_id) {
685                if let ValueOrigin::Variable { kind, index } = &info.origin {
686                    let val = match kind {
687                        VarKind::Global => vars.get(*index).cloned(),
688                        VarKind::Local => {
689                            if let Some(frame) = context.call_stack.last() {
690                                let absolute = frame.locals_start + index;
691                                context.locals.get(absolute).cloned()
692                            } else {
693                                vars.get(*index).cloned()
694                            }
695                        }
696                    };
697                    if let Some(v) = val {
698                        if let Some(shape) = shape_of(&v) {
699                            return Some(shape);
700                        }
701                    }
702                }
703                if let ShapeInfo::Tensor(dims) = &info.shape {
704                    if !dims.is_empty() {
705                        let r = dims.first().and_then(|d| *d).unwrap_or(1);
706                        let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
707                        return Some((r.max(1), c.max(1)));
708                    }
709                }
710            }
711        }
712
713        for v in &request.inputs {
714            if let Some(shape) = shape_of(v) {
715                return Some(shape);
716            }
717        }
718
719        if let ShapeInfo::Tensor(dims) = &plan.group.shape {
720            if !dims.is_empty() {
721                let r = dims.first().and_then(|d| *d).unwrap_or(1);
722                let c = dims.get(1).and_then(|d| *d).unwrap_or(1);
723                return Some((r.max(1), c.max(1)));
724            }
725        }
726        None
727    }
728
729    if log::log_enabled!(log::Level::Debug) {
730        let meta: Vec<String> = plan
731            .inputs
732            .iter()
733            .map(|vid| {
734                if let Some(info) = graph.value(*vid) {
735                    format!(
736                        "vid={} origin={:?} shape={:?}",
737                        vid, info.origin, info.shape
738                    )
739                } else {
740                    format!("vid={} origin=<missing>", vid)
741                }
742            })
743            .collect();
744        log::debug!("reduction gather meta: [{}]", meta.join(", "));
745    }
746
747    let reduce_all = detect_reduce_all(plan, graph);
748    let (mut axis, axis_explicit) = if reduce_all {
749        (0usize, false)
750    } else {
751        resolve_reduction_axis(plan)
752    };
753    if reduce_all && interp_engine::fusion_debug_enabled() {
754        log::debug!(
755            "fusion reduction (all) meta: data_vid={:?} inputs={:?} stack_pattern={:?}",
756            plan.reduction_data,
757            plan.inputs,
758            plan.stack_pattern
759        );
760    }
761
762    let (r, c) =
763        derive_rows_cols(plan, graph, request, consumed_inputs, vars, context).unwrap_or((1, 1));
764    let (reduce_len, num_slices) = if reduce_all {
765        let total_from_runtime = consumed_inputs
766            .iter()
767            .filter_map(|v| v.as_ref())
768            .chain(request.inputs.iter())
769            .find_map(|value| match value {
770                Value::GpuTensor(handle) => Some(if handle.shape.is_empty() {
771                    1
772                } else {
773                    handle
774                        .shape
775                        .iter()
776                        .copied()
777                        .map(|d| d.max(1))
778                        .product::<usize>()
779                }),
780                Value::Tensor(tensor) => Some(if tensor.shape.is_empty() {
781                    1
782                } else {
783                    tensor
784                        .shape
785                        .iter()
786                        .copied()
787                        .map(|d| d.max(1))
788                        .product::<usize>()
789                }),
790                _ => None,
791            });
792        let total = plan
793            .reduction_data_shape(graph)
794            .map(|shape| shape.into_iter().map(|d| d.max(1)).product::<usize>())
795            .or(total_from_runtime)
796            .or_else(|| plan.element_count())
797            .filter(|v| *v > 0)
798            .ok_or_else(|| {
799                mex(
800                    "FusionReductionExtentUnknown",
801                    "fusion: reduction all extent unknown",
802                )
803            })?;
804        if interp_engine::fusion_debug_enabled() {
805            log::debug!(
806                "fusion reduction (all): total_elems={} fallback_rows={} fallback_cols={}",
807                total,
808                r,
809                c
810            );
811        }
812        (total, 1usize)
813    } else {
814        if !axis_explicit {
815            axis = if r == 1 && c > 1 {
816                1
817            } else if r > 1 {
818                0
819            } else {
820                axis
821            };
822        }
823        if interp_engine::fusion_debug_enabled() {
824            if r == 1 && c == 1 {
825                log::debug!(
826                    "fusion reduction: unresolved shape (defaulted to 1x1); axis={}, constants={:?}",
827                    axis,
828                    plan.constants
829                );
830            } else {
831                log::debug!(
832                    "fusion reduction: resolved shape rows={} cols={} axis={} constants={:?}",
833                    r,
834                    c,
835                    axis,
836                    plan.constants
837                );
838            }
839        }
840        if axis == 0 {
841            (r, c)
842        } else {
843            (c, r)
844        }
845    };
846
847    if interp_engine::fusion_debug_enabled() {
848        log::debug!(
849            "fusion reduction: axis={} reduce_len={} num_slices={} constants={:?}",
850            axis,
851            reduce_len,
852            num_slices,
853            plan.constants
854        );
855    }
856
857    let looks_wrong = reduce_len == 1 && num_slices == 1 && {
858        let mut big = false;
859        let mut check_val = |v: &Value| match v {
860            Value::GpuTensor(h) => {
861                let prod = h.shape.iter().copied().product::<usize>();
862                if prod > 1 {
863                    big = true;
864                }
865            }
866            Value::Tensor(t) => {
867                let prod = t.shape.iter().copied().product::<usize>();
868                if prod > 1 {
869                    big = true;
870                }
871            }
872            _ => {}
873        };
874        for v in consumed_inputs.iter().filter_map(|v| v.as_ref()) {
875            check_val(v);
876        }
877        for v in &request.inputs {
878            check_val(v);
879        }
880        big
881    };
882    if looks_wrong {
883        log::debug!("fusion reduction: skipping fusion due to unresolved shape; falling back to provider path");
884        return Err(mex(
885            "FusionReductionShapeUnresolved",
886            "fusion: reduction shape unresolved",
887        ));
888    }
889    if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION")
890        .ok()
891        .as_deref()
892        == Some("1")
893    {
894        return Err(mex(
895            "FusionReductionDisabled",
896            "fusion: fused reductions disabled",
897        ));
898    }
899
900    Ok(ReductionGeometry {
901        axis,
902        reduce_len,
903        num_slices,
904    })
905}
906
907pub fn execute_fusion_reduction(
908    plan: &runmat_accelerate::FusionGroupPlan,
909    graph: &runmat_accelerate::AccelGraph,
910    request: FusionExecutionRequest<'_>,
911    consumed_inputs: &[Option<Value>],
912    stack_guard: StackSliceGuard<'_>,
913    vars: &[Value],
914    context: &ExecutionContext,
915) -> Result<Value, RuntimeError> {
916    let geom = resolve_reduction_geometry(plan, graph, &request, consumed_inputs, vars, context)?;
917    match execute_reduction(request, geom.reduce_len, geom.num_slices, 256u32) {
918        Ok(result) => {
919            stack_guard.commit();
920            Ok(result)
921        }
922        Err(err) => Err(mex("FusionExecutionFailed", &err.to_string())),
923    }
924}
925
926pub async fn try_execute_fusion_group(
927    plan: &runmat_accelerate::FusionGroupPlan,
928    graph: &runmat_accelerate::AccelGraph,
929    stack: &mut Vec<Value>,
930    vars: &mut Vec<Value>,
931    context: &mut ExecutionContext,
932) -> Result<Value, RuntimeError> {
933    let (stack_guard, request, consumed_inputs) =
934        gather_fusion_inputs(plan, graph, stack, vars, context)?;
935    if plan.group.kind.is_elementwise()
936        && !request.inputs.is_empty()
937        && request.inputs.iter().all(is_scalarish_runtime_value)
938    {
939        return Err(mex(
940            "FusionScalarBypass",
941            "fusion: bypass scalar-only elementwise group",
942        ));
943    }
944    log::debug!(
945        "dispatch fusion kind {:?}, supported {}",
946        plan.group.kind,
947        plan.kernel.supported
948    );
949    if plan.group.kind.is_elementwise() {
950        execute_fusion_elementwise(request, stack_guard, vars, context)
951    } else if plan.group.kind.is_reduction() {
952        execute_fusion_reduction(
953            plan,
954            graph,
955            request,
956            &consumed_inputs,
957            stack_guard,
958            vars,
959            context,
960        )
961    } else {
962        execute_fusion_special_kind(plan.group.kind.clone(), &plan.inputs, request, stack_guard)
963            .await
964    }
965}
966
967#[cfg(all(test, feature = "native-accel"))]
968mod tests {
969    use super::write_elementwise_materialized_stores;
970    use crate::bytecode::program::ExecutionContext;
971    use runmat_accelerate::fusion::FusionStoreMaterialization;
972    use runmat_accelerate::fusion_residency;
973    use runmat_accelerate::graph::VarBinding;
974    use runmat_accelerate::VarKind;
975    use runmat_accelerate_api::GpuTensorHandle;
976    use runmat_builtins::Value;
977
978    #[test]
979    fn fusion_writeback_preserves_shared_gpu_handles() {
980        let shared = GpuTensorHandle {
981            shape: vec![1],
982            device_id: 17,
983            buffer_id: 17001,
984        };
985        let old_only = GpuTensorHandle {
986            shape: vec![1],
987            device_id: 17,
988            buffer_id: 17002,
989        };
990        fusion_residency::mark(&shared);
991        fusion_residency::mark(&old_only);
992        assert!(fusion_residency::is_resident(&shared));
993        assert!(fusion_residency::is_resident(&old_only));
994
995        let mut vars = vec![Value::OutputList(vec![
996            Value::GpuTensor(shared.clone()),
997            Value::GpuTensor(old_only.clone()),
998        ])];
999        let mut context = ExecutionContext {
1000            call_stack: Vec::new(),
1001            locals: Vec::new(),
1002            instruction_pointer: 0,
1003            spawned_task_ids: std::collections::HashSet::new(),
1004            next_spawn_task_id: 0,
1005        };
1006        write_elementwise_materialized_stores(
1007            vec![(
1008                FusionStoreMaterialization {
1009                    value_id: 1,
1010                    binding: VarBinding {
1011                        kind: VarKind::Global,
1012                        index: 0,
1013                    },
1014                },
1015                Value::GpuTensor(shared.clone()),
1016            )],
1017            &mut vars,
1018            &mut context,
1019        );
1020
1021        assert!(fusion_residency::is_resident(&shared));
1022        assert!(!fusion_residency::is_resident(&old_only));
1023        fusion_residency::clear(&shared);
1024    }
1025}