1use anyhow::{anyhow, Result};
2
3use crate::fusion::{FusionGroupPlan, FusionKind, FusionPattern, ImageScalar};
4use crate::fusion_residency;
5use crate::graph;
6use crate::graph::{ShapeInfo, ValueId};
7use crate::precision::ensure_provider_supports_dtype;
8use log;
9use runmat_accelerate_api::{
10 provider, AccelProvider, CovRows, CovarianceOptions, GpuTensorHandle, HostTensorView,
11 ImageNormalizeDescriptor, PowerStepEpilogue, ProviderPrecision, ReductionFlavor,
12};
13use runmat_builtins::{NumericDType, Value};
14use runmat_runtime::gather_if_needed;
15use std::sync::OnceLock;
16use std::time::Instant;
17
18struct PreparedInput {
19 handle: GpuTensorHandle,
20 owned: Option<GpuTensorHandle>,
21}
22
23pub struct FusionExecutionRequest<'a> {
24 pub plan: &'a FusionGroupPlan,
25 pub inputs: Vec<Value>,
26}
27
28#[inline]
29fn fusion_timing_enabled() -> bool {
30 static FLAG: OnceLock<bool> = OnceLock::new();
31 *FLAG.get_or_init(|| match std::env::var("RUNMAT_FUSION_TIMING") {
32 Ok(v) => matches!(
33 v.trim().to_ascii_lowercase().as_str(),
34 "1" | "true" | "yes" | "on"
35 ),
36 Err(_) => false,
37 })
38}
39
40struct FusionStageTimer {
41 inner: Option<FusionStageTimerInner>,
42}
43
44struct FusionStageTimerInner {
45 plan_index: usize,
46 kind: &'static str,
47 len: usize,
48 start: Instant,
49 last: Instant,
50 stages: Vec<(&'static str, f64)>,
51}
52
53impl FusionStageTimer {
54 fn new(kind: &'static str, plan_index: usize, len: usize) -> Self {
55 if fusion_timing_enabled() && log::log_enabled!(log::Level::Debug) {
56 let now = Instant::now();
57 Self {
58 inner: Some(FusionStageTimerInner {
59 plan_index,
60 kind,
61 len,
62 start: now,
63 last: now,
64 stages: Vec::new(),
65 }),
66 }
67 } else {
68 Self { inner: None }
69 }
70 }
71
72 fn mark(&mut self, label: &'static str) {
73 if let Some(inner) = &mut self.inner {
74 let now = Instant::now();
75 let delta = now.duration_since(inner.last).as_secs_f64() * 1000.0;
76 inner.stages.push((label, delta));
77 inner.last = now;
78 }
79 }
80
81 fn finish(self) {
82 if let Some(inner) = self.inner {
83 let total = inner.start.elapsed().as_secs_f64() * 1000.0;
84 let summary = inner
85 .stages
86 .into_iter()
87 .map(|(label, ms)| format!("{label}={ms:.3}ms"))
88 .collect::<Vec<_>>()
89 .join(" ");
90 log::debug!(
91 "fusion timing plan={} kind={} len={} {} total={:.3}ms",
92 inner.plan_index,
93 inner.kind,
94 inner.len,
95 summary,
96 total
97 );
98 }
99 }
100}
101
102fn ensure_gpu_tensor(
103 provider: &dyn AccelProvider,
104 value: &Value,
105) -> Result<(GpuTensorHandle, Option<GpuTensorHandle>)> {
106 match value {
107 Value::GpuTensor(handle) => Ok((handle.clone(), None)),
108 Value::Tensor(tensor) => {
109 let view = HostTensorView {
110 data: &tensor.data,
111 shape: &tensor.shape,
112 };
113 let handle = provider.upload(&view)?;
114 Ok((handle.clone(), Some(handle)))
115 }
116 _ => Err(anyhow!("fusion: expected tensor input")),
117 }
118}
119
120fn scalar_upload_dtype(provider: &dyn AccelProvider) -> NumericDType {
121 match provider.precision() {
122 ProviderPrecision::F32 => NumericDType::F32,
123 ProviderPrecision::F64 => NumericDType::F64,
124 }
125}
126
127fn value_to_f64(value: &Value) -> Option<f64> {
128 match value {
129 Value::Num(n) => Some(*n),
130 Value::Int(i) => Some(i.to_f64()),
131 _ => None,
132 }
133}
134
135fn scalar_from_value(value: &Value) -> Result<f64> {
136 if let Some(v) = value_to_f64(value) {
137 return Ok(v);
138 }
139 match value {
140 Value::Tensor(t) => {
141 if t.data.len() == 1 {
142 Ok(t.data[0])
143 } else {
144 Err(anyhow!(
145 "image normalize: expected scalar tensor, got {} elements",
146 t.data.len()
147 ))
148 }
149 }
150 Value::GpuTensor(_) => {
151 let gathered = gather_if_needed(value).map_err(|e| anyhow!("image normalize: {e}"))?;
152 scalar_from_value(&gathered)
153 }
154 _ => Err(anyhow!(
155 "image normalize: expected numeric scalar value, got {:?}",
156 value
157 )),
158 }
159}
160
161fn resolve_image_scalar_value(
162 scalar: &ImageScalar,
163 plan: &FusionGroupPlan,
164 request: &FusionExecutionRequest<'_>,
165) -> Result<f64> {
166 match scalar {
167 ImageScalar::Constant(v) => Ok(*v),
168 ImageScalar::Value(vid) => {
169 if let Some(value) = plan.const_values.get(vid) {
170 return scalar_from_value(value);
171 }
172 if let Some(idx) = plan.inputs.iter().position(|id| *id == *vid) {
173 let runtime_value = request
174 .inputs
175 .get(idx)
176 .ok_or_else(|| anyhow!("image normalize: runtime scalar missing"))?;
177 return scalar_from_value(runtime_value);
178 }
179 Err(anyhow!(
180 "image normalize: scalar input {:?} not materialized in plan",
181 vid
182 ))
183 }
184 }
185}
186
187pub fn execute_elementwise(request: FusionExecutionRequest<'_>) -> Result<Value> {
188 crate::ensure_residency_hooks();
189 if !request.plan.group.kind.is_elementwise() {
190 return Err(anyhow!("unsupported fusion kind"));
191 }
192 let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
193 if !request.plan.kernel.supported {
194 return Err(anyhow!("fusion kernel not supported for this plan"));
195 }
196 if request.inputs.len() != request.plan.inputs.len() {
197 return Err(anyhow!(
198 "fusion input mismatch: expected {}, got {}",
199 request.plan.inputs.len(),
200 request.inputs.len()
201 ));
202 }
203 fn runtime_broadcast_shape(values: &[Value]) -> Option<Vec<usize>> {
205 let mut shapes: Vec<Vec<usize>> = Vec::new();
207 for v in values {
208 match v {
209 Value::GpuTensor(h) => shapes.push(h.shape.clone()),
210 Value::Tensor(t) => shapes.push(t.shape.clone()),
211 Value::Num(_) | Value::Int(_) => shapes.push(Vec::new()),
212 _ => return None, }
214 }
215 let rank = shapes.iter().map(|s| s.len()).max().unwrap_or(0);
216 let mut out = vec![1usize; rank];
217 for shape in shapes {
218 let offset = rank.saturating_sub(shape.len());
219 for (i, &dim) in shape.iter().enumerate() {
220 let j = offset + i;
221 let a = out[j];
222 let b = dim;
223 if a == 1 {
224 out[j] = b.max(1);
225 } else if b == 1 || a == b {
226 } else {
228 return None; }
230 }
231 }
232 Some(out)
233 }
234 let mut output_shape = match &request.plan.group.shape {
236 ShapeInfo::Tensor(dims) if !dims.is_empty() => {
237 let resolved: Vec<usize> = dims.iter().map(|d| d.unwrap_or(1)).collect();
238 resolved
239 }
240 _ => {
241 runtime_broadcast_shape(&request.inputs)
243 .ok_or_else(|| anyhow!("fusion: unknown output shape"))?
244 }
245 };
246 let mut len: usize = output_shape.iter().copied().product();
247 if len == 0 {
248 if let Some(rt_shape) = runtime_broadcast_shape(&request.inputs) {
249 output_shape = rt_shape;
250 len = output_shape.iter().copied().product();
251 }
252 if len == 0 {
253 return Err(anyhow!("fusion: zero-length execution not supported"));
254 }
255 }
256 let mut timer = FusionStageTimer::new("elementwise", request.plan.index, len);
257 let scalar_shape: Vec<usize> = if output_shape.is_empty() {
258 vec![1]
259 } else {
260 vec![1; output_shape.len()]
261 };
262 let mut prepared = Vec::with_capacity(request.inputs.len());
263 let mut temp_scalars: Vec<Vec<f64>> = Vec::new();
264 let scalar_dtype = scalar_upload_dtype(provider);
265 for value in &request.inputs {
266 match value {
267 Value::GpuTensor(handle) => prepared.push(PreparedInput {
268 handle: handle.clone(),
269 owned: None,
270 }),
271 Value::Tensor(t) => {
272 if let Err(msg) = ensure_provider_supports_dtype(provider, t.dtype) {
273 return Err(anyhow!(
274 "fusion: tensor input requires unsupported precision ({msg})"
275 ));
276 }
277 let view = HostTensorView {
278 data: &t.data,
279 shape: &t.shape,
280 };
281 let handle = provider.upload(&view)?;
282 prepared.push(PreparedInput {
283 handle: handle.clone(),
284 owned: Some(handle),
285 });
286 }
287 Value::Num(n) => {
288 if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
289 return Err(anyhow!(
290 "fusion: scalar input requires unsupported precision ({msg})"
291 ));
292 }
293 let scalar = match provider.precision() {
294 ProviderPrecision::F32 => (*n as f32) as f64,
295 ProviderPrecision::F64 => *n,
296 };
297 temp_scalars.push(vec![scalar]);
298 let data = temp_scalars.last().unwrap();
299 let view = HostTensorView {
300 data,
301 shape: &scalar_shape,
302 };
303 let handle = provider.upload(&view)?;
304 prepared.push(PreparedInput {
305 handle: handle.clone(),
306 owned: Some(handle),
307 });
308 }
309 Value::Int(i) => {
310 if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
311 return Err(anyhow!(
312 "fusion: scalar input requires unsupported precision ({msg})"
313 ));
314 }
315 let scalar = match provider.precision() {
316 ProviderPrecision::F32 => (i.to_f64() as f32) as f64,
317 ProviderPrecision::F64 => i.to_f64(),
318 };
319 temp_scalars.push(vec![scalar]);
320 let data = temp_scalars.last().unwrap();
321 let view = HostTensorView {
322 data,
323 shape: &scalar_shape,
324 };
325 let handle = provider.upload(&view)?;
326 prepared.push(PreparedInput {
327 handle: handle.clone(),
328 owned: Some(handle),
329 });
330 }
331 _ => {
332 return Err(anyhow!("fusion: unsupported value type"));
333 }
334 }
335 }
336 timer.mark("prepare_inputs");
337
338 let scalar_ty = match provider.precision() {
339 ProviderPrecision::F32 => "f32",
340 ProviderPrecision::F64 => "f64",
341 };
342 let shader = request
343 .plan
344 .generate_wgsl(scalar_ty)
345 .ok_or_else(|| anyhow!("fusion: WGSL generation failed"))?;
346 timer.mark("generate_wgsl");
347
348 let handles: Vec<GpuTensorHandle> = prepared.iter().map(|p| p.handle.clone()).collect();
349 let output = provider.fused_elementwise(&shader, &handles, &output_shape, len)?;
350 timer.mark("dispatch");
351 fusion_residency::mark(&output);
352
353 for input in prepared {
355 if let Some(handle) = input.owned {
356 let _ = provider.free(&handle);
357 }
358 }
359 timer.mark("cleanup");
360 timer.finish();
361
362 Ok(Value::GpuTensor(output))
363}
364
365pub fn execute_reduction(
366 request: FusionExecutionRequest<'_>,
367 reduce_len: usize,
368 num_slices: usize,
369 workgroup_size: u32,
370) -> Result<Value> {
371 if std::env::var("RUNMAT_DISABLE_FUSED_REDUCTION").is_ok() {
372 return Err(anyhow!("fused reduction disabled by env"));
373 }
374 crate::ensure_residency_hooks();
375 if !request.plan.group.kind.is_reduction() {
376 return Err(anyhow!("unsupported fusion kind"));
377 }
378 let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
379 if !request.plan.kernel.supported {
380 return Err(anyhow!("fusion kernel not supported for this plan"));
381 }
382 if request.inputs.len() != request.plan.inputs.len() {
383 return Err(anyhow!(
384 "fusion input mismatch: expected {}, got {}",
385 request.plan.inputs.len(),
386 request.inputs.len()
387 ));
388 }
389 let len = reduce_len * num_slices;
390 if len == 0 {
391 return Err(anyhow!("fusion: zero-length execution not supported"));
392 }
393 let scalar_shape: Vec<usize> = {
394 let constant_shape = request.plan.constant_shape(len);
395 if constant_shape.is_empty() {
396 vec![1]
397 } else {
398 vec![1; constant_shape.len()]
399 }
400 };
401 let mut timer = FusionStageTimer::new("reduction", request.plan.index, len);
402 let mut prepared = Vec::with_capacity(request.inputs.len());
403 let mut temp_scalars: Vec<Vec<f64>> = Vec::new();
404 let scalar_dtype = scalar_upload_dtype(provider);
405 for value in &request.inputs {
406 match value {
407 Value::GpuTensor(handle) => prepared.push(PreparedInput {
408 handle: handle.clone(),
409 owned: None,
410 }),
411 Value::Tensor(t) => {
412 if let Err(msg) = ensure_provider_supports_dtype(provider, t.dtype) {
413 return Err(anyhow!(
414 "fusion: tensor input requires unsupported precision ({msg})"
415 ));
416 }
417 let view = HostTensorView {
418 data: &t.data,
419 shape: &t.shape,
420 };
421 let handle = provider.upload(&view)?;
422 prepared.push(PreparedInput {
423 handle: handle.clone(),
424 owned: Some(handle),
425 });
426 }
427 Value::Num(n) => {
428 if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
429 return Err(anyhow!(
430 "fusion: scalar input requires unsupported precision ({msg})"
431 ));
432 }
433 let scalar = match provider.precision() {
434 ProviderPrecision::F32 => (*n as f32) as f64,
435 ProviderPrecision::F64 => *n,
436 };
437 temp_scalars.push(vec![scalar]);
438 let data = temp_scalars.last().unwrap();
439 let view = HostTensorView {
440 data,
441 shape: &scalar_shape,
442 };
443 let handle = provider.upload(&view)?;
444 prepared.push(PreparedInput {
445 handle: handle.clone(),
446 owned: Some(handle),
447 });
448 }
449 Value::Int(i) => {
450 if let Err(msg) = ensure_provider_supports_dtype(provider, scalar_dtype) {
451 return Err(anyhow!(
452 "fusion: scalar input requires unsupported precision ({msg})"
453 ));
454 }
455 let scalar = match provider.precision() {
456 ProviderPrecision::F32 => (i.to_f64() as f32) as f64,
457 ProviderPrecision::F64 => i.to_f64(),
458 };
459 temp_scalars.push(vec![scalar]);
460 let data = temp_scalars.last().unwrap();
461 let view = HostTensorView {
462 data,
463 shape: &scalar_shape,
464 };
465 let handle = provider.upload(&view)?;
466 prepared.push(PreparedInput {
467 handle: handle.clone(),
468 owned: Some(handle),
469 });
470 }
471 _ => return Err(anyhow!("fusion: unsupported value type")),
472 }
473 }
474 timer.mark("prepare_inputs");
475
476 let handles: Vec<GpuTensorHandle> = prepared.iter().map(|p| p.handle.clone()).collect();
477 let output_shape = vec![num_slices];
478
479 let scalar_ty = match provider.precision() {
480 ProviderPrecision::F32 => "f32",
481 ProviderPrecision::F64 => "f64",
482 };
483 let shader = request
484 .plan
485 .generate_reduction_wgsl(scalar_ty)
486 .ok_or_else(|| anyhow!("fusion: reduction WGSL generation failed"))?;
487 timer.mark("generate_wgsl");
488 if std::env::var("RUNMAT_DEBUG_DUMP_FUSED_WGSL").is_ok() {
489 println!(
490 "---- fused reduction WGSL ----\n{}\n------------------------------",
491 shader
492 );
493 }
494
495 let mut wg = if workgroup_size == 0 {
496 provider.default_reduction_workgroup_size()
497 } else {
498 workgroup_size
499 };
500 if let Ok(raw) = std::env::var("RUNMAT_FUSED_WG") {
501 if let Ok(val) = raw.trim().parse::<u32>() {
502 if val > 0 {
503 let capped = val.min(provider.default_reduction_workgroup_size());
504 wg = capped.max(1);
505 }
506 }
507 }
508 let flavor = request
509 .plan
510 .reduction_flavor
511 .unwrap_or(ReductionFlavor::Sum);
512 let output = provider.fused_reduction(
513 &shader,
514 &handles,
515 &output_shape,
516 reduce_len,
517 num_slices,
518 wg,
519 flavor,
520 )?;
521 timer.mark("dispatch");
522 fusion_residency::mark(&output);
523
524 for input in prepared {
525 if let Some(handle) = input.owned {
526 let _ = provider.free(&handle);
527 }
528 }
529 timer.mark("cleanup");
530 timer.finish();
531
532 Ok(Value::GpuTensor(output))
533}
534
535pub fn execute_centered_gram(request: FusionExecutionRequest<'_>) -> Result<Value> {
536 crate::ensure_residency_hooks();
537 if request.plan.group.kind != FusionKind::CenteredGram {
538 return Err(anyhow!("unsupported fusion kind"));
539 }
540 let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
541 let (matrix_vid, normalization) = match request.plan.pattern.as_ref() {
542 Some(FusionPattern::CenteredGram {
543 matrix,
544 normalization,
545 }) => (*matrix, *normalization),
546 _ => return Err(anyhow!("centered gram: missing pattern metadata")),
547 };
548
549 let matrix_index = request
550 .plan
551 .inputs
552 .iter()
553 .position(|vid| *vid == matrix_vid)
554 .ok_or_else(|| anyhow!("centered gram: matrix input not found"))?;
555 let matrix_value = request
556 .inputs
557 .get(matrix_index)
558 .ok_or_else(|| anyhow!("centered gram: matrix value missing"))?;
559
560 let (matrix_handle, owned_matrix) = ensure_gpu_tensor(provider, matrix_value)?;
561
562 let options = CovarianceOptions {
563 normalization,
564 rows: CovRows::All,
565 has_weight_vector: false,
566 };
567
568 let output = provider.covariance(&matrix_handle, None, None, &options)?;
569
570 if let Some(temp) = owned_matrix {
571 let _ = provider.free(&temp);
572 }
573
574 fusion_residency::mark(&output);
575 Ok(Value::GpuTensor(output))
576}
577
578pub fn execute_power_step_normalize(request: FusionExecutionRequest<'_>) -> Result<Value> {
579 crate::ensure_residency_hooks();
580 if request.plan.group.kind != FusionKind::PowerStepNormalize {
581 return Err(anyhow!("unsupported fusion kind"));
582 }
583 let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
584 let (lhs_vid, rhs_vid, epsilon) = match request.plan.pattern.as_ref() {
585 Some(FusionPattern::PowerStepNormalize { lhs, rhs, epsilon }) => (*lhs, *rhs, *epsilon),
586 _ => {
587 return Err(anyhow!(
588 "power-step normalization: missing pattern metadata"
589 ))
590 }
591 };
592
593 let lhs_index = request
594 .plan
595 .inputs
596 .iter()
597 .position(|vid| *vid == lhs_vid)
598 .ok_or_else(|| anyhow!("power-step normalization: lhs input not found"))?;
599 let rhs_index = request
600 .plan
601 .inputs
602 .iter()
603 .position(|vid| *vid == rhs_vid)
604 .ok_or_else(|| anyhow!("power-step normalization: rhs input not found"))?;
605
606 let lhs_value = request
607 .inputs
608 .get(lhs_index)
609 .ok_or_else(|| anyhow!("power-step normalization: lhs value missing"))?;
610 let rhs_value = request
611 .inputs
612 .get(rhs_index)
613 .ok_or_else(|| anyhow!("power-step normalization: rhs value missing"))?;
614
615 let (lhs_handle, lhs_owned) = ensure_gpu_tensor(provider, lhs_value)?;
616 let (rhs_handle, rhs_owned) = ensure_gpu_tensor(provider, rhs_value)?;
617
618 let desc = PowerStepEpilogue { epsilon };
619 let output = provider.matmul_power_step(&lhs_handle, &rhs_handle, &desc)?;
620
621 if let Some(temp) = lhs_owned {
622 let _ = provider.free(&temp);
623 }
624 if let Some(temp) = rhs_owned {
625 let _ = provider.free(&temp);
626 }
627
628 fusion_residency::mark(&output);
629 Ok(Value::GpuTensor(output))
630}
631
632pub fn execute_explained_variance(request: FusionExecutionRequest<'_>) -> Result<Value> {
633 crate::ensure_residency_hooks();
634 if request.plan.group.kind != FusionKind::ExplainedVariance {
635 return Err(anyhow!("unsupported fusion kind"));
636 }
637 let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
638 let (q_vid, g_vid) = match request.plan.pattern.as_ref() {
639 Some(FusionPattern::ExplainedVariance { q, g }) => (*q, *g),
640 _ => return Err(anyhow!("explained variance: missing pattern metadata")),
641 };
642
643 let find_value = |vid: ValueId| -> Result<&Value> {
644 if let Some(pos) = request.plan.inputs.iter().position(|id| *id == vid) {
645 request
646 .inputs
647 .get(pos)
648 .ok_or_else(|| anyhow!("explained variance: missing runtime value"))
649 } else {
650 request
651 .plan
652 .const_values
653 .get(&vid)
654 .ok_or_else(|| anyhow!("explained variance: value not materialized"))
655 }
656 };
657
658 let q_value = find_value(q_vid)?;
659 let g_value = find_value(g_vid)?;
660
661 let (mut q_handle, q_owned) = ensure_gpu_tensor(provider, q_value)?;
662 let (g_handle, g_owned) = ensure_gpu_tensor(provider, g_value)?;
663
664 let debug_explained = std::env::var("RUNMAT_DEBUG_EXPLAINED").is_ok();
665 if debug_explained {
666 println!(
667 "[explained] initial Q shape {:?}, G shape {:?}",
668 q_handle.shape, g_handle.shape
669 );
670 if let Ok(info) = provider.download(&q_handle) {
671 println!(
672 "[explained] Q (sample) len={} first=[{:?}]",
673 info.data.len(),
674 info.data.get(0..4)
675 );
676 }
677 }
678
679 let q_shape = q_handle.shape.clone();
680 if q_shape.len() < 2 {
681 return Err(anyhow!("explained variance: Q must be 2-D"));
682 }
683 let q_rows = q_shape[0];
684 let q_cols = q_shape[1];
685 if q_rows == 0 || q_cols == 0 {
686 return Err(anyhow!("explained variance: zero-sized Q"));
687 }
688
689 let g_shape = g_handle.shape.clone();
690 if g_shape.len() < 2 {
691 return Err(anyhow!("explained variance: G must be 2-D"));
692 }
693 if g_shape[0] != q_rows || g_shape[1] != q_rows {
694 return Err(anyhow!("explained variance: G shape mismatch"));
695 }
696
697 let mut tmp = provider.matmul(&q_handle, &g_handle)?;
698 let tmp_shape = tmp.shape.clone();
699 if tmp_shape.len() < 2 {
700 return Err(anyhow!("explained variance: intermediate must be 2-D"));
701 }
702 if tmp_shape[0] != q_cols {
703 return Err(anyhow!(
704 "explained variance: expected intermediate rows {}, got {}",
705 q_cols,
706 tmp_shape[0]
707 ));
708 }
709
710 if debug_explained {
711 println!("[explained] after Q*G tmp shape {:?}", tmp.shape);
712 }
713
714 let mut transposed_shape = q_shape.clone();
718 transposed_shape.swap(0, 1);
719 let q_transposed_view = provider.reshape(&q_handle, &transposed_shape)?;
720
721 tmp = provider.matmul(&q_transposed_view, &g_handle)?;
722
723 if debug_explained {
724 println!(
725 "[explained] after reshape(matmul) tmp shape {:?}",
726 tmp.shape
727 );
728 }
729
730 q_handle = provider.reshape(&q_handle, &q_shape)?;
732
733 let product = provider.matmul(&tmp, &q_handle)?;
734
735 if debug_explained {
736 println!("[explained] product shape {:?}", product.shape);
737 }
738
739 let diag = provider.diag_extract(&product, 0)?;
740 let diag = match diag.shape.as_slice() {
741 [len] => provider.reshape(&diag, &[*len, 1])?,
742 [_len, 1] => diag,
743 _ => diag,
744 };
745
746 if debug_explained {
747 if let Ok(host) = provider.download(&tmp) {
748 println!("tmp runtime shape {:?} data {:?}", host.shape, host.data);
749 }
750 if let Ok(host) = provider.download(&product) {
751 println!("prod runtime shape {:?} data {:?}", host.shape, host.data);
752 }
753 if let Ok(host) = provider.download(&diag) {
754 println!("diag runtime shape {:?} data {:?}", host.shape, host.data);
755 }
756 }
757
758 let _ = provider.free(&tmp);
759 let _ = provider.free(&product);
760 if let Some(temp) = q_owned {
761 let _ = provider.free(&temp);
762 }
763 if let Some(temp) = g_owned {
764 let _ = provider.free(&temp);
765 }
766
767 fusion_residency::mark(&diag);
768 Ok(Value::GpuTensor(diag))
769}
770
771pub fn execute_image_normalize(request: FusionExecutionRequest<'_>) -> Result<Value> {
772 crate::ensure_residency_hooks();
773 if request.plan.group.kind != FusionKind::ImageNormalize {
774 return Err(anyhow!("unsupported fusion kind"));
775 }
776 let provider = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
777 let pattern = match request.plan.pattern.as_ref() {
778 Some(FusionPattern::ImageNormalize(p)) => p,
779 _ => return Err(anyhow!("image normalize: missing pattern metadata")),
780 };
781 if log::log_enabled!(log::Level::Debug) {
782 log::debug!(
783 "execute_image_normalize: plan inputs={:?} stack={:?}",
784 request.plan.inputs,
785 request.plan.stack_pattern
786 );
787 }
788
789 let find_value = |vid: ValueId| -> Result<&Value> {
790 if let Some(pos) = request.plan.inputs.iter().position(|id| *id == vid) {
791 request
792 .inputs
793 .get(pos)
794 .ok_or_else(|| anyhow!("image normalize: runtime value missing"))
795 } else {
796 request
797 .plan
798 .const_values
799 .get(&vid)
800 .ok_or_else(|| anyhow!("image normalize: value {vid:?} not materialized"))
801 }
802 };
803
804 let input_value = find_value(pattern.input)?;
805 let (input_handle, input_owned) = ensure_gpu_tensor(provider, input_value)?;
806 let shape = input_handle.shape.clone();
807 if shape.len() != 3 {
808 return Err(anyhow!(
809 "image normalize: expected 3-D input tensor, got shape {:?}",
810 shape
811 ));
812 }
813 let batch = shape[0];
814 let height = shape[1];
815 let width = shape[2];
816
817 let epsilon = resolve_image_scalar_value(&pattern.epsilon, request.plan, &request)?;
818 let gain = match &pattern.gain {
819 Some(s) => Some(resolve_image_scalar_value(s, request.plan, &request)?),
820 None => None,
821 };
822 let bias = match &pattern.bias {
823 Some(s) => Some(resolve_image_scalar_value(s, request.plan, &request)?),
824 None => None,
825 };
826 let gamma = match &pattern.gamma {
827 Some(s) => Some(resolve_image_scalar_value(s, request.plan, &request)?),
828 None => None,
829 };
830
831 let desc = ImageNormalizeDescriptor {
832 batch,
833 height,
834 width,
835 epsilon,
836 gain,
837 bias,
838 gamma,
839 };
840 if log::log_enabled!(log::Level::Debug) {
841 log::debug!("execute_image_normalize: desc {:?}", desc);
842 }
843
844 let output = provider.image_normalize(&input_handle, &desc)?;
845
846 if let Some(temp) = input_owned {
847 provider.free(&temp).ok();
848 }
849
850 fusion_residency::mark(&output);
851 Ok(Value::GpuTensor(output))
852}
853
854pub fn execute_matmul_epilogue(request: FusionExecutionRequest<'_>) -> Result<Value> {
855 crate::ensure_residency_hooks();
856 if request.plan.group.kind != crate::fusion::FusionKind::MatmulEpilogue {
857 return Err(anyhow!("unsupported fusion kind"));
858 }
859 let prov = provider().ok_or_else(|| anyhow!("no acceleration provider registered"))?;
860
861 let mut prepared: Vec<(graph::ValueId, GpuTensorHandle, Option<GpuTensorHandle>)> = Vec::new();
863 let mut owned: Vec<GpuTensorHandle> = Vec::new();
864 for (idx, vid) in request.plan.inputs.iter().copied().enumerate() {
865 let v = request
866 .inputs
867 .get(idx)
868 .ok_or_else(|| anyhow!("fusion: missing input value"))?;
869 let handle = match v {
870 Value::GpuTensor(h) => h.clone(),
871 Value::Tensor(t) => {
872 let view = HostTensorView {
873 data: &t.data,
874 shape: &t.shape,
875 };
876 let h = prov.upload(&view)?;
877 owned.push(h.clone());
878 h
879 }
880 _ => return Err(anyhow!("matmul_epilogue: unsupported input value kind")),
881 };
882 prepared.push((vid, handle.clone(), None));
883 }
884
885 let find_handle = |vid: graph::ValueId| -> Option<GpuTensorHandle> {
887 prepared
888 .iter()
889 .find_map(|(v, h, _)| if *v == vid { Some(h.clone()) } else { None })
890 };
891
892 let mut cur_out: Option<graph::ValueId> = None;
894 let mut a_vid: Option<graph::ValueId> = None;
895 let mut b_vid: Option<graph::ValueId> = None;
896 for op in &request.plan.operations {
897 if let crate::fusion::FusionOp::Builtin {
898 name,
899 inputs,
900 output,
901 } = op
902 {
903 if name.eq_ignore_ascii_case("mtimes") {
904 a_vid = inputs.first().copied();
905 b_vid = inputs.get(1).copied();
906 cur_out = *output;
907 break;
908 }
909 }
910 }
911 let (a_vid, b_vid, mut cur) = (
912 a_vid.ok_or_else(|| anyhow!("mtimes not found"))?,
913 b_vid.ok_or_else(|| anyhow!("mtimes not found"))?,
914 cur_out.ok_or_else(|| anyhow!("mtimes output missing"))?,
915 );
916
917 let mut alpha: f64 = 1.0;
919 let mut beta: f64 = 0.0;
920 let mut row_scale: Option<GpuTensorHandle> = None;
921 let mut col_scale: Option<GpuTensorHandle> = None;
922 let mut clamp_min: Option<f64> = None;
923 let mut clamp_max: Option<f64> = None;
924 let mut pow_exponent: Option<f64> = None;
925 let mut row_div = false;
926 let mut col_div = false;
927 let mut diag_vid: Option<graph::ValueId> = None;
928
929 for op in &request.plan.operations {
930 match op {
931 crate::fusion::FusionOp::Primitive { op, inputs, output } => {
932 let Some(out) = output else { continue };
933 if !inputs.contains(&cur) {
934 continue;
935 }
936 let other = if inputs[0] == cur {
937 inputs[1]
938 } else {
939 inputs[0]
940 };
941 let const_opt = request.plan.const_values.get(&other);
942 let const_f64 = const_opt.and_then(value_to_f64);
943 match op {
944 crate::graph::PrimitiveOp::Mul | crate::graph::PrimitiveOp::ElemMul => {
945 if let Some(val) = const_f64 {
946 alpha *= val;
947 } else if row_scale.is_none() || col_scale.is_none() {
948 if let Some(h) = find_handle(other) {
949 let r = h.shape.first().copied().unwrap_or(1);
950 let c = h.shape.get(1).copied().unwrap_or(1);
951 if c == 1 && row_scale.is_none() {
952 row_scale = Some(h);
953 row_div = false;
954 } else if r == 1 && col_scale.is_none() {
955 col_scale = Some(h);
956 col_div = false;
957 }
958 }
959 }
960 }
961 crate::graph::PrimitiveOp::Div | crate::graph::PrimitiveOp::ElemDiv => {
962 if let Some(val) = const_f64 {
963 if val != 0.0 {
964 alpha *= 1.0 / val;
965 }
966 } else if row_scale.is_none() || col_scale.is_none() {
967 if let Some(h) = find_handle(other) {
968 let r = h.shape.first().copied().unwrap_or(1);
969 let c = h.shape.get(1).copied().unwrap_or(1);
970 if c == 1 && row_scale.is_none() {
971 row_scale = Some(h);
972 row_div = true;
973 } else if r == 1 && col_scale.is_none() {
974 col_scale = Some(h);
975 col_div = true;
976 }
977 }
978 }
979 }
980 crate::graph::PrimitiveOp::Add => {
981 if let Some(val) = const_f64 {
982 beta += val;
983 }
984 }
985 crate::graph::PrimitiveOp::Sub => {
986 if let Some(val) = const_f64 {
987 beta -= val;
988 }
989 }
990 crate::graph::PrimitiveOp::Pow | crate::graph::PrimitiveOp::ElemPow => {
991 if pow_exponent.is_none() && inputs[0] == cur {
992 pow_exponent = const_f64;
993 }
994 }
995 _ => {}
996 }
997 cur = *out;
998 }
999 crate::fusion::FusionOp::Builtin {
1000 name,
1001 inputs,
1002 output,
1003 } => {
1004 let Some(out) = output else { continue };
1005 if !inputs.contains(&cur) {
1006 continue;
1007 }
1008 let lower = name.to_ascii_lowercase();
1009 if lower == "max" || lower == "min" {
1010 if let Some(&other) = inputs.iter().find(|&&v| v != cur) {
1011 if let Some(val) =
1012 request.plan.const_values.get(&other).and_then(value_to_f64)
1013 {
1014 if lower == "max" {
1015 clamp_min = Some(clamp_min.map_or(val, |prev| prev.max(val)));
1016 } else {
1017 clamp_max = Some(clamp_max.map_or(val, |prev| prev.min(val)));
1018 }
1019 }
1020 }
1021 } else if lower == "pow" && pow_exponent.is_none() {
1022 if let Some(&other) = inputs.iter().find(|&&v| v != cur) {
1023 if let Some(val) =
1024 request.plan.const_values.get(&other).and_then(value_to_f64)
1025 {
1026 pow_exponent = Some(val);
1027 }
1028 }
1029 } else if lower == "diag" {
1030 diag_vid = Some(*out);
1031 }
1032 cur = *out;
1033 }
1034 }
1035 }
1036
1037 let mut ep = runmat_accelerate_api::MatmulEpilogue::noop();
1039 ep.alpha = alpha;
1040 ep.beta = beta;
1041 ep.clamp_min = clamp_min;
1042 ep.clamp_max = clamp_max;
1043 ep.pow_exponent = pow_exponent;
1044 ep.row_op = if row_div {
1045 runmat_accelerate_api::ScaleOp::Divide
1046 } else {
1047 runmat_accelerate_api::ScaleOp::Multiply
1048 };
1049 ep.col_op = if col_div {
1050 runmat_accelerate_api::ScaleOp::Divide
1051 } else {
1052 runmat_accelerate_api::ScaleOp::Multiply
1053 };
1054 if let Some(h) = row_scale.clone() {
1055 ep.row_scale = Some(h);
1056 }
1057 if let Some(h) = col_scale.clone() {
1058 ep.col_scale = Some(h);
1059 }
1060
1061 let a = find_handle(a_vid).ok_or_else(|| anyhow!("missing A"))?;
1062 let b = find_handle(b_vid).ok_or_else(|| anyhow!("missing B"))?;
1063
1064 let mut diag_handle: Option<(graph::ValueId, GpuTensorHandle)> = None;
1065 if let Some(vid) = diag_vid {
1066 let diag_len = std::cmp::min(
1067 a.shape.first().copied().unwrap_or(0),
1068 b.shape.get(1).copied().unwrap_or(0),
1069 );
1070 let mut diag_shape = vec![diag_len, 1];
1071 if diag_len == 0 {
1072 diag_shape[1] = 1;
1073 }
1074 let handle = prov.zeros(&diag_shape)?;
1075 ep.diag_output = Some(handle.clone());
1076 diag_handle = Some((vid, handle));
1077 }
1078
1079 let out = prov.matmul_epilogue(&a, &b, &ep)?;
1080 for h in owned {
1081 let _ = prov.free(&h);
1082 }
1083
1084 if let Some((_, diag)) = &diag_handle {
1085 fusion_residency::mark(diag);
1086 }
1087
1088 let final_vid = request.plan.output.or(Some(cur));
1089 let mut result = out.clone();
1090 let mut free_out = false;
1091 if let Some((vid, diag)) = &diag_handle {
1092 if Some(*vid) == final_vid {
1093 result = diag.clone();
1094 free_out = true;
1095 }
1096 }
1097
1098 if free_out {
1099 let _ = prov.free(&out);
1100 } else {
1101 fusion_residency::mark(&out);
1102 }
1103
1104 fusion_residency::mark(&result);
1105 Ok(Value::GpuTensor(result))
1106}