Skip to main content

runmat_accelerate_api/
lib.rs

1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyMarkFn = fn(&GpuTensorHandle);
17type ResidencyClearFn = fn(&GpuTensorHandle);
18type SequenceThresholdFn = fn() -> Option<usize>;
19type WorkgroupSizeHintFn = fn() -> Option<u32>;
20
21static RESIDENCY_MARK: OnceCell<ResidencyMarkFn> = OnceCell::new();
22static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
23static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
24static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
25
26static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
27static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
28    Lazy::new(|| RwLock::new(HashMap::new()));
29static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
30    Lazy::new(|| RwLock::new(HashMap::new()));
31
32static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
33    Lazy::new(|| RwLock::new(HashMap::new()));
34static HANDLE_STORAGES: Lazy<RwLock<HashMap<u64, GpuTensorStorage>>> =
35    Lazy::new(|| RwLock::new(HashMap::new()));
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct TransposeInfo {
39    pub base_rows: usize,
40    pub base_cols: usize,
41}
42
43/// Register a callback used to mark residency tracking when GPU tensors are
44/// created or returned by device-side execution paths.
45pub fn register_residency_mark(handler: ResidencyMarkFn) {
46    let _ = RESIDENCY_MARK.set(handler);
47}
48
49/// Mark residency metadata for the provided GPU tensor handle, if a backend
50/// has registered a handler via [`register_residency_mark`].
51pub fn mark_residency(handle: &GpuTensorHandle) {
52    if let Some(handler) = RESIDENCY_MARK.get() {
53        handler(handle);
54    }
55}
56
57/// Register a callback used to clear residency tracking when GPU tensors are
58/// gathered back to the host. Backends that maintain residency metadata should
59/// install this hook during initialization.
60pub fn register_residency_clear(handler: ResidencyClearFn) {
61    let _ = RESIDENCY_CLEAR.set(handler);
62}
63
64/// Clear residency metadata for the provided GPU tensor handle, if a backend
65/// has registered a handler via [`register_residency_clear`].
66pub fn clear_residency(handle: &GpuTensorHandle) {
67    if let Some(handler) = RESIDENCY_CLEAR.get() {
68        handler(handle);
69    }
70}
71
72/// Register a callback that exposes the current sequence length threshold
73/// derived from the auto-offload planner. Array constructors can use this hint
74/// to decide when to prefer GPU residency automatically.
75pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
76    let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
77}
78
79/// Query the currently registered sequence threshold hint, if any.
80pub fn sequence_threshold_hint() -> Option<usize> {
81    SEQUENCE_THRESHOLD_PROVIDER
82        .get()
83        .and_then(|provider| provider())
84}
85
86/// Register a callback that reports the calibrated workgroup size selected by
87/// the active acceleration provider (if any). Plotting kernels can reuse this
88/// hint to match backend tuning.
89pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
90    let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
91}
92
93/// Query the current workgroup size hint exposed by the provider.
94pub fn workgroup_size_hint() -> Option<u32> {
95    WORKGROUP_SIZE_HINT_PROVIDER
96        .get()
97        .and_then(|provider| provider())
98}
99
100/// Export a shared acceleration context (e.g., the active WGPU device) when the
101/// current provider exposes one.
102pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
103    provider().and_then(|p| p.export_context(kind))
104}
105
106/// Request a provider-owned WGPU buffer for zero-copy consumers. Returns `None`
107/// when the active provider does not expose buffers or does not support the
108/// supplied handle.
109#[cfg(feature = "wgpu")]
110pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
111    provider().and_then(|p| p.export_wgpu_buffer(handle))
112}
113
114/// Record the precision associated with a GPU tensor handle so host operations can
115/// reconstruct the original dtype when gathering back to the CPU.
116pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
117    if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
118        guard.insert(handle.buffer_id, precision);
119    }
120}
121
122/// Look up the recorded precision for a GPU tensor handle, if any.
123pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
124    HANDLE_PRECISIONS
125        .read()
126        .ok()
127        .and_then(|guard| guard.get(&handle.buffer_id).copied())
128}
129
130/// Clear any recorded precision metadata for a GPU tensor handle.
131pub fn clear_handle_precision(handle: &GpuTensorHandle) {
132    if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
133        guard.remove(&handle.buffer_id);
134    }
135}
136
137/// Annotate a GPU tensor handle as logically-typed (`logical` in MATLAB terms)
138/// or clear the logical flag when `logical` is `false`.
139pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
140    if let Ok(mut guard) = LOGICAL_HANDLES.write() {
141        if logical {
142            guard.insert(handle.buffer_id);
143            if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
144                *hits.entry(handle.buffer_id).or_insert(0) += 1;
145            }
146        } else {
147            guard.remove(&handle.buffer_id);
148            if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
149                hits.remove(&handle.buffer_id);
150            }
151        }
152    }
153}
154
155/// Convenience helper for clearing logical annotations explicitly.
156pub fn clear_handle_logical(handle: &GpuTensorHandle) {
157    set_handle_logical(handle, false);
158}
159
160/// Returns true when the supplied handle has been marked as logical.
161pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
162    LOGICAL_HANDLES
163        .read()
164        .map(|guard| guard.contains(&handle.buffer_id))
165        .unwrap_or(false)
166}
167
168pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
169    LOGICAL_HANDLE_HITS
170        .read()
171        .ok()
172        .and_then(|guard| guard.get(&buffer_id).copied())
173}
174
175pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
176    if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
177        guard.insert(
178            handle.buffer_id,
179            TransposeInfo {
180                base_rows,
181                base_cols,
182            },
183        );
184    }
185}
186
187pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
188    if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
189        guard.remove(&handle.buffer_id);
190    }
191}
192
193pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
194    TRANSPOSED_HANDLES
195        .read()
196        .ok()
197        .and_then(|guard| guard.get(&handle.buffer_id).copied())
198}
199
200pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
201    handle_transpose_info(handle).is_some()
202}
203
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum GpuTensorStorage {
206    Real,
207    ComplexInterleaved,
208}
209
210impl Default for GpuTensorStorage {
211    fn default() -> Self {
212        Self::Real
213    }
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct GpuTensorHandle {
218    pub shape: Vec<usize>,
219    pub device_id: u32,
220    pub buffer_id: u64,
221}
222
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct ApiDeviceInfo {
225    pub device_id: u32,
226    pub name: String,
227    pub vendor: String,
228    pub memory_bytes: Option<u64>,
229    pub backend: Option<String>,
230}
231
232#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
233pub struct ReduceDimResult {
234    pub values: GpuTensorHandle,
235    pub indices: GpuTensorHandle,
236}
237
238#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
239pub struct ProviderCumminResult {
240    pub values: GpuTensorHandle,
241    pub indices: GpuTensorHandle,
242}
243
244/// Result payload returned by provider-side `cummax` scans.
245///
246/// Alias of [`ProviderCumminResult`] because both operations return the same pair of tensors
247/// (running values and MATLAB-compatible indices).
248pub type ProviderCummaxResult = ProviderCumminResult;
249
250/// Names a shared acceleration context that callers may request (e.g. plotting).
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum AccelContextKind {
253    Plotting,
254}
255
256/// Handle returned by [`export_context`] that describes a shared GPU context.
257#[derive(Clone)]
258pub enum AccelContextHandle {
259    #[cfg(feature = "wgpu")]
260    Wgpu(WgpuContextHandle),
261}
262
263impl AccelContextHandle {
264    /// Returns the underlying WGPU context when available.
265    #[cfg(feature = "wgpu")]
266    pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
267        match self {
268            AccelContextHandle::Wgpu(ctx) => Some(ctx),
269        }
270    }
271}
272
273/// Shared WGPU device/queue pair exported by the acceleration provider.
274#[cfg(feature = "wgpu")]
275#[derive(Clone)]
276pub struct WgpuContextHandle {
277    pub instance: Arc<wgpu::Instance>,
278    pub device: Arc<wgpu::Device>,
279    pub queue: Arc<wgpu::Queue>,
280    pub adapter: Arc<wgpu::Adapter>,
281    pub adapter_info: wgpu::AdapterInfo,
282    pub limits: wgpu::Limits,
283    pub features: wgpu::Features,
284}
285
286/// Borrowed reference to a provider-owned WGPU buffer corresponding to a `GpuTensorHandle`.
287#[cfg(feature = "wgpu")]
288#[derive(Clone)]
289pub struct WgpuBufferRef {
290    pub buffer: Arc<wgpu::Buffer>,
291    pub len: usize,
292    pub shape: Vec<usize>,
293    pub element_size: usize,
294    pub precision: ProviderPrecision,
295}
296
297pub fn set_handle_storage(handle: &GpuTensorHandle, storage: GpuTensorStorage) {
298    if let Ok(mut guard) = HANDLE_STORAGES.write() {
299        guard.insert(handle.buffer_id, storage);
300    }
301}
302
303pub fn handle_storage(handle: &GpuTensorHandle) -> GpuTensorStorage {
304    HANDLE_STORAGES
305        .read()
306        .ok()
307        .and_then(|guard| guard.get(&handle.buffer_id).cloned())
308        .unwrap_or(GpuTensorStorage::Real)
309}
310
311pub fn clear_handle_storage(handle: &GpuTensorHandle) {
312    if let Ok(mut guard) = HANDLE_STORAGES.write() {
313        guard.remove(&handle.buffer_id);
314    }
315}
316
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
318pub enum PagefunOp {
319    Mtimes,
320}
321
322#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
323pub struct PagefunRequest {
324    pub op: PagefunOp,
325    pub inputs: Vec<GpuTensorHandle>,
326    pub output_shape: Vec<usize>,
327    pub page_dims: Vec<usize>,
328    pub input_page_dims: Vec<Vec<usize>>,
329}
330
331#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
332pub enum FindDirection {
333    First,
334    Last,
335}
336
337#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
338pub struct ProviderFindResult {
339    pub linear: GpuTensorHandle,
340    pub rows: GpuTensorHandle,
341    pub cols: GpuTensorHandle,
342    pub values: Option<GpuTensorHandle>,
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
346pub struct ProviderBandwidth {
347    pub lower: u32,
348    pub upper: u32,
349}
350
351#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
352pub enum ProviderSymmetryKind {
353    Symmetric,
354    Skew,
355}
356
357#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
358pub enum ProviderHermitianKind {
359    Hermitian,
360    Skew,
361}
362
363#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
364pub struct ProviderLuResult {
365    pub combined: GpuTensorHandle,
366    pub lower: GpuTensorHandle,
367    pub upper: GpuTensorHandle,
368    pub perm_matrix: GpuTensorHandle,
369    pub perm_vector: GpuTensorHandle,
370}
371
372#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
373pub struct ProviderCholResult {
374    pub factor: GpuTensorHandle,
375    /// MATLAB-compatible failure index (0 indicates success).
376    pub info: u32,
377}
378
379#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
380pub struct ProviderQrResult {
381    pub q: GpuTensorHandle,
382    pub r: GpuTensorHandle,
383    pub perm_matrix: GpuTensorHandle,
384    pub perm_vector: GpuTensorHandle,
385}
386
387#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
388pub struct ProviderQrPowerIterResult {
389    pub q: GpuTensorHandle,
390    pub r: GpuTensorHandle,
391    pub perm_matrix: GpuTensorHandle,
392    pub perm_vector: GpuTensorHandle,
393}
394
395#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
396pub struct ProviderLinsolveOptions {
397    pub lower: bool,
398    pub upper: bool,
399    pub rectangular: bool,
400    pub transposed: bool,
401    pub conjugate: bool,
402    pub symmetric: bool,
403    pub posdef: bool,
404    pub need_rcond: bool,
405    pub rcond: Option<f64>,
406}
407
408#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
409pub struct ProviderLinsolveResult {
410    pub solution: GpuTensorHandle,
411    pub reciprocal_condition: f64,
412}
413
414#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
415pub struct ProviderPinvOptions {
416    pub tolerance: Option<f64>,
417}
418
419#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
420pub struct ProviderPolyvalMu {
421    pub mean: f64,
422    pub scale: f64,
423}
424
425#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
426pub struct ProviderPolyvalOptions {
427    pub mu: Option<ProviderPolyvalMu>,
428}
429
430#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
431pub struct ProviderInvOptions {}
432
433#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
434pub struct ProviderPolyfitResult {
435    pub coefficients: Vec<f64>,
436    pub r_matrix: Vec<f64>,
437    pub normr: f64,
438    pub df: f64,
439    pub mu: [f64; 2],
440}
441
442/// Numerator/denominator payload returned by provider-backed `polyder` quotient rule.
443#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
444pub struct ProviderPolyderQuotient {
445    pub numerator: GpuTensorHandle,
446    pub denominator: GpuTensorHandle,
447}
448
449/// Supported norm specifications for the `cond` builtin.
450#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
451pub enum ProviderCondNorm {
452    Two,
453    One,
454    Inf,
455    Fro,
456}
457
458/// Supported norm orders for the `norm` builtin.
459#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
460pub enum ProviderNormOrder {
461    Two,
462    One,
463    Inf,
464    NegInf,
465    Zero,
466    Fro,
467    Nuc,
468    P(f64),
469}
470
471#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
472pub struct ProviderEigResult {
473    pub eigenvalues: GpuTensorHandle,
474    pub diagonal: GpuTensorHandle,
475    pub right: GpuTensorHandle,
476    pub left: Option<GpuTensorHandle>,
477}
478
479#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
480pub enum ProviderQrPivot {
481    Matrix,
482    Vector,
483}
484
485#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
486pub struct ProviderQrOptions {
487    pub economy: bool,
488    pub pivot: ProviderQrPivot,
489}
490
491impl Default for ProviderQrOptions {
492    fn default() -> Self {
493        Self {
494            economy: false,
495            pivot: ProviderQrPivot::Matrix,
496        }
497    }
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
501pub enum ProviderPrecision {
502    F32,
503    F64,
504}
505
506#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
507pub enum ReductionTwoPassMode {
508    Auto,
509    ForceOn,
510    ForceOff,
511}
512
513impl ReductionTwoPassMode {
514    pub fn as_str(self) -> &'static str {
515        match self {
516            ReductionTwoPassMode::Auto => "auto",
517            ReductionTwoPassMode::ForceOn => "force_on",
518            ReductionTwoPassMode::ForceOff => "force_off",
519        }
520    }
521}
522
523#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
524pub enum ReductionFlavor {
525    Sum,
526    Mean,
527    CustomScale(f64),
528}
529
530impl ReductionFlavor {
531    pub fn is_mean(self) -> bool {
532        matches!(self, ReductionFlavor::Mean)
533    }
534
535    pub fn scale(self, reduce_len: usize) -> f64 {
536        match self {
537            ReductionFlavor::Sum => 1.0,
538            ReductionFlavor::Mean => {
539                if reduce_len == 0 {
540                    1.0
541                } else {
542                    1.0 / reduce_len as f64
543                }
544            }
545            ReductionFlavor::CustomScale(scale) => scale,
546        }
547    }
548}
549
550/// Normalisation mode for correlation coefficients.
551#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
552pub enum CorrcoefNormalization {
553    Unbiased,
554    Biased,
555}
556
557/// Row-selection strategy for correlation coefficients.
558#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
559pub enum CorrcoefRows {
560    All,
561    Complete,
562    Pairwise,
563}
564
565/// Options controlling provider-backed correlation coefficient computation.
566#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
567pub struct CorrcoefOptions {
568    pub normalization: CorrcoefNormalization,
569    pub rows: CorrcoefRows,
570}
571
572impl Default for CorrcoefOptions {
573    fn default() -> Self {
574        Self {
575            normalization: CorrcoefNormalization::Unbiased,
576            rows: CorrcoefRows::All,
577        }
578    }
579}
580
581/// Normalisation mode used by covariance computations.
582#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
583pub enum CovNormalization {
584    Unbiased,
585    Biased,
586}
587
588/// Row handling strategy for covariance computations.
589#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
590pub enum CovRows {
591    All,
592    OmitRows,
593    PartialRows,
594}
595
596/// Options controlling provider-backed covariance computation.
597#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
598pub struct CovarianceOptions {
599    pub normalization: CovNormalization,
600    pub rows: CovRows,
601    pub has_weight_vector: bool,
602}
603
604impl Default for CovarianceOptions {
605    fn default() -> Self {
606        Self {
607            normalization: CovNormalization::Unbiased,
608            rows: CovRows::All,
609            has_weight_vector: false,
610        }
611    }
612}
613
614/// Normalization strategy used by provider-backed standard deviation reductions.
615#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
616pub enum ProviderStdNormalization {
617    Sample,
618    Population,
619}
620
621/// NaN handling mode for provider-backed reductions.
622#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
623pub enum ProviderNanMode {
624    Include,
625    Omit,
626}
627
628/// Direction used when computing prefix sums on the device.
629#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
630pub enum ProviderScanDirection {
631    Forward,
632    Reverse,
633}
634
635/// Sort direction used by acceleration providers.
636#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
637pub enum SortOrder {
638    Ascend,
639    Descend,
640}
641
642/// Comparison strategy applied during sorting.
643#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
644pub enum SortComparison {
645    Auto,
646    Real,
647    Abs,
648}
649
650/// Host-resident outputs returned by provider-backed sort operations.
651#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
652pub struct SortResult {
653    pub values: HostTensorOwned,
654    pub indices: HostTensorOwned,
655}
656
657#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
658pub struct SortRowsColumnSpec {
659    pub index: usize,
660    pub order: SortOrder,
661}
662
663/// Ordering applied by provider-backed `unique` operations.
664#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
665pub enum UniqueOrder {
666    Sorted,
667    Stable,
668}
669
670/// Occurrence selection for provider-backed `unique` operations.
671#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
672pub enum UniqueOccurrence {
673    First,
674    Last,
675}
676
677/// Options controlling provider-backed `unique` operations.
678#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
679pub struct UniqueOptions {
680    pub rows: bool,
681    pub order: UniqueOrder,
682    pub occurrence: UniqueOccurrence,
683}
684
685/// Host-resident outputs returned by provider-backed `unique` operations.
686#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
687pub struct UniqueResult {
688    pub values: HostTensorOwned,
689    pub ia: HostTensorOwned,
690    pub ic: HostTensorOwned,
691}
692
693/// Ordering applied by provider-backed `union` operations.
694#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
695pub enum UnionOrder {
696    Sorted,
697    Stable,
698}
699
700/// Options controlling provider-backed `union` operations.
701#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
702pub struct UnionOptions {
703    pub rows: bool,
704    pub order: UnionOrder,
705}
706
707/// Host-resident outputs returned by provider-backed `union` operations.
708#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
709pub struct UnionResult {
710    pub values: HostTensorOwned,
711    pub ia: HostTensorOwned,
712    pub ib: HostTensorOwned,
713}
714
715/// Parameterisation of 2-D filters generated by `fspecial`.
716#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
717pub enum FspecialFilter {
718    Average {
719        rows: u32,
720        cols: u32,
721    },
722    Disk {
723        radius: f64,
724        size: u32,
725    },
726    Gaussian {
727        rows: u32,
728        cols: u32,
729        sigma: f64,
730    },
731    Laplacian {
732        alpha: f64,
733    },
734    Log {
735        rows: u32,
736        cols: u32,
737        sigma: f64,
738    },
739    Motion {
740        length: u32,
741        kernel_size: u32,
742        angle_degrees: f64,
743        oversample: u32,
744    },
745    Prewitt,
746    Sobel,
747    Unsharp {
748        alpha: f64,
749    },
750}
751
752/// Request dispatched to acceleration providers for `fspecial` kernels.
753#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
754pub struct FspecialRequest {
755    pub filter: FspecialFilter,
756}
757
758/// Padding strategy used by `imfilter`.
759#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
760pub enum ImfilterPadding {
761    Constant,
762    Replicate,
763    Symmetric,
764    Circular,
765}
766
767/// Output sizing mode used by `imfilter`.
768#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
769pub enum ImfilterShape {
770    Same,
771    Full,
772    Valid,
773}
774
775/// Correlation vs convolution behaviour for `imfilter`.
776#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
777pub enum ImfilterMode {
778    Correlation,
779    Convolution,
780}
781
782/// Options supplied to acceleration providers for `imfilter`.
783#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
784pub struct ImfilterOptions {
785    pub padding: ImfilterPadding,
786    pub constant_value: f64,
787    pub shape: ImfilterShape,
788    pub mode: ImfilterMode,
789}
790
791impl Default for ImfilterOptions {
792    fn default() -> Self {
793        Self {
794            padding: ImfilterPadding::Constant,
795            constant_value: 0.0,
796            shape: ImfilterShape::Same,
797            mode: ImfilterMode::Correlation,
798        }
799    }
800}
801
802/// Ordering applied by provider-backed `setdiff` operations.
803#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
804pub enum SetdiffOrder {
805    Sorted,
806    Stable,
807}
808
809/// Options controlling provider-backed `setdiff` operations.
810#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
811pub struct SetdiffOptions {
812    pub rows: bool,
813    pub order: SetdiffOrder,
814}
815
816/// Host-resident outputs returned by provider-backed `setdiff` operations.
817#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
818pub struct SetdiffResult {
819    pub values: HostTensorOwned,
820    pub ia: HostTensorOwned,
821}
822
823/// Options controlling provider-backed `ismember` operations.
824#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
825pub struct IsMemberOptions {
826    pub rows: bool,
827}
828
829/// Host-resident logical output returned by providers.
830#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
831pub struct HostLogicalOwned {
832    pub data: Vec<u8>,
833    pub shape: Vec<usize>,
834}
835
836/// Host-resident outputs returned by provider-backed `ismember` operations.
837#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
838pub struct IsMemberResult {
839    pub mask: HostLogicalOwned,
840    pub loc: HostTensorOwned,
841}
842
843#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
844pub enum ProviderConvMode {
845    Full,
846    Same,
847    Valid,
848}
849
850#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
851pub enum ProviderConvOrientation {
852    Row,
853    Column,
854}
855
856#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
857pub struct ProviderConv1dOptions {
858    pub mode: ProviderConvMode,
859    pub orientation: ProviderConvOrientation,
860}
861
862#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
863pub struct ProviderIirFilterOptions {
864    /// Zero-based dimension along which filtering should be applied.
865    pub dim: usize,
866    /// Optional initial conditions (state vector) residing on the device.
867    pub zi: Option<GpuTensorHandle>,
868}
869
870#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
871pub struct ProviderIirFilterResult {
872    /// Filtered output tensor, matching the input signal shape.
873    pub output: GpuTensorHandle,
874    /// Final conditions for the filter state (same shape as the requested `zi` layout).
875    pub final_state: Option<GpuTensorHandle>,
876}
877
878#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
879pub struct ProviderMoments2 {
880    pub mean: GpuTensorHandle,
881    pub ex2: GpuTensorHandle,
882}
883
884#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
885pub struct ProviderDispatchStats {
886    /// Number of GPU dispatches recorded for this category.
887    pub count: u64,
888    /// Accumulated wall-clock time of dispatches in nanoseconds (host measured).
889    pub total_wall_time_ns: u64,
890}
891
892#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
893pub struct ProviderFallbackStat {
894    pub reason: String,
895    pub count: u64,
896}
897
898#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
899pub struct ProviderTelemetry {
900    pub fused_elementwise: ProviderDispatchStats,
901    pub fused_reduction: ProviderDispatchStats,
902    pub matmul: ProviderDispatchStats,
903    pub linsolve: ProviderDispatchStats,
904    pub mldivide: ProviderDispatchStats,
905    pub mrdivide: ProviderDispatchStats,
906    pub upload_bytes: u64,
907    pub download_bytes: u64,
908    pub solve_fallbacks: Vec<ProviderFallbackStat>,
909    pub fusion_cache_hits: u64,
910    pub fusion_cache_misses: u64,
911    pub bind_group_cache_hits: u64,
912    pub bind_group_cache_misses: u64,
913    /// Optional per-layout bind group cache counters (layout tags and their hit/miss counts)
914    pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
915    /// Recent kernel launch metadata (bounded log; newest last)
916    pub kernel_launches: Vec<KernelLaunchTelemetry>,
917}
918
919#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
920pub struct BindGroupLayoutTelemetry {
921    pub tag: String,
922    pub hits: u64,
923    pub misses: u64,
924}
925
926#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
927pub struct KernelAttrTelemetry {
928    pub key: String,
929    pub value: u64,
930}
931
932#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
933pub struct KernelLaunchTelemetry {
934    pub kernel: String,
935    pub precision: Option<String>,
936    pub shape: Vec<KernelAttrTelemetry>,
937    pub tuning: Vec<KernelAttrTelemetry>,
938}
939
940pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
941pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
942
943fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
944    Box::pin(async move { Err(anyhow::anyhow!(message)) })
945}
946
947/// Device/provider interface that backends implement and register into the runtime layer
948pub trait AccelProvider: Send + Sync {
949    fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
950    fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
951    fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
952    fn device_info(&self) -> String;
953    fn device_id(&self) -> u32 {
954        0
955    }
956
957    /// Export a shared GPU context handle, allowing downstream systems (plotting, visualization)
958    /// to reuse the same device/queue without copying tensor data back to the host.
959    fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
960        None
961    }
962
963    /// Export a provider-owned WGPU buffer for zero-copy integrations.
964    #[cfg(feature = "wgpu")]
965    fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
966        let _ = _handle;
967        None
968    }
969
970    /// Gather elements from `source` at the provided zero-based linear `indices`, materialising
971    /// a dense tensor with the specified `output_shape`.
972    fn gather_linear(
973        &self,
974        _source: &GpuTensorHandle,
975        _indices: &[u32],
976        _output_shape: &[usize],
977    ) -> anyhow::Result<GpuTensorHandle> {
978        Err(anyhow::anyhow!("gather_linear not supported by provider"))
979    }
980
981    /// Scatter the contents of `values` into `target` at the provided zero-based linear `indices`.
982    ///
983    /// The provider must ensure `values.len() == indices.len()` and update `target` in place.
984    fn scatter_linear(
985        &self,
986        _target: &GpuTensorHandle,
987        _indices: &[u32],
988        _values: &GpuTensorHandle,
989    ) -> anyhow::Result<()> {
990        Err(anyhow::anyhow!("scatter_linear not supported by provider"))
991    }
992
993    /// Structured device information (optional to override). Default adapts from `device_info()`.
994    fn device_info_struct(&self) -> ApiDeviceInfo {
995        ApiDeviceInfo {
996            device_id: 0,
997            name: self.device_info(),
998            vendor: String::new(),
999            memory_bytes: None,
1000            backend: None,
1001        }
1002    }
1003
1004    fn precision(&self) -> ProviderPrecision {
1005        ProviderPrecision::F64
1006    }
1007
1008    /// Read a single scalar at linear index from a device tensor, returning it as f64.
1009    fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
1010        Err(anyhow::anyhow!("read_scalar not supported by provider"))
1011    }
1012
1013    /// Allocate a zero-initialised tensor with the provided shape on the device.
1014    fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1015        Err(anyhow::anyhow!("zeros not supported by provider"))
1016    }
1017
1018    /// Allocate a one-initialised tensor with the provided shape on the device.
1019    fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1020        Err(anyhow::anyhow!("ones not supported by provider"))
1021    }
1022
1023    /// Allocate a zero-initialised tensor matching the prototype tensor.
1024    fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1025        self.zeros(&prototype.shape)
1026    }
1027
1028    /// Allocate a tensor filled with a constant value on the device.
1029    fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
1030        if value == 0.0 {
1031            return self.zeros(shape);
1032        }
1033        if let Ok(base) = self.zeros(shape) {
1034            match self.scalar_add(&base, value) {
1035                Ok(out) => {
1036                    let _ = self.free(&base);
1037                    return Ok(out);
1038                }
1039                Err(_) => {
1040                    let _ = self.free(&base);
1041                }
1042            }
1043        }
1044        let len: usize = shape.iter().copied().product();
1045        let data = vec![value; len];
1046        let view = HostTensorView { data: &data, shape };
1047        self.upload(&view)
1048    }
1049
1050    /// Allocate a tensor filled with a constant value, matching a prototype's residency.
1051    fn fill_like(
1052        &self,
1053        prototype: &GpuTensorHandle,
1054        value: f64,
1055    ) -> anyhow::Result<GpuTensorHandle> {
1056        if value == 0.0 {
1057            return self.zeros_like(prototype);
1058        }
1059        if let Ok(base) = self.zeros_like(prototype) {
1060            match self.scalar_add(&base, value) {
1061                Ok(out) => {
1062                    let _ = self.free(&base);
1063                    return Ok(out);
1064                }
1065                Err(_) => {
1066                    let _ = self.free(&base);
1067                }
1068            }
1069        }
1070        self.fill(&prototype.shape, value)
1071    }
1072
1073    /// Allocate a one-initialised tensor matching the prototype tensor.
1074    fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1075        self.ones(&prototype.shape)
1076    }
1077
1078    /// Allocate an identity tensor with ones along the leading diagonal of the first two axes.
1079    fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1080        Err(anyhow::anyhow!("eye not supported by provider"))
1081    }
1082
1083    /// Allocate an identity tensor matching the prototype tensor's shape.
1084    fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1085        self.eye(&prototype.shape)
1086    }
1087
1088    /// Construct MATLAB-style coordinate grids from axis vectors.
1089    fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1090        Err(anyhow::anyhow!("meshgrid not supported by provider"))
1091    }
1092
1093    /// Construct a diagonal matrix from a vector-like tensor. `offset` matches MATLAB semantics.
1094    fn diag_from_vector(
1095        &self,
1096        _vector: &GpuTensorHandle,
1097        _offset: isize,
1098    ) -> anyhow::Result<GpuTensorHandle> {
1099        Err(anyhow::anyhow!(
1100            "diag_from_vector not supported by provider"
1101        ))
1102    }
1103
1104    /// Extract a diagonal from a matrix-like tensor. The result is always a column vector.
1105    fn diag_extract(
1106        &self,
1107        _matrix: &GpuTensorHandle,
1108        _offset: isize,
1109    ) -> anyhow::Result<GpuTensorHandle> {
1110        Err(anyhow::anyhow!("diag_extract not supported by provider"))
1111    }
1112
1113    /// Apply a lower-triangular mask to the first two dimensions of a tensor.
1114    fn tril<'a>(
1115        &'a self,
1116        _matrix: &'a GpuTensorHandle,
1117        _offset: isize,
1118    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1119        Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1120    }
1121
1122    /// Apply an upper-triangular mask to the first two dimensions of a tensor.
1123    fn triu<'a>(
1124        &'a self,
1125        _matrix: &'a GpuTensorHandle,
1126        _offset: isize,
1127    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1128        Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1129    }
1130
1131    /// Evaluate a polynomial expressed by `coefficients` at each element in `points`.
1132    fn polyval(
1133        &self,
1134        _coefficients: &GpuTensorHandle,
1135        _points: &GpuTensorHandle,
1136        _options: &ProviderPolyvalOptions,
1137    ) -> anyhow::Result<GpuTensorHandle> {
1138        Err(anyhow::anyhow!("polyval not supported by provider"))
1139    }
1140
1141    /// Fit a polynomial of degree `degree` to `(x, y)` samples. Optional weights must match `x`.
1142    fn polyfit<'a>(
1143        &'a self,
1144        _x: &'a GpuTensorHandle,
1145        _y: &'a GpuTensorHandle,
1146        _degree: usize,
1147        _weights: Option<&'a GpuTensorHandle>,
1148    ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1149        Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1150    }
1151
1152    /// Differentiate a polynomial represented as a vector of coefficients.
1153    fn polyder_single<'a>(
1154        &'a self,
1155        _polynomial: &'a GpuTensorHandle,
1156    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1157        Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1158    }
1159
1160    /// Apply the product rule to polynomials `p` and `q`.
1161    fn polyder_product<'a>(
1162        &'a self,
1163        _p: &'a GpuTensorHandle,
1164        _q: &'a GpuTensorHandle,
1165    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1166        Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1167    }
1168
1169    /// Apply the quotient rule to polynomials `u` and `v`.
1170    fn polyder_quotient<'a>(
1171        &'a self,
1172        _u: &'a GpuTensorHandle,
1173        _v: &'a GpuTensorHandle,
1174    ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1175        Box::pin(async move {
1176            Err(anyhow::anyhow!(
1177                "polyder_quotient not supported by provider"
1178            ))
1179        })
1180    }
1181
1182    /// Integrate a polynomial represented as a vector of coefficients and append a constant term.
1183    fn polyint(
1184        &self,
1185        _polynomial: &GpuTensorHandle,
1186        _constant: f64,
1187    ) -> anyhow::Result<GpuTensorHandle> {
1188        Err(anyhow::anyhow!("polyint not supported by provider"))
1189    }
1190
1191    /// Allocate a tensor filled with random values drawn from U(0, 1).
1192    fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1193        Err(anyhow::anyhow!("random_uniform not supported by provider"))
1194    }
1195
1196    /// Allocate a tensor filled with random values matching the prototype shape.
1197    fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1198        self.random_uniform(&prototype.shape)
1199    }
1200
1201    /// Allocate a tensor filled with standard normal (mean 0, stddev 1) random values.
1202    fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1203        Err(anyhow::anyhow!("random_normal not supported by provider"))
1204    }
1205
1206    /// Allocate a tensor of standard normal values matching a prototype's shape.
1207    fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1208        self.random_normal(&prototype.shape)
1209    }
1210
1211    /// Exponentially-distributed random values with mean `mu`.
1212    fn random_exponential(&self, _mu: f64, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1213        Err(anyhow::anyhow!(
1214            "random_exponential not supported by provider"
1215        ))
1216    }
1217
1218    /// Normal random values with mean `mu` and standard deviation `sigma`.
1219    fn random_normrnd(
1220        &self,
1221        _mu: f64,
1222        _sigma: f64,
1223        _shape: &[usize],
1224    ) -> anyhow::Result<GpuTensorHandle> {
1225        Err(anyhow::anyhow!("random_normrnd not supported by provider"))
1226    }
1227
1228    /// Uniform random values on the interval `[a, b)`.
1229    fn random_unifrnd(
1230        &self,
1231        _a: f64,
1232        _b: f64,
1233        _shape: &[usize],
1234    ) -> anyhow::Result<GpuTensorHandle> {
1235        Err(anyhow::anyhow!("random_unifrnd not supported by provider"))
1236    }
1237
1238    fn stochastic_evolution(
1239        &self,
1240        _state: &GpuTensorHandle,
1241        _drift: f64,
1242        _scale: f64,
1243        _steps: u32,
1244    ) -> anyhow::Result<GpuTensorHandle> {
1245        Err(anyhow::anyhow!(
1246            "stochastic_evolution not supported by provider"
1247        ))
1248    }
1249
1250    /// Set the provider RNG state to align with the host RNG.
1251    fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1252        Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1253    }
1254
1255    /// Generate a 2-D correlation kernel matching MATLAB's `fspecial` builtin.
1256    fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1257        Err(anyhow::anyhow!("fspecial not supported by provider"))
1258    }
1259
1260    /// Evaluate the `peaks` test surface on an n×n grid spanning [-3,3]×[-3,3].
1261    /// Returns the Z matrix (n×n) as a GPU tensor.
1262    fn peaks(&self, _n: usize) -> anyhow::Result<GpuTensorHandle> {
1263        Err(anyhow::anyhow!("peaks not supported by provider"))
1264    }
1265
1266    /// Evaluate the `peaks` formula element-wise on caller-supplied GPU coordinate tensors.
1267    /// X and Y must have the same shape. Returns a Z tensor of the same shape.
1268    fn peaks_xy(
1269        &self,
1270        _x: &GpuTensorHandle,
1271        _y: &GpuTensorHandle,
1272    ) -> anyhow::Result<GpuTensorHandle> {
1273        Err(anyhow::anyhow!("peaks_xy not supported by provider"))
1274    }
1275
1276    fn hann_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1277        Err(anyhow::anyhow!("hann_window not supported by provider"))
1278    }
1279
1280    fn hamming_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1281        Err(anyhow::anyhow!("hamming_window not supported by provider"))
1282    }
1283
1284    fn blackman_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1285        Err(anyhow::anyhow!("blackman_window not supported by provider"))
1286    }
1287
1288    /// Apply an N-D correlation/convolution with padding semantics matching MATLAB's `imfilter`.
1289    fn imfilter<'a>(
1290        &'a self,
1291        _image: &'a GpuTensorHandle,
1292        _kernel: &'a GpuTensorHandle,
1293        _options: &'a ImfilterOptions,
1294    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1295        unsupported_future("imfilter not supported by provider")
1296    }
1297
1298    /// Allocate a tensor filled with random integers over an inclusive range.
1299    fn random_integer_range(
1300        &self,
1301        _lower: i64,
1302        _upper: i64,
1303        _shape: &[usize],
1304    ) -> anyhow::Result<GpuTensorHandle> {
1305        Err(anyhow::anyhow!(
1306            "random_integer_range not supported by provider"
1307        ))
1308    }
1309
1310    /// Allocate a random integer tensor matching the prototype shape.
1311    fn random_integer_like(
1312        &self,
1313        prototype: &GpuTensorHandle,
1314        lower: i64,
1315        upper: i64,
1316    ) -> anyhow::Result<GpuTensorHandle> {
1317        self.random_integer_range(lower, upper, &prototype.shape)
1318    }
1319
1320    /// Allocate a random permutation of 1..=n, returning the first k elements.
1321    fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1322        Err(anyhow!("random_permutation not supported by provider"))
1323    }
1324
1325    /// Allocate a random permutation matching the prototype residency.
1326    fn random_permutation_like(
1327        &self,
1328        _prototype: &GpuTensorHandle,
1329        n: usize,
1330        k: usize,
1331    ) -> anyhow::Result<GpuTensorHandle> {
1332        self.random_permutation(n, k)
1333    }
1334
1335    /// Compute a covariance matrix across the columns of `matrix`.
1336    fn covariance<'a>(
1337        &'a self,
1338        _matrix: &'a GpuTensorHandle,
1339        _second: Option<&'a GpuTensorHandle>,
1340        _weights: Option<&'a GpuTensorHandle>,
1341        _options: &'a CovarianceOptions,
1342    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1343        unsupported_future("covariance not supported by provider")
1344    }
1345
1346    /// Compute a correlation coefficient matrix across the columns of `matrix`.
1347    fn corrcoef<'a>(
1348        &'a self,
1349        _matrix: &'a GpuTensorHandle,
1350        _options: &'a CorrcoefOptions,
1351    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1352        unsupported_future("corrcoef not supported by provider")
1353    }
1354
1355    // Optional operator hooks (default to unsupported)
1356    fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1357        Err(anyhow::anyhow!("linspace not supported by provider"))
1358    }
1359    fn elem_add<'a>(
1360        &'a self,
1361        _a: &'a GpuTensorHandle,
1362        _b: &'a GpuTensorHandle,
1363    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1364        unsupported_future("elem_add not supported by provider")
1365    }
1366    fn elem_mul<'a>(
1367        &'a self,
1368        _a: &'a GpuTensorHandle,
1369        _b: &'a GpuTensorHandle,
1370    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1371        unsupported_future("elem_mul not supported by provider")
1372    }
1373    fn elem_max<'a>(
1374        &'a self,
1375        _a: &'a GpuTensorHandle,
1376        _b: &'a GpuTensorHandle,
1377    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1378        unsupported_future("elem_max not supported by provider")
1379    }
1380    fn elem_min<'a>(
1381        &'a self,
1382        _a: &'a GpuTensorHandle,
1383        _b: &'a GpuTensorHandle,
1384    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1385        unsupported_future("elem_min not supported by provider")
1386    }
1387    fn elem_sub<'a>(
1388        &'a self,
1389        _a: &'a GpuTensorHandle,
1390        _b: &'a GpuTensorHandle,
1391    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1392        unsupported_future("elem_sub not supported by provider")
1393    }
1394    fn elem_div<'a>(
1395        &'a self,
1396        _a: &'a GpuTensorHandle,
1397        _b: &'a GpuTensorHandle,
1398    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1399        unsupported_future("elem_div not supported by provider")
1400    }
1401    fn elem_pow<'a>(
1402        &'a self,
1403        _a: &'a GpuTensorHandle,
1404        _b: &'a GpuTensorHandle,
1405    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1406        unsupported_future("elem_pow not supported by provider")
1407    }
1408
1409    fn elem_hypot<'a>(
1410        &'a self,
1411        _a: &'a GpuTensorHandle,
1412        _b: &'a GpuTensorHandle,
1413    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1414        unsupported_future("elem_hypot not supported by provider")
1415    }
1416    fn elem_ge<'a>(
1417        &'a self,
1418        _a: &'a GpuTensorHandle,
1419        _b: &'a GpuTensorHandle,
1420    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1421        unsupported_future("elem_ge not supported by provider")
1422    }
1423    fn elem_le<'a>(
1424        &'a self,
1425        _a: &'a GpuTensorHandle,
1426        _b: &'a GpuTensorHandle,
1427    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1428        unsupported_future("elem_le not supported by provider")
1429    }
1430    fn elem_lt<'a>(
1431        &'a self,
1432        _a: &'a GpuTensorHandle,
1433        _b: &'a GpuTensorHandle,
1434    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1435        unsupported_future("elem_lt not supported by provider")
1436    }
1437    fn elem_gt<'a>(
1438        &'a self,
1439        _a: &'a GpuTensorHandle,
1440        _b: &'a GpuTensorHandle,
1441    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1442        unsupported_future("elem_gt not supported by provider")
1443    }
1444    fn elem_eq<'a>(
1445        &'a self,
1446        _a: &'a GpuTensorHandle,
1447        _b: &'a GpuTensorHandle,
1448    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1449        unsupported_future("elem_eq not supported by provider")
1450    }
1451    fn elem_ne<'a>(
1452        &'a self,
1453        _a: &'a GpuTensorHandle,
1454        _b: &'a GpuTensorHandle,
1455    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1456        unsupported_future("elem_ne not supported by provider")
1457    }
1458    fn logical_and(
1459        &self,
1460        _a: &GpuTensorHandle,
1461        _b: &GpuTensorHandle,
1462    ) -> anyhow::Result<GpuTensorHandle> {
1463        Err(anyhow::anyhow!("logical_and not supported by provider"))
1464    }
1465    fn logical_or(
1466        &self,
1467        _a: &GpuTensorHandle,
1468        _b: &GpuTensorHandle,
1469    ) -> anyhow::Result<GpuTensorHandle> {
1470        Err(anyhow::anyhow!("logical_or not supported by provider"))
1471    }
1472    fn logical_xor(
1473        &self,
1474        _a: &GpuTensorHandle,
1475        _b: &GpuTensorHandle,
1476    ) -> anyhow::Result<GpuTensorHandle> {
1477        Err(anyhow::anyhow!("logical_xor not supported by provider"))
1478    }
1479    fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1480        Err(anyhow::anyhow!("logical_not not supported by provider"))
1481    }
1482    fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1483        Ok(handle_is_logical(a))
1484    }
1485    fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1486        Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1487    }
1488    fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1489        Err(anyhow::anyhow!(
1490            "logical_isfinite not supported by provider"
1491        ))
1492    }
1493    fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1494        Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1495    }
1496    fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1497        Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1498    }
1499    fn elem_atan2<'a>(
1500        &'a self,
1501        _y: &'a GpuTensorHandle,
1502        _x: &'a GpuTensorHandle,
1503    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1504        unsupported_future("elem_atan2 not supported by provider")
1505    }
1506    // Unary elementwise operations (optional)
1507    fn unary_sin<'a>(
1508        &'a self,
1509        _a: &'a GpuTensorHandle,
1510    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1511        unsupported_future("unary_sin not supported by provider")
1512    }
1513    fn unary_sinc<'a>(
1514        &'a self,
1515        _a: &'a GpuTensorHandle,
1516    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1517        unsupported_future("unary_sinc not supported by provider")
1518    }
1519    fn unary_gamma<'a>(
1520        &'a self,
1521        _a: &'a GpuTensorHandle,
1522    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1523        unsupported_future("unary_gamma not supported by provider")
1524    }
1525    fn unary_factorial<'a>(
1526        &'a self,
1527        _a: &'a GpuTensorHandle,
1528    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1529        unsupported_future("unary_factorial not supported by provider")
1530    }
1531    fn unary_asinh<'a>(
1532        &'a self,
1533        _a: &'a GpuTensorHandle,
1534    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1535        unsupported_future("unary_asinh not supported by provider")
1536    }
1537    fn unary_sinh<'a>(
1538        &'a self,
1539        _a: &'a GpuTensorHandle,
1540    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1541        unsupported_future("unary_sinh not supported by provider")
1542    }
1543    fn unary_cosh<'a>(
1544        &'a self,
1545        _a: &'a GpuTensorHandle,
1546    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1547        unsupported_future("unary_cosh not supported by provider")
1548    }
1549    fn unary_asin<'a>(
1550        &'a self,
1551        _a: &'a GpuTensorHandle,
1552    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1553        unsupported_future("unary_asin not supported by provider")
1554    }
1555    fn unary_acos<'a>(
1556        &'a self,
1557        _a: &'a GpuTensorHandle,
1558    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1559        unsupported_future("unary_acos not supported by provider")
1560    }
1561    fn unary_acosh<'a>(
1562        &'a self,
1563        _a: &'a GpuTensorHandle,
1564    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1565        unsupported_future("unary_acosh not supported by provider")
1566    }
1567    fn unary_tan<'a>(
1568        &'a self,
1569        _a: &'a GpuTensorHandle,
1570    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1571        unsupported_future("unary_tan not supported by provider")
1572    }
1573    fn unary_tanh<'a>(
1574        &'a self,
1575        _a: &'a GpuTensorHandle,
1576    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1577        unsupported_future("unary_tanh not supported by provider")
1578    }
1579    fn unary_atan<'a>(
1580        &'a self,
1581        _a: &'a GpuTensorHandle,
1582    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1583        unsupported_future("unary_atan not supported by provider")
1584    }
1585    fn unary_atanh<'a>(
1586        &'a self,
1587        _a: &'a GpuTensorHandle,
1588    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1589        unsupported_future("unary_atanh not supported by provider")
1590    }
1591    fn unary_ceil<'a>(
1592        &'a self,
1593        _a: &'a GpuTensorHandle,
1594    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1595        unsupported_future("unary_ceil not supported by provider")
1596    }
1597    fn unary_floor<'a>(
1598        &'a self,
1599        _a: &'a GpuTensorHandle,
1600    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1601        unsupported_future("unary_floor not supported by provider")
1602    }
1603    fn unary_round<'a>(
1604        &'a self,
1605        _a: &'a GpuTensorHandle,
1606    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1607        unsupported_future("unary_round not supported by provider")
1608    }
1609    fn unary_fix<'a>(
1610        &'a self,
1611        _a: &'a GpuTensorHandle,
1612    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1613        unsupported_future("unary_fix not supported by provider")
1614    }
1615    fn unary_cos<'a>(
1616        &'a self,
1617        _a: &'a GpuTensorHandle,
1618    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1619        unsupported_future("unary_cos not supported by provider")
1620    }
1621    fn unary_angle<'a>(
1622        &'a self,
1623        _a: &'a GpuTensorHandle,
1624    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1625        unsupported_future("unary_angle not supported by provider")
1626    }
1627    fn unary_imag<'a>(
1628        &'a self,
1629        _a: &'a GpuTensorHandle,
1630    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1631        unsupported_future("unary_imag not supported by provider")
1632    }
1633    fn unary_real<'a>(
1634        &'a self,
1635        _a: &'a GpuTensorHandle,
1636    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1637        unsupported_future("unary_real not supported by provider")
1638    }
1639    fn unary_conj<'a>(
1640        &'a self,
1641        _a: &'a GpuTensorHandle,
1642    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1643        unsupported_future("unary_conj not supported by provider")
1644    }
1645    fn unary_abs<'a>(
1646        &'a self,
1647        _a: &'a GpuTensorHandle,
1648    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1649        unsupported_future("unary_abs not supported by provider")
1650    }
1651    fn unary_sign<'a>(
1652        &'a self,
1653        _a: &'a GpuTensorHandle,
1654    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1655        unsupported_future("unary_sign not supported by provider")
1656    }
1657    fn unary_exp<'a>(
1658        &'a self,
1659        _a: &'a GpuTensorHandle,
1660    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1661        unsupported_future("unary_exp not supported by provider")
1662    }
1663    fn unary_expm1<'a>(
1664        &'a self,
1665        _a: &'a GpuTensorHandle,
1666    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1667        unsupported_future("unary_expm1 not supported by provider")
1668    }
1669    fn unary_log<'a>(
1670        &'a self,
1671        _a: &'a GpuTensorHandle,
1672    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1673        unsupported_future("unary_log not supported by provider")
1674    }
1675    fn unary_log2<'a>(
1676        &'a self,
1677        _a: &'a GpuTensorHandle,
1678    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1679        unsupported_future("unary_log2 not supported by provider")
1680    }
1681    fn unary_log10<'a>(
1682        &'a self,
1683        _a: &'a GpuTensorHandle,
1684    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1685        unsupported_future("unary_log10 not supported by provider")
1686    }
1687    fn unary_log1p<'a>(
1688        &'a self,
1689        _a: &'a GpuTensorHandle,
1690    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1691        unsupported_future("unary_log1p not supported by provider")
1692    }
1693    fn unary_sqrt<'a>(
1694        &'a self,
1695        _a: &'a GpuTensorHandle,
1696    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1697        unsupported_future("unary_sqrt not supported by provider")
1698    }
1699    fn unary_double<'a>(
1700        &'a self,
1701        _a: &'a GpuTensorHandle,
1702    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1703        unsupported_future("unary_double not supported by provider")
1704    }
1705    fn unary_single<'a>(
1706        &'a self,
1707        _a: &'a GpuTensorHandle,
1708    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1709        unsupported_future("unary_single not supported by provider")
1710    }
1711    fn unary_pow2<'a>(
1712        &'a self,
1713        _a: &'a GpuTensorHandle,
1714    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1715        unsupported_future("unary_pow2 not supported by provider")
1716    }
1717    fn unary_nextpow2<'a>(
1718        &'a self,
1719        _a: &'a GpuTensorHandle,
1720    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1721        unsupported_future("unary_nextpow2 not supported by provider")
1722    }
1723    fn pow2_scale(
1724        &self,
1725        _mantissa: &GpuTensorHandle,
1726        _exponent: &GpuTensorHandle,
1727    ) -> anyhow::Result<GpuTensorHandle> {
1728        Err(anyhow::anyhow!("pow2_scale not supported by provider"))
1729    }
1730    // Left-scalar operations (broadcast with scalar on the left)
1731    fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1732        Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
1733    }
1734    fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1735        Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
1736    }
1737    // Scalar operations: apply op with scalar right-hand side (broadcast over a)
1738    fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1739        Err(anyhow::anyhow!("scalar_add not supported by provider"))
1740    }
1741    fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1742        Err(anyhow::anyhow!("scalar_sub not supported by provider"))
1743    }
1744    fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1745        Err(anyhow::anyhow!("scalar_mul not supported by provider"))
1746    }
1747    fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1748        Err(anyhow::anyhow!("scalar_max not supported by provider"))
1749    }
1750    fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1751        Err(anyhow::anyhow!("scalar_min not supported by provider"))
1752    }
1753    fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1754        Err(anyhow::anyhow!("scalar_div not supported by provider"))
1755    }
1756    fn sort_dim<'a>(
1757        &'a self,
1758        _a: &'a GpuTensorHandle,
1759        _dim: usize,
1760        _order: SortOrder,
1761        _comparison: SortComparison,
1762    ) -> AccelProviderFuture<'a, SortResult> {
1763        unsupported_future("sort_dim not supported by provider")
1764    }
1765    fn sort_rows<'a>(
1766        &'a self,
1767        _a: &'a GpuTensorHandle,
1768        _columns: &'a [SortRowsColumnSpec],
1769        _comparison: SortComparison,
1770    ) -> AccelProviderFuture<'a, SortResult> {
1771        unsupported_future("sort_rows not supported by provider")
1772    }
1773    fn matmul<'a>(
1774        &'a self,
1775        _a: &'a GpuTensorHandle,
1776        _b: &'a GpuTensorHandle,
1777    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1778        unsupported_future("matmul not supported by provider")
1779    }
1780
1781    fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1782        Err(anyhow::anyhow!("syrk not supported by provider"))
1783    }
1784    fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
1785        Err(anyhow::anyhow!("pagefun not supported by provider"))
1786    }
1787
1788    /// Optional: matrix multiplication with an epilogue applied before store.
1789    ///
1790    /// The default implementation falls back to `matmul` when the epilogue is effectively a no-op
1791    /// (alpha=1, beta=0, no row/col scales), and otherwise returns `Err`.
1792    fn matmul_epilogue<'a>(
1793        &'a self,
1794        a: &'a GpuTensorHandle,
1795        b: &'a GpuTensorHandle,
1796        epilogue: &'a MatmulEpilogue,
1797    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1798        Box::pin(async move {
1799            if epilogue.is_noop() {
1800                return self.matmul(a, b).await;
1801            }
1802            Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
1803        })
1804    }
1805    fn image_normalize<'a>(
1806        &'a self,
1807        _input: &'a GpuTensorHandle,
1808        _desc: &'a ImageNormalizeDescriptor,
1809    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1810        unsupported_future("image_normalize fusion not supported by provider")
1811    }
1812    fn matmul_power_step<'a>(
1813        &'a self,
1814        _lhs: &'a GpuTensorHandle,
1815        _rhs: &'a GpuTensorHandle,
1816        _epilogue: &'a PowerStepEpilogue,
1817    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1818        unsupported_future("matmul_power_step normalization not supported by provider")
1819    }
1820    fn linsolve<'a>(
1821        &'a self,
1822        _lhs: &'a GpuTensorHandle,
1823        _rhs: &'a GpuTensorHandle,
1824        _options: &'a ProviderLinsolveOptions,
1825    ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
1826        unsupported_future("linsolve not supported by provider")
1827    }
1828    fn inv<'a>(
1829        &'a self,
1830        _matrix: &'a GpuTensorHandle,
1831        _options: ProviderInvOptions,
1832    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1833        unsupported_future("inv not supported by provider")
1834    }
1835    fn pinv<'a>(
1836        &'a self,
1837        _matrix: &'a GpuTensorHandle,
1838        _options: ProviderPinvOptions,
1839    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1840        unsupported_future("pinv not supported by provider")
1841    }
1842    fn cond<'a>(
1843        &'a self,
1844        _matrix: &'a GpuTensorHandle,
1845        _norm: ProviderCondNorm,
1846    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1847        Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
1848    }
1849    fn norm<'a>(
1850        &'a self,
1851        _tensor: &'a GpuTensorHandle,
1852        _order: ProviderNormOrder,
1853    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1854        Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
1855    }
1856    fn rank<'a>(
1857        &'a self,
1858        _matrix: &'a GpuTensorHandle,
1859        _tolerance: Option<f64>,
1860    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1861        Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
1862    }
1863    fn rcond<'a>(
1864        &'a self,
1865        _matrix: &'a GpuTensorHandle,
1866    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1867        Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
1868    }
1869    fn mldivide<'a>(
1870        &'a self,
1871        _lhs: &'a GpuTensorHandle,
1872        _rhs: &'a GpuTensorHandle,
1873    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1874        Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
1875    }
1876    fn mrdivide<'a>(
1877        &'a self,
1878        _lhs: &'a GpuTensorHandle,
1879        _rhs: &'a GpuTensorHandle,
1880    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1881        Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
1882    }
1883    fn eig<'a>(
1884        &'a self,
1885        _a: &'a GpuTensorHandle,
1886        _compute_left: bool,
1887    ) -> AccelProviderFuture<'a, ProviderEigResult> {
1888        Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
1889    }
1890    fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
1891        Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
1892    }
1893
1894    fn chol<'a>(
1895        &'a self,
1896        _a: &'a GpuTensorHandle,
1897        _lower: bool,
1898    ) -> AccelProviderFuture<'a, ProviderCholResult> {
1899        Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
1900    }
1901    fn qr<'a>(
1902        &'a self,
1903        _a: &'a GpuTensorHandle,
1904        _options: ProviderQrOptions,
1905    ) -> AccelProviderFuture<'a, ProviderQrResult> {
1906        Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
1907    }
1908    fn take_matmul_sources(
1909        &self,
1910        _product: &GpuTensorHandle,
1911    ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
1912        None
1913    }
1914    fn qr_power_iter<'a>(
1915        &'a self,
1916        product: &'a GpuTensorHandle,
1917        _product_lhs: Option<&'a GpuTensorHandle>,
1918        q_handle: &'a GpuTensorHandle,
1919        options: &'a ProviderQrOptions,
1920    ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
1921        let _ = (product, q_handle, options);
1922        Box::pin(async move { Ok(None) })
1923    }
1924    fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1925        Err(anyhow::anyhow!("transpose not supported by provider"))
1926    }
1927    fn conv1d(
1928        &self,
1929        _signal: &GpuTensorHandle,
1930        _kernel: &GpuTensorHandle,
1931        _options: ProviderConv1dOptions,
1932    ) -> anyhow::Result<GpuTensorHandle> {
1933        Err(anyhow::anyhow!("conv1d not supported by provider"))
1934    }
1935    fn conv2d(
1936        &self,
1937        _signal: &GpuTensorHandle,
1938        _kernel: &GpuTensorHandle,
1939        _mode: ProviderConvMode,
1940    ) -> anyhow::Result<GpuTensorHandle> {
1941        Err(anyhow::anyhow!("conv2d not supported by provider"))
1942    }
1943    fn iir_filter<'a>(
1944        &'a self,
1945        _b: &'a GpuTensorHandle,
1946        _a: &'a GpuTensorHandle,
1947        _x: &'a GpuTensorHandle,
1948        _options: ProviderIirFilterOptions,
1949    ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
1950        Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
1951    }
1952    /// Reorder tensor dimensions according to `order`, expressed as zero-based indices.
1953    fn permute(
1954        &self,
1955        _handle: &GpuTensorHandle,
1956        _order: &[usize],
1957    ) -> anyhow::Result<GpuTensorHandle> {
1958        Err(anyhow::anyhow!("permute not supported by provider"))
1959    }
1960    fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1961        Err(anyhow::anyhow!("flip not supported by provider"))
1962    }
1963    fn circshift(
1964        &self,
1965        _handle: &GpuTensorHandle,
1966        _shifts: &[isize],
1967    ) -> anyhow::Result<GpuTensorHandle> {
1968        Err(anyhow::anyhow!("circshift not supported by provider"))
1969    }
1970    fn diff_dim(
1971        &self,
1972        _handle: &GpuTensorHandle,
1973        _order: usize,
1974        _dim: usize,
1975    ) -> anyhow::Result<GpuTensorHandle> {
1976        Err(anyhow::anyhow!("diff_dim not supported by provider"))
1977    }
1978    fn gradient_dim(
1979        &self,
1980        _handle: &GpuTensorHandle,
1981        _dim: usize,
1982        _spacing: f64,
1983    ) -> anyhow::Result<GpuTensorHandle> {
1984        Err(anyhow::anyhow!("gradient_dim not supported by provider"))
1985    }
1986    /// Perform an in-place FFT along a zero-based dimension, optionally padding/truncating to `len`.
1987    fn fft_dim<'a>(
1988        &'a self,
1989        _handle: &'a GpuTensorHandle,
1990        _len: Option<usize>,
1991        _dim: usize,
1992    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1993        unsupported_future("fft_dim not supported by provider")
1994    }
1995    fn ifft_dim<'a>(
1996        &'a self,
1997        _handle: &'a GpuTensorHandle,
1998        _len: Option<usize>,
1999        _dim: usize,
2000    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2001        unsupported_future("ifft_dim not supported by provider")
2002    }
2003    fn fft_extract_real<'a>(
2004        &'a self,
2005        _handle: &'a GpuTensorHandle,
2006    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2007        unsupported_future("fft_extract_real not supported by provider")
2008    }
2009    fn unique<'a>(
2010        &'a self,
2011        _handle: &'a GpuTensorHandle,
2012        _options: &'a UniqueOptions,
2013    ) -> AccelProviderFuture<'a, UniqueResult> {
2014        Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
2015    }
2016    fn union<'a>(
2017        &'a self,
2018        _a: &'a GpuTensorHandle,
2019        _b: &'a GpuTensorHandle,
2020        _options: &'a UnionOptions,
2021    ) -> AccelProviderFuture<'a, UnionResult> {
2022        Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
2023    }
2024    fn setdiff<'a>(
2025        &'a self,
2026        _a: &'a GpuTensorHandle,
2027        _b: &'a GpuTensorHandle,
2028        _options: &'a SetdiffOptions,
2029    ) -> AccelProviderFuture<'a, SetdiffResult> {
2030        Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
2031    }
2032    fn ismember<'a>(
2033        &'a self,
2034        _a: &'a GpuTensorHandle,
2035        _b: &'a GpuTensorHandle,
2036        _options: &'a IsMemberOptions,
2037    ) -> AccelProviderFuture<'a, IsMemberResult> {
2038        Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
2039    }
2040    fn reshape(
2041        &self,
2042        handle: &GpuTensorHandle,
2043        new_shape: &[usize],
2044    ) -> anyhow::Result<GpuTensorHandle> {
2045        let mut updated = handle.clone();
2046        updated.shape = new_shape.to_vec();
2047        Ok(updated)
2048    }
2049    /// Concatenate the provided tensors along the 1-based dimension `dim`.
2050    fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
2051        Err(anyhow::anyhow!("cat not supported by provider"))
2052    }
2053    fn repmat(
2054        &self,
2055        _handle: &GpuTensorHandle,
2056        _reps: &[usize],
2057    ) -> anyhow::Result<GpuTensorHandle> {
2058        Err(anyhow::anyhow!("repmat not supported by provider"))
2059    }
2060    /// Compute the Kronecker product of two tensors, matching MATLAB semantics.
2061    fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2062        Err(anyhow::anyhow!("kron not supported by provider"))
2063    }
2064    /// Compute the cross product of 3-element vectors along a matching dimension.
2065    fn cross(
2066        &self,
2067        _lhs: &GpuTensorHandle,
2068        _rhs: &GpuTensorHandle,
2069        _dim: Option<usize>,
2070    ) -> anyhow::Result<GpuTensorHandle> {
2071        Err(anyhow::anyhow!("cross not supported by provider"))
2072    }
2073    fn reduce_sum<'a>(
2074        &'a self,
2075        _a: &'a GpuTensorHandle,
2076    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2077        unsupported_future("reduce_sum not supported by provider")
2078    }
2079    fn reduce_sum_dim<'a>(
2080        &'a self,
2081        _a: &'a GpuTensorHandle,
2082        _dim: usize,
2083    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2084        unsupported_future("reduce_sum_dim not supported by provider")
2085    }
2086    fn dot<'a>(
2087        &'a self,
2088        _lhs: &'a GpuTensorHandle,
2089        _rhs: &'a GpuTensorHandle,
2090        _dim: Option<usize>,
2091    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2092        unsupported_future("dot not supported by provider")
2093    }
2094    fn reduce_nnz<'a>(
2095        &'a self,
2096        _a: &'a GpuTensorHandle,
2097    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2098        unsupported_future("reduce_nnz not supported by provider")
2099    }
2100    fn reduce_nnz_dim<'a>(
2101        &'a self,
2102        _a: &'a GpuTensorHandle,
2103        _dim: usize,
2104    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2105        unsupported_future("reduce_nnz_dim not supported by provider")
2106    }
2107    fn reduce_prod<'a>(
2108        &'a self,
2109        _a: &'a GpuTensorHandle,
2110    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2111        unsupported_future("reduce_prod not supported by provider")
2112    }
2113    fn reduce_prod_dim<'a>(
2114        &'a self,
2115        _a: &'a GpuTensorHandle,
2116        _dim: usize,
2117    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2118        unsupported_future("reduce_prod_dim not supported by provider")
2119    }
2120    fn reduce_mean<'a>(
2121        &'a self,
2122        _a: &'a GpuTensorHandle,
2123    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2124        unsupported_future("reduce_mean not supported by provider")
2125    }
2126    /// Reduce mean across multiple zero-based dimensions in one device pass.
2127    fn reduce_mean_nd<'a>(
2128        &'a self,
2129        _a: &'a GpuTensorHandle,
2130        _dims_zero_based: &'a [usize],
2131    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2132        unsupported_future("reduce_mean_nd not supported by provider")
2133    }
2134    /// Reduce moments across multiple zero-based dimensions in one device pass.
2135    /// Returns mean (E[x]) and mean of squares (E[x^2]).
2136    fn reduce_moments_nd<'a>(
2137        &'a self,
2138        _a: &'a GpuTensorHandle,
2139        _dims_zero_based: &'a [usize],
2140    ) -> AccelProviderFuture<'a, ProviderMoments2> {
2141        unsupported_future("reduce_moments_nd not supported by provider")
2142    }
2143    fn reduce_mean_dim<'a>(
2144        &'a self,
2145        _a: &'a GpuTensorHandle,
2146        _dim: usize,
2147    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2148        unsupported_future("reduce_mean_dim not supported by provider")
2149    }
2150    fn reduce_std<'a>(
2151        &'a self,
2152        _a: &'a GpuTensorHandle,
2153        _normalization: ProviderStdNormalization,
2154        _nan_mode: ProviderNanMode,
2155    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2156        unsupported_future("reduce_std not supported by provider")
2157    }
2158    fn reduce_std_dim<'a>(
2159        &'a self,
2160        _a: &'a GpuTensorHandle,
2161        _dim: usize,
2162        _normalization: ProviderStdNormalization,
2163        _nan_mode: ProviderNanMode,
2164    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2165        unsupported_future("reduce_std_dim not supported by provider")
2166    }
2167    fn reduce_any<'a>(
2168        &'a self,
2169        _a: &'a GpuTensorHandle,
2170        _omit_nan: bool,
2171    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2172        unsupported_future("reduce_any not supported by provider")
2173    }
2174    fn reduce_any_dim<'a>(
2175        &'a self,
2176        _a: &'a GpuTensorHandle,
2177        _dim: usize,
2178        _omit_nan: bool,
2179    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2180        unsupported_future("reduce_any_dim not supported by provider")
2181    }
2182    fn reduce_all<'a>(
2183        &'a self,
2184        _a: &'a GpuTensorHandle,
2185        _omit_nan: bool,
2186    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2187        unsupported_future("reduce_all not supported by provider")
2188    }
2189    fn reduce_all_dim<'a>(
2190        &'a self,
2191        _a: &'a GpuTensorHandle,
2192        _dim: usize,
2193        _omit_nan: bool,
2194    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2195        unsupported_future("reduce_all_dim not supported by provider")
2196    }
2197    fn reduce_median<'a>(
2198        &'a self,
2199        _a: &'a GpuTensorHandle,
2200    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2201        unsupported_future("reduce_median not supported by provider")
2202    }
2203    fn reduce_median_dim<'a>(
2204        &'a self,
2205        _a: &'a GpuTensorHandle,
2206        _dim: usize,
2207    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2208        unsupported_future("reduce_median_dim not supported by provider")
2209    }
2210    fn reduce_min<'a>(
2211        &'a self,
2212        _a: &'a GpuTensorHandle,
2213    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2214        unsupported_future("reduce_min not supported by provider")
2215    }
2216    fn reduce_min_dim<'a>(
2217        &'a self,
2218        _a: &'a GpuTensorHandle,
2219        _dim: usize,
2220    ) -> AccelProviderFuture<'a, ReduceDimResult> {
2221        unsupported_future("reduce_min_dim not supported by provider")
2222    }
2223    fn reduce_max<'a>(
2224        &'a self,
2225        _a: &'a GpuTensorHandle,
2226    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2227        unsupported_future("reduce_max not supported by provider")
2228    }
2229    fn reduce_max_dim<'a>(
2230        &'a self,
2231        _a: &'a GpuTensorHandle,
2232        _dim: usize,
2233    ) -> AccelProviderFuture<'a, ReduceDimResult> {
2234        unsupported_future("reduce_max_dim not supported by provider")
2235    }
2236    fn cumsum_scan(
2237        &self,
2238        _input: &GpuTensorHandle,
2239        _dim: usize,
2240        _direction: ProviderScanDirection,
2241        _nan_mode: ProviderNanMode,
2242    ) -> anyhow::Result<GpuTensorHandle> {
2243        Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2244    }
2245    fn cumprod_scan(
2246        &self,
2247        _input: &GpuTensorHandle,
2248        _dim: usize,
2249        _direction: ProviderScanDirection,
2250        _nan_mode: ProviderNanMode,
2251    ) -> anyhow::Result<GpuTensorHandle> {
2252        Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2253    }
2254    fn cummin_scan(
2255        &self,
2256        _input: &GpuTensorHandle,
2257        _dim: usize,
2258        _direction: ProviderScanDirection,
2259        _nan_mode: ProviderNanMode,
2260    ) -> anyhow::Result<ProviderCumminResult> {
2261        Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2262    }
2263    fn cummax_scan(
2264        &self,
2265        _input: &GpuTensorHandle,
2266        _dim: usize,
2267        _direction: ProviderScanDirection,
2268        _nan_mode: ProviderNanMode,
2269    ) -> anyhow::Result<ProviderCummaxResult> {
2270        Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2271    }
2272
2273    fn find(
2274        &self,
2275        _a: &GpuTensorHandle,
2276        _limit: Option<usize>,
2277        _direction: FindDirection,
2278    ) -> anyhow::Result<ProviderFindResult> {
2279        Err(anyhow::anyhow!("find not supported by provider"))
2280    }
2281
2282    fn fused_elementwise(
2283        &self,
2284        _shader: &str,
2285        _inputs: &[GpuTensorHandle],
2286        _output_shape: &[usize],
2287        _len: usize,
2288    ) -> anyhow::Result<GpuTensorHandle> {
2289        Err(anyhow::anyhow!(
2290            "fused_elementwise not supported by provider"
2291        ))
2292    }
2293
2294    /// Execute a single fused elementwise kernel that writes `num_outputs` output buffers in one
2295    /// dispatch. The shader is expected to declare `output0`, `output1`, … `output{N-1}` storage
2296    /// bindings (at binding indices `inputs.len()` through `inputs.len() + num_outputs - 1`) and a
2297    /// uniform `params` binding at `inputs.len() + num_outputs`.
2298    ///
2299    /// Providers that do not override this method fall back to calling `fused_elementwise` once
2300    /// per output, which preserves correctness at the cost of the O(N²) dispatch overhead this
2301    /// method is designed to eliminate.
2302    fn fused_elementwise_multi(
2303        &self,
2304        _shader: &str,
2305        _inputs: &[GpuTensorHandle],
2306        _output_shape: &[usize],
2307        _len: usize,
2308        _num_outputs: usize,
2309    ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2310        Err(anyhow::anyhow!(
2311            "fused_elementwise_multi not supported by provider"
2312        ))
2313    }
2314
2315    /// Build a numeric tensor where NaNs in `a` are replaced with 0.0 (device side).
2316    fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2317        Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2318    }
2319
2320    /// Build a numeric mask tensor with 1.0 where value is not NaN and 0.0 where value is NaN.
2321    fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2322        Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2323    }
2324
2325    /// Generic fused reduction entrypoint.
2326    ///
2327    /// The shader is expected to implement a column-major reduction across `reduce_len` with
2328    /// `num_slices` independent slices (e.g., columns). Providers should create a uniform buffer
2329    /// compatible with the expected `Params/MParams` struct in the shader and dispatch
2330    /// `num_slices` workgroups with `workgroup_size` threads, or an equivalent strategy.
2331    #[allow(clippy::too_many_arguments)]
2332    fn fused_reduction(
2333        &self,
2334        _shader: &str,
2335        _inputs: &[GpuTensorHandle],
2336        _output_shape: &[usize],
2337        _reduce_len: usize,
2338        _num_slices: usize,
2339        _workgroup_size: u32,
2340        _flavor: ReductionFlavor,
2341    ) -> anyhow::Result<GpuTensorHandle> {
2342        Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2343    }
2344
2345    /// Optionally pre-compile commonly used pipelines to amortize first-dispatch costs.
2346    fn warmup(&self) {}
2347
2348    /// Returns (cache_hits, cache_misses) for fused pipeline cache, if supported.
2349    fn fused_cache_counters(&self) -> (u64, u64) {
2350        (0, 0)
2351    }
2352
2353    /// Returns the duration of the last provider warmup in milliseconds, if known.
2354    fn last_warmup_millis(&self) -> Option<u64> {
2355        None
2356    }
2357
2358    /// Returns a snapshot of provider telemetry counters if supported.
2359    fn telemetry_snapshot(&self) -> ProviderTelemetry {
2360        let (hits, misses) = self.fused_cache_counters();
2361        ProviderTelemetry {
2362            fused_elementwise: ProviderDispatchStats::default(),
2363            fused_reduction: ProviderDispatchStats::default(),
2364            matmul: ProviderDispatchStats::default(),
2365            linsolve: ProviderDispatchStats::default(),
2366            mldivide: ProviderDispatchStats::default(),
2367            mrdivide: ProviderDispatchStats::default(),
2368            upload_bytes: 0,
2369            download_bytes: 0,
2370            solve_fallbacks: Vec::new(),
2371            fusion_cache_hits: hits,
2372            fusion_cache_misses: misses,
2373            bind_group_cache_hits: 0,
2374            bind_group_cache_misses: 0,
2375            bind_group_cache_by_layout: None,
2376            kernel_launches: Vec::new(),
2377        }
2378    }
2379
2380    /// Reset all telemetry counters maintained by the provider, if supported.
2381    fn reset_telemetry(&self) {}
2382
2383    /// Default reduction workgroup size the provider prefers.
2384    fn default_reduction_workgroup_size(&self) -> u32 {
2385        256
2386    }
2387
2388    /// Threshold above which provider will prefer two-pass reduction.
2389    fn two_pass_threshold(&self) -> usize {
2390        1024
2391    }
2392
2393    /// Current two-pass mode preference (auto/forced on/off).
2394    fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2395        ReductionTwoPassMode::Auto
2396    }
2397
2398    /// Fast-path: write a GPU column in a matrix from a GPU vector, returning a new handle.
2399    /// Expected: `values.shape == [rows, 1]` (or `[rows]`) and `col_index < cols`.
2400    fn scatter_column(
2401        &self,
2402        _matrix: &GpuTensorHandle,
2403        _col_index: usize,
2404        _values: &GpuTensorHandle,
2405    ) -> anyhow::Result<GpuTensorHandle> {
2406        Err(anyhow::anyhow!("scatter_column not supported by provider"))
2407    }
2408
2409    /// Fast-path: write a GPU row in a matrix from a GPU vector, returning a new handle.
2410    /// Expected: `values.shape == [1, cols]` (or `[cols]`) and `row_index < rows`.
2411    fn scatter_row(
2412        &self,
2413        _matrix: &GpuTensorHandle,
2414        _row_index: usize,
2415        _values: &GpuTensorHandle,
2416    ) -> anyhow::Result<GpuTensorHandle> {
2417        Err(anyhow::anyhow!("scatter_row not supported by provider"))
2418    }
2419
2420    fn sub2ind(
2421        &self,
2422        _dims: &[usize],
2423        _strides: &[usize],
2424        _inputs: &[&GpuTensorHandle],
2425        _scalar_mask: &[bool],
2426        _len: usize,
2427        _output_shape: &[usize],
2428    ) -> anyhow::Result<GpuTensorHandle> {
2429        Err(anyhow::anyhow!("sub2ind not supported by provider"))
2430    }
2431
2432    /// Returns true if the provider offers a device-side `ind2sub` implementation.
2433    fn supports_ind2sub(&self) -> bool {
2434        false
2435    }
2436
2437    /// Convert linear indices into per-dimension subscripts on the device.
2438    fn ind2sub(
2439        &self,
2440        _dims: &[usize],
2441        _strides: &[usize],
2442        _indices: &GpuTensorHandle,
2443        _total: usize,
2444        _len: usize,
2445        _output_shape: &[usize],
2446    ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2447        Err(anyhow::anyhow!("ind2sub not supported by provider"))
2448    }
2449
2450    /// Determine if a matrix is symmetric (or skew-symmetric) without gathering it to the host.
2451    fn issymmetric(
2452        &self,
2453        _matrix: &GpuTensorHandle,
2454        _kind: ProviderSymmetryKind,
2455        _tolerance: f64,
2456    ) -> anyhow::Result<bool> {
2457        Err(anyhow::anyhow!(
2458            "issymmetric predicate not supported by provider"
2459        ))
2460    }
2461
2462    /// Determine if a matrix is Hermitian (or skew-Hermitian) without gathering it to the host.
2463    fn ishermitian<'a>(
2464        &'a self,
2465        _matrix: &'a GpuTensorHandle,
2466        _kind: ProviderHermitianKind,
2467        _tolerance: f64,
2468    ) -> AccelProviderFuture<'a, bool> {
2469        Box::pin(async move {
2470            Err(anyhow::anyhow!(
2471                "ishermitian predicate not supported by provider"
2472            ))
2473        })
2474    }
2475
2476    /// Inspect the bandwidth of a matrix without gathering it back to the host.
2477    fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2478        Err(anyhow::anyhow!("bandwidth not supported by provider"))
2479    }
2480
2481    /// Compute the symmetric reverse Cuthill-McKee permutation for the matrix.
2482    ///
2483    /// Implementations may execute on the device or gather to the host. The permutation should be
2484    /// returned as zero-based indices.
2485    fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2486        Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2487    }
2488}
2489
2490static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2491    Lazy::new(|| RwLock::new(None));
2492static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2493    Lazy::new(|| RwLock::new(HashMap::new()));
2494static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2495
2496#[cfg(not(target_arch = "wasm32"))]
2497thread_local! {
2498    static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2499}
2500
2501#[cfg(target_arch = "wasm32")]
2502static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2503    Lazy::new(|| Mutex::new(None));
2504
2505#[cfg(not(target_arch = "wasm32"))]
2506fn replace_thread_provider(
2507    provider: Option<&'static dyn AccelProvider>,
2508) -> Option<&'static dyn AccelProvider> {
2509    THREAD_PROVIDER.with(|cell| {
2510        let prev = cell.get();
2511        cell.set(provider);
2512        prev
2513    })
2514}
2515
2516#[cfg(target_arch = "wasm32")]
2517fn replace_thread_provider(
2518    provider: Option<&'static dyn AccelProvider>,
2519) -> Option<&'static dyn AccelProvider> {
2520    let mut slot = WASM_THREAD_PROVIDER
2521        .lock()
2522        .expect("wasm provider mutex poisoned");
2523    let prev = *slot;
2524    *slot = provider;
2525    prev
2526}
2527
2528#[cfg(not(target_arch = "wasm32"))]
2529fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2530    THREAD_PROVIDER.with(|cell| cell.get())
2531}
2532
2533#[cfg(target_arch = "wasm32")]
2534fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2535    WASM_THREAD_PROVIDER
2536        .lock()
2537        .expect("wasm provider mutex poisoned")
2538        .as_ref()
2539        .copied()
2540}
2541
2542/// Register a global acceleration provider.
2543///
2544/// # Safety
2545/// - The caller must guarantee that `p` is valid for the entire program lifetime
2546///   (e.g., a `'static` singleton), as the runtime stores a raw reference globally.
2547/// - Concurrent callers must ensure registration happens once or is properly
2548///   synchronized; this function does not enforce thread-safety for re-registration.
2549pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2550    if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2551        *guard = Some(p);
2552    }
2553    register_provider_for_device(p.device_id(), p);
2554}
2555
2556unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2557    if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2558        guard.insert(device_id, provider);
2559    }
2560}
2561
2562pub fn provider() -> Option<&'static dyn AccelProvider> {
2563    if let Some(p) = current_thread_provider() {
2564        return Some(p);
2565    }
2566    GLOBAL_PROVIDER
2567        .read()
2568        .ok()
2569        .and_then(|guard| guard.as_ref().copied())
2570}
2571
2572/// Clear the globally registered provider. Intended for tests to ensure deterministic behaviour.
2573pub fn clear_provider() {
2574    replace_thread_provider(None);
2575    if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2576        *guard = None;
2577    }
2578    if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2579        map.clear();
2580    }
2581}
2582
2583pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2584    PROVIDER_REGISTRY
2585        .read()
2586        .ok()
2587        .and_then(|guard| guard.get(&device_id).copied())
2588        .or_else(|| provider())
2589}
2590
2591pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2592    provider_for_device(handle.device_id)
2593}
2594
2595pub fn next_device_id() -> u32 {
2596    DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2597}
2598
2599pub struct ThreadProviderGuard {
2600    prev: Option<&'static dyn AccelProvider>,
2601}
2602
2603impl ThreadProviderGuard {
2604    pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2605        let prev = replace_thread_provider(provider);
2606        ThreadProviderGuard { prev }
2607    }
2608}
2609
2610impl Drop for ThreadProviderGuard {
2611    fn drop(&mut self) {
2612        let prev = self.prev.take();
2613        replace_thread_provider(prev);
2614    }
2615}
2616
2617pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2618    replace_thread_provider(provider);
2619}
2620
2621/// Convenience: perform elementwise add via provider if possible; otherwise return None
2622pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2623    if let Some(p) = provider() {
2624        if let Ok(h) = p.elem_add(a, b).await {
2625            return Some(h);
2626        }
2627    }
2628    None
2629}
2630
2631/// Convenience: perform elementwise hypot via provider if possible; otherwise return None
2632pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2633    if let Some(p) = provider() {
2634        if let Ok(h) = p.elem_hypot(a, b).await {
2635            return Some(h);
2636        }
2637    }
2638    None
2639}
2640
2641/// Convenience: perform elementwise max via provider if possible; otherwise return None
2642pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2643    if let Some(p) = provider() {
2644        if let Ok(h) = p.elem_max(a, b).await {
2645            return Some(h);
2646        }
2647    }
2648    None
2649}
2650
2651/// Convenience: perform elementwise min via provider if possible; otherwise return None
2652pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2653    if let Some(p) = provider() {
2654        if let Ok(h) = p.elem_min(a, b).await {
2655            return Some(h);
2656        }
2657    }
2658    None
2659}
2660
2661/// Convenience: perform elementwise atan2 via provider if possible; otherwise return None
2662pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2663    if let Some(p) = provider() {
2664        if let Ok(h) = p.elem_atan2(y, x).await {
2665            return Some(h);
2666        }
2667    }
2668    None
2669}
2670
2671// Minimal host tensor views to avoid depending on runmat-builtins and cycles
2672#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2673pub struct HostTensorOwned {
2674    pub data: Vec<f64>,
2675    pub shape: Vec<usize>,
2676    pub storage: GpuTensorStorage,
2677}
2678
2679#[derive(Debug)]
2680pub struct HostTensorView<'a> {
2681    pub data: &'a [f64],
2682    pub shape: &'a [usize],
2683}
2684
2685/// Lightweight 1-D axis view used by provider meshgrid hooks.
2686#[derive(Debug)]
2687pub struct MeshgridAxisView<'a> {
2688    pub data: &'a [f64],
2689}
2690
2691/// Provider-side meshgrid result containing coordinate tensor handles.
2692#[derive(Debug, Clone)]
2693pub struct ProviderMeshgridResult {
2694    pub outputs: Vec<GpuTensorHandle>,
2695}
2696
2697/// Descriptor for GEMM epilogues applied to `C = A * B` before storing to `C`.
2698///
2699/// Supported operations:
2700/// - Scale by `alpha` and add scalar `beta`.
2701/// - Multiply output by per-row and/or per-column scale vectors (broadcasted).
2702#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
2703pub enum ScaleOp {
2704    Multiply,
2705    Divide,
2706}
2707
2708#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2709pub struct MatmulEpilogue {
2710    /// Scalar multiply applied to each output element.
2711    pub alpha: f64,
2712    /// Scalar add applied to each output element after scaling.
2713    pub beta: f64,
2714    /// Optional per-row scale (length m). When present, output[row, col] *= row_scale[row].
2715    pub row_scale: Option<GpuTensorHandle>,
2716    /// Optional per-column scale (length n). When present, output[row, col] *= col_scale[col].
2717    pub col_scale: Option<GpuTensorHandle>,
2718    /// Row scale operation (multiply or divide). Ignored when `row_scale` is None.
2719    pub row_op: ScaleOp,
2720    /// Column scale operation (multiply or divide). Ignored when `col_scale` is None.
2721    pub col_op: ScaleOp,
2722    /// Optional lower clamp bound applied after scale/bias.
2723    #[serde(default)]
2724    pub clamp_min: Option<f64>,
2725    /// Optional upper clamp bound applied after scale/bias.
2726    #[serde(default)]
2727    pub clamp_max: Option<f64>,
2728    /// Optional power exponent applied after clamp (final operation in the epilogue).
2729    #[serde(default)]
2730    pub pow_exponent: Option<f64>,
2731    /// Optional output buffer for the diagonal of the result (length min(m, n)).
2732    #[serde(default)]
2733    pub diag_output: Option<GpuTensorHandle>,
2734}
2735
2736impl MatmulEpilogue {
2737    pub fn noop() -> Self {
2738        Self {
2739            alpha: 1.0,
2740            beta: 0.0,
2741            row_scale: None,
2742            col_scale: None,
2743            row_op: ScaleOp::Multiply,
2744            col_op: ScaleOp::Multiply,
2745            clamp_min: None,
2746            clamp_max: None,
2747            pow_exponent: None,
2748            diag_output: None,
2749        }
2750    }
2751    pub fn is_noop(&self) -> bool {
2752        self.alpha == 1.0
2753            && self.beta == 0.0
2754            && self.row_scale.is_none()
2755            && self.col_scale.is_none()
2756            && self.clamp_min.is_none()
2757            && self.clamp_max.is_none()
2758            && self.pow_exponent.is_none()
2759            && self.diag_output.is_none()
2760    }
2761}
2762
2763#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
2764pub struct PowerStepEpilogue {
2765    pub epsilon: f64,
2766}
2767
2768impl Default for PowerStepEpilogue {
2769    fn default() -> Self {
2770        Self { epsilon: 0.0 }
2771    }
2772}
2773
2774#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2775pub struct ImageNormalizeDescriptor {
2776    pub batch: usize,
2777    pub height: usize,
2778    pub width: usize,
2779    pub epsilon: f64,
2780    #[serde(default)]
2781    pub gain: Option<f64>,
2782    #[serde(default)]
2783    pub bias: Option<f64>,
2784    #[serde(default)]
2785    pub gamma: Option<f64>,
2786}