1use anyhow::{anyhow, Result};
2
3use crate::fusion::{
4 FusionGroupPlan, FusionKind, FusionPattern, FusionStoreMaterialization, ImageScalar,
5};
6use crate::fusion_residency;
7use crate::graph;
8use crate::graph::{ShapeInfo, ValueId};
9use crate::precision::ensure_provider_supports_dtype;
10use log;
11use runmat_accelerate_api::{
12 provider, AccelProvider, CovRows, CovarianceOptions, GpuTensorHandle, HostTensorView,
13 ImageNormalizeDescriptor, PowerStepEpilogue, ProviderPrecision, ReductionFlavor,
14};
15use runmat_builtins::{NumericDType, Value};
16use runmat_runtime::builtins::common::shape::normalize_scalar_shape;
17use runmat_runtime::gather_if_needed;
18use runmat_time::Instant;
19use std::collections::HashMap;
20use std::sync::OnceLock;
21
22struct PreparedInput {
23 handle: GpuTensorHandle,
24 owned: Option<GpuTensorHandle>,
25}
26
27pub struct FusionExecutionRequest<'a> {
28 pub plan: &'a FusionGroupPlan,
29 pub inputs: Vec<Value>,
30}
31
32pub struct ElementwiseExecutionResult {
33 pub final_value: Value,
34 pub materialized_stores: Vec<(FusionStoreMaterialization, Value)>,
35}
36
37#[inline]
38fn fusion_timing_enabled() -> bool {
39 static FLAG: OnceLock<bool> = OnceLock::new();
40 *FLAG.get_or_init(|| match std::env::var("RUNMAT_FUSION_TIMING") {
41 Ok(v) => matches!(
42 v.trim().to_ascii_lowercase().as_str(),
43 "1" | "true" | "yes" | "on"
44 ),
45 Err(_) => false,
46 })
47}
48
49struct FusionStageTimer {
50 inner: Option<FusionStageTimerInner>,
51}
52
53struct FusionStageTimerInner {
54 plan_index: usize,
55 kind: &'static str,
56 len: usize,
57 start: Instant,
58 last: Instant,
59 stages: Vec<(&'static str, f64)>,
60}
61
62impl FusionStageTimer {
63 fn new(kind: &'static str, plan_index: usize, len: usize) -> Self {
64 if fusion_timing_enabled() && log::log_enabled!(log::Level::Debug) {
65 let now = Instant::now();
66 Self {
67 inner: Some(FusionStageTimerInner {
68 plan_index,
69 kind,
70 len,
71 start: now,
72 last: now,
73 stages: Vec::new(),
74 }),
75 }
76 } else {
77 Self { inner: None }
78 }
79 }
80
81 fn mark(&mut self, label: &'static str) {
82 if let Some(inner) = &mut self.inner {
83 let now = Instant::now();
84 let delta = now.duration_since(inner.last).as_secs_f64() * 1000.0;
85 inner.stages.push((label, delta));
86 inner.last = now;
87 }
88 }
89
90 fn finish(self) {
91 if let Some(inner) = self.inner {
92 let total = inner.start.elapsed().as_secs_f64() * 1000.0;
93 let summary = inner
94 .stages
95 .into_iter()
96 .map(|(label, ms)| format!("{label}={ms:.3}ms"))
97 .collect::<Vec<_>>()
98 .join(" ");
99 log::debug!(
100 "fusion timing plan={} kind={} len={} {} total={:.3}ms",
101 inner.plan_index,
102 inner.kind,
103 inner.len,
104 summary,
105 total
106 );
107 }
108 }
109}
110
111fn ensure_gpu_tensor(
112 provider: &dyn AccelProvider,
113 value: &Value,
114) -> Result<(GpuTensorHandle, Option<GpuTensorHandle>)> {
115 match value {
116 Value::GpuTensor(handle) => Ok((handle.clone(), None)),
117 Value::Tensor(tensor) => {
118 let view = HostTensorView {
119 data: &tensor.data,
120 shape: &tensor.shape,
121 };
122 let handle = provider.upload(&view)?;
123 Ok((handle.clone(), Some(handle)))
124 }
125 _ => Err(anyhow!("fusion: expected tensor input")),
126 }
127}
128
129fn scalar_upload_dtype(provider: &dyn AccelProvider) -> NumericDType {
130 match provider.precision() {
131 ProviderPrecision::F32 => NumericDType::F32,
132 ProviderPrecision::F64 => NumericDType::F64,
133 }
134}
135
136fn value_to_f64(value: &Value) -> Option<f64> {
137 match value {
138 Value::Num(n) => Some(*n),
139 Value::Int(i) => Some(i.to_f64()),
140 _ => None,
141 }
142}
143
144fn scalar_from_value(value: &Value) -> Result<f64> {
145 if let Some(v) = value_to_f64(value) {
146 return Ok(v);
147 }
148 match value {
149 Value::Tensor(t) => {
150 if t.data.len() == 1 {
151 Ok(t.data[0])
152 } else {
153 Err(anyhow!(
154 "image normalize: expected scalar tensor, got {} elements",
155 t.data.len()
156 ))
157 }
158 }
159 Value::GpuTensor(_) => {
160 let gathered = gather_if_needed(value).map_err(|e| anyhow!("image normalize: {e}"))?;
161 scalar_from_value(&gathered)
162 }
163 _ => Err(anyhow!(
164 "image normalize: expected numeric scalar value, got {:?}",
165 value
166 )),
167 }
168}
169
170fn resolve_image_scalar_value(
171 scalar: &ImageScalar,
172 plan: &FusionGroupPlan,
173 request: &FusionExecutionRequest<'_>,
174) -> Result<f64> {
175 match scalar {
176 ImageScalar::Constant(v) => Ok(*v),
177 ImageScalar::Value(vid) => {
178 if let Some(value) = plan.const_values.get(vid) {
179 return scalar_from_value(value);
180 }
181 if let Some(idx) = plan.inputs.iter().position(|id| *id == *vid) {
182 let runtime_value = request
183 .inputs
184 .get(idx)
185 .ok_or_else(|| anyhow!("image normalize: runtime scalar missing"))?;
186 return scalar_from_value(runtime_value);
187 }
188 Err(anyhow!(
189 "image normalize: scalar input {:?} not materialized in plan",
190 vid
191 ))
192 }
193 }
194}
195
196fn execute_elementwise_outputs(
197 request: &FusionExecutionRequest<'_>,
198 output_ids: &[ValueId],
199) -> Result<HashMap<ValueId, Value>> {
200 crate::ensure_residency_hooks();
201 if !request.plan.group.kind.is_elementwise() {
202 return Err(anyhow!("unsupported fusion kind"));
203 }
204 let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
205 if !request.plan.kernel.supported {
206 return Err(anyhow!("fusion kernel not supported for this plan"));
207 }
208 if request.inputs.len() != request.plan.inputs.len() {
209 return Err(anyhow!(
210 "fusion input mismatch: expected {}, got {}",
211 request.plan.inputs.len(),
212 request.inputs.len()
213 ));
214 }
215 fn runtime_broadcast_shape(values: &[Value]) -> Option<Vec<usize>> {
217 let mut shapes: Vec<Vec<usize>> = Vec::new();
219 for v in values {
220 match v {
221 Value::GpuTensor(h) => shapes.push(h.shape.clone()),
222 Value::Tensor(t) => shapes.push(t.shape.clone()),
223 Value::Num(_) | Value::Int(_) => shapes.push(Vec::new()),
224 _ => return None, }
226 }
227 let rank = shapes.iter().map(|s| s.len()).max().unwrap_or(0);
228 let mut out = vec![1usize; rank];
229 for shape in shapes {
230 let offset = rank.saturating_sub(shape.len());
231 for (i, &dim) in shape.iter().enumerate() {
232 let j = offset + i;
233 let a = out[j];
234 let b = dim;
235 if a == 1 {
236 out[j] = b.max(1);
237 } else if b == 1 || a == b {
238 } else {
240 return None; }
242 }
243 }
244 Some(out)
245 }
246 let runtime_shape = runtime_broadcast_shape(&request.inputs);
248 let mut output_shape = match &request.plan.group.shape {
249 ShapeInfo::Tensor(dims) if !dims.is_empty() && dims.iter().all(|d| d.is_some()) => {
250 dims.iter().map(|d| d.unwrap_or(1)).collect()
251 }
252 ShapeInfo::Tensor(dims) if !dims.is_empty() => {
253 let rt_shape = runtime_shape
254 .clone()
255 .ok_or_else(|| anyhow!("fusion: unknown output shape"))?;
256 if rt_shape.len() != dims.len() {
257 rt_shape
258 } else {
259 dims.iter()
260 .zip(rt_shape.iter())
261 .map(|(planned, runtime)| planned.unwrap_or(*runtime))
262 .collect()
263 }
264 }
265 _ => runtime_shape.ok_or_else(|| anyhow!("fusion: unknown output shape"))?,
266 };
267 let mut len: usize = output_shape.iter().copied().product();
268 if len == 0 {
269 if let Some(rt_shape) = runtime_broadcast_shape(&request.inputs) {
270 output_shape = rt_shape;
271 len = output_shape.iter().copied().product();
272 }
273 if len == 0 {
274 return Err(anyhow!("fusion: zero-length execution not supported"));
275 }
276 }
277 output_shape = normalize_scalar_shape(&output_shape);
278 let mut timer = FusionStageTimer::new("elementwise", request.plan.index, len);
279 let scalar_shape = normalize_scalar_shape(&vec![1; output_shape.len()]);
280 let mut prepared = Vec::with_capacity(request.inputs.len());
281 let mut temp_scalars: Vec<Vec<f64>> = Vec::new();
282 let scalar_dtype = scalar_upload_dtype(provider);
283 for value in &request.inputs {
284 match value {
285 Value::GpuTensor(handle) => prepared.push(PreparedInput {
286 handle: handle.clone(),
287 owned: None,
288 }),
289 Value::Tensor(t) => {
290 if let Err(msg) = ensure_provider_supports_dtype(provider, t.dtype) {
291 return Err(anyhow!(
292 "fusion: tensor input requires unsupported precision ({msg})"
293 ));
294 }
295 let view = HostTensorView {
296 data: &t.data,
297 shape: &t.shape,
298 };
299 let handle = provider.upload(&view)?;
300 prepared.push(PreparedInput {
301 handle: handle.clone(),
302 owned: Some(handle),
303 });
304 }
305 Value::Num(n) => {
306 if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
307 return Err(anyhow!(
308 "fusion: scalar input requires unsupported precision ({msg})"
309 ));
310 }
311 let scalar = match provider.precision() {
312 ProviderPrecision::F32 => (*n as f32) as f64,
313 ProviderPrecision::F64 => *n,
314 };
315 temp_scalars.push(vec![scalar]);
316 let data = temp_scalars.last().unwrap();
317 let view = HostTensorView {
318 data,
319 shape: &scalar_shape,
320 };
321 let handle = provider.upload(&view)?;
322 prepared.push(PreparedInput {
323 handle: handle.clone(),
324 owned: Some(handle),
325 });
326 }
327 Value::Int(i) => {
328 if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
329 return Err(anyhow!(
330 "fusion: scalar input requires unsupported precision ({msg})"
331 ));
332 }
333 let scalar = match provider.precision() {
334 ProviderPrecision::F32 => (i.to_f64() as f32) as f64,
335 ProviderPrecision::F64 => i.to_f64(),
336 };
337 temp_scalars.push(vec![scalar]);
338 let data = temp_scalars.last().unwrap();
339 let view = HostTensorView {
340 data,
341 shape: &scalar_shape,
342 };
343 let handle = provider.upload(&view)?;
344 prepared.push(PreparedInput {
345 handle: handle.clone(),
346 owned: Some(handle),
347 });
348 }
349 _ => {
350 return Err(anyhow!("fusion: unsupported value type"));
351 }
352 }
353 }
354 timer.mark("prepare_inputs");
355
356 let scalar_ty = match provider.precision() {
357 ProviderPrecision::F32 => "f32",
358 ProviderPrecision::F64 => "f64",
359 };
360 let handles: Vec<GpuTensorHandle> = prepared.iter().map(|p| p.handle.clone()).collect();
361 let mut outputs = HashMap::new();
362
363 if output_ids.len() == 1 {
364 let output_id = output_ids[0];
366 let shader = request
367 .plan
368 .generate_wgsl_for_output(output_id, scalar_ty)
369 .ok_or_else(|| anyhow!("fusion: WGSL generation failed for output {output_id}"))?;
370 timer.mark("generate_wgsl");
371 let output = provider.fused_elementwise(&shader, &handles, &output_shape, len)?;
372 fusion_residency::mark(&output);
373 outputs.insert(output_id, Value::GpuTensor(output));
374 } else {
375 let shader = request
378 .plan
379 .generate_wgsl_for_outputs(output_ids, scalar_ty)
380 .ok_or_else(|| anyhow!("fusion: multi-output WGSL generation failed"))?;
381 timer.mark("generate_wgsl");
382 match provider.fused_elementwise_multi(
383 &shader,
384 &handles,
385 &output_shape,
386 len,
387 output_ids.len(),
388 ) {
389 Ok(out_handles) => {
390 for (output_id, handle) in output_ids.iter().copied().zip(out_handles) {
391 fusion_residency::mark(&handle);
392 outputs.insert(output_id, Value::GpuTensor(handle));
393 }
394 }
395 Err(_) => {
396 for output_id in output_ids.iter().copied() {
398 let single_shader = request
399 .plan
400 .generate_wgsl_for_output(output_id, scalar_ty)
401 .ok_or_else(|| {
402 anyhow!("fusion: WGSL generation failed for output {output_id}")
403 })?;
404 let output =
405 provider.fused_elementwise(&single_shader, &handles, &output_shape, len)?;
406 fusion_residency::mark(&output);
407 outputs.insert(output_id, Value::GpuTensor(output));
408 }
409 }
410 }
411 }
412 timer.mark("dispatch");
413
414 for input in prepared {
416 if let Some(handle) = input.owned {
417 let _ = provider.free(&handle);
418 }
419 }
420 timer.mark("cleanup");
421 timer.finish();
422
423 Ok(outputs)
424}
425
426pub fn execute_elementwise(
427 request: FusionExecutionRequest<'_>,
428) -> Result<ElementwiseExecutionResult> {
429 let final_output = request
430 .plan
431 .output
432 .ok_or_else(|| anyhow!("fusion: missing final output value id"))?;
433 let mut output_ids = vec![final_output];
434 for store in &request.plan.materialized_stores {
435 if !output_ids.contains(&store.value_id) {
436 output_ids.push(store.value_id);
437 }
438 }
439 let mut outputs = execute_elementwise_outputs(&request, &output_ids)?;
440 let final_value = outputs
441 .remove(&final_output)
442 .ok_or_else(|| anyhow!("fusion: final output materialization missing"))?;
443 let materialized_stores = request
444 .plan
445 .materialized_stores
446 .iter()
447 .filter_map(|store| {
448 let value = if store.value_id == final_output {
451 Some(final_value.clone())
452 } else {
453 outputs.remove(&store.value_id)
454 };
455 value.map(|v| (store.clone(), v))
456 })
457 .collect();
458 Ok(ElementwiseExecutionResult {
459 final_value,
460 materialized_stores,
461 })
462}
463
464pub fn execute_reduction(
465 request: FusionExecutionRequest<'_>,
466 reduce_len: usize,
467 num_slices: usize,
468 workgroup_size: u32,
469) -> Result<Value> {
470 if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION").is_ok() {
471 return Err(anyhow!("fused reduction disabled by env"));
472 }
473 crate::ensure_residency_hooks();
474 if !request.plan.group.kind.is_reduction() {
475 return Err(anyhow!("unsupported fusion kind"));
476 }
477 let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
478 if !request.plan.kernel.supported {
479 return Err(anyhow!("fusion kernel not supported for this plan"));
480 }
481 if request.inputs.len() != request.plan.inputs.len() {
482 return Err(anyhow!(
483 "fusion input mismatch: expected {}, got {}",
484 request.plan.inputs.len(),
485 request.inputs.len()
486 ));
487 }
488 let len = reduce_len * num_slices;
489 if len == 0 {
490 return Err(anyhow!("fusion: zero-length execution not supported"));
491 }
492 let scalar_shape = {
493 let constant_shape = request.plan.constant_shape(len);
494 normalize_scalar_shape(&vec![1; constant_shape.len()])
495 };
496 let mut timer = FusionStageTimer::new("reduction", request.plan.index, len);
497 let mut prepared = Vec::with_capacity(request.inputs.len());
498 let mut temp_scalars: Vec<Vec<f64>> = Vec::new();
499 let scalar_dtype = scalar_upload_dtype(provider);
500 for value in &request.inputs {
501 match value {
502 Value::GpuTensor(handle) => prepared.push(PreparedInput {
503 handle: handle.clone(),
504 owned: None,
505 }),
506 Value::Tensor(t) => {
507 if let Err(msg) = ensure_provider_supports_dtype(provider, t.dtype) {
508 return Err(anyhow!(
509 "fusion: tensor input requires unsupported precision ({msg})"
510 ));
511 }
512 let view = HostTensorView {
513 data: &t.data,
514 shape: &t.shape,
515 };
516 let handle = provider.upload(&view)?;
517 prepared.push(PreparedInput {
518 handle: handle.clone(),
519 owned: Some(handle),
520 });
521 }
522 Value::Num(n) => {
523 if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
524 return Err(anyhow!(
525 "fusion: scalar input requires unsupported precision ({msg})"
526 ));
527 }
528 let scalar = match provider.precision() {
529 ProviderPrecision::F32 => (*n as f32) as f64,
530 ProviderPrecision::F64 => *n,
531 };
532 temp_scalars.push(vec![scalar]);
533 let data = temp_scalars.last().unwrap();
534 let view = HostTensorView {
535 data,
536 shape: &scalar_shape,
537 };
538 let handle = provider.upload(&view)?;
539 prepared.push(PreparedInput {
540 handle: handle.clone(),
541 owned: Some(handle),
542 });
543 }
544 Value::Int(i) => {
545 if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
546 return Err(anyhow!(
547 "fusion: scalar input requires unsupported precision ({msg})"
548 ));
549 }
550 let scalar = match provider.precision() {
551 ProviderPrecision::F32 => (i.to_f64() as f32) as f64,
552 ProviderPrecision::F64 => i.to_f64(),
553 };
554 temp_scalars.push(vec![scalar]);
555 let data = temp_scalars.last().unwrap();
556 let view = HostTensorView {
557 data,
558 shape: &scalar_shape,
559 };
560 let handle = provider.upload(&view)?;
561 prepared.push(PreparedInput {
562 handle: handle.clone(),
563 owned: Some(handle),
564 });
565 }
566 _ => return Err(anyhow!("fusion: unsupported value type")),
567 }
568 }
569 timer.mark("prepare_inputs");
570
571 let handles: Vec<GpuTensorHandle> = prepared.iter().map(|p| p.handle.clone()).collect();
572 let output_shape = vec![num_slices];
573
574 let scalar_ty = match provider.precision() {
575 ProviderPrecision::F32 => "f32",
576 ProviderPrecision::F64 => "f64",
577 };
578 let shader = request
579 .plan
580 .generate_reduction_wgsl(scalar_ty)
581 .ok_or_else(|| anyhow!("fusion: reduction WGSL generation failed"))?;
582 timer.mark("generate_wgsl");
583 if std::env::var("RUNMAT_DEBUG_DUMP_FUSED_WGSL").is_ok() {
584 println!(
585 "---- fused reduction WGSL ----\n{}\n------------------------------",
586 shader
587 );
588 }
589
590 let mut wg = if workgroup_size == 0 {
591 provider.default_reduction_workgroup_size()
592 } else {
593 workgroup_size
594 };
595 if let Ok(raw) = std::env::var("RUNMAT_FUSED_WG") {
596 if let Ok(val) = raw.trim().parse::<u32>() {
597 if val > 0 {
598 let capped = val.min(provider.default_reduction_workgroup_size());
599 wg = capped.max(1);
600 }
601 }
602 }
603 let flavor = request
604 .plan
605 .reduction_flavor
606 .unwrap_or(ReductionFlavor::Sum);
607 let output = provider.fused_reduction(
608 &shader,
609 &handles,
610 &output_shape,
611 reduce_len,
612 num_slices,
613 wg,
614 flavor,
615 )?;
616 timer.mark("dispatch");
617 fusion_residency::mark(&output);
618
619 for input in prepared {
620 if let Some(handle) = input.owned {
621 let _ = provider.free(&handle);
622 }
623 }
624 timer.mark("cleanup");
625 timer.finish();
626
627 Ok(Value::GpuTensor(output))
628}
629
630pub async fn execute_centered_gram(request: FusionExecutionRequest<'_>) -> Result<Value> {
631 crate::ensure_residency_hooks();
632 if request.plan.group.kind != FusionKind::CenteredGram {
633 return Err(anyhow!("unsupported fusion kind"));
634 }
635 let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
636 let (matrix_vid, normalization) = match request.plan.pattern.as_ref() {
637 Some(FusionPattern::CenteredGram {
638 matrix,
639 normalization,
640 }) => (*matrix, *normalization),
641 _ => return Err(anyhow!("centered gram: missing pattern metadata")),
642 };
643
644 let matrix_index = request
645 .plan
646 .inputs
647 .iter()
648 .position(|vid| *vid == matrix_vid)
649 .ok_or_else(|| anyhow!("centered gram: matrix input not found"))?;
650 let matrix_value = request
651 .inputs
652 .get(matrix_index)
653 .ok_or_else(|| anyhow!("centered gram: matrix value missing"))?;
654
655 let (matrix_handle, owned_matrix) = ensure_gpu_tensor(provider, matrix_value)?;
656
657 let options = CovarianceOptions {
658 normalization,
659 rows: CovRows::All,
660 has_weight_vector: false,
661 };
662
663 let output = provider
664 .covariance(&matrix_handle, None, None, &options)
665 .await?;
666
667 if let Some(temp) = owned_matrix {
668 let _ = provider.free(&temp);
669 }
670
671 fusion_residency::mark(&output);
672 Ok(Value::GpuTensor(output))
673}
674
675pub async fn execute_power_step_normalize(request: FusionExecutionRequest<'_>) -> Result<Value> {
676 crate::ensure_residency_hooks();
677 if request.plan.group.kind != FusionKind::PowerStepNormalize {
678 return Err(anyhow!("unsupported fusion kind"));
679 }
680 let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
681 let (lhs_vid, rhs_vid, epsilon) = match request.plan.pattern.as_ref() {
682 Some(FusionPattern::PowerStepNormalize { lhs, rhs, epsilon }) => (*lhs, *rhs, *epsilon),
683 _ => {
684 return Err(anyhow!(
685 "power-step normalization: missing pattern metadata"
686 ))
687 }
688 };
689
690 let lhs_index = request
691 .plan
692 .inputs
693 .iter()
694 .position(|vid| *vid == lhs_vid)
695 .ok_or_else(|| anyhow!("power-step normalization: lhs input not found"))?;
696 let rhs_index = request
697 .plan
698 .inputs
699 .iter()
700 .position(|vid| *vid == rhs_vid)
701 .ok_or_else(|| anyhow!("power-step normalization: rhs input not found"))?;
702
703 let lhs_value = request
704 .inputs
705 .get(lhs_index)
706 .ok_or_else(|| anyhow!("power-step normalization: lhs value missing"))?;
707 let rhs_value = request
708 .inputs
709 .get(rhs_index)
710 .ok_or_else(|| anyhow!("power-step normalization: rhs value missing"))?;
711
712 let (lhs_handle, lhs_owned) = ensure_gpu_tensor(provider, lhs_value)?;
713 let (rhs_handle, rhs_owned) = ensure_gpu_tensor(provider, rhs_value)?;
714
715 let desc = PowerStepEpilogue { epsilon };
716 let output = provider
717 .matmul_power_step(&lhs_handle, &rhs_handle, &desc)
718 .await?;
719
720 if let Some(temp) = lhs_owned {
721 let _ = provider.free(&temp);
722 }
723 if let Some(temp) = rhs_owned {
724 let _ = provider.free(&temp);
725 }
726
727 fusion_residency::mark(&output);
728 Ok(Value::GpuTensor(output))
729}
730
731pub async fn execute_explained_variance(request: FusionExecutionRequest<'_>) -> Result<Value> {
732 crate::ensure_residency_hooks();
733 if request.plan.group.kind != FusionKind::ExplainedVariance {
734 return Err(anyhow!("unsupported fusion kind"));
735 }
736 let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
737 let (q_vid, g_vid) = match request.plan.pattern.as_ref() {
738 Some(FusionPattern::ExplainedVariance { q, g }) => (*q, *g),
739 _ => return Err(anyhow!("explained variance: missing pattern metadata")),
740 };
741
742 let find_value = |vid: ValueId| -> Result<&Value> {
743 if let Some(pos) = request.plan.inputs.iter().position(|id| *id == vid) {
744 request
745 .inputs
746 .get(pos)
747 .ok_or_else(|| anyhow!("explained variance: missing runtime value"))
748 } else {
749 request
750 .plan
751 .const_values
752 .get(&vid)
753 .ok_or_else(|| anyhow!("explained variance: value not materialized"))
754 }
755 };
756
757 let q_value = find_value(q_vid)?;
758 let g_value = find_value(g_vid)?;
759
760 let (mut q_handle, q_owned) = ensure_gpu_tensor(provider, q_value)?;
761 let (g_handle, g_owned) = ensure_gpu_tensor(provider, g_value)?;
762
763 let debug_explained = std::env::var("RUNMAT_DEBUG_EXPLAINED").is_ok();
764 if debug_explained {
765 println!(
766 "[explained] initial Q shape {:?}, G shape {:?}",
767 q_handle.shape, g_handle.shape
768 );
769 if let Ok(info) = provider.download(&q_handle).await {
770 println!(
771 "[explained] Q (sample) len={} first=[{:?}]",
772 info.data.len(),
773 info.data.get(0..4)
774 );
775 }
776 }
777
778 let q_shape = q_handle.shape.clone();
779 if q_shape.len() < 2 {
780 return Err(anyhow!("explained variance: Q must be 2-D"));
781 }
782 let q_rows = q_shape[0];
783 let q_cols = q_shape[1];
784 if q_rows == 0 || q_cols == 0 {
785 return Err(anyhow!("explained variance: zero-sized Q"));
786 }
787
788 let g_shape = g_handle.shape.clone();
789 if g_shape.len() < 2 {
790 return Err(anyhow!("explained variance: G must be 2-D"));
791 }
792 if g_shape[0] != q_rows || g_shape[1] != q_rows {
793 return Err(anyhow!("explained variance: G shape mismatch"));
794 }
795
796 let mut tmp = provider.matmul(&q_handle, &g_handle).await?;
797 let tmp_shape = tmp.shape.clone();
798 if tmp_shape.len() < 2 {
799 return Err(anyhow!("explained variance: intermediate must be 2-D"));
800 }
801 if tmp_shape[0] != q_cols {
802 return Err(anyhow!(
803 "explained variance: expected intermediate rows {}, got {}",
804 q_cols,
805 tmp_shape[0]
806 ));
807 }
808
809 if debug_explained {
810 println!("[explained] after Q*G tmp shape {:?}", tmp.shape);
811 }
812
813 let mut transposed_shape = q_shape.clone();
817 transposed_shape.swap(0, 1);
818 let q_transposed_view = provider.reshape(&q_handle, &transposed_shape)?;
819
820 tmp = provider.matmul(&q_transposed_view, &g_handle).await?;
821
822 if debug_explained {
823 println!(
824 "[explained] after reshape(matmul) tmp shape {:?}",
825 tmp.shape
826 );
827 }
828
829 q_handle = provider.reshape(&q_handle, &q_shape)?;
831
832 let product = provider.matmul(&tmp, &q_handle).await?;
833
834 if debug_explained {
835 println!("[explained] product shape {:?}", product.shape);
836 }
837
838 let diag = provider.diag_extract(&product, 0)?;
839 let diag = match diag.shape.as_slice() {
840 [len] => provider.reshape(&diag, &[*len, 1])?,
841 [_len, 1] => diag,
842 _ => diag,
843 };
844
845 if debug_explained {
846 if let Ok(host) = provider.download(&tmp).await {
847 println!("tmp runtime shape {:?} data {:?}", host.shape, host.data);
848 }
849 if let Ok(host) = provider.download(&product).await {
850 println!("prod runtime shape {:?} data {:?}", host.shape, host.data);
851 }
852 if let Ok(host) = provider.download(&diag).await {
853 println!("diag runtime shape {:?} data {:?}", host.shape, host.data);
854 }
855 }
856
857 let _ = provider.free(&tmp);
858 let _ = provider.free(&product);
859 if let Some(temp) = q_owned {
860 let _ = provider.free(&temp);
861 }
862 if let Some(temp) = g_owned {
863 let _ = provider.free(&temp);
864 }
865
866 fusion_residency::mark(&diag);
867 Ok(Value::GpuTensor(diag))
868}
869
870pub async fn execute_image_normalize(request: FusionExecutionRequest<'_>) -> Result<Value> {
871 crate::ensure_residency_hooks();
872 if request.plan.group.kind != FusionKind::ImageNormalize {
873 return Err(anyhow!("unsupported fusion kind"));
874 }
875 let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
876 let pattern = match request.plan.pattern.as_ref() {
877 Some(FusionPattern::ImageNormalize(p)) => p,
878 _ => return Err(anyhow!("image normalize: missing pattern metadata")),
879 };
880 if log::log_enabled!(log::Level::Debug) {
881 log::debug!(
882 "execute_image_normalize: plan inputs={:?} stack={:?}",
883 request.plan.inputs,
884 request.plan.stack_pattern
885 );
886 }
887
888 let find_value = |vid: ValueId| -> Result<&Value> {
889 if let Some(pos) = request.plan.inputs.iter().position(|id| *id == vid) {
890 request
891 .inputs
892 .get(pos)
893 .ok_or_else(|| anyhow!("image normalize: runtime value missing"))
894 } else {
895 request
896 .plan
897 .const_values
898 .get(&vid)
899 .ok_or_else(|| anyhow!("image normalize: value {vid:?} not materialized"))
900 }
901 };
902
903 let input_value = find_value(pattern.input)?;
904 let (input_handle, input_owned) = ensure_gpu_tensor(provider, input_value)?;
905 let shape = input_handle.shape.clone();
906 if shape.len() != 3 {
907 return Err(anyhow!(
908 "image normalize: expected 3-D input tensor, got shape {:?}",
909 shape
910 ));
911 }
912 let batch = shape[0];
913 let height = shape[1];
914 let width = shape[2];
915
916 let epsilon = resolve_image_scalar_value(&pattern.epsilon, request.plan, &request)?;
917 let gain = match &pattern.gain {
918 Some(s) => Some(resolve_image_scalar_value(s, request.plan, &request)?),
919 None => None,
920 };
921 let bias = match &pattern.bias {
922 Some(s) => Some(resolve_image_scalar_value(s, request.plan, &request)?),
923 None => None,
924 };
925 let gamma = match &pattern.gamma {
926 Some(s) => Some(resolve_image_scalar_value(s, request.plan, &request)?),
927 None => None,
928 };
929
930 let desc = ImageNormalizeDescriptor {
931 batch,
932 height,
933 width,
934 epsilon,
935 gain,
936 bias,
937 gamma,
938 };
939 if log::log_enabled!(log::Level::Trace) {
940 log::trace!("execute_image_normalize: desc {:?}", desc);
941 }
942
943 let output = provider.image_normalize(&input_handle, &desc).await?;
944
945 if let Some(temp) = input_owned {
946 provider.free(&temp).ok();
947 }
948
949 fusion_residency::mark(&output);
950 Ok(Value::GpuTensor(output))
951}
952
953pub async fn execute_matmul_epilogue(request: FusionExecutionRequest<'_>) -> Result<Value> {
954 crate::ensure_residency_hooks();
955 if request.plan.group.kind != crate::fusion::FusionKind::MatmulEpilogue {
956 return Err(anyhow!("unsupported fusion kind"));
957 }
958 let prov = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
959
960 let mut prepared: Vec<(graph::ValueId, GpuTensorHandle, Option<GpuTensorHandle>)> = Vec::new();
962 let mut owned: Vec<GpuTensorHandle> = Vec::new();
963 for (idx, vid) in request.plan.inputs.iter().copied().enumerate() {
964 let v = request
965 .inputs
966 .get(idx)
967 .ok_or_else(|| anyhow!("fusion: missing input value"))?;
968 let handle = match v {
969 Value::GpuTensor(h) => h.clone(),
970 Value::Tensor(t) => {
971 let view = HostTensorView {
972 data: &t.data,
973 shape: &t.shape,
974 };
975 let h = prov.upload(&view)?;
976 owned.push(h.clone());
977 h
978 }
979 _ => return Err(anyhow!("matmul_epilogue: unsupported input value kind")),
980 };
981 prepared.push((vid, handle.clone(), None));
982 }
983
984 let find_handle = |vid: graph::ValueId| -> Option<GpuTensorHandle> {
986 prepared
987 .iter()
988 .find_map(|(v, h, _)| if *v == vid { Some(h.clone()) } else { None })
989 };
990
991 let mut cur_out: Option<graph::ValueId> = None;
993 let mut a_vid: Option<graph::ValueId> = None;
994 let mut b_vid: Option<graph::ValueId> = None;
995 for op in &request.plan.operations {
996 if let crate::fusion::FusionOp::Builtin {
997 name,
998 inputs,
999 output,
1000 } = op
1001 {
1002 if name.eq_ignore_ascii_case("mtimes") {
1003 a_vid = inputs.first().copied();
1004 b_vid = inputs.get(1).copied();
1005 cur_out = *output;
1006 break;
1007 }
1008 }
1009 }
1010 let (a_vid, b_vid, mut cur) = (
1011 a_vid.ok_or_else(|| anyhow!("mtimes not found"))?,
1012 b_vid.ok_or_else(|| anyhow!("mtimes not found"))?,
1013 cur_out.ok_or_else(|| anyhow!("mtimes output missing"))?,
1014 );
1015
1016 let mut alpha: f64 = 1.0;
1018 let mut beta: f64 = 0.0;
1019 let mut row_scale: Option<GpuTensorHandle> = None;
1020 let mut col_scale: Option<GpuTensorHandle> = None;
1021 let mut clamp_min: Option<f64> = None;
1022 let mut clamp_max: Option<f64> = None;
1023 let mut pow_exponent: Option<f64> = None;
1024 let mut row_div = false;
1025 let mut col_div = false;
1026 let mut diag_vid: Option<graph::ValueId> = None;
1027
1028 for op in &request.plan.operations {
1029 match op {
1030 crate::fusion::FusionOp::Primitive { op, inputs, output } => {
1031 let Some(out) = output else { continue };
1032 if !inputs.contains(&cur) {
1033 continue;
1034 }
1035 let other = if inputs[0] == cur {
1036 inputs[1]
1037 } else {
1038 inputs[0]
1039 };
1040 let const_opt = request.plan.const_values.get(&other);
1041 let const_f64 = const_opt.and_then(value_to_f64);
1042 match op {
1043 crate::graph::PrimitiveOp::Mul | crate::graph::PrimitiveOp::ElemMul => {
1044 if let Some(val) = const_f64 {
1045 alpha *= val;
1046 } else if row_scale.is_none() || col_scale.is_none() {
1047 if let Some(h) = find_handle(other) {
1048 let r = h.shape.first().copied().unwrap_or(1);
1049 let c = h.shape.get(1).copied().unwrap_or(1);
1050 if c == 1 && row_scale.is_none() {
1051 row_scale = Some(h);
1052 row_div = false;
1053 } else if r == 1 && col_scale.is_none() {
1054 col_scale = Some(h);
1055 col_div = false;
1056 }
1057 }
1058 }
1059 }
1060 crate::graph::PrimitiveOp::ElemDiv => {
1061 if let Some(val) = const_f64 {
1062 if val != 0.0 {
1063 alpha *= 1.0 / val;
1064 }
1065 } else if row_scale.is_none() || col_scale.is_none() {
1066 if let Some(h) = find_handle(other) {
1067 let r = h.shape.first().copied().unwrap_or(1);
1068 let c = h.shape.get(1).copied().unwrap_or(1);
1069 if c == 1 && row_scale.is_none() {
1070 row_scale = Some(h);
1071 row_div = true;
1072 } else if r == 1 && col_scale.is_none() {
1073 col_scale = Some(h);
1074 col_div = true;
1075 }
1076 }
1077 }
1078 }
1079 crate::graph::PrimitiveOp::Add => {
1080 if let Some(val) = const_f64 {
1081 beta += val;
1082 }
1083 }
1084 crate::graph::PrimitiveOp::Sub => {
1085 if let Some(val) = const_f64 {
1086 beta -= val;
1087 }
1088 }
1089 crate::graph::PrimitiveOp::Pow | crate::graph::PrimitiveOp::ElemPow => {
1090 if pow_exponent.is_none() && inputs[0] == cur {
1091 pow_exponent = const_f64;
1092 }
1093 }
1094 _ => {}
1095 }
1096 cur = *out;
1097 }
1098 crate::fusion::FusionOp::Builtin {
1099 name,
1100 inputs,
1101 output,
1102 } => {
1103 let Some(out) = output else { continue };
1104 if !inputs.contains(&cur) {
1105 continue;
1106 }
1107 let lower = name.to_ascii_lowercase();
1108 if lower == "max" || lower == "min" {
1109 if let Some(&other) = inputs.iter().find(|&&v| v != cur) {
1110 if let Some(val) =
1111 request.plan.const_values.get(&other).and_then(value_to_f64)
1112 {
1113 if lower == "max" {
1114 clamp_min = Some(clamp_min.map_or(val, |prev| prev.max(val)));
1115 } else {
1116 clamp_max = Some(clamp_max.map_or(val, |prev| prev.min(val)));
1117 }
1118 }
1119 }
1120 } else if lower == "pow" && pow_exponent.is_none() {
1121 if let Some(&other) = inputs.iter().find(|&&v| v != cur) {
1122 if let Some(val) =
1123 request.plan.const_values.get(&other).and_then(value_to_f64)
1124 {
1125 pow_exponent = Some(val);
1126 }
1127 }
1128 } else if lower == "diag" {
1129 diag_vid = Some(*out);
1130 }
1131 cur = *out;
1132 }
1133 }
1134 }
1135
1136 let mut ep = runmat_accelerate_api::MatmulEpilogue::noop();
1138 ep.alpha = alpha;
1139 ep.beta = beta;
1140 ep.clamp_min = clamp_min;
1141 ep.clamp_max = clamp_max;
1142 ep.pow_exponent = pow_exponent;
1143 ep.row_op = if row_div {
1144 runmat_accelerate_api::ScaleOp::Divide
1145 } else {
1146 runmat_accelerate_api::ScaleOp::Multiply
1147 };
1148 ep.col_op = if col_div {
1149 runmat_accelerate_api::ScaleOp::Divide
1150 } else {
1151 runmat_accelerate_api::ScaleOp::Multiply
1152 };
1153 if let Some(h) = row_scale.clone() {
1154 ep.row_scale = Some(h);
1155 }
1156 if let Some(h) = col_scale.clone() {
1157 ep.col_scale = Some(h);
1158 }
1159
1160 let a = find_handle(a_vid).ok_or_else(|| anyhow!("missing A"))?;
1161 let b = find_handle(b_vid).ok_or_else(|| anyhow!("missing B"))?;
1162
1163 let mut diag_handle: Option<(graph::ValueId, GpuTensorHandle)> = None;
1164 if let Some(vid) = diag_vid {
1165 let diag_len = std::cmp::min(
1166 a.shape.first().copied().unwrap_or(0),
1167 b.shape.get(1).copied().unwrap_or(0),
1168 );
1169 let mut diag_shape = vec![diag_len, 1];
1170 if diag_len == 0 {
1171 diag_shape[1] = 1;
1172 }
1173 let handle = prov.zeros(&diag_shape)?;
1174 ep.diag_output = Some(handle.clone());
1175 diag_handle = Some((vid, handle));
1176 }
1177
1178 let out = prov.matmul_epilogue(&a, &b, &ep).await?;
1179 for h in owned {
1180 let _ = prov.free(&h);
1181 }
1182
1183 if let Some((_, diag)) = &diag_handle {
1184 fusion_residency::mark(diag);
1185 }
1186
1187 let final_vid = request.plan.output.or(Some(cur));
1188 let mut result = out.clone();
1189 let mut free_out = false;
1190 if let Some((vid, diag)) = &diag_handle {
1191 if Some(*vid) == final_vid {
1192 result = diag.clone();
1193 free_out = true;
1194 }
1195 }
1196
1197 if free_out {
1198 let _ = prov.free(&out);
1199 } else {
1200 fusion_residency::mark(&out);
1201 }
1202
1203 fusion_residency::mark(&result);
1204 Ok(Value::GpuTensor(result))
1205}