runmat_accelerate/
fusion_exec.rs

1use anyhow::{anyhow, Result};
2
3use crate::fusion::{FusionGroupPlan, FusionKind, FusionPattern, ImageScalar};
4use crate::fusion_residency;
5use crate::graph;
6use crate::graph::{ShapeInfo, ValueId};
7use crate::precision::ensure_provider_supports_dtype;
8use log;
9use runmat_accelerate_api::{
10    provider, AccelProvider, CovRows, CovarianceOptions, GpuTensorHandle, HostTensorView,
11    ImageNormalizeDescriptor, PowerStepEpilogue, ProviderPrecision, ReductionFlavor,
12};
13use runmat_builtins::{NumericDType, Value};
14use runmat_runtime::gather_if_needed;
15use std::sync::OnceLock;
16use std::time::Instant;
17
18struct PreparedInput {
19    handle: GpuTensorHandle,
20    owned: Option<GpuTensorHandle>,
21}
22
23pub struct FusionExecutionRequest<'a> {
24    pub plan: &'a FusionGroupPlan,
25    pub inputs: Vec<Value>,
26}
27
28#[inline]
29fn fusion_timing_enabled() -> bool {
30    static FLAG: OnceLock<bool> = OnceLock::new();
31    *FLAG.get_or_init(|| match std::env::var("RUNMAT_FUSION_TIMING") {
32        Ok(v) => matches!(
33            v.trim().to_ascii_lowercase().as_str(),
34            "1" | "true" | "yes" | "on"
35        ),
36        Err(_) => false,
37    })
38}
39
40struct FusionStageTimer {
41    inner: Option<FusionStageTimerInner>,
42}
43
44struct FusionStageTimerInner {
45    plan_index: usize,
46    kind: &'static str,
47    len: usize,
48    start: Instant,
49    last: Instant,
50    stages: Vec<(&'static str, f64)>,
51}
52
53impl FusionStageTimer {
54    fn new(kind: &'static str, plan_index: usize, len: usize) -> Self {
55        if fusion_timing_enabled() && log::log_enabled!(log::Level::Debug) {
56            let now = Instant::now();
57            Self {
58                inner: Some(FusionStageTimerInner {
59                    plan_index,
60                    kind,
61                    len,
62                    start: now,
63                    last: now,
64                    stages: Vec::new(),
65                }),
66            }
67        } else {
68            Self { inner: None }
69        }
70    }
71
72    fn mark(&mut self, label: &'static str) {
73        if let Some(inner) = &mut self.inner {
74            let now = Instant::now();
75            let delta = now.duration_since(inner.last).as_secs_f64() * 1000.0;
76            inner.stages.push((label, delta));
77            inner.last = now;
78        }
79    }
80
81    fn finish(self) {
82        if let Some(inner) = self.inner {
83            let total = inner.start.elapsed().as_secs_f64() * 1000.0;
84            let summary = inner
85                .stages
86                .into_iter()
87                .map(|(label, ms)| format!("{label}={ms:.3}ms"))
88                .collect::<Vec<_>>()
89                .join(" ");
90            log::debug!(
91                "fusion timing plan={} kind={} len={} {} total={:.3}ms",
92                inner.plan_index,
93                inner.kind,
94                inner.len,
95                summary,
96                total
97            );
98        }
99    }
100}
101
102fn ensure_gpu_tensor(
103    provider: &dyn AccelProvider,
104    value: &Value,
105) -> Result<(GpuTensorHandle, Option<GpuTensorHandle>)> {
106    match value {
107        Value::GpuTensor(handle) => Ok((handle.clone(), None)),
108        Value::Tensor(tensor) => {
109            let view = HostTensorView {
110                data: &tensor.data,
111                shape: &tensor.shape,
112            };
113            let handle = provider.upload(&view)?;
114            Ok((handle.clone(), Some(handle)))
115        }
116        _ => Err(anyhow!("fusion: expected tensor input")),
117    }
118}
119
120fn scalar_upload_dtype(provider: &dyn AccelProvider) -> NumericDType {
121    match provider.precision() {
122        ProviderPrecision::F32 => NumericDType::F32,
123        ProviderPrecision::F64 => NumericDType::F64,
124    }
125}
126
127fn value_to_f64(value: &Value) -> Option<f64> {
128    match value {
129        Value::Num(n) => Some(*n),
130        Value::Int(i) => Some(i.to_f64()),
131        _ => None,
132    }
133}
134
135fn scalar_from_value(value: &Value) -> Result<f64> {
136    if let Some(v) = value_to_f64(value) {
137        return Ok(v);
138    }
139    match value {
140        Value::Tensor(t) => {
141            if t.data.len() == 1 {
142                Ok(t.data[0])
143            } else {
144                Err(anyhow!(
145                    "image normalize: expected scalar tensor, got {} elements",
146                    t.data.len()
147                ))
148            }
149        }
150        Value::GpuTensor(_) => {
151            let gathered = gather_if_needed(value).map_err(|e| anyhow!("image normalize: {e}"))?;
152            scalar_from_value(&gathered)
153        }
154        _ => Err(anyhow!(
155            "image normalize: expected numeric scalar value, got {:?}",
156            value
157        )),
158    }
159}
160
161fn resolve_image_scalar_value(
162    scalar: &ImageScalar,
163    plan: &FusionGroupPlan,
164    request: &FusionExecutionRequest<'_>,
165) -> Result<f64> {
166    match scalar {
167        ImageScalar::Constant(v) => Ok(*v),
168        ImageScalar::Value(vid) => {
169            if let Some(value) = plan.const_values.get(vid) {
170                return scalar_from_value(value);
171            }
172            if let Some(idx) = plan.inputs.iter().position(|id| *id == *vid) {
173                let runtime_value = request
174                    .inputs
175                    .get(idx)
176                    .ok_or_else(|| anyhow!("image normalize: runtime scalar missing"))?;
177                return scalar_from_value(runtime_value);
178            }
179            Err(anyhow!(
180                "image normalize: scalar input {:?} not materialized in plan",
181                vid
182            ))
183        }
184    }
185}
186
187pub fn execute_elementwise(request: FusionExecutionRequest<'_>) -> Result<Value> {
188    crate::ensure_residency_hooks();
189    if !request.plan.group.kind.is_elementwise() {
190        return Err(anyhow!("unsupported fusion kind"));
191    }
192    let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
193    if !request.plan.kernel.supported {
194        return Err(anyhow!("fusion kernel not supported for this plan"));
195    }
196    if request.inputs.len() != request.plan.inputs.len() {
197        return Err(anyhow!(
198            "fusion input mismatch: expected {}, got {}",
199            request.plan.inputs.len(),
200            request.inputs.len()
201        ));
202    }
203    // Determine output shape from the fusion plan; if unknown, derive from runtime inputs via broadcasting.
204    fn runtime_broadcast_shape(values: &[Value]) -> Option<Vec<usize>> {
205        // Collect shapes; scalars map to empty shape which broadcasts to any
206        let mut shapes: Vec<Vec<usize>> = Vec::new();
207        for v in values {
208            match v {
209                Value::GpuTensor(h) => shapes.push(h.shape.clone()),
210                Value::Tensor(t) => shapes.push(t.shape.clone()),
211                Value::Num(_) | Value::Int(_) => shapes.push(Vec::new()),
212                _ => return None, // unsupported at runtime for broadcasting
213            }
214        }
215        let rank = shapes.iter().map(|s| s.len()).max().unwrap_or(0);
216        let mut out = vec![1usize; rank];
217        for shape in shapes {
218            let offset = rank.saturating_sub(shape.len());
219            for (i, &dim) in shape.iter().enumerate() {
220                let j = offset + i;
221                let a = out[j];
222                let b = dim;
223                if a == 1 {
224                    out[j] = b.max(1);
225                } else if b == 1 || a == b {
226                    // keep a
227                } else {
228                    return None; // incompatible
229                }
230            }
231        }
232        Some(out)
233    }
234    // Determine output shape from the fusion plan and derive the element count from it.
235    let mut output_shape = match &request.plan.group.shape {
236        ShapeInfo::Tensor(dims) if !dims.is_empty() => {
237            let resolved: Vec<usize> = dims.iter().map(|d| d.unwrap_or(1)).collect();
238            resolved
239        }
240        _ => {
241            // Fallback to runtime broadcasting inference
242            runtime_broadcast_shape(&request.inputs)
243                .ok_or_else(|| anyhow!("fusion: unknown output shape"))?
244        }
245    };
246    let mut len: usize = output_shape.iter().copied().product();
247    if len == 0 {
248        if let Some(rt_shape) = runtime_broadcast_shape(&request.inputs) {
249            output_shape = rt_shape;
250            len = output_shape.iter().copied().product();
251        }
252        if len == 0 {
253            return Err(anyhow!("fusion: zero-length execution not supported"));
254        }
255    }
256    let mut timer = FusionStageTimer::new("elementwise", request.plan.index, len);
257    let scalar_shape: Vec<usize> = if output_shape.is_empty() {
258        vec![1]
259    } else {
260        vec![1; output_shape.len()]
261    };
262    let mut prepared = Vec::with_capacity(request.inputs.len());
263    let mut temp_scalars: Vec<Vec<f64>> = Vec::new();
264    let scalar_dtype = scalar_upload_dtype(provider);
265    for value in &request.inputs {
266        match value {
267            Value::GpuTensor(handle) => prepared.push(PreparedInput {
268                handle: handle.clone(),
269                owned: None,
270            }),
271            Value::Tensor(t) => {
272                if let Err(msg) = ensure_provider_supports_dtype(provider, t.dtype) {
273                    return Err(anyhow!(
274                        "fusion: tensor input requires unsupported precision ({msg})"
275                    ));
276                }
277                let view = HostTensorView {
278                    data: &t.data,
279                    shape: &t.shape,
280                };
281                let handle = provider.upload(&view)?;
282                prepared.push(PreparedInput {
283                    handle: handle.clone(),
284                    owned: Some(handle),
285                });
286            }
287            Value::Num(n) => {
288                if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
289                    return Err(anyhow!(
290                        "fusion: scalar input requires unsupported precision ({msg})"
291                    ));
292                }
293                let scalar = match provider.precision() {
294                    ProviderPrecision::F32 => (*n as f32) as f64,
295                    ProviderPrecision::F64 => *n,
296                };
297                temp_scalars.push(vec![scalar]);
298                let data = temp_scalars.last().unwrap();
299                let view = HostTensorView {
300                    data,
301                    shape: &scalar_shape,
302                };
303                let handle = provider.upload(&view)?;
304                prepared.push(PreparedInput {
305                    handle: handle.clone(),
306                    owned: Some(handle),
307                });
308            }
309            Value::Int(i) => {
310                if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
311                    return Err(anyhow!(
312                        "fusion: scalar input requires unsupported precision ({msg})"
313                    ));
314                }
315                let scalar = match provider.precision() {
316                    ProviderPrecision::F32 => (i.to_f64() as f32) as f64,
317                    ProviderPrecision::F64 => i.to_f64(),
318                };
319                temp_scalars.push(vec![scalar]);
320                let data = temp_scalars.last().unwrap();
321                let view = HostTensorView {
322                    data,
323                    shape: &scalar_shape,
324                };
325                let handle = provider.upload(&view)?;
326                prepared.push(PreparedInput {
327                    handle: handle.clone(),
328                    owned: Some(handle),
329                });
330            }
331            _ => {
332                return Err(anyhow!("fusion: unsupported value type"));
333            }
334        }
335    }
336    timer.mark("prepare_inputs");
337
338    let scalar_ty = match provider.precision() {
339        ProviderPrecision::F32 => "f32",
340        ProviderPrecision::F64 => "f64",
341    };
342    let shader = request
343        .plan
344        .generate_wgsl(scalar_ty)
345        .ok_or_else(|| anyhow!("fusion: WGSL generation failed"))?;
346    timer.mark("generate_wgsl");
347
348    let handles: Vec<GpuTensorHandle> = prepared.iter().map(|p| p.handle.clone()).collect();
349    let output = provider.fused_elementwise(&shader, &handles, &output_shape, len)?;
350    timer.mark("dispatch");
351    fusion_residency::mark(&output);
352
353    // Clean up temporary uploads
354    for input in prepared {
355        if let Some(handle) = input.owned {
356            let _ = provider.free(&handle);
357        }
358    }
359    timer.mark("cleanup");
360    timer.finish();
361
362    Ok(Value::GpuTensor(output))
363}
364
365pub fn execute_reduction(
366    request: FusionExecutionRequest<'_>,
367    reduce_len: usize,
368    num_slices: usize,
369    workgroup_size: u32,
370) -> Result<Value> {
371    if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION").is_ok() {
372        return Err(anyhow!("fused reduction disabled by env"));
373    }
374    crate::ensure_residency_hooks();
375    if !request.plan.group.kind.is_reduction() {
376        return Err(anyhow!("unsupported fusion kind"));
377    }
378    let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
379    if !request.plan.kernel.supported {
380        return Err(anyhow!("fusion kernel not supported for this plan"));
381    }
382    if request.inputs.len() != request.plan.inputs.len() {
383        return Err(anyhow!(
384            "fusion input mismatch: expected {}, got {}",
385            request.plan.inputs.len(),
386            request.inputs.len()
387        ));
388    }
389    let len = reduce_len * num_slices;
390    if len == 0 {
391        return Err(anyhow!("fusion: zero-length execution not supported"));
392    }
393    let scalar_shape: Vec<usize> = {
394        let constant_shape = request.plan.constant_shape(len);
395        if constant_shape.is_empty() {
396            vec![1]
397        } else {
398            vec![1; constant_shape.len()]
399        }
400    };
401    let mut timer = FusionStageTimer::new("reduction", request.plan.index, len);
402    let mut prepared = Vec::with_capacity(request.inputs.len());
403    let mut temp_scalars: Vec<Vec<f64>> = Vec::new();
404    let scalar_dtype = scalar_upload_dtype(provider);
405    for value in &request.inputs {
406        match value {
407            Value::GpuTensor(handle) => prepared.push(PreparedInput {
408                handle: handle.clone(),
409                owned: None,
410            }),
411            Value::Tensor(t) => {
412                if let Err(msg) = ensure_provider_supports_dtype(provider, t.dtype) {
413                    return Err(anyhow!(
414                        "fusion: tensor input requires unsupported precision ({msg})"
415                    ));
416                }
417                let view = HostTensorView {
418                    data: &t.data,
419                    shape: &t.shape,
420                };
421                let handle = provider.upload(&view)?;
422                prepared.push(PreparedInput {
423                    handle: handle.clone(),
424                    owned: Some(handle),
425                });
426            }
427            Value::Num(n) => {
428                if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
429                    return Err(anyhow!(
430                        "fusion: scalar input requires unsupported precision ({msg})"
431                    ));
432                }
433                let scalar = match provider.precision() {
434                    ProviderPrecision::F32 => (*n as f32) as f64,
435                    ProviderPrecision::F64 => *n,
436                };
437                temp_scalars.push(vec![scalar]);
438                let data = temp_scalars.last().unwrap();
439                let view = HostTensorView {
440                    data,
441                    shape: &scalar_shape,
442                };
443                let handle = provider.upload(&view)?;
444                prepared.push(PreparedInput {
445                    handle: handle.clone(),
446                    owned: Some(handle),
447                });
448            }
449            Value::Int(i) => {
450                if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
451                    return Err(anyhow!(
452                        "fusion: scalar input requires unsupported precision ({msg})"
453                    ));
454                }
455                let scalar = match provider.precision() {
456                    ProviderPrecision::F32 => (i.to_f64() as f32) as f64,
457                    ProviderPrecision::F64 => i.to_f64(),
458                };
459                temp_scalars.push(vec![scalar]);
460                let data = temp_scalars.last().unwrap();
461                let view = HostTensorView {
462                    data,
463                    shape: &scalar_shape,
464                };
465                let handle = provider.upload(&view)?;
466                prepared.push(PreparedInput {
467                    handle: handle.clone(),
468                    owned: Some(handle),
469                });
470            }
471            _ => return Err(anyhow!("fusion: unsupported value type")),
472        }
473    }
474    timer.mark("prepare_inputs");
475
476    let handles: Vec<GpuTensorHandle> = prepared.iter().map(|p| p.handle.clone()).collect();
477    let output_shape = vec![num_slices];
478
479    let scalar_ty = match provider.precision() {
480        ProviderPrecision::F32 => "f32",
481        ProviderPrecision::F64 => "f64",
482    };
483    let shader = request
484        .plan
485        .generate_reduction_wgsl(scalar_ty)
486        .ok_or_else(|| anyhow!("fusion: reduction WGSL generation failed"))?;
487    timer.mark("generate_wgsl");
488    if std::env::var("RUNMAT_DEBUG_DUMP_FUSED_WGSL").is_ok() {
489        println!(
490            "---- fused reduction WGSL ----\n{}\n------------------------------",
491            shader
492        );
493    }
494
495    let mut wg = if workgroup_size == 0 {
496        provider.default_reduction_workgroup_size()
497    } else {
498        workgroup_size
499    };
500    if let Ok(raw) = std::env::var("RUNMAT_FUSED_WG") {
501        if let Ok(val) = raw.trim().parse::<u32>() {
502            if val > 0 {
503                let capped = val.min(provider.default_reduction_workgroup_size());
504                wg = capped.max(1);
505            }
506        }
507    }
508    let flavor = request
509        .plan
510        .reduction_flavor
511        .unwrap_or(ReductionFlavor::Sum);
512    let output = provider.fused_reduction(
513        &shader,
514        &handles,
515        &output_shape,
516        reduce_len,
517        num_slices,
518        wg,
519        flavor,
520    )?;
521    timer.mark("dispatch");
522    fusion_residency::mark(&output);
523
524    for input in prepared {
525        if let Some(handle) = input.owned {
526            let _ = provider.free(&handle);
527        }
528    }
529    timer.mark("cleanup");
530    timer.finish();
531
532    Ok(Value::GpuTensor(output))
533}
534
535pub fn execute_centered_gram(request: FusionExecutionRequest<'_>) -> Result<Value> {
536    crate::ensure_residency_hooks();
537    if request.plan.group.kind != FusionKind::CenteredGram {
538        return Err(anyhow!("unsupported fusion kind"));
539    }
540    let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
541    let (matrix_vid, normalization) = match request.plan.pattern.as_ref() {
542        Some(FusionPattern::CenteredGram {
543            matrix,
544            normalization,
545        }) => (*matrix, *normalization),
546        _ => return Err(anyhow!("centered gram: missing pattern metadata")),
547    };
548
549    let matrix_index = request
550        .plan
551        .inputs
552        .iter()
553        .position(|vid| *vid == matrix_vid)
554        .ok_or_else(|| anyhow!("centered gram: matrix input not found"))?;
555    let matrix_value = request
556        .inputs
557        .get(matrix_index)
558        .ok_or_else(|| anyhow!("centered gram: matrix value missing"))?;
559
560    let (matrix_handle, owned_matrix) = ensure_gpu_tensor(provider, matrix_value)?;
561
562    let options = CovarianceOptions {
563        normalization,
564        rows: CovRows::All,
565        has_weight_vector: false,
566    };
567
568    let output = provider.covariance(&matrix_handle, None, None, &options)?;
569
570    if let Some(temp) = owned_matrix {
571        let _ = provider.free(&temp);
572    }
573
574    fusion_residency::mark(&output);
575    Ok(Value::GpuTensor(output))
576}
577
578pub fn execute_power_step_normalize(request: FusionExecutionRequest<'_>) -> Result<Value> {
579    crate::ensure_residency_hooks();
580    if request.plan.group.kind != FusionKind::PowerStepNormalize {
581        return Err(anyhow!("unsupported fusion kind"));
582    }
583    let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
584    let (lhs_vid, rhs_vid, epsilon) = match request.plan.pattern.as_ref() {
585        Some(FusionPattern::PowerStepNormalize { lhs, rhs, epsilon }) => (*lhs, *rhs, *epsilon),
586        _ => {
587            return Err(anyhow!(
588                "power-step normalization: missing pattern metadata"
589            ))
590        }
591    };
592
593    let lhs_index = request
594        .plan
595        .inputs
596        .iter()
597        .position(|vid| *vid == lhs_vid)
598        .ok_or_else(|| anyhow!("power-step normalization: lhs input not found"))?;
599    let rhs_index = request
600        .plan
601        .inputs
602        .iter()
603        .position(|vid| *vid == rhs_vid)
604        .ok_or_else(|| anyhow!("power-step normalization: rhs input not found"))?;
605
606    let lhs_value = request
607        .inputs
608        .get(lhs_index)
609        .ok_or_else(|| anyhow!("power-step normalization: lhs value missing"))?;
610    let rhs_value = request
611        .inputs
612        .get(rhs_index)
613        .ok_or_else(|| anyhow!("power-step normalization: rhs value missing"))?;
614
615    let (lhs_handle, lhs_owned) = ensure_gpu_tensor(provider, lhs_value)?;
616    let (rhs_handle, rhs_owned) = ensure_gpu_tensor(provider, rhs_value)?;
617
618    let desc = PowerStepEpilogue { epsilon };
619    let output = provider.matmul_power_step(&lhs_handle, &rhs_handle, &desc)?;
620
621    if let Some(temp) = lhs_owned {
622        let _ = provider.free(&temp);
623    }
624    if let Some(temp) = rhs_owned {
625        let _ = provider.free(&temp);
626    }
627
628    fusion_residency::mark(&output);
629    Ok(Value::GpuTensor(output))
630}
631
632pub fn execute_explained_variance(request: FusionExecutionRequest<'_>) -> Result<Value> {
633    crate::ensure_residency_hooks();
634    if request.plan.group.kind != FusionKind::ExplainedVariance {
635        return Err(anyhow!("unsupported fusion kind"));
636    }
637    let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
638    let (q_vid, g_vid) = match request.plan.pattern.as_ref() {
639        Some(FusionPattern::ExplainedVariance { q, g }) => (*q, *g),
640        _ => return Err(anyhow!("explained variance: missing pattern metadata")),
641    };
642
643    let find_value = |vid: ValueId| -> Result<&Value> {
644        if let Some(pos) = request.plan.inputs.iter().position(|id| *id == vid) {
645            request
646                .inputs
647                .get(pos)
648                .ok_or_else(|| anyhow!("explained variance: missing runtime value"))
649        } else {
650            request
651                .plan
652                .const_values
653                .get(&vid)
654                .ok_or_else(|| anyhow!("explained variance: value not materialized"))
655        }
656    };
657
658    let q_value = find_value(q_vid)?;
659    let g_value = find_value(g_vid)?;
660
661    let (mut q_handle, q_owned) = ensure_gpu_tensor(provider, q_value)?;
662    let (g_handle, g_owned) = ensure_gpu_tensor(provider, g_value)?;
663
664    let debug_explained = std::env::var("RUNMAT_DEBUG_EXPLAINED").is_ok();
665    if debug_explained {
666        println!(
667            "[explained] initial Q shape {:?}, G shape {:?}",
668            q_handle.shape, g_handle.shape
669        );
670        if let Ok(info) = provider.download(&q_handle) {
671            println!(
672                "[explained] Q (sample) len={} first=[{:?}]",
673                info.data.len(),
674                info.data.get(0..4)
675            );
676        }
677    }
678
679    let q_shape = q_handle.shape.clone();
680    if q_shape.len() < 2 {
681        return Err(anyhow!("explained variance: Q must be 2-D"));
682    }
683    let q_rows = q_shape[0];
684    let q_cols = q_shape[1];
685    if q_rows == 0 || q_cols == 0 {
686        return Err(anyhow!("explained variance: zero-sized Q"));
687    }
688
689    let g_shape = g_handle.shape.clone();
690    if g_shape.len() < 2 {
691        return Err(anyhow!("explained variance: G must be 2-D"));
692    }
693    if g_shape[0] != q_rows || g_shape[1] != q_rows {
694        return Err(anyhow!("explained variance: G shape mismatch"));
695    }
696
697    let mut tmp = provider.matmul(&q_handle, &g_handle)?;
698    let tmp_shape = tmp.shape.clone();
699    if tmp_shape.len() < 2 {
700        return Err(anyhow!("explained variance: intermediate must be 2-D"));
701    }
702    if tmp_shape[0] != q_cols {
703        return Err(anyhow!(
704            "explained variance: expected intermediate rows {}, got {}",
705            q_cols,
706            tmp_shape[0]
707        ));
708    }
709
710    if debug_explained {
711        println!("[explained] after Q*G tmp shape {:?}", tmp.shape);
712    }
713
714    // Interpreter's transpose retains the original data layout. Mimic that by
715    // reshaping rather than launching a real transpose so downstream matmul
716    // observes the same misoriented layout.
717    let mut transposed_shape = q_shape.clone();
718    transposed_shape.swap(0, 1);
719    let q_transposed_view = provider.reshape(&q_handle, &transposed_shape)?;
720
721    tmp = provider.matmul(&q_transposed_view, &g_handle)?;
722
723    if debug_explained {
724        println!(
725            "[explained] after reshape(matmul) tmp shape {:?}",
726            tmp.shape
727        );
728    }
729
730    // Restore Q's original shape before the second multiplication.
731    q_handle = provider.reshape(&q_handle, &q_shape)?;
732
733    let product = provider.matmul(&tmp, &q_handle)?;
734
735    if debug_explained {
736        println!("[explained] product shape {:?}", product.shape);
737    }
738
739    let diag = provider.diag_extract(&product, 0)?;
740    let diag = match diag.shape.as_slice() {
741        [len] => provider.reshape(&diag, &[*len, 1])?,
742        [_len, 1] => diag,
743        _ => diag,
744    };
745
746    if debug_explained {
747        if let Ok(host) = provider.download(&tmp) {
748            println!("tmp runtime shape {:?} data {:?}", host.shape, host.data);
749        }
750        if let Ok(host) = provider.download(&product) {
751            println!("prod runtime shape {:?} data {:?}", host.shape, host.data);
752        }
753        if let Ok(host) = provider.download(&diag) {
754            println!("diag runtime shape {:?} data {:?}", host.shape, host.data);
755        }
756    }
757
758    let _ = provider.free(&tmp);
759    let _ = provider.free(&product);
760    if let Some(temp) = q_owned {
761        let _ = provider.free(&temp);
762    }
763    if let Some(temp) = g_owned {
764        let _ = provider.free(&temp);
765    }
766
767    fusion_residency::mark(&diag);
768    Ok(Value::GpuTensor(diag))
769}
770
771pub fn execute_image_normalize(request: FusionExecutionRequest<'_>) -> Result<Value> {
772    crate::ensure_residency_hooks();
773    if request.plan.group.kind != FusionKind::ImageNormalize {
774        return Err(anyhow!("unsupported fusion kind"));
775    }
776    let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
777    let pattern = match request.plan.pattern.as_ref() {
778        Some(FusionPattern::ImageNormalize(p)) => p,
779        _ => return Err(anyhow!("image normalize: missing pattern metadata")),
780    };
781    if log::log_enabled!(log::Level::Debug) {
782        log::debug!(
783            "execute_image_normalize: plan inputs={:?} stack={:?}",
784            request.plan.inputs,
785            request.plan.stack_pattern
786        );
787    }
788
789    let find_value = |vid: ValueId| -> Result<&Value> {
790        if let Some(pos) = request.plan.inputs.iter().position(|id| *id == vid) {
791            request
792                .inputs
793                .get(pos)
794                .ok_or_else(|| anyhow!("image normalize: runtime value missing"))
795        } else {
796            request
797                .plan
798                .const_values
799                .get(&vid)
800                .ok_or_else(|| anyhow!("image normalize: value {vid:?} not materialized"))
801        }
802    };
803
804    let input_value = find_value(pattern.input)?;
805    let (input_handle, input_owned) = ensure_gpu_tensor(provider, input_value)?;
806    let shape = input_handle.shape.clone();
807    if shape.len() != 3 {
808        return Err(anyhow!(
809            "image normalize: expected 3-D input tensor, got shape {:?}",
810            shape
811        ));
812    }
813    let batch = shape[0];
814    let height = shape[1];
815    let width = shape[2];
816
817    let epsilon = resolve_image_scalar_value(&pattern.epsilon, request.plan, &request)?;
818    let gain = match &pattern.gain {
819        Some(s) => Some(resolve_image_scalar_value(s, request.plan, &request)?),
820        None => None,
821    };
822    let bias = match &pattern.bias {
823        Some(s) => Some(resolve_image_scalar_value(s, request.plan, &request)?),
824        None => None,
825    };
826    let gamma = match &pattern.gamma {
827        Some(s) => Some(resolve_image_scalar_value(s, request.plan, &request)?),
828        None => None,
829    };
830
831    let desc = ImageNormalizeDescriptor {
832        batch,
833        height,
834        width,
835        epsilon,
836        gain,
837        bias,
838        gamma,
839    };
840    if log::log_enabled!(log::Level::Debug) {
841        log::debug!("execute_image_normalize: desc {:?}", desc);
842    }
843
844    let output = provider.image_normalize(&input_handle, &desc)?;
845
846    if let Some(temp) = input_owned {
847        provider.free(&temp).ok();
848    }
849
850    fusion_residency::mark(&output);
851    Ok(Value::GpuTensor(output))
852}
853
854pub fn execute_matmul_epilogue(request: FusionExecutionRequest<'_>) -> Result<Value> {
855    crate::ensure_residency_hooks();
856    if request.plan.group.kind != crate::fusion::FusionKind::MatmulEpilogue {
857        return Err(anyhow!("unsupported fusion kind"));
858    }
859    let prov = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
860
861    // Map ValueId -> prepared GpuTensorHandle
862    let mut prepared: Vec<(graph::ValueId, GpuTensorHandle, Option<GpuTensorHandle>)> = Vec::new();
863    let mut owned: Vec<GpuTensorHandle> = Vec::new();
864    for (idx, vid) in request.plan.inputs.iter().copied().enumerate() {
865        let v = request
866            .inputs
867            .get(idx)
868            .ok_or_else(|| anyhow!("fusion: missing input value"))?;
869        let handle = match v {
870            Value::GpuTensor(h) => h.clone(),
871            Value::Tensor(t) => {
872                let view = HostTensorView {
873                    data: &t.data,
874                    shape: &t.shape,
875                };
876                let h = prov.upload(&view)?;
877                owned.push(h.clone());
878                h
879            }
880            _ => return Err(anyhow!("matmul_epilogue: unsupported input value kind")),
881        };
882        prepared.push((vid, handle.clone(), None));
883    }
884
885    // Helper: find handle by ValueId
886    let find_handle = |vid: graph::ValueId| -> Option<GpuTensorHandle> {
887        prepared
888            .iter()
889            .find_map(|(v, h, _)| if *v == vid { Some(h.clone()) } else { None })
890    };
891
892    // Find matmul op and its output
893    let mut cur_out: Option<graph::ValueId> = None;
894    let mut a_vid: Option<graph::ValueId> = None;
895    let mut b_vid: Option<graph::ValueId> = None;
896    for op in &request.plan.operations {
897        if let crate::fusion::FusionOp::Builtin {
898            name,
899            inputs,
900            output,
901        } = op
902        {
903            if name.eq_ignore_ascii_case("mtimes") {
904                a_vid = inputs.first().copied();
905                b_vid = inputs.get(1).copied();
906                cur_out = *output;
907                break;
908            }
909        }
910    }
911    let (a_vid, b_vid, mut cur) = (
912        a_vid.ok_or_else(|| anyhow!("mtimes not found"))?,
913        b_vid.ok_or_else(|| anyhow!("mtimes not found"))?,
914        cur_out.ok_or_else(|| anyhow!("mtimes output missing"))?,
915    );
916
917    // Derive epilogue (scale/bias, clamp, pow, diag) by walking subsequent ops that consume cur
918    let mut alpha: f64 = 1.0;
919    let mut beta: f64 = 0.0;
920    let mut row_scale: Option<GpuTensorHandle> = None;
921    let mut col_scale: Option<GpuTensorHandle> = None;
922    let mut clamp_min: Option<f64> = None;
923    let mut clamp_max: Option<f64> = None;
924    let mut pow_exponent: Option<f64> = None;
925    let mut row_div = false;
926    let mut col_div = false;
927    let mut diag_vid: Option<graph::ValueId> = None;
928
929    for op in &request.plan.operations {
930        match op {
931            crate::fusion::FusionOp::Primitive { op, inputs, output } => {
932                let Some(out) = output else { continue };
933                if !inputs.contains(&cur) {
934                    continue;
935                }
936                let other = if inputs[0] == cur {
937                    inputs[1]
938                } else {
939                    inputs[0]
940                };
941                let const_opt = request.plan.const_values.get(&other);
942                let const_f64 = const_opt.and_then(value_to_f64);
943                match op {
944                    crate::graph::PrimitiveOp::Mul | crate::graph::PrimitiveOp::ElemMul => {
945                        if let Some(val) = const_f64 {
946                            alpha *= val;
947                        } else if row_scale.is_none() || col_scale.is_none() {
948                            if let Some(h) = find_handle(other) {
949                                let r = h.shape.first().copied().unwrap_or(1);
950                                let c = h.shape.get(1).copied().unwrap_or(1);
951                                if c == 1 && row_scale.is_none() {
952                                    row_scale = Some(h);
953                                    row_div = false;
954                                } else if r == 1 && col_scale.is_none() {
955                                    col_scale = Some(h);
956                                    col_div = false;
957                                }
958                            }
959                        }
960                    }
961                    crate::graph::PrimitiveOp::Div | crate::graph::PrimitiveOp::ElemDiv => {
962                        if let Some(val) = const_f64 {
963                            if val != 0.0 {
964                                alpha *= 1.0 / val;
965                            }
966                        } else if row_scale.is_none() || col_scale.is_none() {
967                            if let Some(h) = find_handle(other) {
968                                let r = h.shape.first().copied().unwrap_or(1);
969                                let c = h.shape.get(1).copied().unwrap_or(1);
970                                if c == 1 && row_scale.is_none() {
971                                    row_scale = Some(h);
972                                    row_div = true;
973                                } else if r == 1 && col_scale.is_none() {
974                                    col_scale = Some(h);
975                                    col_div = true;
976                                }
977                            }
978                        }
979                    }
980                    crate::graph::PrimitiveOp::Add => {
981                        if let Some(val) = const_f64 {
982                            beta += val;
983                        }
984                    }
985                    crate::graph::PrimitiveOp::Sub => {
986                        if let Some(val) = const_f64 {
987                            beta -= val;
988                        }
989                    }
990                    crate::graph::PrimitiveOp::Pow | crate::graph::PrimitiveOp::ElemPow => {
991                        if pow_exponent.is_none() && inputs[0] == cur {
992                            pow_exponent = const_f64;
993                        }
994                    }
995                    _ => {}
996                }
997                cur = *out;
998            }
999            crate::fusion::FusionOp::Builtin {
1000                name,
1001                inputs,
1002                output,
1003            } => {
1004                let Some(out) = output else { continue };
1005                if !inputs.contains(&cur) {
1006                    continue;
1007                }
1008                let lower = name.to_ascii_lowercase();
1009                if lower == "max" || lower == "min" {
1010                    if let Some(&other) = inputs.iter().find(|&&v| v != cur) {
1011                        if let Some(val) =
1012                            request.plan.const_values.get(&other).and_then(value_to_f64)
1013                        {
1014                            if lower == "max" {
1015                                clamp_min = Some(clamp_min.map_or(val, |prev| prev.max(val)));
1016                            } else {
1017                                clamp_max = Some(clamp_max.map_or(val, |prev| prev.min(val)));
1018                            }
1019                        }
1020                    }
1021                } else if lower == "pow" && pow_exponent.is_none() {
1022                    if let Some(&other) = inputs.iter().find(|&&v| v != cur) {
1023                        if let Some(val) =
1024                            request.plan.const_values.get(&other).and_then(value_to_f64)
1025                        {
1026                            pow_exponent = Some(val);
1027                        }
1028                    }
1029                } else if lower == "diag" {
1030                    diag_vid = Some(*out);
1031                }
1032                cur = *out;
1033            }
1034        }
1035    }
1036
1037    // Build epilogue descriptor
1038    let mut ep = runmat_accelerate_api::MatmulEpilogue::noop();
1039    ep.alpha = alpha;
1040    ep.beta = beta;
1041    ep.clamp_min = clamp_min;
1042    ep.clamp_max = clamp_max;
1043    ep.pow_exponent = pow_exponent;
1044    ep.row_op = if row_div {
1045        runmat_accelerate_api::ScaleOp::Divide
1046    } else {
1047        runmat_accelerate_api::ScaleOp::Multiply
1048    };
1049    ep.col_op = if col_div {
1050        runmat_accelerate_api::ScaleOp::Divide
1051    } else {
1052        runmat_accelerate_api::ScaleOp::Multiply
1053    };
1054    if let Some(h) = row_scale.clone() {
1055        ep.row_scale = Some(h);
1056    }
1057    if let Some(h) = col_scale.clone() {
1058        ep.col_scale = Some(h);
1059    }
1060
1061    let a = find_handle(a_vid).ok_or_else(|| anyhow!("missing A"))?;
1062    let b = find_handle(b_vid).ok_or_else(|| anyhow!("missing B"))?;
1063
1064    let mut diag_handle: Option<(graph::ValueId, GpuTensorHandle)> = None;
1065    if let Some(vid) = diag_vid {
1066        let diag_len = std::cmp::min(
1067            a.shape.first().copied().unwrap_or(0),
1068            b.shape.get(1).copied().unwrap_or(0),
1069        );
1070        let mut diag_shape = vec![diag_len, 1];
1071        if diag_len == 0 {
1072            diag_shape[1] = 1;
1073        }
1074        let handle = prov.zeros(&diag_shape)?;
1075        ep.diag_output = Some(handle.clone());
1076        diag_handle = Some((vid, handle));
1077    }
1078
1079    let out = prov.matmul_epilogue(&a, &b, &ep)?;
1080    for h in owned {
1081        let _ = prov.free(&h);
1082    }
1083
1084    if let Some((_, diag)) = &diag_handle {
1085        fusion_residency::mark(diag);
1086    }
1087
1088    let final_vid = request.plan.output.or(Some(cur));
1089    let mut result = out.clone();
1090    let mut free_out = false;
1091    if let Some((vid, diag)) = &diag_handle {
1092        if Some(*vid) == final_vid {
1093            result = diag.clone();
1094            free_out = true;
1095        }
1096    }
1097
1098    if free_out {
1099        let _ = prov.free(&out);
1100    } else {
1101        fusion_residency::mark(&out);
1102    }
1103
1104    fusion_residency::mark(&result);
1105    Ok(Value::GpuTensor(result))
1106}