Skip to main content

runmat_accelerate/
fusion_exec.rs

1use anyhow::{anyhow, Result};
2
3use crate::fusion::{
4    FusionGroupPlan, FusionKind, FusionPattern, FusionStoreMaterialization, ImageScalar,
5};
6use crate::fusion_residency;
7use crate::graph;
8use crate::graph::{ShapeInfo, ValueId};
9use crate::precision::ensure_provider_supports_dtype;
10use log;
11use runmat_accelerate_api::{
12    provider, AccelProvider, CovRows, CovarianceOptions, GpuTensorHandle, HostTensorView,
13    ImageNormalizeDescriptor, PowerStepEpilogue, ProviderPrecision, ReductionFlavor,
14};
15use runmat_builtins::{NumericDType, Value};
16use runmat_runtime::builtins::common::shape::normalize_scalar_shape;
17use runmat_runtime::gather_if_needed;
18use runmat_time::Instant;
19use std::collections::HashMap;
20use std::sync::OnceLock;
21
22struct PreparedInput {
23    handle: GpuTensorHandle,
24    owned: Option<GpuTensorHandle>,
25}
26
27pub struct FusionExecutionRequest<'a> {
28    pub plan: &'a FusionGroupPlan,
29    pub inputs: Vec<Value>,
30}
31
32pub struct ElementwiseExecutionResult {
33    pub final_value: Value,
34    pub materialized_stores: Vec<(FusionStoreMaterialization, Value)>,
35}
36
37#[inline]
38fn fusion_timing_enabled() -> bool {
39    static FLAG: OnceLock<bool> = OnceLock::new();
40    *FLAG.get_or_init(|| match std::env::var("RUNMAT_FUSION_TIMING") {
41        Ok(v) => matches!(
42            v.trim().to_ascii_lowercase().as_str(),
43            "1" | "true" | "yes" | "on"
44        ),
45        Err(_) => false,
46    })
47}
48
49struct FusionStageTimer {
50    inner: Option<FusionStageTimerInner>,
51}
52
53struct FusionStageTimerInner {
54    plan_index: usize,
55    kind: &'static str,
56    len: usize,
57    start: Instant,
58    last: Instant,
59    stages: Vec<(&'static str, f64)>,
60}
61
62impl FusionStageTimer {
63    fn new(kind: &'static str, plan_index: usize, len: usize) -> Self {
64        if fusion_timing_enabled() && log::log_enabled!(log::Level::Debug) {
65            let now = Instant::now();
66            Self {
67                inner: Some(FusionStageTimerInner {
68                    plan_index,
69                    kind,
70                    len,
71                    start: now,
72                    last: now,
73                    stages: Vec::new(),
74                }),
75            }
76        } else {
77            Self { inner: None }
78        }
79    }
80
81    fn mark(&mut self, label: &'static str) {
82        if let Some(inner) = &mut self.inner {
83            let now = Instant::now();
84            let delta = now.duration_since(inner.last).as_secs_f64() * 1000.0;
85            inner.stages.push((label, delta));
86            inner.last = now;
87        }
88    }
89
90    fn finish(self) {
91        if let Some(inner) = self.inner {
92            let total = inner.start.elapsed().as_secs_f64() * 1000.0;
93            let summary = inner
94                .stages
95                .into_iter()
96                .map(|(label, ms)| format!("{label}={ms:.3}ms"))
97                .collect::<Vec<_>>()
98                .join(" ");
99            log::debug!(
100                "fusion timing plan={} kind={} len={} {} total={:.3}ms",
101                inner.plan_index,
102                inner.kind,
103                inner.len,
104                summary,
105                total
106            );
107        }
108    }
109}
110
111fn ensure_gpu_tensor(
112    provider: &dyn AccelProvider,
113    value: &Value,
114) -> Result<(GpuTensorHandle, Option<GpuTensorHandle>)> {
115    match value {
116        Value::GpuTensor(handle) => Ok((handle.clone(), None)),
117        Value::Tensor(tensor) => {
118            let view = HostTensorView {
119                data: &tensor.data,
120                shape: &tensor.shape,
121            };
122            let handle = provider.upload(&view)?;
123            Ok((handle.clone(), Some(handle)))
124        }
125        _ => Err(anyhow!("fusion: expected tensor input")),
126    }
127}
128
129fn scalar_upload_dtype(provider: &dyn AccelProvider) -> NumericDType {
130    match provider.precision() {
131        ProviderPrecision::F32 => NumericDType::F32,
132        ProviderPrecision::F64 => NumericDType::F64,
133    }
134}
135
136fn value_to_f64(value: &Value) -> Option<f64> {
137    match value {
138        Value::Num(n) => Some(*n),
139        Value::Int(i) => Some(i.to_f64()),
140        _ => None,
141    }
142}
143
144fn scalar_from_value(value: &Value) -> Result<f64> {
145    if let Some(v) = value_to_f64(value) {
146        return Ok(v);
147    }
148    match value {
149        Value::Tensor(t) => {
150            if t.data.len() == 1 {
151                Ok(t.data[0])
152            } else {
153                Err(anyhow!(
154                    "image normalize: expected scalar tensor, got {} elements",
155                    t.data.len()
156                ))
157            }
158        }
159        Value::GpuTensor(_) => {
160            let gathered = gather_if_needed(value).map_err(|e| anyhow!("image normalize: {e}"))?;
161            scalar_from_value(&gathered)
162        }
163        _ => Err(anyhow!(
164            "image normalize: expected numeric scalar value, got {:?}",
165            value
166        )),
167    }
168}
169
170fn resolve_image_scalar_value(
171    scalar: &ImageScalar,
172    plan: &FusionGroupPlan,
173    request: &FusionExecutionRequest<'_>,
174) -> Result<f64> {
175    match scalar {
176        ImageScalar::Constant(v) => Ok(*v),
177        ImageScalar::Value(vid) => {
178            if let Some(value) = plan.const_values.get(vid) {
179                return scalar_from_value(value);
180            }
181            if let Some(idx) = plan.inputs.iter().position(|id| *id == *vid) {
182                let runtime_value = request
183                    .inputs
184                    .get(idx)
185                    .ok_or_else(|| anyhow!("image normalize: runtime scalar missing"))?;
186                return scalar_from_value(runtime_value);
187            }
188            Err(anyhow!(
189                "image normalize: scalar input {:?} not materialized in plan",
190                vid
191            ))
192        }
193    }
194}
195
196fn execute_elementwise_outputs(
197    request: &FusionExecutionRequest<'_>,
198    output_ids: &[ValueId],
199) -> Result<HashMap<ValueId, Value>> {
200    crate::ensure_residency_hooks();
201    if !request.plan.group.kind.is_elementwise() {
202        return Err(anyhow!("unsupported fusion kind"));
203    }
204    let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
205    if !request.plan.kernel.supported {
206        return Err(anyhow!("fusion kernel not supported for this plan"));
207    }
208    if request.inputs.len() != request.plan.inputs.len() {
209        return Err(anyhow!(
210            "fusion input mismatch: expected {}, got {}",
211            request.plan.inputs.len(),
212            request.inputs.len()
213        ));
214    }
215    // Determine output shape from the fusion plan; if unknown, derive from runtime inputs via broadcasting.
216    fn runtime_broadcast_shape(values: &[Value]) -> Option<Vec<usize>> {
217        // Collect shapes; scalars map to empty shape which broadcasts to any
218        let mut shapes: Vec<Vec<usize>> = Vec::new();
219        for v in values {
220            match v {
221                Value::GpuTensor(h) => shapes.push(h.shape.clone()),
222                Value::Tensor(t) => shapes.push(t.shape.clone()),
223                Value::Num(_) | Value::Int(_) => shapes.push(Vec::new()),
224                _ => return None, // unsupported at runtime for broadcasting
225            }
226        }
227        let rank = shapes.iter().map(|s| s.len()).max().unwrap_or(0);
228        let mut out = vec![1usize; rank];
229        for shape in shapes {
230            let offset = rank.saturating_sub(shape.len());
231            for (i, &dim) in shape.iter().enumerate() {
232                let j = offset + i;
233                let a = out[j];
234                let b = dim;
235                if a == 1 {
236                    out[j] = b.max(1);
237                } else if b == 1 || a == b {
238                    // keep a
239                } else {
240                    return None; // incompatible
241                }
242            }
243        }
244        Some(out)
245    }
246    // Determine output shape from the fusion plan and derive the element count from it.
247    let runtime_shape = runtime_broadcast_shape(&request.inputs);
248    let mut output_shape = match &request.plan.group.shape {
249        ShapeInfo::Tensor(dims) if !dims.is_empty() && dims.iter().all(|d| d.is_some()) => {
250            dims.iter().map(|d| d.unwrap_or(1)).collect()
251        }
252        ShapeInfo::Tensor(dims) if !dims.is_empty() => {
253            let rt_shape = runtime_shape
254                .clone()
255                .ok_or_else(|| anyhow!("fusion: unknown output shape"))?;
256            if rt_shape.len() != dims.len() {
257                rt_shape
258            } else {
259                dims.iter()
260                    .zip(rt_shape.iter())
261                    .map(|(planned, runtime)| planned.unwrap_or(*runtime))
262                    .collect()
263            }
264        }
265        _ => runtime_shape.ok_or_else(|| anyhow!("fusion: unknown output shape"))?,
266    };
267    let mut len: usize = output_shape.iter().copied().product();
268    if len == 0 {
269        if let Some(rt_shape) = runtime_broadcast_shape(&request.inputs) {
270            output_shape = rt_shape;
271            len = output_shape.iter().copied().product();
272        }
273        if len == 0 {
274            return Err(anyhow!("fusion: zero-length execution not supported"));
275        }
276    }
277    output_shape = normalize_scalar_shape(&output_shape);
278    let mut timer = FusionStageTimer::new("elementwise", request.plan.index, len);
279    let scalar_shape = normalize_scalar_shape(&vec![1; output_shape.len()]);
280    let mut prepared = Vec::with_capacity(request.inputs.len());
281    let mut temp_scalars: Vec<Vec<f64>> = Vec::new();
282    let scalar_dtype = scalar_upload_dtype(provider);
283    for value in &request.inputs {
284        match value {
285            Value::GpuTensor(handle) => prepared.push(PreparedInput {
286                handle: handle.clone(),
287                owned: None,
288            }),
289            Value::Tensor(t) => {
290                if let Err(msg) = ensure_provider_supports_dtype(provider, t.dtype) {
291                    return Err(anyhow!(
292                        "fusion: tensor input requires unsupported precision ({msg})"
293                    ));
294                }
295                let view = HostTensorView {
296                    data: &t.data,
297                    shape: &t.shape,
298                };
299                let handle = provider.upload(&view)?;
300                prepared.push(PreparedInput {
301                    handle: handle.clone(),
302                    owned: Some(handle),
303                });
304            }
305            Value::Num(n) => {
306                if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
307                    return Err(anyhow!(
308                        "fusion: scalar input requires unsupported precision ({msg})"
309                    ));
310                }
311                let scalar = match provider.precision() {
312                    ProviderPrecision::F32 => (*n as f32) as f64,
313                    ProviderPrecision::F64 => *n,
314                };
315                temp_scalars.push(vec![scalar]);
316                let data = temp_scalars.last().unwrap();
317                let view = HostTensorView {
318                    data,
319                    shape: &scalar_shape,
320                };
321                let handle = provider.upload(&view)?;
322                prepared.push(PreparedInput {
323                    handle: handle.clone(),
324                    owned: Some(handle),
325                });
326            }
327            Value::Int(i) => {
328                if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
329                    return Err(anyhow!(
330                        "fusion: scalar input requires unsupported precision ({msg})"
331                    ));
332                }
333                let scalar = match provider.precision() {
334                    ProviderPrecision::F32 => (i.to_f64() as f32) as f64,
335                    ProviderPrecision::F64 => i.to_f64(),
336                };
337                temp_scalars.push(vec![scalar]);
338                let data = temp_scalars.last().unwrap();
339                let view = HostTensorView {
340                    data,
341                    shape: &scalar_shape,
342                };
343                let handle = provider.upload(&view)?;
344                prepared.push(PreparedInput {
345                    handle: handle.clone(),
346                    owned: Some(handle),
347                });
348            }
349            _ => {
350                return Err(anyhow!("fusion: unsupported value type"));
351            }
352        }
353    }
354    timer.mark("prepare_inputs");
355
356    let scalar_ty = match provider.precision() {
357        ProviderPrecision::F32 => "f32",
358        ProviderPrecision::F64 => "f64",
359    };
360    let handles: Vec<GpuTensorHandle> = prepared.iter().map(|p| p.handle.clone()).collect();
361    let mut outputs = HashMap::new();
362
363    if output_ids.len() == 1 {
364        // Fast path: single output, use the existing single-dispatch API.
365        let output_id = output_ids[0];
366        let shader = request
367            .plan
368            .generate_wgsl_for_output(output_id, scalar_ty)
369            .ok_or_else(|| anyhow!("fusion: WGSL generation failed for output {output_id}"))?;
370        timer.mark("generate_wgsl");
371        let output = provider.fused_elementwise(&shader, &handles, &output_shape, len)?;
372        fusion_residency::mark(&output);
373        outputs.insert(output_id, Value::GpuTensor(output));
374    } else {
375        // Multi-output path: generate one shader that writes all outputs in a single dispatch,
376        // reducing O(N²) GPU work (N full-chain recomputations) to O(N).
377        let shader = request
378            .plan
379            .generate_wgsl_for_outputs(output_ids, scalar_ty)
380            .ok_or_else(|| anyhow!("fusion: multi-output WGSL generation failed"))?;
381        timer.mark("generate_wgsl");
382        match provider.fused_elementwise_multi(
383            &shader,
384            &handles,
385            &output_shape,
386            len,
387            output_ids.len(),
388        ) {
389            Ok(out_handles) => {
390                for (output_id, handle) in output_ids.iter().copied().zip(out_handles) {
391                    fusion_residency::mark(&handle);
392                    outputs.insert(output_id, Value::GpuTensor(handle));
393                }
394            }
395            Err(_) => {
396                // Provider does not support multi-output dispatch; fall back to N single dispatches.
397                for output_id in output_ids.iter().copied() {
398                    let single_shader = request
399                        .plan
400                        .generate_wgsl_for_output(output_id, scalar_ty)
401                        .ok_or_else(|| {
402                            anyhow!("fusion: WGSL generation failed for output {output_id}")
403                        })?;
404                    let output =
405                        provider.fused_elementwise(&single_shader, &handles, &output_shape, len)?;
406                    fusion_residency::mark(&output);
407                    outputs.insert(output_id, Value::GpuTensor(output));
408                }
409            }
410        }
411    }
412    timer.mark("dispatch");
413
414    // Clean up temporary uploads
415    for input in prepared {
416        if let Some(handle) = input.owned {
417            let _ = provider.free(&handle);
418        }
419    }
420    timer.mark("cleanup");
421    timer.finish();
422
423    Ok(outputs)
424}
425
426pub fn execute_elementwise(
427    request: FusionExecutionRequest<'_>,
428) -> Result<ElementwiseExecutionResult> {
429    let final_output = request
430        .plan
431        .output
432        .ok_or_else(|| anyhow!("fusion: missing final output value id"))?;
433    let mut output_ids = vec![final_output];
434    for store in &request.plan.materialized_stores {
435        if !output_ids.contains(&store.value_id) {
436            output_ids.push(store.value_id);
437        }
438    }
439    let mut outputs = execute_elementwise_outputs(&request, &output_ids)?;
440    let final_value = outputs
441        .remove(&final_output)
442        .ok_or_else(|| anyhow!("fusion: final output materialization missing"))?;
443    let materialized_stores = request
444        .plan
445        .materialized_stores
446        .iter()
447        .filter_map(|store| {
448            // A store whose value_id equals final_output was already consumed above;
449            // reuse a clone of final_value rather than silently dropping the store.
450            let value = if store.value_id == final_output {
451                Some(final_value.clone())
452            } else {
453                outputs.remove(&store.value_id)
454            };
455            value.map(|v| (store.clone(), v))
456        })
457        .collect();
458    Ok(ElementwiseExecutionResult {
459        final_value,
460        materialized_stores,
461    })
462}
463
464pub fn execute_reduction(
465    request: FusionExecutionRequest<'_>,
466    reduce_len: usize,
467    num_slices: usize,
468    workgroup_size: u32,
469) -> Result<Value> {
470    if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION").is_ok() {
471        return Err(anyhow!("fused reduction disabled by env"));
472    }
473    crate::ensure_residency_hooks();
474    if !request.plan.group.kind.is_reduction() {
475        return Err(anyhow!("unsupported fusion kind"));
476    }
477    let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
478    if !request.plan.kernel.supported {
479        return Err(anyhow!("fusion kernel not supported for this plan"));
480    }
481    if request.inputs.len() != request.plan.inputs.len() {
482        return Err(anyhow!(
483            "fusion input mismatch: expected {}, got {}",
484            request.plan.inputs.len(),
485            request.inputs.len()
486        ));
487    }
488    let len = reduce_len * num_slices;
489    if len == 0 {
490        return Err(anyhow!("fusion: zero-length execution not supported"));
491    }
492    let scalar_shape = {
493        let constant_shape = request.plan.constant_shape(len);
494        normalize_scalar_shape(&vec![1; constant_shape.len()])
495    };
496    let mut timer = FusionStageTimer::new("reduction", request.plan.index, len);
497    let mut prepared = Vec::with_capacity(request.inputs.len());
498    let mut temp_scalars: Vec<Vec<f64>> = Vec::new();
499    let scalar_dtype = scalar_upload_dtype(provider);
500    for value in &request.inputs {
501        match value {
502            Value::GpuTensor(handle) => prepared.push(PreparedInput {
503                handle: handle.clone(),
504                owned: None,
505            }),
506            Value::Tensor(t) => {
507                if let Err(msg) = ensure_provider_supports_dtype(provider, t.dtype) {
508                    return Err(anyhow!(
509                        "fusion: tensor input requires unsupported precision ({msg})"
510                    ));
511                }
512                let view = HostTensorView {
513                    data: &t.data,
514                    shape: &t.shape,
515                };
516                let handle = provider.upload(&view)?;
517                prepared.push(PreparedInput {
518                    handle: handle.clone(),
519                    owned: Some(handle),
520                });
521            }
522            Value::Num(n) => {
523                if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
524                    return Err(anyhow!(
525                        "fusion: scalar input requires unsupported precision ({msg})"
526                    ));
527                }
528                let scalar = match provider.precision() {
529                    ProviderPrecision::F32 => (*n as f32) as f64,
530                    ProviderPrecision::F64 => *n,
531                };
532                temp_scalars.push(vec![scalar]);
533                let data = temp_scalars.last().unwrap();
534                let view = HostTensorView {
535                    data,
536                    shape: &scalar_shape,
537                };
538                let handle = provider.upload(&view)?;
539                prepared.push(PreparedInput {
540                    handle: handle.clone(),
541                    owned: Some(handle),
542                });
543            }
544            Value::Int(i) => {
545                if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
546                    return Err(anyhow!(
547                        "fusion: scalar input requires unsupported precision ({msg})"
548                    ));
549                }
550                let scalar = match provider.precision() {
551                    ProviderPrecision::F32 => (i.to_f64() as f32) as f64,
552                    ProviderPrecision::F64 => i.to_f64(),
553                };
554                temp_scalars.push(vec![scalar]);
555                let data = temp_scalars.last().unwrap();
556                let view = HostTensorView {
557                    data,
558                    shape: &scalar_shape,
559                };
560                let handle = provider.upload(&view)?;
561                prepared.push(PreparedInput {
562                    handle: handle.clone(),
563                    owned: Some(handle),
564                });
565            }
566            _ => return Err(anyhow!("fusion: unsupported value type")),
567        }
568    }
569    timer.mark("prepare_inputs");
570
571    let handles: Vec<GpuTensorHandle> = prepared.iter().map(|p| p.handle.clone()).collect();
572    let output_shape = vec![num_slices];
573
574    let scalar_ty = match provider.precision() {
575        ProviderPrecision::F32 => "f32",
576        ProviderPrecision::F64 => "f64",
577    };
578    let shader = request
579        .plan
580        .generate_reduction_wgsl(scalar_ty)
581        .ok_or_else(|| anyhow!("fusion: reduction WGSL generation failed"))?;
582    timer.mark("generate_wgsl");
583    if std::env::var("RUNMAT_DEBUG_DUMP_FUSED_WGSL").is_ok() {
584        println!(
585            "---- fused reduction WGSL ----\n{}\n------------------------------",
586            shader
587        );
588    }
589
590    let mut wg = if workgroup_size == 0 {
591        provider.default_reduction_workgroup_size()
592    } else {
593        workgroup_size
594    };
595    if let Ok(raw) = std::env::var("RUNMAT_FUSED_WG") {
596        if let Ok(val) = raw.trim().parse::<u32>() {
597            if val > 0 {
598                let capped = val.min(provider.default_reduction_workgroup_size());
599                wg = capped.max(1);
600            }
601        }
602    }
603    let flavor = request
604        .plan
605        .reduction_flavor
606        .unwrap_or(ReductionFlavor::Sum);
607    let output = provider.fused_reduction(
608        &shader,
609        &handles,
610        &output_shape,
611        reduce_len,
612        num_slices,
613        wg,
614        flavor,
615    )?;
616    timer.mark("dispatch");
617    fusion_residency::mark(&output);
618
619    for input in prepared {
620        if let Some(handle) = input.owned {
621            let _ = provider.free(&handle);
622        }
623    }
624    timer.mark("cleanup");
625    timer.finish();
626
627    Ok(Value::GpuTensor(output))
628}
629
630pub async fn execute_centered_gram(request: FusionExecutionRequest<'_>) -> Result<Value> {
631    crate::ensure_residency_hooks();
632    if request.plan.group.kind != FusionKind::CenteredGram {
633        return Err(anyhow!("unsupported fusion kind"));
634    }
635    let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
636    let (matrix_vid, normalization) = match request.plan.pattern.as_ref() {
637        Some(FusionPattern::CenteredGram {
638            matrix,
639            normalization,
640        }) => (*matrix, *normalization),
641        _ => return Err(anyhow!("centered gram: missing pattern metadata")),
642    };
643
644    let matrix_index = request
645        .plan
646        .inputs
647        .iter()
648        .position(|vid| *vid == matrix_vid)
649        .ok_or_else(|| anyhow!("centered gram: matrix input not found"))?;
650    let matrix_value = request
651        .inputs
652        .get(matrix_index)
653        .ok_or_else(|| anyhow!("centered gram: matrix value missing"))?;
654
655    let (matrix_handle, owned_matrix) = ensure_gpu_tensor(provider, matrix_value)?;
656
657    let options = CovarianceOptions {
658        normalization,
659        rows: CovRows::All,
660        has_weight_vector: false,
661    };
662
663    let output = provider
664        .covariance(&matrix_handle, None, None, &options)
665        .await?;
666
667    if let Some(temp) = owned_matrix {
668        let _ = provider.free(&temp);
669    }
670
671    fusion_residency::mark(&output);
672    Ok(Value::GpuTensor(output))
673}
674
675pub async fn execute_power_step_normalize(request: FusionExecutionRequest<'_>) -> Result<Value> {
676    crate::ensure_residency_hooks();
677    if request.plan.group.kind != FusionKind::PowerStepNormalize {
678        return Err(anyhow!("unsupported fusion kind"));
679    }
680    let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
681    let (lhs_vid, rhs_vid, epsilon) = match request.plan.pattern.as_ref() {
682        Some(FusionPattern::PowerStepNormalize { lhs, rhs, epsilon }) => (*lhs, *rhs, *epsilon),
683        _ => {
684            return Err(anyhow!(
685                "power-step normalization: missing pattern metadata"
686            ))
687        }
688    };
689
690    let lhs_index = request
691        .plan
692        .inputs
693        .iter()
694        .position(|vid| *vid == lhs_vid)
695        .ok_or_else(|| anyhow!("power-step normalization: lhs input not found"))?;
696    let rhs_index = request
697        .plan
698        .inputs
699        .iter()
700        .position(|vid| *vid == rhs_vid)
701        .ok_or_else(|| anyhow!("power-step normalization: rhs input not found"))?;
702
703    let lhs_value = request
704        .inputs
705        .get(lhs_index)
706        .ok_or_else(|| anyhow!("power-step normalization: lhs value missing"))?;
707    let rhs_value = request
708        .inputs
709        .get(rhs_index)
710        .ok_or_else(|| anyhow!("power-step normalization: rhs value missing"))?;
711
712    let (lhs_handle, lhs_owned) = ensure_gpu_tensor(provider, lhs_value)?;
713    let (rhs_handle, rhs_owned) = ensure_gpu_tensor(provider, rhs_value)?;
714
715    let desc = PowerStepEpilogue { epsilon };
716    let output = provider
717        .matmul_power_step(&lhs_handle, &rhs_handle, &desc)
718        .await?;
719
720    if let Some(temp) = lhs_owned {
721        let _ = provider.free(&temp);
722    }
723    if let Some(temp) = rhs_owned {
724        let _ = provider.free(&temp);
725    }
726
727    fusion_residency::mark(&output);
728    Ok(Value::GpuTensor(output))
729}
730
731pub async fn execute_explained_variance(request: FusionExecutionRequest<'_>) -> Result<Value> {
732    crate::ensure_residency_hooks();
733    if request.plan.group.kind != FusionKind::ExplainedVariance {
734        return Err(anyhow!("unsupported fusion kind"));
735    }
736    let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
737    let (q_vid, g_vid) = match request.plan.pattern.as_ref() {
738        Some(FusionPattern::ExplainedVariance { q, g }) => (*q, *g),
739        _ => return Err(anyhow!("explained variance: missing pattern metadata")),
740    };
741
742    let find_value = |vid: ValueId| -> Result<&Value> {
743        if let Some(pos) = request.plan.inputs.iter().position(|id| *id == vid) {
744            request
745                .inputs
746                .get(pos)
747                .ok_or_else(|| anyhow!("explained variance: missing runtime value"))
748        } else {
749            request
750                .plan
751                .const_values
752                .get(&vid)
753                .ok_or_else(|| anyhow!("explained variance: value not materialized"))
754        }
755    };
756
757    let q_value = find_value(q_vid)?;
758    let g_value = find_value(g_vid)?;
759
760    let (mut q_handle, q_owned) = ensure_gpu_tensor(provider, q_value)?;
761    let (g_handle, g_owned) = ensure_gpu_tensor(provider, g_value)?;
762
763    let debug_explained = std::env::var("RUNMAT_DEBUG_EXPLAINED").is_ok();
764    if debug_explained {
765        println!(
766            "[explained] initial Q shape {:?}, G shape {:?}",
767            q_handle.shape, g_handle.shape
768        );
769        if let Ok(info) = provider.download(&q_handle).await {
770            println!(
771                "[explained] Q (sample) len={} first=[{:?}]",
772                info.data.len(),
773                info.data.get(0..4)
774            );
775        }
776    }
777
778    let q_shape = q_handle.shape.clone();
779    if q_shape.len() < 2 {
780        return Err(anyhow!("explained variance: Q must be 2-D"));
781    }
782    let q_rows = q_shape[0];
783    let q_cols = q_shape[1];
784    if q_rows == 0 || q_cols == 0 {
785        return Err(anyhow!("explained variance: zero-sized Q"));
786    }
787
788    let g_shape = g_handle.shape.clone();
789    if g_shape.len() < 2 {
790        return Err(anyhow!("explained variance: G must be 2-D"));
791    }
792    if g_shape[0] != q_rows || g_shape[1] != q_rows {
793        return Err(anyhow!("explained variance: G shape mismatch"));
794    }
795
796    let mut tmp = provider.matmul(&q_handle, &g_handle).await?;
797    let tmp_shape = tmp.shape.clone();
798    if tmp_shape.len() < 2 {
799        return Err(anyhow!("explained variance: intermediate must be 2-D"));
800    }
801    if tmp_shape[0] != q_cols {
802        return Err(anyhow!(
803            "explained variance: expected intermediate rows {}, got {}",
804            q_cols,
805            tmp_shape[0]
806        ));
807    }
808
809    if debug_explained {
810        println!("[explained] after Q*G tmp shape {:?}", tmp.shape);
811    }
812
813    // Interpreter's transpose retains the original data layout. Mimic that by
814    // reshaping rather than launching a real transpose so downstream matmul
815    // observes the same misoriented layout.
816    let mut transposed_shape = q_shape.clone();
817    transposed_shape.swap(0, 1);
818    let q_transposed_view = provider.reshape(&q_handle, &transposed_shape)?;
819
820    tmp = provider.matmul(&q_transposed_view, &g_handle).await?;
821
822    if debug_explained {
823        println!(
824            "[explained] after reshape(matmul) tmp shape {:?}",
825            tmp.shape
826        );
827    }
828
829    // Restore Q's original shape before the second multiplication.
830    q_handle = provider.reshape(&q_handle, &q_shape)?;
831
832    let product = provider.matmul(&tmp, &q_handle).await?;
833
834    if debug_explained {
835        println!("[explained] product shape {:?}", product.shape);
836    }
837
838    let diag = provider.diag_extract(&product, 0)?;
839    let diag = match diag.shape.as_slice() {
840        [len] => provider.reshape(&diag, &[*len, 1])?,
841        [_len, 1] => diag,
842        _ => diag,
843    };
844
845    if debug_explained {
846        if let Ok(host) = provider.download(&tmp).await {
847            println!("tmp runtime shape {:?} data {:?}", host.shape, host.data);
848        }
849        if let Ok(host) = provider.download(&product).await {
850            println!("prod runtime shape {:?} data {:?}", host.shape, host.data);
851        }
852        if let Ok(host) = provider.download(&diag).await {
853            println!("diag runtime shape {:?} data {:?}", host.shape, host.data);
854        }
855    }
856
857    let _ = provider.free(&tmp);
858    let _ = provider.free(&product);
859    if let Some(temp) = q_owned {
860        let _ = provider.free(&temp);
861    }
862    if let Some(temp) = g_owned {
863        let _ = provider.free(&temp);
864    }
865
866    fusion_residency::mark(&diag);
867    Ok(Value::GpuTensor(diag))
868}
869
870pub async fn execute_image_normalize(request: FusionExecutionRequest<'_>) -> Result<Value> {
871    crate::ensure_residency_hooks();
872    if request.plan.group.kind != FusionKind::ImageNormalize {
873        return Err(anyhow!("unsupported fusion kind"));
874    }
875    let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
876    let pattern = match request.plan.pattern.as_ref() {
877        Some(FusionPattern::ImageNormalize(p)) => p,
878        _ => return Err(anyhow!("image normalize: missing pattern metadata")),
879    };
880    if log::log_enabled!(log::Level::Debug) {
881        log::debug!(
882            "execute_image_normalize: plan inputs={:?} stack={:?}",
883            request.plan.inputs,
884            request.plan.stack_pattern
885        );
886    }
887
888    let find_value = |vid: ValueId| -> Result<&Value> {
889        if let Some(pos) = request.plan.inputs.iter().position(|id| *id == vid) {
890            request
891                .inputs
892                .get(pos)
893                .ok_or_else(|| anyhow!("image normalize: runtime value missing"))
894        } else {
895            request
896                .plan
897                .const_values
898                .get(&vid)
899                .ok_or_else(|| anyhow!("image normalize: value {vid:?} not materialized"))
900        }
901    };
902
903    let input_value = find_value(pattern.input)?;
904    let (input_handle, input_owned) = ensure_gpu_tensor(provider, input_value)?;
905    let shape = input_handle.shape.clone();
906    if shape.len() != 3 {
907        return Err(anyhow!(
908            "image normalize: expected 3-D input tensor, got shape {:?}",
909            shape
910        ));
911    }
912    let batch = shape[0];
913    let height = shape[1];
914    let width = shape[2];
915
916    let epsilon = resolve_image_scalar_value(&pattern.epsilon, request.plan, &request)?;
917    let gain = match &pattern.gain {
918        Some(s) => Some(resolve_image_scalar_value(s, request.plan, &request)?),
919        None => None,
920    };
921    let bias = match &pattern.bias {
922        Some(s) => Some(resolve_image_scalar_value(s, request.plan, &request)?),
923        None => None,
924    };
925    let gamma = match &pattern.gamma {
926        Some(s) => Some(resolve_image_scalar_value(s, request.plan, &request)?),
927        None => None,
928    };
929
930    let desc = ImageNormalizeDescriptor {
931        batch,
932        height,
933        width,
934        epsilon,
935        gain,
936        bias,
937        gamma,
938    };
939    if log::log_enabled!(log::Level::Trace) {
940        log::trace!("execute_image_normalize: desc {:?}", desc);
941    }
942
943    let output = provider.image_normalize(&input_handle, &desc).await?;
944
945    if let Some(temp) = input_owned {
946        provider.free(&temp).ok();
947    }
948
949    fusion_residency::mark(&output);
950    Ok(Value::GpuTensor(output))
951}
952
953pub async fn execute_matmul_epilogue(request: FusionExecutionRequest<'_>) -> Result<Value> {
954    crate::ensure_residency_hooks();
955    if request.plan.group.kind != crate::fusion::FusionKind::MatmulEpilogue {
956        return Err(anyhow!("unsupported fusion kind"));
957    }
958    let prov = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
959
960    // Map ValueId -> prepared GpuTensorHandle
961    let mut prepared: Vec<(graph::ValueId, GpuTensorHandle, Option<GpuTensorHandle>)> = Vec::new();
962    let mut owned: Vec<GpuTensorHandle> = Vec::new();
963    for (idx, vid) in request.plan.inputs.iter().copied().enumerate() {
964        let v = request
965            .inputs
966            .get(idx)
967            .ok_or_else(|| anyhow!("fusion: missing input value"))?;
968        let handle = match v {
969            Value::GpuTensor(h) => h.clone(),
970            Value::Tensor(t) => {
971                let view = HostTensorView {
972                    data: &t.data,
973                    shape: &t.shape,
974                };
975                let h = prov.upload(&view)?;
976                owned.push(h.clone());
977                h
978            }
979            _ => return Err(anyhow!("matmul_epilogue: unsupported input value kind")),
980        };
981        prepared.push((vid, handle.clone(), None));
982    }
983
984    // Helper: find handle by ValueId
985    let find_handle = |vid: graph::ValueId| -> Option<GpuTensorHandle> {
986        prepared
987            .iter()
988            .find_map(|(v, h, _)| if *v == vid { Some(h.clone()) } else { None })
989    };
990
991    // Find matmul op and its output
992    let mut cur_out: Option<graph::ValueId> = None;
993    let mut a_vid: Option<graph::ValueId> = None;
994    let mut b_vid: Option<graph::ValueId> = None;
995    for op in &request.plan.operations {
996        if let crate::fusion::FusionOp::Builtin {
997            name,
998            inputs,
999            output,
1000        } = op
1001        {
1002            if name.eq_ignore_ascii_case("mtimes") {
1003                a_vid = inputs.first().copied();
1004                b_vid = inputs.get(1).copied();
1005                cur_out = *output;
1006                break;
1007            }
1008        }
1009    }
1010    let (a_vid, b_vid, mut cur) = (
1011        a_vid.ok_or_else(|| anyhow!("mtimes not found"))?,
1012        b_vid.ok_or_else(|| anyhow!("mtimes not found"))?,
1013        cur_out.ok_or_else(|| anyhow!("mtimes output missing"))?,
1014    );
1015
1016    // Derive epilogue (scale/bias, clamp, pow, diag) by walking subsequent ops that consume cur
1017    let mut alpha: f64 = 1.0;
1018    let mut beta: f64 = 0.0;
1019    let mut row_scale: Option<GpuTensorHandle> = None;
1020    let mut col_scale: Option<GpuTensorHandle> = None;
1021    let mut clamp_min: Option<f64> = None;
1022    let mut clamp_max: Option<f64> = None;
1023    let mut pow_exponent: Option<f64> = None;
1024    let mut row_div = false;
1025    let mut col_div = false;
1026    let mut diag_vid: Option<graph::ValueId> = None;
1027
1028    for op in &request.plan.operations {
1029        match op {
1030            crate::fusion::FusionOp::Primitive { op, inputs, output } => {
1031                let Some(out) = output else { continue };
1032                if !inputs.contains(&cur) {
1033                    continue;
1034                }
1035                let other = if inputs[0] == cur {
1036                    inputs[1]
1037                } else {
1038                    inputs[0]
1039                };
1040                let const_opt = request.plan.const_values.get(&other);
1041                let const_f64 = const_opt.and_then(value_to_f64);
1042                match op {
1043                    crate::graph::PrimitiveOp::Mul | crate::graph::PrimitiveOp::ElemMul => {
1044                        if let Some(val) = const_f64 {
1045                            alpha *= val;
1046                        } else if row_scale.is_none() || col_scale.is_none() {
1047                            if let Some(h) = find_handle(other) {
1048                                let r = h.shape.first().copied().unwrap_or(1);
1049                                let c = h.shape.get(1).copied().unwrap_or(1);
1050                                if c == 1 && row_scale.is_none() {
1051                                    row_scale = Some(h);
1052                                    row_div = false;
1053                                } else if r == 1 && col_scale.is_none() {
1054                                    col_scale = Some(h);
1055                                    col_div = false;
1056                                }
1057                            }
1058                        }
1059                    }
1060                    crate::graph::PrimitiveOp::ElemDiv => {
1061                        if let Some(val) = const_f64 {
1062                            if val != 0.0 {
1063                                alpha *= 1.0 / val;
1064                            }
1065                        } else if row_scale.is_none() || col_scale.is_none() {
1066                            if let Some(h) = find_handle(other) {
1067                                let r = h.shape.first().copied().unwrap_or(1);
1068                                let c = h.shape.get(1).copied().unwrap_or(1);
1069                                if c == 1 && row_scale.is_none() {
1070                                    row_scale = Some(h);
1071                                    row_div = true;
1072                                } else if r == 1 && col_scale.is_none() {
1073                                    col_scale = Some(h);
1074                                    col_div = true;
1075                                }
1076                            }
1077                        }
1078                    }
1079                    crate::graph::PrimitiveOp::Add => {
1080                        if let Some(val) = const_f64 {
1081                            beta += val;
1082                        }
1083                    }
1084                    crate::graph::PrimitiveOp::Sub => {
1085                        if let Some(val) = const_f64 {
1086                            beta -= val;
1087                        }
1088                    }
1089                    crate::graph::PrimitiveOp::Pow | crate::graph::PrimitiveOp::ElemPow => {
1090                        if pow_exponent.is_none() && inputs[0] == cur {
1091                            pow_exponent = const_f64;
1092                        }
1093                    }
1094                    _ => {}
1095                }
1096                cur = *out;
1097            }
1098            crate::fusion::FusionOp::Builtin {
1099                name,
1100                inputs,
1101                output,
1102            } => {
1103                let Some(out) = output else { continue };
1104                if !inputs.contains(&cur) {
1105                    continue;
1106                }
1107                let lower = name.to_ascii_lowercase();
1108                if lower == "max" || lower == "min" {
1109                    if let Some(&other) = inputs.iter().find(|&&v| v != cur) {
1110                        if let Some(val) =
1111                            request.plan.const_values.get(&other).and_then(value_to_f64)
1112                        {
1113                            if lower == "max" {
1114                                clamp_min = Some(clamp_min.map_or(val, |prev| prev.max(val)));
1115                            } else {
1116                                clamp_max = Some(clamp_max.map_or(val, |prev| prev.min(val)));
1117                            }
1118                        }
1119                    }
1120                } else if lower == "pow" && pow_exponent.is_none() {
1121                    if let Some(&other) = inputs.iter().find(|&&v| v != cur) {
1122                        if let Some(val) =
1123                            request.plan.const_values.get(&other).and_then(value_to_f64)
1124                        {
1125                            pow_exponent = Some(val);
1126                        }
1127                    }
1128                } else if lower == "diag" {
1129                    diag_vid = Some(*out);
1130                }
1131                cur = *out;
1132            }
1133        }
1134    }
1135
1136    // Build epilogue descriptor
1137    let mut ep = runmat_accelerate_api::MatmulEpilogue::noop();
1138    ep.alpha = alpha;
1139    ep.beta = beta;
1140    ep.clamp_min = clamp_min;
1141    ep.clamp_max = clamp_max;
1142    ep.pow_exponent = pow_exponent;
1143    ep.row_op = if row_div {
1144        runmat_accelerate_api::ScaleOp::Divide
1145    } else {
1146        runmat_accelerate_api::ScaleOp::Multiply
1147    };
1148    ep.col_op = if col_div {
1149        runmat_accelerate_api::ScaleOp::Divide
1150    } else {
1151        runmat_accelerate_api::ScaleOp::Multiply
1152    };
1153    if let Some(h) = row_scale.clone() {
1154        ep.row_scale = Some(h);
1155    }
1156    if let Some(h) = col_scale.clone() {
1157        ep.col_scale = Some(h);
1158    }
1159
1160    let a = find_handle(a_vid).ok_or_else(|| anyhow!("missing A"))?;
1161    let b = find_handle(b_vid).ok_or_else(|| anyhow!("missing B"))?;
1162
1163    let mut diag_handle: Option<(graph::ValueId, GpuTensorHandle)> = None;
1164    if let Some(vid) = diag_vid {
1165        let diag_len = std::cmp::min(
1166            a.shape.first().copied().unwrap_or(0),
1167            b.shape.get(1).copied().unwrap_or(0),
1168        );
1169        let mut diag_shape = vec![diag_len, 1];
1170        if diag_len == 0 {
1171            diag_shape[1] = 1;
1172        }
1173        let handle = prov.zeros(&diag_shape)?;
1174        ep.diag_output = Some(handle.clone());
1175        diag_handle = Some((vid, handle));
1176    }
1177
1178    let out = prov.matmul_epilogue(&a, &b, &ep).await?;
1179    for h in owned {
1180        let _ = prov.free(&h);
1181    }
1182
1183    if let Some((_, diag)) = &diag_handle {
1184        fusion_residency::mark(diag);
1185    }
1186
1187    let final_vid = request.plan.output.or(Some(cur));
1188    let mut result = out.clone();
1189    let mut free_out = false;
1190    if let Some((vid, diag)) = &diag_handle {
1191        if Some(*vid) == final_vid {
1192            result = diag.clone();
1193            free_out = true;
1194        }
1195    }
1196
1197    if free_out {
1198        let _ = prov.free(&out);
1199    } else {
1200        fusion_residency::mark(&out);
1201    }
1202
1203    fusion_residency::mark(&result);
1204    Ok(Value::GpuTensor(result))
1205}