runmat_accelerate_api/
lib.rs

1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4use std::cell::Cell;
5use std::collections::{HashMap, HashSet};
6use std::sync::atomic::{AtomicU32, Ordering};
7use std::sync::RwLock;
8
9type ResidencyClearFn = fn(&GpuTensorHandle);
10type SequenceThresholdFn = fn() -> Option<usize>;
11
12static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
13static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
14
15static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
16static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
17    Lazy::new(|| RwLock::new(HashMap::new()));
18static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
19    Lazy::new(|| RwLock::new(HashMap::new()));
20
21static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
22    Lazy::new(|| RwLock::new(HashMap::new()));
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub struct TransposeInfo {
26    pub base_rows: usize,
27    pub base_cols: usize,
28}
29
30/// Register a callback used to clear residency tracking when GPU tensors are
31/// gathered back to the host. Backends that maintain residency metadata should
32/// install this hook during initialization.
33pub fn register_residency_clear(handler: ResidencyClearFn) {
34    let _ = RESIDENCY_CLEAR.set(handler);
35}
36
37/// Clear residency metadata for the provided GPU tensor handle, if a backend
38/// has registered a handler via [`register_residency_clear`].
39pub fn clear_residency(handle: &GpuTensorHandle) {
40    if let Some(handler) = RESIDENCY_CLEAR.get() {
41        handler(handle);
42    }
43}
44
45/// Register a callback that exposes the current sequence length threshold
46/// derived from the auto-offload planner. Array constructors can use this hint
47/// to decide when to prefer GPU residency automatically.
48pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
49    let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
50}
51
52/// Query the currently registered sequence threshold hint, if any.
53pub fn sequence_threshold_hint() -> Option<usize> {
54    SEQUENCE_THRESHOLD_PROVIDER
55        .get()
56        .and_then(|provider| provider())
57}
58
59/// Record the precision associated with a GPU tensor handle so host operations can
60/// reconstruct the original dtype when gathering back to the CPU.
61pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
62    if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
63        guard.insert(handle.buffer_id, precision);
64    }
65}
66
67/// Look up the recorded precision for a GPU tensor handle, if any.
68pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
69    HANDLE_PRECISIONS
70        .read()
71        .ok()
72        .and_then(|guard| guard.get(&handle.buffer_id).copied())
73}
74
75/// Clear any recorded precision metadata for a GPU tensor handle.
76pub fn clear_handle_precision(handle: &GpuTensorHandle) {
77    if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
78        guard.remove(&handle.buffer_id);
79    }
80}
81
82/// Annotate a GPU tensor handle as logically-typed (`logical` in MATLAB terms)
83/// or clear the logical flag when `logical` is `false`.
84pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
85    if let Ok(mut guard) = LOGICAL_HANDLES.write() {
86        if logical {
87            guard.insert(handle.buffer_id);
88            if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
89                *hits.entry(handle.buffer_id).or_insert(0) += 1;
90            }
91        } else {
92            guard.remove(&handle.buffer_id);
93            if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
94                hits.remove(&handle.buffer_id);
95            }
96        }
97    }
98}
99
100/// Convenience helper for clearing logical annotations explicitly.
101pub fn clear_handle_logical(handle: &GpuTensorHandle) {
102    set_handle_logical(handle, false);
103}
104
105/// Returns true when the supplied handle has been marked as logical.
106pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
107    LOGICAL_HANDLES
108        .read()
109        .map(|guard| guard.contains(&handle.buffer_id))
110        .unwrap_or(false)
111}
112
113pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
114    LOGICAL_HANDLE_HITS
115        .read()
116        .ok()
117        .and_then(|guard| guard.get(&buffer_id).copied())
118}
119
120pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
121    if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
122        guard.insert(
123            handle.buffer_id,
124            TransposeInfo {
125                base_rows,
126                base_cols,
127            },
128        );
129    }
130}
131
132pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
133    if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
134        guard.remove(&handle.buffer_id);
135    }
136}
137
138pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
139    TRANSPOSED_HANDLES
140        .read()
141        .ok()
142        .and_then(|guard| guard.get(&handle.buffer_id).copied())
143}
144
145pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
146    handle_transpose_info(handle).is_some()
147}
148
149#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
150pub struct GpuTensorHandle {
151    pub shape: Vec<usize>,
152    pub device_id: u32,
153    pub buffer_id: u64,
154}
155
156#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
157pub struct ApiDeviceInfo {
158    pub device_id: u32,
159    pub name: String,
160    pub vendor: String,
161    pub memory_bytes: Option<u64>,
162    pub backend: Option<String>,
163}
164
165#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166pub struct ReduceDimResult {
167    pub values: GpuTensorHandle,
168    pub indices: GpuTensorHandle,
169}
170
171#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
172pub struct ProviderCumminResult {
173    pub values: GpuTensorHandle,
174    pub indices: GpuTensorHandle,
175}
176
177/// Result payload returned by provider-side `cummax` scans.
178///
179/// Alias of [`ProviderCumminResult`] because both operations return the same pair of tensors
180/// (running values and MATLAB-compatible indices).
181pub type ProviderCummaxResult = ProviderCumminResult;
182
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
184pub enum PagefunOp {
185    Mtimes,
186}
187
188#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
189pub struct PagefunRequest {
190    pub op: PagefunOp,
191    pub inputs: Vec<GpuTensorHandle>,
192    pub output_shape: Vec<usize>,
193    pub page_dims: Vec<usize>,
194    pub input_page_dims: Vec<Vec<usize>>,
195}
196
197#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
198pub enum FindDirection {
199    First,
200    Last,
201}
202
203#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
204pub struct ProviderFindResult {
205    pub linear: GpuTensorHandle,
206    pub rows: GpuTensorHandle,
207    pub cols: GpuTensorHandle,
208    pub values: Option<GpuTensorHandle>,
209}
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
212pub struct ProviderBandwidth {
213    pub lower: u32,
214    pub upper: u32,
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
218pub enum ProviderSymmetryKind {
219    Symmetric,
220    Skew,
221}
222
223#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
224pub enum ProviderHermitianKind {
225    Hermitian,
226    Skew,
227}
228
229#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
230pub struct ProviderLuResult {
231    pub combined: GpuTensorHandle,
232    pub lower: GpuTensorHandle,
233    pub upper: GpuTensorHandle,
234    pub perm_matrix: GpuTensorHandle,
235    pub perm_vector: GpuTensorHandle,
236}
237
238#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
239pub struct ProviderCholResult {
240    pub factor: GpuTensorHandle,
241    /// MATLAB-compatible failure index (0 indicates success).
242    pub info: u32,
243}
244
245#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
246pub struct ProviderQrResult {
247    pub q: GpuTensorHandle,
248    pub r: GpuTensorHandle,
249    pub perm_matrix: GpuTensorHandle,
250    pub perm_vector: GpuTensorHandle,
251}
252
253#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
254pub struct ProviderQrPowerIterResult {
255    pub q: GpuTensorHandle,
256    pub r: GpuTensorHandle,
257    pub perm_matrix: GpuTensorHandle,
258    pub perm_vector: GpuTensorHandle,
259}
260
261#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
262pub struct ProviderLinsolveOptions {
263    pub lower: bool,
264    pub upper: bool,
265    pub rectangular: bool,
266    pub transposed: bool,
267    pub conjugate: bool,
268    pub symmetric: bool,
269    pub posdef: bool,
270    pub rcond: Option<f64>,
271}
272
273#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
274pub struct ProviderLinsolveResult {
275    pub solution: GpuTensorHandle,
276    pub reciprocal_condition: f64,
277}
278
279#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
280pub struct ProviderPinvOptions {
281    pub tolerance: Option<f64>,
282}
283
284#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
285pub struct ProviderPolyvalMu {
286    pub mean: f64,
287    pub scale: f64,
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
291pub struct ProviderPolyvalOptions {
292    pub mu: Option<ProviderPolyvalMu>,
293}
294
295#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
296pub struct ProviderInvOptions {}
297
298#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
299pub struct ProviderPolyfitResult {
300    pub coefficients: Vec<f64>,
301    pub r_matrix: Vec<f64>,
302    pub normr: f64,
303    pub df: f64,
304    pub mu: [f64; 2],
305}
306
307/// Numerator/denominator payload returned by provider-backed `polyder` quotient rule.
308#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
309pub struct ProviderPolyderQuotient {
310    pub numerator: GpuTensorHandle,
311    pub denominator: GpuTensorHandle,
312}
313
314/// Supported norm specifications for the `cond` builtin.
315#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
316pub enum ProviderCondNorm {
317    Two,
318    One,
319    Inf,
320    Fro,
321}
322
323/// Supported norm orders for the `norm` builtin.
324#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
325pub enum ProviderNormOrder {
326    Two,
327    One,
328    Inf,
329    NegInf,
330    Zero,
331    Fro,
332    Nuc,
333    P(f64),
334}
335
336#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
337pub struct ProviderEigResult {
338    pub eigenvalues: GpuTensorHandle,
339    pub diagonal: GpuTensorHandle,
340    pub right: GpuTensorHandle,
341    pub left: Option<GpuTensorHandle>,
342}
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
345pub enum ProviderQrPivot {
346    Matrix,
347    Vector,
348}
349
350#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
351pub struct ProviderQrOptions {
352    pub economy: bool,
353    pub pivot: ProviderQrPivot,
354}
355
356impl Default for ProviderQrOptions {
357    fn default() -> Self {
358        Self {
359            economy: false,
360            pivot: ProviderQrPivot::Matrix,
361        }
362    }
363}
364
365#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
366pub enum ProviderPrecision {
367    F32,
368    F64,
369}
370
371#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
372pub enum ReductionTwoPassMode {
373    Auto,
374    ForceOn,
375    ForceOff,
376}
377
378impl ReductionTwoPassMode {
379    pub fn as_str(self) -> &'static str {
380        match self {
381            ReductionTwoPassMode::Auto => "auto",
382            ReductionTwoPassMode::ForceOn => "force_on",
383            ReductionTwoPassMode::ForceOff => "force_off",
384        }
385    }
386}
387
388#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
389pub enum ReductionFlavor {
390    Sum,
391    Mean,
392    CustomScale(f64),
393}
394
395impl ReductionFlavor {
396    pub fn is_mean(self) -> bool {
397        matches!(self, ReductionFlavor::Mean)
398    }
399
400    pub fn scale(self, reduce_len: usize) -> f64 {
401        match self {
402            ReductionFlavor::Sum => 1.0,
403            ReductionFlavor::Mean => {
404                if reduce_len == 0 {
405                    1.0
406                } else {
407                    1.0 / reduce_len as f64
408                }
409            }
410            ReductionFlavor::CustomScale(scale) => scale,
411        }
412    }
413}
414
415/// Normalisation mode for correlation coefficients.
416#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
417pub enum CorrcoefNormalization {
418    Unbiased,
419    Biased,
420}
421
422/// Row-selection strategy for correlation coefficients.
423#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
424pub enum CorrcoefRows {
425    All,
426    Complete,
427    Pairwise,
428}
429
430/// Options controlling provider-backed correlation coefficient computation.
431#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
432pub struct CorrcoefOptions {
433    pub normalization: CorrcoefNormalization,
434    pub rows: CorrcoefRows,
435}
436
437impl Default for CorrcoefOptions {
438    fn default() -> Self {
439        Self {
440            normalization: CorrcoefNormalization::Unbiased,
441            rows: CorrcoefRows::All,
442        }
443    }
444}
445
446/// Normalisation mode used by covariance computations.
447#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
448pub enum CovNormalization {
449    Unbiased,
450    Biased,
451}
452
453/// Row handling strategy for covariance computations.
454#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
455pub enum CovRows {
456    All,
457    OmitRows,
458    PartialRows,
459}
460
461/// Options controlling provider-backed covariance computation.
462#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
463pub struct CovarianceOptions {
464    pub normalization: CovNormalization,
465    pub rows: CovRows,
466    pub has_weight_vector: bool,
467}
468
469impl Default for CovarianceOptions {
470    fn default() -> Self {
471        Self {
472            normalization: CovNormalization::Unbiased,
473            rows: CovRows::All,
474            has_weight_vector: false,
475        }
476    }
477}
478
479/// Normalization strategy used by provider-backed standard deviation reductions.
480#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
481pub enum ProviderStdNormalization {
482    Sample,
483    Population,
484}
485
486/// NaN handling mode for provider-backed reductions.
487#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
488pub enum ProviderNanMode {
489    Include,
490    Omit,
491}
492
493/// Direction used when computing prefix sums on the device.
494#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
495pub enum ProviderScanDirection {
496    Forward,
497    Reverse,
498}
499
500/// Sort direction used by acceleration providers.
501#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
502pub enum SortOrder {
503    Ascend,
504    Descend,
505}
506
507/// Comparison strategy applied during sorting.
508#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
509pub enum SortComparison {
510    Auto,
511    Real,
512    Abs,
513}
514
515/// Host-resident outputs returned by provider-backed sort operations.
516#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
517pub struct SortResult {
518    pub values: HostTensorOwned,
519    pub indices: HostTensorOwned,
520}
521
522#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
523pub struct SortRowsColumnSpec {
524    pub index: usize,
525    pub order: SortOrder,
526}
527
528/// Ordering applied by provider-backed `unique` operations.
529#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
530pub enum UniqueOrder {
531    Sorted,
532    Stable,
533}
534
535/// Occurrence selection for provider-backed `unique` operations.
536#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
537pub enum UniqueOccurrence {
538    First,
539    Last,
540}
541
542/// Options controlling provider-backed `unique` operations.
543#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
544pub struct UniqueOptions {
545    pub rows: bool,
546    pub order: UniqueOrder,
547    pub occurrence: UniqueOccurrence,
548}
549
550/// Host-resident outputs returned by provider-backed `unique` operations.
551#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
552pub struct UniqueResult {
553    pub values: HostTensorOwned,
554    pub ia: HostTensorOwned,
555    pub ic: HostTensorOwned,
556}
557
558/// Ordering applied by provider-backed `union` operations.
559#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
560pub enum UnionOrder {
561    Sorted,
562    Stable,
563}
564
565/// Options controlling provider-backed `union` operations.
566#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
567pub struct UnionOptions {
568    pub rows: bool,
569    pub order: UnionOrder,
570}
571
572/// Host-resident outputs returned by provider-backed `union` operations.
573#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
574pub struct UnionResult {
575    pub values: HostTensorOwned,
576    pub ia: HostTensorOwned,
577    pub ib: HostTensorOwned,
578}
579
580/// Parameterisation of 2-D filters generated by `fspecial`.
581#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
582pub enum FspecialFilter {
583    Average {
584        rows: u32,
585        cols: u32,
586    },
587    Disk {
588        radius: f64,
589        size: u32,
590    },
591    Gaussian {
592        rows: u32,
593        cols: u32,
594        sigma: f64,
595    },
596    Laplacian {
597        alpha: f64,
598    },
599    Log {
600        rows: u32,
601        cols: u32,
602        sigma: f64,
603    },
604    Motion {
605        length: u32,
606        kernel_size: u32,
607        angle_degrees: f64,
608        oversample: u32,
609    },
610    Prewitt,
611    Sobel,
612    Unsharp {
613        alpha: f64,
614    },
615}
616
617/// Request dispatched to acceleration providers for `fspecial` kernels.
618#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
619pub struct FspecialRequest {
620    pub filter: FspecialFilter,
621}
622
623/// Padding strategy used by `imfilter`.
624#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
625pub enum ImfilterPadding {
626    Constant,
627    Replicate,
628    Symmetric,
629    Circular,
630}
631
632/// Output sizing mode used by `imfilter`.
633#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
634pub enum ImfilterShape {
635    Same,
636    Full,
637    Valid,
638}
639
640/// Correlation vs convolution behaviour for `imfilter`.
641#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
642pub enum ImfilterMode {
643    Correlation,
644    Convolution,
645}
646
647/// Options supplied to acceleration providers for `imfilter`.
648#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
649pub struct ImfilterOptions {
650    pub padding: ImfilterPadding,
651    pub constant_value: f64,
652    pub shape: ImfilterShape,
653    pub mode: ImfilterMode,
654}
655
656impl Default for ImfilterOptions {
657    fn default() -> Self {
658        Self {
659            padding: ImfilterPadding::Constant,
660            constant_value: 0.0,
661            shape: ImfilterShape::Same,
662            mode: ImfilterMode::Correlation,
663        }
664    }
665}
666
667/// Ordering applied by provider-backed `setdiff` operations.
668#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
669pub enum SetdiffOrder {
670    Sorted,
671    Stable,
672}
673
674/// Options controlling provider-backed `setdiff` operations.
675#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
676pub struct SetdiffOptions {
677    pub rows: bool,
678    pub order: SetdiffOrder,
679}
680
681/// Host-resident outputs returned by provider-backed `setdiff` operations.
682#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
683pub struct SetdiffResult {
684    pub values: HostTensorOwned,
685    pub ia: HostTensorOwned,
686}
687
688/// Options controlling provider-backed `ismember` operations.
689#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
690pub struct IsMemberOptions {
691    pub rows: bool,
692}
693
694/// Host-resident logical output returned by providers.
695#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
696pub struct HostLogicalOwned {
697    pub data: Vec<u8>,
698    pub shape: Vec<usize>,
699}
700
701/// Host-resident outputs returned by provider-backed `ismember` operations.
702#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
703pub struct IsMemberResult {
704    pub mask: HostLogicalOwned,
705    pub loc: HostTensorOwned,
706}
707
708#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
709pub enum ProviderConvMode {
710    Full,
711    Same,
712    Valid,
713}
714
715#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
716pub enum ProviderConvOrientation {
717    Row,
718    Column,
719}
720
721#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
722pub struct ProviderConv1dOptions {
723    pub mode: ProviderConvMode,
724    pub orientation: ProviderConvOrientation,
725}
726
727#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
728pub struct ProviderIirFilterOptions {
729    /// Zero-based dimension along which filtering should be applied.
730    pub dim: usize,
731    /// Optional initial conditions (state vector) residing on the device.
732    pub zi: Option<GpuTensorHandle>,
733}
734
735#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
736pub struct ProviderIirFilterResult {
737    /// Filtered output tensor, matching the input signal shape.
738    pub output: GpuTensorHandle,
739    /// Final conditions for the filter state (same shape as the requested `zi` layout).
740    pub final_state: Option<GpuTensorHandle>,
741}
742
743#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
744pub struct ProviderMoments2 {
745    pub mean: GpuTensorHandle,
746    pub ex2: GpuTensorHandle,
747}
748
749#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
750pub struct ProviderDispatchStats {
751    /// Number of GPU dispatches recorded for this category.
752    pub count: u64,
753    /// Accumulated wall-clock time of dispatches in nanoseconds (host measured).
754    pub total_wall_time_ns: u64,
755}
756
757#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
758pub struct ProviderTelemetry {
759    pub fused_elementwise: ProviderDispatchStats,
760    pub fused_reduction: ProviderDispatchStats,
761    pub matmul: ProviderDispatchStats,
762    pub upload_bytes: u64,
763    pub download_bytes: u64,
764    pub fusion_cache_hits: u64,
765    pub fusion_cache_misses: u64,
766    pub bind_group_cache_hits: u64,
767    pub bind_group_cache_misses: u64,
768    /// Optional per-layout bind group cache counters (layout tags and their hit/miss counts)
769    pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
770    /// Recent kernel launch metadata (bounded log; newest last)
771    pub kernel_launches: Vec<KernelLaunchTelemetry>,
772}
773
774#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
775pub struct BindGroupLayoutTelemetry {
776    pub tag: String,
777    pub hits: u64,
778    pub misses: u64,
779}
780
781#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
782pub struct KernelAttrTelemetry {
783    pub key: String,
784    pub value: u64,
785}
786
787#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
788pub struct KernelLaunchTelemetry {
789    pub kernel: String,
790    pub precision: Option<String>,
791    pub shape: Vec<KernelAttrTelemetry>,
792    pub tuning: Vec<KernelAttrTelemetry>,
793}
794
795/// Device/provider interface that backends implement and register into the runtime layer
796pub trait AccelProvider: Send + Sync {
797    fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
798    fn download(&self, h: &GpuTensorHandle) -> anyhow::Result<crate::HostTensorOwned>;
799    fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
800    fn device_info(&self) -> String;
801    fn device_id(&self) -> u32 {
802        0
803    }
804
805    /// Gather elements from `source` at the provided zero-based linear `indices`, materialising
806    /// a dense tensor with the specified `output_shape`.
807    fn gather_linear(
808        &self,
809        _source: &GpuTensorHandle,
810        _indices: &[u32],
811        _output_shape: &[usize],
812    ) -> anyhow::Result<GpuTensorHandle> {
813        Err(anyhow::anyhow!("gather_linear not supported by provider"))
814    }
815
816    /// Scatter the contents of `values` into `target` at the provided zero-based linear `indices`.
817    ///
818    /// The provider must ensure `values.len() == indices.len()` and update `target` in place.
819    fn scatter_linear(
820        &self,
821        _target: &GpuTensorHandle,
822        _indices: &[u32],
823        _values: &GpuTensorHandle,
824    ) -> anyhow::Result<()> {
825        Err(anyhow::anyhow!("scatter_linear not supported by provider"))
826    }
827
828    /// Structured device information (optional to override). Default adapts from `device_info()`.
829    fn device_info_struct(&self) -> ApiDeviceInfo {
830        ApiDeviceInfo {
831            device_id: 0,
832            name: self.device_info(),
833            vendor: String::new(),
834            memory_bytes: None,
835            backend: None,
836        }
837    }
838
839    fn precision(&self) -> ProviderPrecision {
840        ProviderPrecision::F64
841    }
842
843    /// Read a single scalar at linear index from a device tensor, returning it as f64.
844    fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
845        Err(anyhow::anyhow!("read_scalar not supported by provider"))
846    }
847
848    /// Allocate a zero-initialised tensor with the provided shape on the device.
849    fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
850        Err(anyhow::anyhow!("zeros not supported by provider"))
851    }
852
853    /// Allocate a one-initialised tensor with the provided shape on the device.
854    fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
855        Err(anyhow::anyhow!("ones not supported by provider"))
856    }
857
858    /// Allocate a zero-initialised tensor matching the prototype tensor.
859    fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
860        self.zeros(&prototype.shape)
861    }
862
863    /// Allocate a tensor filled with a constant value on the device.
864    fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
865        if value == 0.0 {
866            return self.zeros(shape);
867        }
868        if let Ok(base) = self.zeros(shape) {
869            match self.scalar_add(&base, value) {
870                Ok(out) => {
871                    let _ = self.free(&base);
872                    return Ok(out);
873                }
874                Err(_) => {
875                    let _ = self.free(&base);
876                }
877            }
878        }
879        let len: usize = shape.iter().copied().product();
880        let data = vec![value; len];
881        let view = HostTensorView { data: &data, shape };
882        self.upload(&view)
883    }
884
885    /// Allocate a tensor filled with a constant value, matching a prototype's residency.
886    fn fill_like(
887        &self,
888        prototype: &GpuTensorHandle,
889        value: f64,
890    ) -> anyhow::Result<GpuTensorHandle> {
891        if value == 0.0 {
892            return self.zeros_like(prototype);
893        }
894        if let Ok(base) = self.zeros_like(prototype) {
895            match self.scalar_add(&base, value) {
896                Ok(out) => {
897                    let _ = self.free(&base);
898                    return Ok(out);
899                }
900                Err(_) => {
901                    let _ = self.free(&base);
902                }
903            }
904        }
905        self.fill(&prototype.shape, value)
906    }
907
908    /// Allocate a one-initialised tensor matching the prototype tensor.
909    fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
910        self.ones(&prototype.shape)
911    }
912
913    /// Allocate an identity tensor with ones along the leading diagonal of the first two axes.
914    fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
915        Err(anyhow::anyhow!("eye not supported by provider"))
916    }
917
918    /// Allocate an identity tensor matching the prototype tensor's shape.
919    fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
920        self.eye(&prototype.shape)
921    }
922
923    /// Construct MATLAB-style coordinate grids from axis vectors.
924    fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
925        Err(anyhow::anyhow!("meshgrid not supported by provider"))
926    }
927
928    /// Construct a diagonal matrix from a vector-like tensor. `offset` matches MATLAB semantics.
929    fn diag_from_vector(
930        &self,
931        _vector: &GpuTensorHandle,
932        _offset: isize,
933    ) -> anyhow::Result<GpuTensorHandle> {
934        Err(anyhow::anyhow!(
935            "diag_from_vector not supported by provider"
936        ))
937    }
938
939    /// Extract a diagonal from a matrix-like tensor. The result is always a column vector.
940    fn diag_extract(
941        &self,
942        _matrix: &GpuTensorHandle,
943        _offset: isize,
944    ) -> anyhow::Result<GpuTensorHandle> {
945        Err(anyhow::anyhow!("diag_extract not supported by provider"))
946    }
947
948    /// Apply a lower-triangular mask to the first two dimensions of a tensor.
949    fn tril(&self, _matrix: &GpuTensorHandle, _offset: isize) -> anyhow::Result<GpuTensorHandle> {
950        Err(anyhow!("tril not supported by provider"))
951    }
952
953    /// Apply an upper-triangular mask to the first two dimensions of a tensor.
954    fn triu(&self, _matrix: &GpuTensorHandle, _offset: isize) -> anyhow::Result<GpuTensorHandle> {
955        Err(anyhow!("triu not supported by provider"))
956    }
957
958    /// Evaluate a polynomial expressed by `coefficients` at each element in `points`.
959    fn polyval(
960        &self,
961        _coefficients: &GpuTensorHandle,
962        _points: &GpuTensorHandle,
963        _options: &ProviderPolyvalOptions,
964    ) -> anyhow::Result<GpuTensorHandle> {
965        Err(anyhow::anyhow!("polyval not supported by provider"))
966    }
967
968    /// Fit a polynomial of degree `degree` to `(x, y)` samples. Optional weights must match `x`.
969    fn polyfit(
970        &self,
971        _x: &GpuTensorHandle,
972        _y: &GpuTensorHandle,
973        _degree: usize,
974        _weights: Option<&GpuTensorHandle>,
975    ) -> anyhow::Result<ProviderPolyfitResult> {
976        Err(anyhow::anyhow!("polyfit not supported by provider"))
977    }
978
979    /// Differentiate a polynomial represented as a vector of coefficients.
980    fn polyder_single(&self, _polynomial: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
981        Err(anyhow::anyhow!("polyder_single not supported by provider"))
982    }
983
984    /// Apply the product rule to polynomials `p` and `q`.
985    fn polyder_product(
986        &self,
987        _p: &GpuTensorHandle,
988        _q: &GpuTensorHandle,
989    ) -> anyhow::Result<GpuTensorHandle> {
990        Err(anyhow::anyhow!("polyder_product not supported by provider"))
991    }
992
993    /// Apply the quotient rule to polynomials `u` and `v`.
994    fn polyder_quotient(
995        &self,
996        _u: &GpuTensorHandle,
997        _v: &GpuTensorHandle,
998    ) -> anyhow::Result<ProviderPolyderQuotient> {
999        Err(anyhow::anyhow!(
1000            "polyder_quotient not supported by provider"
1001        ))
1002    }
1003
1004    /// Integrate a polynomial represented as a vector of coefficients and append a constant term.
1005    fn polyint(
1006        &self,
1007        _polynomial: &GpuTensorHandle,
1008        _constant: f64,
1009    ) -> anyhow::Result<GpuTensorHandle> {
1010        Err(anyhow::anyhow!("polyint not supported by provider"))
1011    }
1012
1013    /// Allocate a tensor filled with random values drawn from U(0, 1).
1014    fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1015        Err(anyhow::anyhow!("random_uniform not supported by provider"))
1016    }
1017
1018    /// Allocate a tensor filled with random values matching the prototype shape.
1019    fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1020        self.random_uniform(&prototype.shape)
1021    }
1022
1023    /// Allocate a tensor filled with standard normal (mean 0, stddev 1) random values.
1024    fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1025        Err(anyhow::anyhow!("random_normal not supported by provider"))
1026    }
1027
1028    /// Allocate a tensor of standard normal values matching a prototype's shape.
1029    fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1030        self.random_normal(&prototype.shape)
1031    }
1032
1033    fn stochastic_evolution(
1034        &self,
1035        _state: &GpuTensorHandle,
1036        _drift: f64,
1037        _scale: f64,
1038        _steps: u32,
1039    ) -> anyhow::Result<GpuTensorHandle> {
1040        Err(anyhow::anyhow!(
1041            "stochastic_evolution not supported by provider"
1042        ))
1043    }
1044
1045    /// Set the provider RNG state to align with the host RNG.
1046    fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1047        Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1048    }
1049
1050    /// Generate a 2-D correlation kernel matching MATLAB's `fspecial` builtin.
1051    fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1052        Err(anyhow::anyhow!("fspecial not supported by provider"))
1053    }
1054
1055    /// Apply an N-D correlation/convolution with padding semantics matching MATLAB's `imfilter`.
1056    fn imfilter(
1057        &self,
1058        _image: &GpuTensorHandle,
1059        _kernel: &GpuTensorHandle,
1060        _options: &ImfilterOptions,
1061    ) -> anyhow::Result<GpuTensorHandle> {
1062        Err(anyhow::anyhow!("imfilter not supported by provider"))
1063    }
1064
1065    /// Allocate a tensor filled with random integers over an inclusive range.
1066    fn random_integer_range(
1067        &self,
1068        _lower: i64,
1069        _upper: i64,
1070        _shape: &[usize],
1071    ) -> anyhow::Result<GpuTensorHandle> {
1072        Err(anyhow::anyhow!(
1073            "random_integer_range not supported by provider"
1074        ))
1075    }
1076
1077    /// Allocate a random integer tensor matching the prototype shape.
1078    fn random_integer_like(
1079        &self,
1080        prototype: &GpuTensorHandle,
1081        lower: i64,
1082        upper: i64,
1083    ) -> anyhow::Result<GpuTensorHandle> {
1084        self.random_integer_range(lower, upper, &prototype.shape)
1085    }
1086
1087    /// Allocate a random permutation of 1..=n, returning the first k elements.
1088    fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1089        Err(anyhow!("random_permutation not supported by provider"))
1090    }
1091
1092    /// Allocate a random permutation matching the prototype residency.
1093    fn random_permutation_like(
1094        &self,
1095        _prototype: &GpuTensorHandle,
1096        n: usize,
1097        k: usize,
1098    ) -> anyhow::Result<GpuTensorHandle> {
1099        self.random_permutation(n, k)
1100    }
1101
1102    /// Compute a covariance matrix across the columns of `matrix`.
1103    fn covariance(
1104        &self,
1105        _matrix: &GpuTensorHandle,
1106        _second: Option<&GpuTensorHandle>,
1107        _weights: Option<&GpuTensorHandle>,
1108        _options: &CovarianceOptions,
1109    ) -> anyhow::Result<GpuTensorHandle> {
1110        Err(anyhow::anyhow!("covariance not supported by provider"))
1111    }
1112
1113    /// Compute a correlation coefficient matrix across the columns of `matrix`.
1114    fn corrcoef(
1115        &self,
1116        _matrix: &GpuTensorHandle,
1117        _options: &CorrcoefOptions,
1118    ) -> anyhow::Result<GpuTensorHandle> {
1119        Err(anyhow::anyhow!("corrcoef not supported by provider"))
1120    }
1121
1122    // Optional operator hooks (default to unsupported)
1123    fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1124        Err(anyhow::anyhow!("linspace not supported by provider"))
1125    }
1126    fn elem_add(
1127        &self,
1128        _a: &GpuTensorHandle,
1129        _b: &GpuTensorHandle,
1130    ) -> anyhow::Result<GpuTensorHandle> {
1131        Err(anyhow::anyhow!("elem_add not supported by provider"))
1132    }
1133    fn elem_mul(
1134        &self,
1135        _a: &GpuTensorHandle,
1136        _b: &GpuTensorHandle,
1137    ) -> anyhow::Result<GpuTensorHandle> {
1138        Err(anyhow::anyhow!("elem_mul not supported by provider"))
1139    }
1140    fn elem_max(
1141        &self,
1142        _a: &GpuTensorHandle,
1143        _b: &GpuTensorHandle,
1144    ) -> anyhow::Result<GpuTensorHandle> {
1145        Err(anyhow::anyhow!("elem_max not supported by provider"))
1146    }
1147    fn elem_min(
1148        &self,
1149        _a: &GpuTensorHandle,
1150        _b: &GpuTensorHandle,
1151    ) -> anyhow::Result<GpuTensorHandle> {
1152        Err(anyhow::anyhow!("elem_min not supported by provider"))
1153    }
1154    fn elem_sub(
1155        &self,
1156        _a: &GpuTensorHandle,
1157        _b: &GpuTensorHandle,
1158    ) -> anyhow::Result<GpuTensorHandle> {
1159        Err(anyhow::anyhow!("elem_sub not supported by provider"))
1160    }
1161    fn elem_div(
1162        &self,
1163        _a: &GpuTensorHandle,
1164        _b: &GpuTensorHandle,
1165    ) -> anyhow::Result<GpuTensorHandle> {
1166        Err(anyhow::anyhow!("elem_div not supported by provider"))
1167    }
1168    fn elem_pow(
1169        &self,
1170        _a: &GpuTensorHandle,
1171        _b: &GpuTensorHandle,
1172    ) -> anyhow::Result<GpuTensorHandle> {
1173        Err(anyhow::anyhow!("elem_pow not supported by provider"))
1174    }
1175
1176    fn elem_hypot(
1177        &self,
1178        _a: &GpuTensorHandle,
1179        _b: &GpuTensorHandle,
1180    ) -> anyhow::Result<GpuTensorHandle> {
1181        Err(anyhow::anyhow!("elem_hypot not supported by provider"))
1182    }
1183    fn elem_ge(
1184        &self,
1185        _a: &GpuTensorHandle,
1186        _b: &GpuTensorHandle,
1187    ) -> anyhow::Result<GpuTensorHandle> {
1188        Err(anyhow::anyhow!("elem_ge not supported by provider"))
1189    }
1190    fn elem_le(
1191        &self,
1192        _a: &GpuTensorHandle,
1193        _b: &GpuTensorHandle,
1194    ) -> anyhow::Result<GpuTensorHandle> {
1195        Err(anyhow::anyhow!("elem_le not supported by provider"))
1196    }
1197    fn elem_lt(
1198        &self,
1199        _a: &GpuTensorHandle,
1200        _b: &GpuTensorHandle,
1201    ) -> anyhow::Result<GpuTensorHandle> {
1202        Err(anyhow::anyhow!("elem_lt not supported by provider"))
1203    }
1204    fn elem_gt(
1205        &self,
1206        _a: &GpuTensorHandle,
1207        _b: &GpuTensorHandle,
1208    ) -> anyhow::Result<GpuTensorHandle> {
1209        Err(anyhow::anyhow!("elem_gt not supported by provider"))
1210    }
1211    fn elem_eq(
1212        &self,
1213        _a: &GpuTensorHandle,
1214        _b: &GpuTensorHandle,
1215    ) -> anyhow::Result<GpuTensorHandle> {
1216        Err(anyhow::anyhow!("elem_eq not supported by provider"))
1217    }
1218    fn elem_ne(
1219        &self,
1220        _a: &GpuTensorHandle,
1221        _b: &GpuTensorHandle,
1222    ) -> anyhow::Result<GpuTensorHandle> {
1223        Err(anyhow::anyhow!("elem_ne not supported by provider"))
1224    }
1225    fn logical_and(
1226        &self,
1227        _a: &GpuTensorHandle,
1228        _b: &GpuTensorHandle,
1229    ) -> anyhow::Result<GpuTensorHandle> {
1230        Err(anyhow::anyhow!("logical_and not supported by provider"))
1231    }
1232    fn logical_or(
1233        &self,
1234        _a: &GpuTensorHandle,
1235        _b: &GpuTensorHandle,
1236    ) -> anyhow::Result<GpuTensorHandle> {
1237        Err(anyhow::anyhow!("logical_or not supported by provider"))
1238    }
1239    fn logical_xor(
1240        &self,
1241        _a: &GpuTensorHandle,
1242        _b: &GpuTensorHandle,
1243    ) -> anyhow::Result<GpuTensorHandle> {
1244        Err(anyhow::anyhow!("logical_xor not supported by provider"))
1245    }
1246    fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1247        Err(anyhow::anyhow!("logical_not not supported by provider"))
1248    }
1249    fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1250        Ok(handle_is_logical(a))
1251    }
1252    fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1253        Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1254    }
1255    fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1256        Err(anyhow::anyhow!(
1257            "logical_isfinite not supported by provider"
1258        ))
1259    }
1260    fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1261        Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1262    }
1263    fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1264        Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1265    }
1266    fn elem_atan2(
1267        &self,
1268        _y: &GpuTensorHandle,
1269        _x: &GpuTensorHandle,
1270    ) -> anyhow::Result<GpuTensorHandle> {
1271        Err(anyhow::anyhow!("elem_atan2 not supported by provider"))
1272    }
1273    // Unary elementwise operations (optional)
1274    fn unary_sin(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1275        Err(anyhow::anyhow!("unary_sin not supported by provider"))
1276    }
1277    fn unary_gamma(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1278        Err(anyhow::anyhow!("unary_gamma not supported by provider"))
1279    }
1280    fn unary_factorial(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1281        Err(anyhow::anyhow!("unary_factorial not supported by provider"))
1282    }
1283    fn unary_asinh(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1284        Err(anyhow::anyhow!("unary_asinh not supported by provider"))
1285    }
1286    fn unary_sinh(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1287        Err(anyhow::anyhow!("unary_sinh not supported by provider"))
1288    }
1289    fn unary_cosh(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1290        Err(anyhow::anyhow!("unary_cosh not supported by provider"))
1291    }
1292    fn unary_asin(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1293        Err(anyhow::anyhow!("unary_asin not supported by provider"))
1294    }
1295    fn unary_acos(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1296        Err(anyhow::anyhow!("unary_acos not supported by provider"))
1297    }
1298    fn unary_acosh(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1299        Err(anyhow::anyhow!("unary_acosh not supported by provider"))
1300    }
1301    fn unary_tan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1302        Err(anyhow::anyhow!("unary_tan not supported by provider"))
1303    }
1304    fn unary_tanh(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1305        Err(anyhow::anyhow!("unary_tanh not supported by provider"))
1306    }
1307    fn unary_atan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1308        Err(anyhow::anyhow!("unary_atan not supported by provider"))
1309    }
1310    fn unary_atanh(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1311        Err(anyhow::anyhow!("unary_atanh not supported by provider"))
1312    }
1313    fn unary_ceil(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1314        Err(anyhow::anyhow!("unary_ceil not supported by provider"))
1315    }
1316    fn unary_floor(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1317        Err(anyhow::anyhow!("unary_floor not supported by provider"))
1318    }
1319    fn unary_round(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1320        Err(anyhow::anyhow!("unary_round not supported by provider"))
1321    }
1322    fn unary_fix(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1323        Err(anyhow::anyhow!("unary_fix not supported by provider"))
1324    }
1325    fn unary_cos(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1326        Err(anyhow::anyhow!("unary_cos not supported by provider"))
1327    }
1328    fn unary_angle(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1329        Err(anyhow::anyhow!("unary_angle not supported by provider"))
1330    }
1331    fn unary_imag(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1332        Err(anyhow::anyhow!("unary_imag not supported by provider"))
1333    }
1334    fn unary_real(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1335        Err(anyhow::anyhow!("unary_real not supported by provider"))
1336    }
1337    fn unary_conj(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1338        Err(anyhow::anyhow!("unary_conj not supported by provider"))
1339    }
1340    fn unary_abs(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1341        Err(anyhow::anyhow!("unary_abs not supported by provider"))
1342    }
1343    fn unary_sign(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1344        Err(anyhow::anyhow!("unary_sign not supported by provider"))
1345    }
1346    fn unary_exp(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1347        Err(anyhow::anyhow!("unary_exp not supported by provider"))
1348    }
1349    fn unary_expm1(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1350        Err(anyhow::anyhow!("unary_expm1 not supported by provider"))
1351    }
1352    fn unary_log(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1353        Err(anyhow::anyhow!("unary_log not supported by provider"))
1354    }
1355    fn unary_log2(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1356        Err(anyhow::anyhow!("unary_log2 not supported by provider"))
1357    }
1358    fn unary_log10(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1359        Err(anyhow::anyhow!("unary_log10 not supported by provider"))
1360    }
1361    fn unary_log1p(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1362        Err(anyhow::anyhow!("unary_log1p not supported by provider"))
1363    }
1364    fn unary_sqrt(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1365        Err(anyhow::anyhow!("unary_sqrt not supported by provider"))
1366    }
1367    fn unary_double(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1368        Err(anyhow::anyhow!("unary_double not supported by provider"))
1369    }
1370    fn unary_single(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1371        Err(anyhow::anyhow!("unary_single not supported by provider"))
1372    }
1373    fn unary_pow2(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1374        Err(anyhow::anyhow!("unary_pow2 not supported by provider"))
1375    }
1376    fn pow2_scale(
1377        &self,
1378        _mantissa: &GpuTensorHandle,
1379        _exponent: &GpuTensorHandle,
1380    ) -> anyhow::Result<GpuTensorHandle> {
1381        Err(anyhow::anyhow!("pow2_scale not supported by provider"))
1382    }
1383    // Left-scalar operations (broadcast with scalar on the left)
1384    fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1385        Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
1386    }
1387    fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1388        Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
1389    }
1390    // Scalar operations: apply op with scalar right-hand side (broadcast over a)
1391    fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1392        Err(anyhow::anyhow!("scalar_add not supported by provider"))
1393    }
1394    fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1395        Err(anyhow::anyhow!("scalar_sub not supported by provider"))
1396    }
1397    fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1398        Err(anyhow::anyhow!("scalar_mul not supported by provider"))
1399    }
1400    fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1401        Err(anyhow::anyhow!("scalar_max not supported by provider"))
1402    }
1403    fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1404        Err(anyhow::anyhow!("scalar_min not supported by provider"))
1405    }
1406    fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1407        Err(anyhow::anyhow!("scalar_div not supported by provider"))
1408    }
1409    fn sort_dim(
1410        &self,
1411        _a: &GpuTensorHandle,
1412        _dim: usize,
1413        _order: SortOrder,
1414        _comparison: SortComparison,
1415    ) -> anyhow::Result<SortResult> {
1416        Err(anyhow::anyhow!("sort_dim not supported by provider"))
1417    }
1418    fn sort_rows(
1419        &self,
1420        _a: &GpuTensorHandle,
1421        _columns: &[SortRowsColumnSpec],
1422        _comparison: SortComparison,
1423    ) -> anyhow::Result<SortResult> {
1424        Err(anyhow::anyhow!("sort_rows not supported by provider"))
1425    }
1426    fn matmul(
1427        &self,
1428        _a: &GpuTensorHandle,
1429        _b: &GpuTensorHandle,
1430    ) -> anyhow::Result<GpuTensorHandle> {
1431        Err(anyhow::anyhow!("matmul not supported by provider"))
1432    }
1433
1434    fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1435        Err(anyhow::anyhow!("syrk not supported by provider"))
1436    }
1437    fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
1438        Err(anyhow::anyhow!("pagefun not supported by provider"))
1439    }
1440
1441    /// Optional: matrix multiplication with an epilogue applied before store.
1442    ///
1443    /// The default implementation falls back to `matmul` when the epilogue is effectively a no-op
1444    /// (alpha=1, beta=0, no row/col scales), and otherwise returns `Err`.
1445    fn matmul_epilogue(
1446        &self,
1447        a: &GpuTensorHandle,
1448        b: &GpuTensorHandle,
1449        epilogue: &MatmulEpilogue,
1450    ) -> anyhow::Result<GpuTensorHandle> {
1451        if epilogue.is_noop() {
1452            return self.matmul(a, b);
1453        }
1454        Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
1455    }
1456    fn image_normalize(
1457        &self,
1458        _input: &GpuTensorHandle,
1459        _desc: &ImageNormalizeDescriptor,
1460    ) -> anyhow::Result<GpuTensorHandle> {
1461        Err(anyhow::anyhow!(
1462            "image_normalize fusion not supported by provider"
1463        ))
1464    }
1465    fn matmul_power_step(
1466        &self,
1467        _lhs: &GpuTensorHandle,
1468        _rhs: &GpuTensorHandle,
1469        _epilogue: &PowerStepEpilogue,
1470    ) -> anyhow::Result<GpuTensorHandle> {
1471        Err(anyhow::anyhow!(
1472            "matmul_power_step normalization not supported by provider"
1473        ))
1474    }
1475    fn linsolve(
1476        &self,
1477        _lhs: &GpuTensorHandle,
1478        _rhs: &GpuTensorHandle,
1479        _options: &ProviderLinsolveOptions,
1480    ) -> anyhow::Result<ProviderLinsolveResult> {
1481        Err(anyhow::anyhow!("linsolve not supported by provider"))
1482    }
1483    fn inv(
1484        &self,
1485        _matrix: &GpuTensorHandle,
1486        _options: ProviderInvOptions,
1487    ) -> anyhow::Result<GpuTensorHandle> {
1488        Err(anyhow::anyhow!("inv not supported by provider"))
1489    }
1490    fn pinv(
1491        &self,
1492        _matrix: &GpuTensorHandle,
1493        _options: ProviderPinvOptions,
1494    ) -> anyhow::Result<GpuTensorHandle> {
1495        Err(anyhow::anyhow!("pinv not supported by provider"))
1496    }
1497    fn cond(
1498        &self,
1499        _matrix: &GpuTensorHandle,
1500        _norm: ProviderCondNorm,
1501    ) -> anyhow::Result<GpuTensorHandle> {
1502        Err(anyhow::anyhow!("cond not supported by provider"))
1503    }
1504    fn norm(
1505        &self,
1506        _tensor: &GpuTensorHandle,
1507        _order: ProviderNormOrder,
1508    ) -> anyhow::Result<GpuTensorHandle> {
1509        Err(anyhow::anyhow!("norm not supported by provider"))
1510    }
1511    fn rank(
1512        &self,
1513        _matrix: &GpuTensorHandle,
1514        _tolerance: Option<f64>,
1515    ) -> anyhow::Result<GpuTensorHandle> {
1516        Err(anyhow::anyhow!("rank not supported by provider"))
1517    }
1518    fn rcond(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1519        Err(anyhow::anyhow!("rcond not supported by provider"))
1520    }
1521    fn mldivide(
1522        &self,
1523        _lhs: &GpuTensorHandle,
1524        _rhs: &GpuTensorHandle,
1525    ) -> anyhow::Result<GpuTensorHandle> {
1526        Err(anyhow::anyhow!("mldivide not supported by provider"))
1527    }
1528    fn mrdivide(
1529        &self,
1530        _lhs: &GpuTensorHandle,
1531        _rhs: &GpuTensorHandle,
1532    ) -> anyhow::Result<GpuTensorHandle> {
1533        Err(anyhow::anyhow!("mrdivide not supported by provider"))
1534    }
1535    fn eig(&self, _a: &GpuTensorHandle, _compute_left: bool) -> anyhow::Result<ProviderEigResult> {
1536        Err(anyhow::anyhow!("eig not supported by provider"))
1537    }
1538    fn lu(&self, _a: &GpuTensorHandle) -> anyhow::Result<ProviderLuResult> {
1539        Err(anyhow::anyhow!("lu not supported by provider"))
1540    }
1541
1542    fn chol(&self, _a: &GpuTensorHandle, _lower: bool) -> anyhow::Result<ProviderCholResult> {
1543        Err(anyhow::anyhow!("chol not supported by provider"))
1544    }
1545    fn qr(
1546        &self,
1547        _a: &GpuTensorHandle,
1548        _options: ProviderQrOptions,
1549    ) -> anyhow::Result<ProviderQrResult> {
1550        Err(anyhow::anyhow!("qr not supported by provider"))
1551    }
1552    fn take_matmul_sources(
1553        &self,
1554        _product: &GpuTensorHandle,
1555    ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
1556        None
1557    }
1558    fn qr_power_iter(
1559        &self,
1560        product: &GpuTensorHandle,
1561        _product_lhs: Option<&GpuTensorHandle>,
1562        q_handle: &GpuTensorHandle,
1563        options: &ProviderQrOptions,
1564    ) -> anyhow::Result<Option<ProviderQrPowerIterResult>> {
1565        let _ = (product, q_handle, options);
1566        Ok(None)
1567    }
1568    fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1569        Err(anyhow::anyhow!("transpose not supported by provider"))
1570    }
1571    fn conv1d(
1572        &self,
1573        _signal: &GpuTensorHandle,
1574        _kernel: &GpuTensorHandle,
1575        _options: ProviderConv1dOptions,
1576    ) -> anyhow::Result<GpuTensorHandle> {
1577        Err(anyhow::anyhow!("conv1d not supported by provider"))
1578    }
1579    fn conv2d(
1580        &self,
1581        _signal: &GpuTensorHandle,
1582        _kernel: &GpuTensorHandle,
1583        _mode: ProviderConvMode,
1584    ) -> anyhow::Result<GpuTensorHandle> {
1585        Err(anyhow::anyhow!("conv2d not supported by provider"))
1586    }
1587    fn iir_filter(
1588        &self,
1589        _b: &GpuTensorHandle,
1590        _a: &GpuTensorHandle,
1591        _x: &GpuTensorHandle,
1592        _options: ProviderIirFilterOptions,
1593    ) -> anyhow::Result<ProviderIirFilterResult> {
1594        Err(anyhow::anyhow!("iir_filter not supported by provider"))
1595    }
1596    /// Reorder tensor dimensions according to `order`, expressed as zero-based indices.
1597    fn permute(
1598        &self,
1599        _handle: &GpuTensorHandle,
1600        _order: &[usize],
1601    ) -> anyhow::Result<GpuTensorHandle> {
1602        Err(anyhow::anyhow!("permute not supported by provider"))
1603    }
1604    fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1605        Err(anyhow::anyhow!("flip not supported by provider"))
1606    }
1607    fn circshift(
1608        &self,
1609        _handle: &GpuTensorHandle,
1610        _shifts: &[isize],
1611    ) -> anyhow::Result<GpuTensorHandle> {
1612        Err(anyhow::anyhow!("circshift not supported by provider"))
1613    }
1614    fn diff_dim(
1615        &self,
1616        _handle: &GpuTensorHandle,
1617        _order: usize,
1618        _dim: usize,
1619    ) -> anyhow::Result<GpuTensorHandle> {
1620        Err(anyhow::anyhow!("diff_dim not supported by provider"))
1621    }
1622    /// Perform an in-place FFT along a zero-based dimension, optionally padding/truncating to `len`.
1623    fn fft_dim(
1624        &self,
1625        _handle: &GpuTensorHandle,
1626        _len: Option<usize>,
1627        _dim: usize,
1628    ) -> anyhow::Result<GpuTensorHandle> {
1629        Err(anyhow::anyhow!("fft_dim not supported by provider"))
1630    }
1631    fn ifft_dim(
1632        &self,
1633        _handle: &GpuTensorHandle,
1634        _len: Option<usize>,
1635        _dim: usize,
1636    ) -> anyhow::Result<GpuTensorHandle> {
1637        Err(anyhow::anyhow!("ifft_dim not supported by provider"))
1638    }
1639    fn unique(
1640        &self,
1641        _handle: &GpuTensorHandle,
1642        _options: &UniqueOptions,
1643    ) -> anyhow::Result<UniqueResult> {
1644        Err(anyhow::anyhow!("unique not supported by provider"))
1645    }
1646    fn union(
1647        &self,
1648        _a: &GpuTensorHandle,
1649        _b: &GpuTensorHandle,
1650        _options: &UnionOptions,
1651    ) -> anyhow::Result<UnionResult> {
1652        Err(anyhow::anyhow!("union not supported by provider"))
1653    }
1654    fn setdiff(
1655        &self,
1656        _a: &GpuTensorHandle,
1657        _b: &GpuTensorHandle,
1658        _options: &SetdiffOptions,
1659    ) -> anyhow::Result<SetdiffResult> {
1660        Err(anyhow::anyhow!("setdiff not supported by provider"))
1661    }
1662    fn ismember(
1663        &self,
1664        _a: &GpuTensorHandle,
1665        _b: &GpuTensorHandle,
1666        _options: &IsMemberOptions,
1667    ) -> anyhow::Result<IsMemberResult> {
1668        Err(anyhow::anyhow!("ismember not supported by provider"))
1669    }
1670    fn reshape(
1671        &self,
1672        handle: &GpuTensorHandle,
1673        new_shape: &[usize],
1674    ) -> anyhow::Result<GpuTensorHandle> {
1675        let mut updated = handle.clone();
1676        updated.shape = new_shape.to_vec();
1677        Ok(updated)
1678    }
1679    /// Concatenate the provided tensors along the 1-based dimension `dim`.
1680    fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
1681        Err(anyhow::anyhow!("cat not supported by provider"))
1682    }
1683    fn repmat(
1684        &self,
1685        _handle: &GpuTensorHandle,
1686        _reps: &[usize],
1687    ) -> anyhow::Result<GpuTensorHandle> {
1688        Err(anyhow::anyhow!("repmat not supported by provider"))
1689    }
1690    /// Compute the Kronecker product of two tensors, matching MATLAB semantics.
1691    fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1692        Err(anyhow::anyhow!("kron not supported by provider"))
1693    }
1694    fn reduce_sum(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1695        Err(anyhow::anyhow!("reduce_sum not supported by provider"))
1696    }
1697    fn reduce_sum_dim(&self, _a: &GpuTensorHandle, _dim: usize) -> anyhow::Result<GpuTensorHandle> {
1698        Err(anyhow::anyhow!("reduce_sum_dim not supported by provider"))
1699    }
1700    fn dot(
1701        &self,
1702        _lhs: &GpuTensorHandle,
1703        _rhs: &GpuTensorHandle,
1704        _dim: Option<usize>,
1705    ) -> anyhow::Result<GpuTensorHandle> {
1706        Err(anyhow::anyhow!("dot not supported by provider"))
1707    }
1708    fn reduce_nnz(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1709        Err(anyhow::anyhow!("reduce_nnz not supported by provider"))
1710    }
1711    fn reduce_nnz_dim(&self, _a: &GpuTensorHandle, _dim: usize) -> anyhow::Result<GpuTensorHandle> {
1712        Err(anyhow::anyhow!("reduce_nnz_dim not supported by provider"))
1713    }
1714    fn reduce_prod(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1715        Err(anyhow::anyhow!("reduce_prod not supported by provider"))
1716    }
1717    fn reduce_prod_dim(
1718        &self,
1719        _a: &GpuTensorHandle,
1720        _dim: usize,
1721    ) -> anyhow::Result<GpuTensorHandle> {
1722        Err(anyhow::anyhow!("reduce_prod_dim not supported by provider"))
1723    }
1724    fn reduce_mean(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1725        Err(anyhow::anyhow!("reduce_mean not supported by provider"))
1726    }
1727    /// Reduce mean across multiple zero-based dimensions in one device pass.
1728    fn reduce_mean_nd(
1729        &self,
1730        _a: &GpuTensorHandle,
1731        _dims_zero_based: &[usize],
1732    ) -> anyhow::Result<GpuTensorHandle> {
1733        Err(anyhow::anyhow!("reduce_mean_nd not supported by provider"))
1734    }
1735    /// Reduce moments across multiple zero-based dimensions in one device pass.
1736    /// Returns mean (E[x]) and mean of squares (E[x^2]).
1737    fn reduce_moments_nd(
1738        &self,
1739        _a: &GpuTensorHandle,
1740        _dims_zero_based: &[usize],
1741    ) -> anyhow::Result<ProviderMoments2> {
1742        Err(anyhow::anyhow!(
1743            "reduce_moments_nd not supported by provider"
1744        ))
1745    }
1746    fn reduce_mean_dim(
1747        &self,
1748        _a: &GpuTensorHandle,
1749        _dim: usize,
1750    ) -> anyhow::Result<GpuTensorHandle> {
1751        Err(anyhow::anyhow!("reduce_mean_dim not supported by provider"))
1752    }
1753    fn reduce_std(
1754        &self,
1755        _a: &GpuTensorHandle,
1756        _normalization: ProviderStdNormalization,
1757        _nan_mode: ProviderNanMode,
1758    ) -> anyhow::Result<GpuTensorHandle> {
1759        Err(anyhow::anyhow!("reduce_std not supported by provider"))
1760    }
1761    fn reduce_std_dim(
1762        &self,
1763        _a: &GpuTensorHandle,
1764        _dim: usize,
1765        _normalization: ProviderStdNormalization,
1766        _nan_mode: ProviderNanMode,
1767    ) -> anyhow::Result<GpuTensorHandle> {
1768        Err(anyhow::anyhow!("reduce_std_dim not supported by provider"))
1769    }
1770    fn reduce_any(&self, _a: &GpuTensorHandle, _omit_nan: bool) -> anyhow::Result<GpuTensorHandle> {
1771        Err(anyhow::anyhow!("reduce_any not supported by provider"))
1772    }
1773    fn reduce_any_dim(
1774        &self,
1775        _a: &GpuTensorHandle,
1776        _dim: usize,
1777        _omit_nan: bool,
1778    ) -> anyhow::Result<GpuTensorHandle> {
1779        Err(anyhow::anyhow!("reduce_any_dim not supported by provider"))
1780    }
1781    fn reduce_all(&self, _a: &GpuTensorHandle, _omit_nan: bool) -> anyhow::Result<GpuTensorHandle> {
1782        Err(anyhow::anyhow!("reduce_all not supported by provider"))
1783    }
1784    fn reduce_all_dim(
1785        &self,
1786        _a: &GpuTensorHandle,
1787        _dim: usize,
1788        _omit_nan: bool,
1789    ) -> anyhow::Result<GpuTensorHandle> {
1790        Err(anyhow::anyhow!("reduce_all_dim not supported by provider"))
1791    }
1792    fn reduce_median(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1793        Err(anyhow::anyhow!("reduce_median not supported by provider"))
1794    }
1795    fn reduce_median_dim(
1796        &self,
1797        _a: &GpuTensorHandle,
1798        _dim: usize,
1799    ) -> anyhow::Result<GpuTensorHandle> {
1800        Err(anyhow::anyhow!(
1801            "reduce_median_dim not supported by provider"
1802        ))
1803    }
1804    fn reduce_min(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1805        Err(anyhow::anyhow!("reduce_min not supported by provider"))
1806    }
1807    fn reduce_min_dim(&self, _a: &GpuTensorHandle, _dim: usize) -> anyhow::Result<ReduceDimResult> {
1808        Err(anyhow::anyhow!("reduce_min_dim not supported by provider"))
1809    }
1810    fn reduce_max(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1811        Err(anyhow::anyhow!("reduce_max not supported by provider"))
1812    }
1813    fn reduce_max_dim(&self, _a: &GpuTensorHandle, _dim: usize) -> anyhow::Result<ReduceDimResult> {
1814        Err(anyhow::anyhow!("reduce_max_dim not supported by provider"))
1815    }
1816    fn cumsum_scan(
1817        &self,
1818        _input: &GpuTensorHandle,
1819        _dim: usize,
1820        _direction: ProviderScanDirection,
1821        _nan_mode: ProviderNanMode,
1822    ) -> anyhow::Result<GpuTensorHandle> {
1823        Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
1824    }
1825    fn cumprod_scan(
1826        &self,
1827        _input: &GpuTensorHandle,
1828        _dim: usize,
1829        _direction: ProviderScanDirection,
1830        _nan_mode: ProviderNanMode,
1831    ) -> anyhow::Result<GpuTensorHandle> {
1832        Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
1833    }
1834    fn cummin_scan(
1835        &self,
1836        _input: &GpuTensorHandle,
1837        _dim: usize,
1838        _direction: ProviderScanDirection,
1839        _nan_mode: ProviderNanMode,
1840    ) -> anyhow::Result<ProviderCumminResult> {
1841        Err(anyhow::anyhow!("cummin_scan not supported by provider"))
1842    }
1843    fn cummax_scan(
1844        &self,
1845        _input: &GpuTensorHandle,
1846        _dim: usize,
1847        _direction: ProviderScanDirection,
1848        _nan_mode: ProviderNanMode,
1849    ) -> anyhow::Result<ProviderCummaxResult> {
1850        Err(anyhow::anyhow!("cummax_scan not supported by provider"))
1851    }
1852
1853    fn find(
1854        &self,
1855        _a: &GpuTensorHandle,
1856        _limit: Option<usize>,
1857        _direction: FindDirection,
1858    ) -> anyhow::Result<ProviderFindResult> {
1859        Err(anyhow::anyhow!("find not supported by provider"))
1860    }
1861
1862    fn fused_elementwise(
1863        &self,
1864        _shader: &str,
1865        _inputs: &[GpuTensorHandle],
1866        _output_shape: &[usize],
1867        _len: usize,
1868    ) -> anyhow::Result<GpuTensorHandle> {
1869        Err(anyhow::anyhow!(
1870            "fused_elementwise not supported by provider"
1871        ))
1872    }
1873
1874    /// Build a numeric tensor where NaNs in `a` are replaced with 0.0 (device side).
1875    fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1876        Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
1877    }
1878
1879    /// Build a numeric mask tensor with 1.0 where value is not NaN and 0.0 where value is NaN.
1880    fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1881        Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
1882    }
1883
1884    /// Generic fused reduction entrypoint.
1885    ///
1886    /// The shader is expected to implement a column-major reduction across `reduce_len` with
1887    /// `num_slices` independent slices (e.g., columns). Providers should create a uniform buffer
1888    /// compatible with the expected `Params/MParams` struct in the shader and dispatch
1889    /// `num_slices` workgroups with `workgroup_size` threads, or an equivalent strategy.
1890    #[allow(clippy::too_many_arguments)]
1891    fn fused_reduction(
1892        &self,
1893        _shader: &str,
1894        _inputs: &[GpuTensorHandle],
1895        _output_shape: &[usize],
1896        _reduce_len: usize,
1897        _num_slices: usize,
1898        _workgroup_size: u32,
1899        _flavor: ReductionFlavor,
1900    ) -> anyhow::Result<GpuTensorHandle> {
1901        Err(anyhow::anyhow!("fused_reduction not supported by provider"))
1902    }
1903
1904    /// Optionally pre-compile commonly used pipelines to amortize first-dispatch costs.
1905    fn warmup(&self) {}
1906
1907    /// Returns (cache_hits, cache_misses) for fused pipeline cache, if supported.
1908    fn fused_cache_counters(&self) -> (u64, u64) {
1909        (0, 0)
1910    }
1911
1912    /// Returns the duration of the last provider warmup in milliseconds, if known.
1913    fn last_warmup_millis(&self) -> Option<u64> {
1914        None
1915    }
1916
1917    /// Returns a snapshot of provider telemetry counters if supported.
1918    fn telemetry_snapshot(&self) -> ProviderTelemetry {
1919        let (hits, misses) = self.fused_cache_counters();
1920        ProviderTelemetry {
1921            fused_elementwise: ProviderDispatchStats::default(),
1922            fused_reduction: ProviderDispatchStats::default(),
1923            matmul: ProviderDispatchStats::default(),
1924            upload_bytes: 0,
1925            download_bytes: 0,
1926            fusion_cache_hits: hits,
1927            fusion_cache_misses: misses,
1928            bind_group_cache_hits: 0,
1929            bind_group_cache_misses: 0,
1930            bind_group_cache_by_layout: None,
1931            kernel_launches: Vec::new(),
1932        }
1933    }
1934
1935    /// Reset all telemetry counters maintained by the provider, if supported.
1936    fn reset_telemetry(&self) {}
1937
1938    /// Default reduction workgroup size the provider prefers.
1939    fn default_reduction_workgroup_size(&self) -> u32 {
1940        256
1941    }
1942
1943    /// Threshold above which provider will prefer two-pass reduction.
1944    fn two_pass_threshold(&self) -> usize {
1945        1024
1946    }
1947
1948    /// Current two-pass mode preference (auto/forced on/off).
1949    fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
1950        ReductionTwoPassMode::Auto
1951    }
1952
1953    /// Fast-path: write a GPU column in a matrix from a GPU vector, returning a new handle.
1954    /// Expected: `values.shape == [rows, 1]` (or `[rows]`) and `col_index < cols`.
1955    fn scatter_column(
1956        &self,
1957        _matrix: &GpuTensorHandle,
1958        _col_index: usize,
1959        _values: &GpuTensorHandle,
1960    ) -> anyhow::Result<GpuTensorHandle> {
1961        Err(anyhow::anyhow!("scatter_column not supported by provider"))
1962    }
1963
1964    /// Fast-path: write a GPU row in a matrix from a GPU vector, returning a new handle.
1965    /// Expected: `values.shape == [1, cols]` (or `[cols]`) and `row_index < rows`.
1966    fn scatter_row(
1967        &self,
1968        _matrix: &GpuTensorHandle,
1969        _row_index: usize,
1970        _values: &GpuTensorHandle,
1971    ) -> anyhow::Result<GpuTensorHandle> {
1972        Err(anyhow::anyhow!("scatter_row not supported by provider"))
1973    }
1974
1975    fn sub2ind(
1976        &self,
1977        _dims: &[usize],
1978        _strides: &[usize],
1979        _inputs: &[&GpuTensorHandle],
1980        _scalar_mask: &[bool],
1981        _len: usize,
1982        _output_shape: &[usize],
1983    ) -> anyhow::Result<GpuTensorHandle> {
1984        Err(anyhow::anyhow!("sub2ind not supported by provider"))
1985    }
1986
1987    /// Returns true if the provider offers a device-side `ind2sub` implementation.
1988    fn supports_ind2sub(&self) -> bool {
1989        false
1990    }
1991
1992    /// Convert linear indices into per-dimension subscripts on the device.
1993    fn ind2sub(
1994        &self,
1995        _dims: &[usize],
1996        _strides: &[usize],
1997        _indices: &GpuTensorHandle,
1998        _total: usize,
1999        _len: usize,
2000        _output_shape: &[usize],
2001    ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2002        Err(anyhow::anyhow!("ind2sub not supported by provider"))
2003    }
2004
2005    /// Determine if a matrix is symmetric (or skew-symmetric) without gathering it to the host.
2006    fn issymmetric(
2007        &self,
2008        _matrix: &GpuTensorHandle,
2009        _kind: ProviderSymmetryKind,
2010        _tolerance: f64,
2011    ) -> anyhow::Result<bool> {
2012        Err(anyhow::anyhow!(
2013            "issymmetric predicate not supported by provider"
2014        ))
2015    }
2016
2017    /// Determine if a matrix is Hermitian (or skew-Hermitian) without gathering it to the host.
2018    fn ishermitian(
2019        &self,
2020        _matrix: &GpuTensorHandle,
2021        _kind: ProviderHermitianKind,
2022        _tolerance: f64,
2023    ) -> anyhow::Result<bool> {
2024        Err(anyhow::anyhow!(
2025            "ishermitian predicate not supported by provider"
2026        ))
2027    }
2028
2029    /// Inspect the bandwidth of a matrix without gathering it back to the host.
2030    fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2031        Err(anyhow::anyhow!("bandwidth not supported by provider"))
2032    }
2033
2034    /// Compute the symmetric reverse Cuthill-McKee permutation for the matrix.
2035    ///
2036    /// Implementations may execute on the device or gather to the host. The permutation should be
2037    /// returned as zero-based indices.
2038    fn sym_rcm(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<Vec<usize>> {
2039        Err(anyhow::anyhow!("sym_rcm not supported by provider"))
2040    }
2041}
2042
2043static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2044    Lazy::new(|| RwLock::new(None));
2045static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2046    Lazy::new(|| RwLock::new(HashMap::new()));
2047static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2048thread_local! {
2049    static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2050}
2051
2052/// Register a global acceleration provider.
2053///
2054/// # Safety
2055/// - The caller must guarantee that `p` is valid for the entire program lifetime
2056///   (e.g., a `'static` singleton), as the runtime stores a raw reference globally.
2057/// - Concurrent callers must ensure registration happens once or is properly
2058///   synchronized; this function does not enforce thread-safety for re-registration.
2059pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2060    if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2061        *guard = Some(p);
2062    }
2063    register_provider_for_device(p.device_id(), p);
2064}
2065
2066unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2067    if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2068        guard.insert(device_id, provider);
2069    }
2070}
2071
2072pub fn provider() -> Option<&'static dyn AccelProvider> {
2073    if let Some(p) = THREAD_PROVIDER.with(|cell| cell.get()) {
2074        return Some(p);
2075    }
2076    GLOBAL_PROVIDER
2077        .read()
2078        .ok()
2079        .and_then(|guard| guard.as_ref().copied())
2080}
2081
2082/// Clear the globally registered provider. Intended for tests to ensure deterministic behaviour.
2083pub fn clear_provider() {
2084    if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2085        *guard = None;
2086    }
2087    if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2088        map.clear();
2089    }
2090}
2091
2092pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2093    PROVIDER_REGISTRY
2094        .read()
2095        .ok()
2096        .and_then(|guard| guard.get(&device_id).copied())
2097        .or_else(|| provider())
2098}
2099
2100pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2101    provider_for_device(handle.device_id)
2102}
2103
2104pub fn next_device_id() -> u32 {
2105    DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2106}
2107
2108pub struct ThreadProviderGuard {
2109    prev: Option<&'static dyn AccelProvider>,
2110}
2111
2112impl ThreadProviderGuard {
2113    pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2114        let prev = THREAD_PROVIDER.with(|cell| {
2115            let old = cell.get();
2116            cell.set(provider);
2117            old
2118        });
2119        ThreadProviderGuard { prev }
2120    }
2121}
2122
2123impl Drop for ThreadProviderGuard {
2124    fn drop(&mut self) {
2125        let prev = self.prev.take();
2126        THREAD_PROVIDER.with(|cell| cell.set(prev));
2127    }
2128}
2129
2130pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2131    THREAD_PROVIDER.with(|cell| cell.set(provider));
2132}
2133
2134/// Convenience: perform elementwise add via provider if possible; otherwise return None
2135pub fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2136    if let Some(p) = provider() {
2137        if let Ok(h) = p.elem_add(a, b) {
2138            return Some(h);
2139        }
2140    }
2141    None
2142}
2143
2144/// Convenience: perform elementwise hypot via provider if possible; otherwise return None
2145pub fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2146    if let Some(p) = provider() {
2147        if let Ok(h) = p.elem_hypot(a, b) {
2148            return Some(h);
2149        }
2150    }
2151    None
2152}
2153
2154/// Convenience: perform elementwise max via provider if possible; otherwise return None
2155pub fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2156    if let Some(p) = provider() {
2157        if let Ok(h) = p.elem_max(a, b) {
2158            return Some(h);
2159        }
2160    }
2161    None
2162}
2163
2164/// Convenience: perform elementwise min via provider if possible; otherwise return None
2165pub fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2166    if let Some(p) = provider() {
2167        if let Ok(h) = p.elem_min(a, b) {
2168            return Some(h);
2169        }
2170    }
2171    None
2172}
2173
2174/// Convenience: perform elementwise atan2 via provider if possible; otherwise return None
2175pub fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2176    if let Some(p) = provider() {
2177        if let Ok(h) = p.elem_atan2(y, x) {
2178            return Some(h);
2179        }
2180    }
2181    None
2182}
2183
2184// Minimal host tensor views to avoid depending on runmat-builtins and cycles
2185#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2186pub struct HostTensorOwned {
2187    pub data: Vec<f64>,
2188    pub shape: Vec<usize>,
2189}
2190
2191#[derive(Debug)]
2192pub struct HostTensorView<'a> {
2193    pub data: &'a [f64],
2194    pub shape: &'a [usize],
2195}
2196
2197/// Lightweight 1-D axis view used by provider meshgrid hooks.
2198#[derive(Debug)]
2199pub struct MeshgridAxisView<'a> {
2200    pub data: &'a [f64],
2201}
2202
2203/// Provider-side meshgrid result containing coordinate tensor handles.
2204#[derive(Debug, Clone)]
2205pub struct ProviderMeshgridResult {
2206    pub outputs: Vec<GpuTensorHandle>,
2207}
2208
2209/// Descriptor for GEMM epilogues applied to `C = A * B` before storing to `C`.
2210///
2211/// Supported operations:
2212/// - Scale by `alpha` and add scalar `beta`.
2213/// - Multiply output by per-row and/or per-column scale vectors (broadcasted).
2214#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
2215pub enum ScaleOp {
2216    Multiply,
2217    Divide,
2218}
2219
2220#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2221pub struct MatmulEpilogue {
2222    /// Scalar multiply applied to each output element.
2223    pub alpha: f64,
2224    /// Scalar add applied to each output element after scaling.
2225    pub beta: f64,
2226    /// Optional per-row scale (length m). When present, output[row, col] *= row_scale[row].
2227    pub row_scale: Option<GpuTensorHandle>,
2228    /// Optional per-column scale (length n). When present, output[row, col] *= col_scale[col].
2229    pub col_scale: Option<GpuTensorHandle>,
2230    /// Row scale operation (multiply or divide). Ignored when `row_scale` is None.
2231    pub row_op: ScaleOp,
2232    /// Column scale operation (multiply or divide). Ignored when `col_scale` is None.
2233    pub col_op: ScaleOp,
2234    /// Optional lower clamp bound applied after scale/bias.
2235    #[serde(default)]
2236    pub clamp_min: Option<f64>,
2237    /// Optional upper clamp bound applied after scale/bias.
2238    #[serde(default)]
2239    pub clamp_max: Option<f64>,
2240    /// Optional power exponent applied after clamp (final operation in the epilogue).
2241    #[serde(default)]
2242    pub pow_exponent: Option<f64>,
2243    /// Optional output buffer for the diagonal of the result (length min(m, n)).
2244    #[serde(default)]
2245    pub diag_output: Option<GpuTensorHandle>,
2246}
2247
2248impl MatmulEpilogue {
2249    pub fn noop() -> Self {
2250        Self {
2251            alpha: 1.0,
2252            beta: 0.0,
2253            row_scale: None,
2254            col_scale: None,
2255            row_op: ScaleOp::Multiply,
2256            col_op: ScaleOp::Multiply,
2257            clamp_min: None,
2258            clamp_max: None,
2259            pow_exponent: None,
2260            diag_output: None,
2261        }
2262    }
2263    pub fn is_noop(&self) -> bool {
2264        self.alpha == 1.0
2265            && self.beta == 0.0
2266            && self.row_scale.is_none()
2267            && self.col_scale.is_none()
2268            && self.clamp_min.is_none()
2269            && self.clamp_max.is_none()
2270            && self.pow_exponent.is_none()
2271            && self.diag_output.is_none()
2272    }
2273}
2274
2275#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
2276pub struct PowerStepEpilogue {
2277    pub epsilon: f64,
2278}
2279
2280impl Default for PowerStepEpilogue {
2281    fn default() -> Self {
2282        Self { epsilon: 0.0 }
2283    }
2284}
2285
2286#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2287pub struct ImageNormalizeDescriptor {
2288    pub batch: usize,
2289    pub height: usize,
2290    pub width: usize,
2291    pub epsilon: f64,
2292    #[serde(default)]
2293    pub gain: Option<f64>,
2294    #[serde(default)]
2295    pub bias: Option<f64>,
2296    #[serde(default)]
2297    pub gamma: Option<f64>,
2298}