Skip to main content

runmat_accelerate_api/
lib.rs

1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyMarkFn = fn(&GpuTensorHandle);
17type ResidencyClearFn = fn(&GpuTensorHandle);
18type SequenceThresholdFn = fn() -> Option<usize>;
19type WorkgroupSizeHintFn = fn() -> Option<u32>;
20
21static RESIDENCY_MARK: OnceCell<ResidencyMarkFn> = OnceCell::new();
22static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
23static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
24static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
25
26static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
27static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
28    Lazy::new(|| RwLock::new(HashMap::new()));
29static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
30    Lazy::new(|| RwLock::new(HashMap::new()));
31
32static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
33    Lazy::new(|| RwLock::new(HashMap::new()));
34static HANDLE_STORAGES: Lazy<RwLock<HashMap<u64, GpuTensorStorage>>> =
35    Lazy::new(|| RwLock::new(HashMap::new()));
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct TransposeInfo {
39    pub base_rows: usize,
40    pub base_cols: usize,
41}
42
43/// Register a callback used to mark residency tracking when GPU tensors are
44/// created or returned by device-side execution paths.
45pub fn register_residency_mark(handler: ResidencyMarkFn) {
46    let _ = RESIDENCY_MARK.set(handler);
47}
48
49/// Mark residency metadata for the provided GPU tensor handle, if a backend
50/// has registered a handler via [`register_residency_mark`].
51pub fn mark_residency(handle: &GpuTensorHandle) {
52    if let Some(handler) = RESIDENCY_MARK.get() {
53        handler(handle);
54    }
55}
56
57/// Register a callback used to clear residency tracking when GPU tensors are
58/// gathered back to the host. Backends that maintain residency metadata should
59/// install this hook during initialization.
60pub fn register_residency_clear(handler: ResidencyClearFn) {
61    let _ = RESIDENCY_CLEAR.set(handler);
62}
63
64/// Clear residency metadata for the provided GPU tensor handle, if a backend
65/// has registered a handler via [`register_residency_clear`].
66pub fn clear_residency(handle: &GpuTensorHandle) {
67    if let Some(handler) = RESIDENCY_CLEAR.get() {
68        handler(handle);
69    }
70}
71
72/// Register a callback that exposes the current sequence length threshold
73/// derived from the auto-offload planner. Array constructors can use this hint
74/// to decide when to prefer GPU residency automatically.
75pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
76    let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
77}
78
79/// Query the currently registered sequence threshold hint, if any.
80pub fn sequence_threshold_hint() -> Option<usize> {
81    SEQUENCE_THRESHOLD_PROVIDER
82        .get()
83        .and_then(|provider| provider())
84}
85
86/// Register a callback that reports the calibrated workgroup size selected by
87/// the active acceleration provider (if any). Plotting kernels can reuse this
88/// hint to match backend tuning.
89pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
90    let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
91}
92
93/// Query the current workgroup size hint exposed by the provider.
94pub fn workgroup_size_hint() -> Option<u32> {
95    WORKGROUP_SIZE_HINT_PROVIDER
96        .get()
97        .and_then(|provider| provider())
98}
99
100/// Export a shared acceleration context (e.g., the active WGPU device) when the
101/// current provider exposes one.
102pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
103    provider().and_then(|p| p.export_context(kind))
104}
105
106/// Request a provider-owned WGPU buffer for zero-copy consumers. Returns `None`
107/// when the active provider does not expose buffers or does not support the
108/// supplied handle.
109#[cfg(feature = "wgpu")]
110pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
111    provider().and_then(|p| p.export_wgpu_buffer(handle))
112}
113
114/// Record the precision associated with a GPU tensor handle so host operations can
115/// reconstruct the original dtype when gathering back to the CPU.
116pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
117    if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
118        guard.insert(handle.buffer_id, precision);
119    }
120}
121
122/// Look up the recorded precision for a GPU tensor handle, if any.
123pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
124    HANDLE_PRECISIONS
125        .read()
126        .ok()
127        .and_then(|guard| guard.get(&handle.buffer_id).copied())
128}
129
130/// Clear any recorded precision metadata for a GPU tensor handle.
131pub fn clear_handle_precision(handle: &GpuTensorHandle) {
132    if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
133        guard.remove(&handle.buffer_id);
134    }
135}
136
137/// Annotate a GPU tensor handle as logically-typed (`logical` in MATLAB terms)
138/// or clear the logical flag when `logical` is `false`.
139pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
140    if let Ok(mut guard) = LOGICAL_HANDLES.write() {
141        if logical {
142            guard.insert(handle.buffer_id);
143            if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
144                *hits.entry(handle.buffer_id).or_insert(0) += 1;
145            }
146        } else {
147            guard.remove(&handle.buffer_id);
148            if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
149                hits.remove(&handle.buffer_id);
150            }
151        }
152    }
153}
154
155/// Convenience helper for clearing logical annotations explicitly.
156pub fn clear_handle_logical(handle: &GpuTensorHandle) {
157    set_handle_logical(handle, false);
158}
159
160/// Returns true when the supplied handle has been marked as logical.
161pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
162    LOGICAL_HANDLES
163        .read()
164        .map(|guard| guard.contains(&handle.buffer_id))
165        .unwrap_or(false)
166}
167
168pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
169    LOGICAL_HANDLE_HITS
170        .read()
171        .ok()
172        .and_then(|guard| guard.get(&buffer_id).copied())
173}
174
175pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
176    if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
177        guard.insert(
178            handle.buffer_id,
179            TransposeInfo {
180                base_rows,
181                base_cols,
182            },
183        );
184    }
185}
186
187pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
188    if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
189        guard.remove(&handle.buffer_id);
190    }
191}
192
193pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
194    TRANSPOSED_HANDLES
195        .read()
196        .ok()
197        .and_then(|guard| guard.get(&handle.buffer_id).copied())
198}
199
200pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
201    handle_transpose_info(handle).is_some()
202}
203
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum GpuTensorStorage {
206    Real,
207    ComplexInterleaved,
208}
209
210impl Default for GpuTensorStorage {
211    fn default() -> Self {
212        Self::Real
213    }
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct GpuTensorHandle {
218    pub shape: Vec<usize>,
219    pub device_id: u32,
220    pub buffer_id: u64,
221}
222
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct ApiDeviceInfo {
225    pub device_id: u32,
226    pub name: String,
227    pub vendor: String,
228    pub memory_bytes: Option<u64>,
229    pub backend: Option<String>,
230}
231
232#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
233pub struct ReduceDimResult {
234    pub values: GpuTensorHandle,
235    pub indices: GpuTensorHandle,
236}
237
238#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
239pub struct ProviderCumminResult {
240    pub values: GpuTensorHandle,
241    pub indices: GpuTensorHandle,
242}
243
244/// Result payload returned by provider-side `cummax` scans.
245///
246/// Alias of [`ProviderCumminResult`] because both operations return the same pair of tensors
247/// (running values and MATLAB-compatible indices).
248pub type ProviderCummaxResult = ProviderCumminResult;
249
250/// Names a shared acceleration context that callers may request (e.g. plotting).
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum AccelContextKind {
253    Plotting,
254}
255
256/// Handle returned by [`export_context`] that describes a shared GPU context.
257#[derive(Clone)]
258pub enum AccelContextHandle {
259    #[cfg(feature = "wgpu")]
260    Wgpu(WgpuContextHandle),
261}
262
263impl AccelContextHandle {
264    /// Returns the underlying WGPU context when available.
265    #[cfg(feature = "wgpu")]
266    pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
267        match self {
268            AccelContextHandle::Wgpu(ctx) => Some(ctx),
269        }
270    }
271}
272
273/// Shared WGPU device/queue pair exported by the acceleration provider.
274#[cfg(feature = "wgpu")]
275#[derive(Clone)]
276pub struct WgpuContextHandle {
277    pub instance: Arc<wgpu::Instance>,
278    pub device: Arc<wgpu::Device>,
279    pub queue: Arc<wgpu::Queue>,
280    pub adapter: Arc<wgpu::Adapter>,
281    pub adapter_info: wgpu::AdapterInfo,
282    pub limits: wgpu::Limits,
283    pub features: wgpu::Features,
284}
285
286/// Borrowed reference to a provider-owned WGPU buffer corresponding to a `GpuTensorHandle`.
287#[cfg(feature = "wgpu")]
288#[derive(Clone)]
289pub struct WgpuBufferRef {
290    pub buffer: Arc<wgpu::Buffer>,
291    pub len: usize,
292    pub shape: Vec<usize>,
293    pub element_size: usize,
294    pub precision: ProviderPrecision,
295}
296
297pub fn set_handle_storage(handle: &GpuTensorHandle, storage: GpuTensorStorage) {
298    if let Ok(mut guard) = HANDLE_STORAGES.write() {
299        guard.insert(handle.buffer_id, storage);
300    }
301}
302
303pub fn handle_storage(handle: &GpuTensorHandle) -> GpuTensorStorage {
304    HANDLE_STORAGES
305        .read()
306        .ok()
307        .and_then(|guard| guard.get(&handle.buffer_id).cloned())
308        .unwrap_or(GpuTensorStorage::Real)
309}
310
311pub fn clear_handle_storage(handle: &GpuTensorHandle) {
312    if let Ok(mut guard) = HANDLE_STORAGES.write() {
313        guard.remove(&handle.buffer_id);
314    }
315}
316
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
318pub enum PagefunOp {
319    Mtimes,
320}
321
322#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
323pub struct PagefunRequest {
324    pub op: PagefunOp,
325    pub inputs: Vec<GpuTensorHandle>,
326    pub output_shape: Vec<usize>,
327    pub page_dims: Vec<usize>,
328    pub input_page_dims: Vec<Vec<usize>>,
329}
330
331#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
332pub enum FindDirection {
333    First,
334    Last,
335}
336
337#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
338pub struct ProviderFindResult {
339    pub linear: GpuTensorHandle,
340    pub rows: GpuTensorHandle,
341    pub cols: GpuTensorHandle,
342    pub values: Option<GpuTensorHandle>,
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
346pub struct ProviderBandwidth {
347    pub lower: u32,
348    pub upper: u32,
349}
350
351#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
352pub enum ProviderSymmetryKind {
353    Symmetric,
354    Skew,
355}
356
357#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
358pub enum ProviderHermitianKind {
359    Hermitian,
360    Skew,
361}
362
363#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
364pub struct ProviderLuResult {
365    pub combined: GpuTensorHandle,
366    pub lower: GpuTensorHandle,
367    pub upper: GpuTensorHandle,
368    pub perm_matrix: GpuTensorHandle,
369    pub perm_vector: GpuTensorHandle,
370}
371
372#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
373pub struct ProviderCholResult {
374    pub factor: GpuTensorHandle,
375    /// MATLAB-compatible failure index (0 indicates success).
376    pub info: u32,
377}
378
379#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
380pub struct ProviderQrResult {
381    pub q: GpuTensorHandle,
382    pub r: GpuTensorHandle,
383    pub perm_matrix: GpuTensorHandle,
384    pub perm_vector: GpuTensorHandle,
385}
386
387#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
388pub struct ProviderQrPowerIterResult {
389    pub q: GpuTensorHandle,
390    pub r: GpuTensorHandle,
391    pub perm_matrix: GpuTensorHandle,
392    pub perm_vector: GpuTensorHandle,
393}
394
395#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
396pub struct ProviderLinsolveOptions {
397    pub lower: bool,
398    pub upper: bool,
399    pub rectangular: bool,
400    pub transposed: bool,
401    pub conjugate: bool,
402    pub symmetric: bool,
403    pub posdef: bool,
404    pub need_rcond: bool,
405    pub rcond: Option<f64>,
406}
407
408#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
409pub struct ProviderLinsolveResult {
410    pub solution: GpuTensorHandle,
411    pub reciprocal_condition: f64,
412}
413
414#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
415pub struct ProviderPinvOptions {
416    pub tolerance: Option<f64>,
417}
418
419#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
420pub struct ProviderPolyvalMu {
421    pub mean: f64,
422    pub scale: f64,
423}
424
425#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
426pub struct ProviderPolyvalOptions {
427    pub mu: Option<ProviderPolyvalMu>,
428}
429
430#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
431pub struct ProviderInvOptions {}
432
433#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
434pub struct ProviderPolyfitResult {
435    pub coefficients: Vec<f64>,
436    pub r_matrix: Vec<f64>,
437    pub normr: f64,
438    pub df: f64,
439    pub mu: [f64; 2],
440}
441
442/// Numerator/denominator payload returned by provider-backed `polyder` quotient rule.
443#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
444pub struct ProviderPolyderQuotient {
445    pub numerator: GpuTensorHandle,
446    pub denominator: GpuTensorHandle,
447}
448
449/// Supported norm specifications for the `cond` builtin.
450#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
451pub enum ProviderCondNorm {
452    Two,
453    One,
454    Inf,
455    Fro,
456}
457
458/// Supported norm orders for the `norm` builtin.
459#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
460pub enum ProviderNormOrder {
461    Two,
462    One,
463    Inf,
464    NegInf,
465    Zero,
466    Fro,
467    Nuc,
468    P(f64),
469}
470
471#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
472pub struct ProviderEigResult {
473    pub eigenvalues: GpuTensorHandle,
474    pub diagonal: GpuTensorHandle,
475    pub right: GpuTensorHandle,
476    pub left: Option<GpuTensorHandle>,
477}
478
479#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
480pub enum ProviderQrPivot {
481    Matrix,
482    Vector,
483}
484
485#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
486pub struct ProviderQrOptions {
487    pub economy: bool,
488    pub pivot: ProviderQrPivot,
489}
490
491impl Default for ProviderQrOptions {
492    fn default() -> Self {
493        Self {
494            economy: false,
495            pivot: ProviderQrPivot::Matrix,
496        }
497    }
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
501pub enum ProviderPrecision {
502    F32,
503    F64,
504}
505
506/// Declares how provider-owned GPU handles may cross async spawn boundaries.
507///
508/// This is a runtime/provider policy surface (not a semantic type fact) used by
509/// VM/runtime spawn handling to prevent unsynchronized device-handle races.
510#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
511pub enum SpawnHandleConcurrency {
512    /// Provider supports immutable sharing of handle-backed values across spawned tasks.
513    ImmutableShare,
514    /// Provider supports copy-on-write semantics when spawned and parent tasks diverge.
515    CopyOnWrite,
516    /// Provider supports synchronized mutation for shared handles.
517    SynchronizedMutation,
518    /// Provider rejects spawned sharing of raw handles.
519    Reject,
520}
521
522impl SpawnHandleConcurrency {
523    pub fn as_str(self) -> &'static str {
524        match self {
525            SpawnHandleConcurrency::ImmutableShare => "immutable_share",
526            SpawnHandleConcurrency::CopyOnWrite => "copy_on_write",
527            SpawnHandleConcurrency::SynchronizedMutation => "synchronized_mutation",
528            SpawnHandleConcurrency::Reject => "reject",
529        }
530    }
531}
532
533#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
534pub enum ReductionTwoPassMode {
535    Auto,
536    ForceOn,
537    ForceOff,
538}
539
540impl ReductionTwoPassMode {
541    pub fn as_str(self) -> &'static str {
542        match self {
543            ReductionTwoPassMode::Auto => "auto",
544            ReductionTwoPassMode::ForceOn => "force_on",
545            ReductionTwoPassMode::ForceOff => "force_off",
546        }
547    }
548}
549
550#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
551pub enum ReductionFlavor {
552    Sum,
553    Mean,
554    CustomScale(f64),
555}
556
557impl ReductionFlavor {
558    pub fn is_mean(self) -> bool {
559        matches!(self, ReductionFlavor::Mean)
560    }
561
562    pub fn scale(self, reduce_len: usize) -> f64 {
563        match self {
564            ReductionFlavor::Sum => 1.0,
565            ReductionFlavor::Mean => {
566                if reduce_len == 0 {
567                    1.0
568                } else {
569                    1.0 / reduce_len as f64
570                }
571            }
572            ReductionFlavor::CustomScale(scale) => scale,
573        }
574    }
575}
576
577/// Normalisation mode for correlation coefficients.
578#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
579pub enum CorrcoefNormalization {
580    Unbiased,
581    Biased,
582}
583
584/// Row-selection strategy for correlation coefficients.
585#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
586pub enum CorrcoefRows {
587    All,
588    Complete,
589    Pairwise,
590}
591
592/// Options controlling provider-backed correlation coefficient computation.
593#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
594pub struct CorrcoefOptions {
595    pub normalization: CorrcoefNormalization,
596    pub rows: CorrcoefRows,
597}
598
599impl Default for CorrcoefOptions {
600    fn default() -> Self {
601        Self {
602            normalization: CorrcoefNormalization::Unbiased,
603            rows: CorrcoefRows::All,
604        }
605    }
606}
607
608/// Normalisation mode used by covariance computations.
609#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
610pub enum CovNormalization {
611    Unbiased,
612    Biased,
613}
614
615/// Row handling strategy for covariance computations.
616#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
617pub enum CovRows {
618    All,
619    OmitRows,
620    PartialRows,
621}
622
623/// Options controlling provider-backed covariance computation.
624#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
625pub struct CovarianceOptions {
626    pub normalization: CovNormalization,
627    pub rows: CovRows,
628    pub has_weight_vector: bool,
629}
630
631impl Default for CovarianceOptions {
632    fn default() -> Self {
633        Self {
634            normalization: CovNormalization::Unbiased,
635            rows: CovRows::All,
636            has_weight_vector: false,
637        }
638    }
639}
640
641/// Normalization strategy used by provider-backed standard deviation reductions.
642#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
643pub enum ProviderStdNormalization {
644    Sample,
645    Population,
646}
647
648/// NaN handling mode for provider-backed reductions.
649#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
650pub enum ProviderNanMode {
651    Include,
652    Omit,
653}
654
655/// Direction used when computing prefix sums on the device.
656#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
657pub enum ProviderScanDirection {
658    Forward,
659    Reverse,
660}
661
662/// Sort direction used by acceleration providers.
663#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
664pub enum SortOrder {
665    Ascend,
666    Descend,
667}
668
669/// Comparison strategy applied during sorting.
670#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
671pub enum SortComparison {
672    Auto,
673    Real,
674    Abs,
675}
676
677/// Host-resident outputs returned by provider-backed sort operations.
678#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
679pub struct SortResult {
680    pub values: HostTensorOwned,
681    pub indices: HostTensorOwned,
682}
683
684#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
685pub struct SortRowsColumnSpec {
686    pub index: usize,
687    pub order: SortOrder,
688}
689
690/// Ordering applied by provider-backed `unique` operations.
691#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
692pub enum UniqueOrder {
693    Sorted,
694    Stable,
695}
696
697/// Occurrence selection for provider-backed `unique` operations.
698#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
699pub enum UniqueOccurrence {
700    First,
701    Last,
702}
703
704/// Options controlling provider-backed `unique` operations.
705#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
706pub struct UniqueOptions {
707    pub rows: bool,
708    pub order: UniqueOrder,
709    pub occurrence: UniqueOccurrence,
710}
711
712/// Host-resident outputs returned by provider-backed `unique` operations.
713#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
714pub struct UniqueResult {
715    pub values: HostTensorOwned,
716    pub ia: HostTensorOwned,
717    pub ic: HostTensorOwned,
718}
719
720/// Ordering applied by provider-backed `union` operations.
721#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
722pub enum UnionOrder {
723    Sorted,
724    Stable,
725}
726
727/// Options controlling provider-backed `union` operations.
728#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
729pub struct UnionOptions {
730    pub rows: bool,
731    pub order: UnionOrder,
732}
733
734/// Host-resident outputs returned by provider-backed `union` operations.
735#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
736pub struct UnionResult {
737    pub values: HostTensorOwned,
738    pub ia: HostTensorOwned,
739    pub ib: HostTensorOwned,
740}
741
742/// Parameterisation of 2-D filters generated by `fspecial`.
743#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
744pub enum FspecialFilter {
745    Average {
746        rows: u32,
747        cols: u32,
748    },
749    Disk {
750        radius: f64,
751        size: u32,
752    },
753    Gaussian {
754        rows: u32,
755        cols: u32,
756        sigma: f64,
757    },
758    Laplacian {
759        alpha: f64,
760    },
761    Log {
762        rows: u32,
763        cols: u32,
764        sigma: f64,
765    },
766    Motion {
767        length: u32,
768        kernel_size: u32,
769        angle_degrees: f64,
770        oversample: u32,
771    },
772    Prewitt,
773    Sobel,
774    Unsharp {
775        alpha: f64,
776    },
777}
778
779/// Request dispatched to acceleration providers for `fspecial` kernels.
780#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
781pub struct FspecialRequest {
782    pub filter: FspecialFilter,
783}
784
785/// Padding strategy used by `imfilter`.
786#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
787pub enum ImfilterPadding {
788    Constant,
789    Replicate,
790    Symmetric,
791    Circular,
792}
793
794/// Output sizing mode used by `imfilter`.
795#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
796pub enum ImfilterShape {
797    Same,
798    Full,
799    Valid,
800}
801
802/// Correlation vs convolution behaviour for `imfilter`.
803#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
804pub enum ImfilterMode {
805    Correlation,
806    Convolution,
807}
808
809/// Options supplied to acceleration providers for `imfilter`.
810#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
811pub struct ImfilterOptions {
812    pub padding: ImfilterPadding,
813    pub constant_value: f64,
814    pub shape: ImfilterShape,
815    pub mode: ImfilterMode,
816}
817
818impl Default for ImfilterOptions {
819    fn default() -> Self {
820        Self {
821            padding: ImfilterPadding::Constant,
822            constant_value: 0.0,
823            shape: ImfilterShape::Same,
824            mode: ImfilterMode::Correlation,
825        }
826    }
827}
828
829/// Ordering applied by provider-backed `setdiff` operations.
830#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
831pub enum SetdiffOrder {
832    Sorted,
833    Stable,
834}
835
836/// Options controlling provider-backed `setdiff` operations.
837#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
838pub struct SetdiffOptions {
839    pub rows: bool,
840    pub order: SetdiffOrder,
841}
842
843/// Host-resident outputs returned by provider-backed `setdiff` operations.
844#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
845pub struct SetdiffResult {
846    pub values: HostTensorOwned,
847    pub ia: HostTensorOwned,
848}
849
850/// Options controlling provider-backed `ismember` operations.
851#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
852pub struct IsMemberOptions {
853    pub rows: bool,
854}
855
856/// Host-resident logical output returned by providers.
857#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
858pub struct HostLogicalOwned {
859    pub data: Vec<u8>,
860    pub shape: Vec<usize>,
861}
862
863/// Host-resident outputs returned by provider-backed `ismember` operations.
864#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
865pub struct IsMemberResult {
866    pub mask: HostLogicalOwned,
867    pub loc: HostTensorOwned,
868}
869
870#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
871pub enum ProviderConvMode {
872    Full,
873    Same,
874    Valid,
875}
876
877#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
878pub enum ProviderConvOrientation {
879    Row,
880    Column,
881}
882
883#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
884pub struct ProviderConv1dOptions {
885    pub mode: ProviderConvMode,
886    pub orientation: ProviderConvOrientation,
887}
888
889#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
890pub struct ProviderIirFilterOptions {
891    /// Zero-based dimension along which filtering should be applied.
892    pub dim: usize,
893    /// Optional initial conditions (state vector) residing on the device.
894    pub zi: Option<GpuTensorHandle>,
895}
896
897#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
898pub struct ProviderIirFilterResult {
899    /// Filtered output tensor, matching the input signal shape.
900    pub output: GpuTensorHandle,
901    /// Final conditions for the filter state (same shape as the requested `zi` layout).
902    pub final_state: Option<GpuTensorHandle>,
903}
904
905#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
906pub struct ProviderMoments2 {
907    pub mean: GpuTensorHandle,
908    pub ex2: GpuTensorHandle,
909}
910
911#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
912pub struct ProviderDispatchStats {
913    /// Number of GPU dispatches recorded for this category.
914    pub count: u64,
915    /// Accumulated wall-clock time of dispatches in nanoseconds (host measured).
916    pub total_wall_time_ns: u64,
917}
918
919#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
920pub struct ProviderFallbackStat {
921    pub reason: String,
922    pub count: u64,
923}
924
925#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
926pub struct ProviderTelemetry {
927    pub fused_elementwise: ProviderDispatchStats,
928    pub fused_reduction: ProviderDispatchStats,
929    pub matmul: ProviderDispatchStats,
930    pub linsolve: ProviderDispatchStats,
931    pub mldivide: ProviderDispatchStats,
932    pub mrdivide: ProviderDispatchStats,
933    pub upload_bytes: u64,
934    pub download_bytes: u64,
935    pub solve_fallbacks: Vec<ProviderFallbackStat>,
936    pub fusion_cache_hits: u64,
937    pub fusion_cache_misses: u64,
938    pub bind_group_cache_hits: u64,
939    pub bind_group_cache_misses: u64,
940    /// Optional per-layout bind group cache counters (layout tags and their hit/miss counts)
941    pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
942    /// Recent kernel launch metadata (bounded log; newest last)
943    pub kernel_launches: Vec<KernelLaunchTelemetry>,
944}
945
946#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
947pub struct BindGroupLayoutTelemetry {
948    pub tag: String,
949    pub hits: u64,
950    pub misses: u64,
951}
952
953#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
954pub struct KernelAttrTelemetry {
955    pub key: String,
956    pub value: u64,
957}
958
959#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
960pub struct KernelLaunchTelemetry {
961    pub kernel: String,
962    pub precision: Option<String>,
963    pub shape: Vec<KernelAttrTelemetry>,
964    pub tuning: Vec<KernelAttrTelemetry>,
965}
966
967pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
968pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
969
970fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
971    Box::pin(async move { Err(anyhow::anyhow!(message)) })
972}
973
974/// Device/provider interface that backends implement and register into the runtime layer
975pub trait AccelProvider: Send + Sync {
976    fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
977    fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
978    fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
979    fn device_info(&self) -> String;
980    fn device_id(&self) -> u32 {
981        0
982    }
983
984    /// Declares provider policy for sharing `GpuTensorHandle` values across
985    /// spawned async boundaries.
986    ///
987    /// Default is conservative rejection. Providers that can safely support
988    /// cross-task sharing should override this.
989    fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
990        SpawnHandleConcurrency::Reject
991    }
992
993    /// Export a shared GPU context handle, allowing downstream systems (plotting, visualization)
994    /// to reuse the same device/queue without copying tensor data back to the host.
995    fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
996        None
997    }
998
999    /// Export a provider-owned WGPU buffer for zero-copy integrations.
1000    #[cfg(feature = "wgpu")]
1001    fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
1002        let _ = _handle;
1003        None
1004    }
1005
1006    /// Gather elements from `source` at the provided zero-based linear `indices`, materialising
1007    /// a dense tensor with the specified `output_shape`.
1008    fn gather_linear(
1009        &self,
1010        _source: &GpuTensorHandle,
1011        _indices: &[u32],
1012        _output_shape: &[usize],
1013    ) -> anyhow::Result<GpuTensorHandle> {
1014        Err(anyhow::anyhow!("gather_linear not supported by provider"))
1015    }
1016
1017    /// Scatter the contents of `values` into `target` at the provided zero-based linear `indices`.
1018    ///
1019    /// The provider must ensure `values.len() == indices.len()` and update `target` in place.
1020    fn scatter_linear(
1021        &self,
1022        _target: &GpuTensorHandle,
1023        _indices: &[u32],
1024        _values: &GpuTensorHandle,
1025    ) -> anyhow::Result<()> {
1026        Err(anyhow::anyhow!("scatter_linear not supported by provider"))
1027    }
1028
1029    /// Structured device information (optional to override). Default adapts from `device_info()`.
1030    fn device_info_struct(&self) -> ApiDeviceInfo {
1031        ApiDeviceInfo {
1032            device_id: 0,
1033            name: self.device_info(),
1034            vendor: String::new(),
1035            memory_bytes: None,
1036            backend: None,
1037        }
1038    }
1039
1040    fn precision(&self) -> ProviderPrecision {
1041        ProviderPrecision::F64
1042    }
1043
1044    /// Read a single scalar at linear index from a device tensor, returning it as f64.
1045    fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
1046        Err(anyhow::anyhow!("read_scalar not supported by provider"))
1047    }
1048
1049    /// Allocate a zero-initialised tensor with the provided shape on the device.
1050    fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1051        Err(anyhow::anyhow!("zeros not supported by provider"))
1052    }
1053
1054    /// Allocate a one-initialised tensor with the provided shape on the device.
1055    fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1056        Err(anyhow::anyhow!("ones not supported by provider"))
1057    }
1058
1059    /// Allocate a zero-initialised tensor matching the prototype tensor.
1060    fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1061        self.zeros(&prototype.shape)
1062    }
1063
1064    /// Allocate a tensor filled with a constant value on the device.
1065    fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
1066        if value == 0.0 {
1067            return self.zeros(shape);
1068        }
1069        if let Ok(base) = self.zeros(shape) {
1070            match self.scalar_add(&base, value) {
1071                Ok(out) => {
1072                    let _ = self.free(&base);
1073                    return Ok(out);
1074                }
1075                Err(_) => {
1076                    let _ = self.free(&base);
1077                }
1078            }
1079        }
1080        let len: usize = shape.iter().copied().product();
1081        let data = vec![value; len];
1082        let view = HostTensorView { data: &data, shape };
1083        self.upload(&view)
1084    }
1085
1086    /// Allocate a tensor filled with a constant value, matching a prototype's residency.
1087    fn fill_like(
1088        &self,
1089        prototype: &GpuTensorHandle,
1090        value: f64,
1091    ) -> anyhow::Result<GpuTensorHandle> {
1092        if value == 0.0 {
1093            return self.zeros_like(prototype);
1094        }
1095        if let Ok(base) = self.zeros_like(prototype) {
1096            match self.scalar_add(&base, value) {
1097                Ok(out) => {
1098                    let _ = self.free(&base);
1099                    return Ok(out);
1100                }
1101                Err(_) => {
1102                    let _ = self.free(&base);
1103                }
1104            }
1105        }
1106        self.fill(&prototype.shape, value)
1107    }
1108
1109    /// Allocate a one-initialised tensor matching the prototype tensor.
1110    fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1111        self.ones(&prototype.shape)
1112    }
1113
1114    /// Allocate an identity tensor with ones along the leading diagonal of the first two axes.
1115    fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1116        Err(anyhow::anyhow!("eye not supported by provider"))
1117    }
1118
1119    /// Allocate an identity tensor matching the prototype tensor's shape.
1120    fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1121        self.eye(&prototype.shape)
1122    }
1123
1124    /// Construct MATLAB-style coordinate grids from axis vectors.
1125    fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1126        Err(anyhow::anyhow!("meshgrid not supported by provider"))
1127    }
1128
1129    /// Construct a diagonal matrix from a vector-like tensor. `offset` matches MATLAB semantics.
1130    fn diag_from_vector(
1131        &self,
1132        _vector: &GpuTensorHandle,
1133        _offset: isize,
1134    ) -> anyhow::Result<GpuTensorHandle> {
1135        Err(anyhow::anyhow!(
1136            "diag_from_vector not supported by provider"
1137        ))
1138    }
1139
1140    /// Extract a diagonal from a matrix-like tensor. The result is always a column vector.
1141    fn diag_extract(
1142        &self,
1143        _matrix: &GpuTensorHandle,
1144        _offset: isize,
1145    ) -> anyhow::Result<GpuTensorHandle> {
1146        Err(anyhow::anyhow!("diag_extract not supported by provider"))
1147    }
1148
1149    /// Apply a lower-triangular mask to the first two dimensions of a tensor.
1150    fn tril<'a>(
1151        &'a self,
1152        _matrix: &'a GpuTensorHandle,
1153        _offset: isize,
1154    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1155        Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1156    }
1157
1158    /// Apply an upper-triangular mask to the first two dimensions of a tensor.
1159    fn triu<'a>(
1160        &'a self,
1161        _matrix: &'a GpuTensorHandle,
1162        _offset: isize,
1163    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1164        Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1165    }
1166
1167    /// Evaluate a polynomial expressed by `coefficients` at each element in `points`.
1168    fn polyval(
1169        &self,
1170        _coefficients: &GpuTensorHandle,
1171        _points: &GpuTensorHandle,
1172        _options: &ProviderPolyvalOptions,
1173    ) -> anyhow::Result<GpuTensorHandle> {
1174        Err(anyhow::anyhow!("polyval not supported by provider"))
1175    }
1176
1177    /// Fit a polynomial of degree `degree` to `(x, y)` samples. Optional weights must match `x`.
1178    fn polyfit<'a>(
1179        &'a self,
1180        _x: &'a GpuTensorHandle,
1181        _y: &'a GpuTensorHandle,
1182        _degree: usize,
1183        _weights: Option<&'a GpuTensorHandle>,
1184    ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1185        Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1186    }
1187
1188    /// Differentiate a polynomial represented as a vector of coefficients.
1189    fn polyder_single<'a>(
1190        &'a self,
1191        _polynomial: &'a GpuTensorHandle,
1192    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1193        Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1194    }
1195
1196    /// Apply the product rule to polynomials `p` and `q`.
1197    fn polyder_product<'a>(
1198        &'a self,
1199        _p: &'a GpuTensorHandle,
1200        _q: &'a GpuTensorHandle,
1201    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1202        Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1203    }
1204
1205    /// Apply the quotient rule to polynomials `u` and `v`.
1206    fn polyder_quotient<'a>(
1207        &'a self,
1208        _u: &'a GpuTensorHandle,
1209        _v: &'a GpuTensorHandle,
1210    ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1211        Box::pin(async move {
1212            Err(anyhow::anyhow!(
1213                "polyder_quotient not supported by provider"
1214            ))
1215        })
1216    }
1217
1218    /// Integrate a polynomial represented as a vector of coefficients and append a constant term.
1219    fn polyint(
1220        &self,
1221        _polynomial: &GpuTensorHandle,
1222        _constant: f64,
1223    ) -> anyhow::Result<GpuTensorHandle> {
1224        Err(anyhow::anyhow!("polyint not supported by provider"))
1225    }
1226
1227    /// Allocate a tensor filled with random values drawn from U(0, 1).
1228    fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1229        Err(anyhow::anyhow!("random_uniform not supported by provider"))
1230    }
1231
1232    /// Allocate a tensor filled with random values matching the prototype shape.
1233    fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1234        self.random_uniform(&prototype.shape)
1235    }
1236
1237    /// Allocate a tensor filled with standard normal (mean 0, stddev 1) random values.
1238    fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1239        Err(anyhow::anyhow!("random_normal not supported by provider"))
1240    }
1241
1242    /// Allocate a tensor of standard normal values matching a prototype's shape.
1243    fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1244        self.random_normal(&prototype.shape)
1245    }
1246
1247    /// Exponentially-distributed random values with mean `mu`.
1248    fn random_exponential(&self, _mu: f64, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1249        Err(anyhow::anyhow!(
1250            "random_exponential not supported by provider"
1251        ))
1252    }
1253
1254    /// Normal random values with mean `mu` and standard deviation `sigma`.
1255    fn random_normrnd(
1256        &self,
1257        _mu: f64,
1258        _sigma: f64,
1259        _shape: &[usize],
1260    ) -> anyhow::Result<GpuTensorHandle> {
1261        Err(anyhow::anyhow!("random_normrnd not supported by provider"))
1262    }
1263
1264    /// Uniform random values on the interval `[a, b)`.
1265    fn random_unifrnd(
1266        &self,
1267        _a: f64,
1268        _b: f64,
1269        _shape: &[usize],
1270    ) -> anyhow::Result<GpuTensorHandle> {
1271        Err(anyhow::anyhow!("random_unifrnd not supported by provider"))
1272    }
1273
1274    fn stochastic_evolution(
1275        &self,
1276        _state: &GpuTensorHandle,
1277        _drift: f64,
1278        _scale: f64,
1279        _steps: u32,
1280    ) -> anyhow::Result<GpuTensorHandle> {
1281        Err(anyhow::anyhow!(
1282            "stochastic_evolution not supported by provider"
1283        ))
1284    }
1285
1286    /// Set the provider RNG state to align with the host RNG.
1287    fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1288        Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1289    }
1290
1291    /// Generate a 2-D correlation kernel matching MATLAB's `fspecial` builtin.
1292    fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1293        Err(anyhow::anyhow!("fspecial not supported by provider"))
1294    }
1295
1296    /// Evaluate the `peaks` test surface on an n×n grid spanning [-3,3]×[-3,3].
1297    /// Returns the Z matrix (n×n) as a GPU tensor.
1298    fn peaks(&self, _n: usize) -> anyhow::Result<GpuTensorHandle> {
1299        Err(anyhow::anyhow!("peaks not supported by provider"))
1300    }
1301
1302    /// Evaluate the `peaks` formula element-wise on caller-supplied GPU coordinate tensors.
1303    /// X and Y must have the same shape. Returns a Z tensor of the same shape.
1304    fn peaks_xy(
1305        &self,
1306        _x: &GpuTensorHandle,
1307        _y: &GpuTensorHandle,
1308    ) -> anyhow::Result<GpuTensorHandle> {
1309        Err(anyhow::anyhow!("peaks_xy not supported by provider"))
1310    }
1311
1312    fn hann_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1313        Err(anyhow::anyhow!("hann_window not supported by provider"))
1314    }
1315
1316    fn hamming_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1317        Err(anyhow::anyhow!("hamming_window not supported by provider"))
1318    }
1319
1320    fn blackman_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1321        Err(anyhow::anyhow!("blackman_window not supported by provider"))
1322    }
1323
1324    /// Apply an N-D correlation/convolution with padding semantics matching MATLAB's `imfilter`.
1325    fn imfilter<'a>(
1326        &'a self,
1327        _image: &'a GpuTensorHandle,
1328        _kernel: &'a GpuTensorHandle,
1329        _options: &'a ImfilterOptions,
1330    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1331        unsupported_future("imfilter not supported by provider")
1332    }
1333
1334    /// Allocate a tensor filled with random integers over an inclusive range.
1335    fn random_integer_range(
1336        &self,
1337        _lower: i64,
1338        _upper: i64,
1339        _shape: &[usize],
1340    ) -> anyhow::Result<GpuTensorHandle> {
1341        Err(anyhow::anyhow!(
1342            "random_integer_range not supported by provider"
1343        ))
1344    }
1345
1346    /// Allocate a random integer tensor matching the prototype shape.
1347    fn random_integer_like(
1348        &self,
1349        prototype: &GpuTensorHandle,
1350        lower: i64,
1351        upper: i64,
1352    ) -> anyhow::Result<GpuTensorHandle> {
1353        self.random_integer_range(lower, upper, &prototype.shape)
1354    }
1355
1356    /// Allocate a random permutation of 1..=n, returning the first k elements.
1357    fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1358        Err(anyhow!("random_permutation not supported by provider"))
1359    }
1360
1361    /// Allocate a random permutation matching the prototype residency.
1362    fn random_permutation_like(
1363        &self,
1364        _prototype: &GpuTensorHandle,
1365        n: usize,
1366        k: usize,
1367    ) -> anyhow::Result<GpuTensorHandle> {
1368        self.random_permutation(n, k)
1369    }
1370
1371    /// Compute a covariance matrix across the columns of `matrix`.
1372    fn covariance<'a>(
1373        &'a self,
1374        _matrix: &'a GpuTensorHandle,
1375        _second: Option<&'a GpuTensorHandle>,
1376        _weights: Option<&'a GpuTensorHandle>,
1377        _options: &'a CovarianceOptions,
1378    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1379        unsupported_future("covariance not supported by provider")
1380    }
1381
1382    /// Compute a correlation coefficient matrix across the columns of `matrix`.
1383    fn corrcoef<'a>(
1384        &'a self,
1385        _matrix: &'a GpuTensorHandle,
1386        _options: &'a CorrcoefOptions,
1387    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1388        unsupported_future("corrcoef not supported by provider")
1389    }
1390
1391    // Optional operator hooks (default to unsupported)
1392    fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1393        Err(anyhow::anyhow!("linspace not supported by provider"))
1394    }
1395    fn elem_add<'a>(
1396        &'a self,
1397        _a: &'a GpuTensorHandle,
1398        _b: &'a GpuTensorHandle,
1399    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1400        unsupported_future("elem_add not supported by provider")
1401    }
1402    fn elem_mul<'a>(
1403        &'a self,
1404        _a: &'a GpuTensorHandle,
1405        _b: &'a GpuTensorHandle,
1406    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1407        unsupported_future("elem_mul not supported by provider")
1408    }
1409    fn elem_max<'a>(
1410        &'a self,
1411        _a: &'a GpuTensorHandle,
1412        _b: &'a GpuTensorHandle,
1413    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1414        unsupported_future("elem_max not supported by provider")
1415    }
1416    fn elem_min<'a>(
1417        &'a self,
1418        _a: &'a GpuTensorHandle,
1419        _b: &'a GpuTensorHandle,
1420    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1421        unsupported_future("elem_min not supported by provider")
1422    }
1423    fn elem_sub<'a>(
1424        &'a self,
1425        _a: &'a GpuTensorHandle,
1426        _b: &'a GpuTensorHandle,
1427    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1428        unsupported_future("elem_sub not supported by provider")
1429    }
1430    fn elem_div<'a>(
1431        &'a self,
1432        _a: &'a GpuTensorHandle,
1433        _b: &'a GpuTensorHandle,
1434    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1435        unsupported_future("elem_div not supported by provider")
1436    }
1437    fn elem_pow<'a>(
1438        &'a self,
1439        _a: &'a GpuTensorHandle,
1440        _b: &'a GpuTensorHandle,
1441    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1442        unsupported_future("elem_pow not supported by provider")
1443    }
1444
1445    fn elem_hypot<'a>(
1446        &'a self,
1447        _a: &'a GpuTensorHandle,
1448        _b: &'a GpuTensorHandle,
1449    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1450        unsupported_future("elem_hypot not supported by provider")
1451    }
1452    fn elem_ge<'a>(
1453        &'a self,
1454        _a: &'a GpuTensorHandle,
1455        _b: &'a GpuTensorHandle,
1456    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1457        unsupported_future("elem_ge not supported by provider")
1458    }
1459    fn elem_le<'a>(
1460        &'a self,
1461        _a: &'a GpuTensorHandle,
1462        _b: &'a GpuTensorHandle,
1463    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1464        unsupported_future("elem_le not supported by provider")
1465    }
1466    fn elem_lt<'a>(
1467        &'a self,
1468        _a: &'a GpuTensorHandle,
1469        _b: &'a GpuTensorHandle,
1470    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1471        unsupported_future("elem_lt not supported by provider")
1472    }
1473    fn elem_gt<'a>(
1474        &'a self,
1475        _a: &'a GpuTensorHandle,
1476        _b: &'a GpuTensorHandle,
1477    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1478        unsupported_future("elem_gt not supported by provider")
1479    }
1480    fn elem_eq<'a>(
1481        &'a self,
1482        _a: &'a GpuTensorHandle,
1483        _b: &'a GpuTensorHandle,
1484    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1485        unsupported_future("elem_eq not supported by provider")
1486    }
1487    fn elem_ne<'a>(
1488        &'a self,
1489        _a: &'a GpuTensorHandle,
1490        _b: &'a GpuTensorHandle,
1491    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1492        unsupported_future("elem_ne not supported by provider")
1493    }
1494    fn logical_and(
1495        &self,
1496        _a: &GpuTensorHandle,
1497        _b: &GpuTensorHandle,
1498    ) -> anyhow::Result<GpuTensorHandle> {
1499        Err(anyhow::anyhow!("logical_and not supported by provider"))
1500    }
1501    fn logical_or(
1502        &self,
1503        _a: &GpuTensorHandle,
1504        _b: &GpuTensorHandle,
1505    ) -> anyhow::Result<GpuTensorHandle> {
1506        Err(anyhow::anyhow!("logical_or not supported by provider"))
1507    }
1508    fn logical_xor(
1509        &self,
1510        _a: &GpuTensorHandle,
1511        _b: &GpuTensorHandle,
1512    ) -> anyhow::Result<GpuTensorHandle> {
1513        Err(anyhow::anyhow!("logical_xor not supported by provider"))
1514    }
1515    fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1516        Err(anyhow::anyhow!("logical_not not supported by provider"))
1517    }
1518    fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1519        Ok(handle_is_logical(a))
1520    }
1521    fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1522        Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1523    }
1524    fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1525        Err(anyhow::anyhow!(
1526            "logical_isfinite not supported by provider"
1527        ))
1528    }
1529    fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1530        Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1531    }
1532    fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1533        Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1534    }
1535    fn elem_atan2<'a>(
1536        &'a self,
1537        _y: &'a GpuTensorHandle,
1538        _x: &'a GpuTensorHandle,
1539    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1540        unsupported_future("elem_atan2 not supported by provider")
1541    }
1542    // Unary elementwise operations (optional)
1543    fn unary_sin<'a>(
1544        &'a self,
1545        _a: &'a GpuTensorHandle,
1546    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1547        unsupported_future("unary_sin not supported by provider")
1548    }
1549    fn unary_sinc<'a>(
1550        &'a self,
1551        _a: &'a GpuTensorHandle,
1552    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1553        unsupported_future("unary_sinc not supported by provider")
1554    }
1555    fn unary_gamma<'a>(
1556        &'a self,
1557        _a: &'a GpuTensorHandle,
1558    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1559        unsupported_future("unary_gamma not supported by provider")
1560    }
1561    fn unary_factorial<'a>(
1562        &'a self,
1563        _a: &'a GpuTensorHandle,
1564    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1565        unsupported_future("unary_factorial not supported by provider")
1566    }
1567    fn unary_asinh<'a>(
1568        &'a self,
1569        _a: &'a GpuTensorHandle,
1570    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1571        unsupported_future("unary_asinh not supported by provider")
1572    }
1573    fn unary_sinh<'a>(
1574        &'a self,
1575        _a: &'a GpuTensorHandle,
1576    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1577        unsupported_future("unary_sinh not supported by provider")
1578    }
1579    fn unary_cosh<'a>(
1580        &'a self,
1581        _a: &'a GpuTensorHandle,
1582    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1583        unsupported_future("unary_cosh not supported by provider")
1584    }
1585    fn unary_asin<'a>(
1586        &'a self,
1587        _a: &'a GpuTensorHandle,
1588    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1589        unsupported_future("unary_asin not supported by provider")
1590    }
1591    fn unary_acos<'a>(
1592        &'a self,
1593        _a: &'a GpuTensorHandle,
1594    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1595        unsupported_future("unary_acos not supported by provider")
1596    }
1597    fn unary_acosh<'a>(
1598        &'a self,
1599        _a: &'a GpuTensorHandle,
1600    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1601        unsupported_future("unary_acosh not supported by provider")
1602    }
1603    fn unary_tan<'a>(
1604        &'a self,
1605        _a: &'a GpuTensorHandle,
1606    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1607        unsupported_future("unary_tan not supported by provider")
1608    }
1609    fn unary_tanh<'a>(
1610        &'a self,
1611        _a: &'a GpuTensorHandle,
1612    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1613        unsupported_future("unary_tanh not supported by provider")
1614    }
1615    fn unary_atan<'a>(
1616        &'a self,
1617        _a: &'a GpuTensorHandle,
1618    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1619        unsupported_future("unary_atan not supported by provider")
1620    }
1621    fn unary_atanh<'a>(
1622        &'a self,
1623        _a: &'a GpuTensorHandle,
1624    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1625        unsupported_future("unary_atanh not supported by provider")
1626    }
1627    fn unary_ceil<'a>(
1628        &'a self,
1629        _a: &'a GpuTensorHandle,
1630    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1631        unsupported_future("unary_ceil not supported by provider")
1632    }
1633    fn unary_floor<'a>(
1634        &'a self,
1635        _a: &'a GpuTensorHandle,
1636    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1637        unsupported_future("unary_floor not supported by provider")
1638    }
1639    fn unary_round<'a>(
1640        &'a self,
1641        _a: &'a GpuTensorHandle,
1642    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1643        unsupported_future("unary_round not supported by provider")
1644    }
1645    fn unary_fix<'a>(
1646        &'a self,
1647        _a: &'a GpuTensorHandle,
1648    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1649        unsupported_future("unary_fix not supported by provider")
1650    }
1651    fn unary_cos<'a>(
1652        &'a self,
1653        _a: &'a GpuTensorHandle,
1654    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1655        unsupported_future("unary_cos not supported by provider")
1656    }
1657    fn unary_angle<'a>(
1658        &'a self,
1659        _a: &'a GpuTensorHandle,
1660    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1661        unsupported_future("unary_angle not supported by provider")
1662    }
1663    fn unary_imag<'a>(
1664        &'a self,
1665        _a: &'a GpuTensorHandle,
1666    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1667        unsupported_future("unary_imag not supported by provider")
1668    }
1669    fn unary_real<'a>(
1670        &'a self,
1671        _a: &'a GpuTensorHandle,
1672    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1673        unsupported_future("unary_real not supported by provider")
1674    }
1675    fn unary_conj<'a>(
1676        &'a self,
1677        _a: &'a GpuTensorHandle,
1678    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1679        unsupported_future("unary_conj not supported by provider")
1680    }
1681    fn unary_abs<'a>(
1682        &'a self,
1683        _a: &'a GpuTensorHandle,
1684    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1685        unsupported_future("unary_abs not supported by provider")
1686    }
1687    fn unary_sign<'a>(
1688        &'a self,
1689        _a: &'a GpuTensorHandle,
1690    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1691        unsupported_future("unary_sign not supported by provider")
1692    }
1693    fn unary_heaviside<'a>(
1694        &'a self,
1695        _a: &'a GpuTensorHandle,
1696    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1697        unsupported_future("unary_heaviside not supported by provider")
1698    }
1699    fn unary_exp<'a>(
1700        &'a self,
1701        _a: &'a GpuTensorHandle,
1702    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1703        unsupported_future("unary_exp not supported by provider")
1704    }
1705    fn unary_expm1<'a>(
1706        &'a self,
1707        _a: &'a GpuTensorHandle,
1708    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1709        unsupported_future("unary_expm1 not supported by provider")
1710    }
1711    fn unary_log<'a>(
1712        &'a self,
1713        _a: &'a GpuTensorHandle,
1714    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1715        unsupported_future("unary_log not supported by provider")
1716    }
1717    fn unary_log2<'a>(
1718        &'a self,
1719        _a: &'a GpuTensorHandle,
1720    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1721        unsupported_future("unary_log2 not supported by provider")
1722    }
1723    fn unary_log10<'a>(
1724        &'a self,
1725        _a: &'a GpuTensorHandle,
1726    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1727        unsupported_future("unary_log10 not supported by provider")
1728    }
1729    fn unary_log1p<'a>(
1730        &'a self,
1731        _a: &'a GpuTensorHandle,
1732    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1733        unsupported_future("unary_log1p not supported by provider")
1734    }
1735    fn unary_sqrt<'a>(
1736        &'a self,
1737        _a: &'a GpuTensorHandle,
1738    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1739        unsupported_future("unary_sqrt not supported by provider")
1740    }
1741    fn unary_double<'a>(
1742        &'a self,
1743        _a: &'a GpuTensorHandle,
1744    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1745        unsupported_future("unary_double not supported by provider")
1746    }
1747    fn unary_single<'a>(
1748        &'a self,
1749        _a: &'a GpuTensorHandle,
1750    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1751        unsupported_future("unary_single not supported by provider")
1752    }
1753    fn unary_pow2<'a>(
1754        &'a self,
1755        _a: &'a GpuTensorHandle,
1756    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1757        unsupported_future("unary_pow2 not supported by provider")
1758    }
1759    fn unary_nextpow2<'a>(
1760        &'a self,
1761        _a: &'a GpuTensorHandle,
1762    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1763        unsupported_future("unary_nextpow2 not supported by provider")
1764    }
1765    fn pow2_scale(
1766        &self,
1767        _mantissa: &GpuTensorHandle,
1768        _exponent: &GpuTensorHandle,
1769    ) -> anyhow::Result<GpuTensorHandle> {
1770        Err(anyhow::anyhow!("pow2_scale not supported by provider"))
1771    }
1772    // Left-scalar operations (broadcast with scalar on the left)
1773    fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1774        Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
1775    }
1776    fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1777        Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
1778    }
1779    // Scalar operations: apply op with scalar right-hand side (broadcast over a)
1780    fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1781        Err(anyhow::anyhow!("scalar_add not supported by provider"))
1782    }
1783    fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1784        Err(anyhow::anyhow!("scalar_sub not supported by provider"))
1785    }
1786    fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1787        Err(anyhow::anyhow!("scalar_mul not supported by provider"))
1788    }
1789    fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1790        Err(anyhow::anyhow!("scalar_max not supported by provider"))
1791    }
1792    fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1793        Err(anyhow::anyhow!("scalar_min not supported by provider"))
1794    }
1795    fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1796        Err(anyhow::anyhow!("scalar_div not supported by provider"))
1797    }
1798    fn sort_dim<'a>(
1799        &'a self,
1800        _a: &'a GpuTensorHandle,
1801        _dim: usize,
1802        _order: SortOrder,
1803        _comparison: SortComparison,
1804    ) -> AccelProviderFuture<'a, SortResult> {
1805        unsupported_future("sort_dim not supported by provider")
1806    }
1807    fn sort_rows<'a>(
1808        &'a self,
1809        _a: &'a GpuTensorHandle,
1810        _columns: &'a [SortRowsColumnSpec],
1811        _comparison: SortComparison,
1812    ) -> AccelProviderFuture<'a, SortResult> {
1813        unsupported_future("sort_rows not supported by provider")
1814    }
1815    fn matmul<'a>(
1816        &'a self,
1817        _a: &'a GpuTensorHandle,
1818        _b: &'a GpuTensorHandle,
1819    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1820        unsupported_future("matmul not supported by provider")
1821    }
1822
1823    fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1824        Err(anyhow::anyhow!("syrk not supported by provider"))
1825    }
1826    fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
1827        Err(anyhow::anyhow!("pagefun not supported by provider"))
1828    }
1829
1830    /// Optional: matrix multiplication with an epilogue applied before store.
1831    ///
1832    /// The default implementation falls back to `matmul` when the epilogue is effectively a no-op
1833    /// (alpha=1, beta=0, no row/col scales), and otherwise returns `Err`.
1834    fn matmul_epilogue<'a>(
1835        &'a self,
1836        a: &'a GpuTensorHandle,
1837        b: &'a GpuTensorHandle,
1838        epilogue: &'a MatmulEpilogue,
1839    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1840        Box::pin(async move {
1841            if epilogue.is_noop() {
1842                return self.matmul(a, b).await;
1843            }
1844            Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
1845        })
1846    }
1847    fn image_normalize<'a>(
1848        &'a self,
1849        _input: &'a GpuTensorHandle,
1850        _desc: &'a ImageNormalizeDescriptor,
1851    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1852        unsupported_future("image_normalize fusion not supported by provider")
1853    }
1854    fn matmul_power_step<'a>(
1855        &'a self,
1856        _lhs: &'a GpuTensorHandle,
1857        _rhs: &'a GpuTensorHandle,
1858        _epilogue: &'a PowerStepEpilogue,
1859    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1860        unsupported_future("matmul_power_step normalization not supported by provider")
1861    }
1862    fn linsolve<'a>(
1863        &'a self,
1864        _lhs: &'a GpuTensorHandle,
1865        _rhs: &'a GpuTensorHandle,
1866        _options: &'a ProviderLinsolveOptions,
1867    ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
1868        unsupported_future("linsolve not supported by provider")
1869    }
1870    fn inv<'a>(
1871        &'a self,
1872        _matrix: &'a GpuTensorHandle,
1873        _options: ProviderInvOptions,
1874    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1875        unsupported_future("inv not supported by provider")
1876    }
1877    fn pinv<'a>(
1878        &'a self,
1879        _matrix: &'a GpuTensorHandle,
1880        _options: ProviderPinvOptions,
1881    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1882        unsupported_future("pinv not supported by provider")
1883    }
1884    fn cond<'a>(
1885        &'a self,
1886        _matrix: &'a GpuTensorHandle,
1887        _norm: ProviderCondNorm,
1888    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1889        Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
1890    }
1891    fn norm<'a>(
1892        &'a self,
1893        _tensor: &'a GpuTensorHandle,
1894        _order: ProviderNormOrder,
1895    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1896        Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
1897    }
1898    fn rank<'a>(
1899        &'a self,
1900        _matrix: &'a GpuTensorHandle,
1901        _tolerance: Option<f64>,
1902    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1903        Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
1904    }
1905    fn rcond<'a>(
1906        &'a self,
1907        _matrix: &'a GpuTensorHandle,
1908    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1909        Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
1910    }
1911    fn mldivide<'a>(
1912        &'a self,
1913        _lhs: &'a GpuTensorHandle,
1914        _rhs: &'a GpuTensorHandle,
1915    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1916        Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
1917    }
1918    fn mrdivide<'a>(
1919        &'a self,
1920        _lhs: &'a GpuTensorHandle,
1921        _rhs: &'a GpuTensorHandle,
1922    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1923        Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
1924    }
1925    fn eig<'a>(
1926        &'a self,
1927        _a: &'a GpuTensorHandle,
1928        _compute_left: bool,
1929    ) -> AccelProviderFuture<'a, ProviderEigResult> {
1930        Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
1931    }
1932    fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
1933        Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
1934    }
1935
1936    fn chol<'a>(
1937        &'a self,
1938        _a: &'a GpuTensorHandle,
1939        _lower: bool,
1940    ) -> AccelProviderFuture<'a, ProviderCholResult> {
1941        Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
1942    }
1943    fn qr<'a>(
1944        &'a self,
1945        _a: &'a GpuTensorHandle,
1946        _options: ProviderQrOptions,
1947    ) -> AccelProviderFuture<'a, ProviderQrResult> {
1948        Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
1949    }
1950    fn take_matmul_sources(
1951        &self,
1952        _product: &GpuTensorHandle,
1953    ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
1954        None
1955    }
1956    fn qr_power_iter<'a>(
1957        &'a self,
1958        product: &'a GpuTensorHandle,
1959        _product_lhs: Option<&'a GpuTensorHandle>,
1960        q_handle: &'a GpuTensorHandle,
1961        options: &'a ProviderQrOptions,
1962    ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
1963        let _ = (product, q_handle, options);
1964        Box::pin(async move { Ok(None) })
1965    }
1966    fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1967        Err(anyhow::anyhow!("transpose not supported by provider"))
1968    }
1969    fn conv1d(
1970        &self,
1971        _signal: &GpuTensorHandle,
1972        _kernel: &GpuTensorHandle,
1973        _options: ProviderConv1dOptions,
1974    ) -> anyhow::Result<GpuTensorHandle> {
1975        Err(anyhow::anyhow!("conv1d not supported by provider"))
1976    }
1977    fn conv2d(
1978        &self,
1979        _signal: &GpuTensorHandle,
1980        _kernel: &GpuTensorHandle,
1981        _mode: ProviderConvMode,
1982    ) -> anyhow::Result<GpuTensorHandle> {
1983        Err(anyhow::anyhow!("conv2d not supported by provider"))
1984    }
1985    fn iir_filter<'a>(
1986        &'a self,
1987        _b: &'a GpuTensorHandle,
1988        _a: &'a GpuTensorHandle,
1989        _x: &'a GpuTensorHandle,
1990        _options: ProviderIirFilterOptions,
1991    ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
1992        Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
1993    }
1994    /// Reorder tensor dimensions according to `order`, expressed as zero-based indices.
1995    fn permute(
1996        &self,
1997        _handle: &GpuTensorHandle,
1998        _order: &[usize],
1999    ) -> anyhow::Result<GpuTensorHandle> {
2000        Err(anyhow::anyhow!("permute not supported by provider"))
2001    }
2002    fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
2003        Err(anyhow::anyhow!("flip not supported by provider"))
2004    }
2005    fn circshift(
2006        &self,
2007        _handle: &GpuTensorHandle,
2008        _shifts: &[isize],
2009    ) -> anyhow::Result<GpuTensorHandle> {
2010        Err(anyhow::anyhow!("circshift not supported by provider"))
2011    }
2012    fn diff_dim(
2013        &self,
2014        _handle: &GpuTensorHandle,
2015        _order: usize,
2016        _dim: usize,
2017    ) -> anyhow::Result<GpuTensorHandle> {
2018        Err(anyhow::anyhow!("diff_dim not supported by provider"))
2019    }
2020    fn gradient_dim(
2021        &self,
2022        _handle: &GpuTensorHandle,
2023        _dim: usize,
2024        _spacing: f64,
2025    ) -> anyhow::Result<GpuTensorHandle> {
2026        Err(anyhow::anyhow!("gradient_dim not supported by provider"))
2027    }
2028    /// Perform an in-place FFT along a zero-based dimension, optionally padding/truncating to `len`.
2029    fn fft_dim<'a>(
2030        &'a self,
2031        _handle: &'a GpuTensorHandle,
2032        _len: Option<usize>,
2033        _dim: usize,
2034    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2035        unsupported_future("fft_dim not supported by provider")
2036    }
2037    fn ifft_dim<'a>(
2038        &'a self,
2039        _handle: &'a GpuTensorHandle,
2040        _len: Option<usize>,
2041        _dim: usize,
2042    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2043        unsupported_future("ifft_dim not supported by provider")
2044    }
2045    fn fft_extract_real<'a>(
2046        &'a self,
2047        _handle: &'a GpuTensorHandle,
2048    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2049        unsupported_future("fft_extract_real not supported by provider")
2050    }
2051    fn unique<'a>(
2052        &'a self,
2053        _handle: &'a GpuTensorHandle,
2054        _options: &'a UniqueOptions,
2055    ) -> AccelProviderFuture<'a, UniqueResult> {
2056        Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
2057    }
2058    fn union<'a>(
2059        &'a self,
2060        _a: &'a GpuTensorHandle,
2061        _b: &'a GpuTensorHandle,
2062        _options: &'a UnionOptions,
2063    ) -> AccelProviderFuture<'a, UnionResult> {
2064        Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
2065    }
2066    fn setdiff<'a>(
2067        &'a self,
2068        _a: &'a GpuTensorHandle,
2069        _b: &'a GpuTensorHandle,
2070        _options: &'a SetdiffOptions,
2071    ) -> AccelProviderFuture<'a, SetdiffResult> {
2072        Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
2073    }
2074    fn ismember<'a>(
2075        &'a self,
2076        _a: &'a GpuTensorHandle,
2077        _b: &'a GpuTensorHandle,
2078        _options: &'a IsMemberOptions,
2079    ) -> AccelProviderFuture<'a, IsMemberResult> {
2080        Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
2081    }
2082    fn reshape(
2083        &self,
2084        handle: &GpuTensorHandle,
2085        new_shape: &[usize],
2086    ) -> anyhow::Result<GpuTensorHandle> {
2087        let mut updated = handle.clone();
2088        updated.shape = new_shape.to_vec();
2089        Ok(updated)
2090    }
2091    /// Concatenate the provided tensors along the 1-based dimension `dim`.
2092    fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
2093        Err(anyhow::anyhow!("cat not supported by provider"))
2094    }
2095    fn repmat(
2096        &self,
2097        _handle: &GpuTensorHandle,
2098        _reps: &[usize],
2099    ) -> anyhow::Result<GpuTensorHandle> {
2100        Err(anyhow::anyhow!("repmat not supported by provider"))
2101    }
2102    /// Compute the Kronecker product of two tensors, matching MATLAB semantics.
2103    fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2104        Err(anyhow::anyhow!("kron not supported by provider"))
2105    }
2106    /// Compute the cross product of 3-element vectors along a matching dimension.
2107    fn cross(
2108        &self,
2109        _lhs: &GpuTensorHandle,
2110        _rhs: &GpuTensorHandle,
2111        _dim: Option<usize>,
2112    ) -> anyhow::Result<GpuTensorHandle> {
2113        Err(anyhow::anyhow!("cross not supported by provider"))
2114    }
2115    fn reduce_sum<'a>(
2116        &'a self,
2117        _a: &'a GpuTensorHandle,
2118    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2119        unsupported_future("reduce_sum not supported by provider")
2120    }
2121    fn reduce_sum_dim<'a>(
2122        &'a self,
2123        _a: &'a GpuTensorHandle,
2124        _dim: usize,
2125    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2126        unsupported_future("reduce_sum_dim not supported by provider")
2127    }
2128    fn dot<'a>(
2129        &'a self,
2130        _lhs: &'a GpuTensorHandle,
2131        _rhs: &'a GpuTensorHandle,
2132        _dim: Option<usize>,
2133    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2134        unsupported_future("dot not supported by provider")
2135    }
2136    fn reduce_nnz<'a>(
2137        &'a self,
2138        _a: &'a GpuTensorHandle,
2139    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2140        unsupported_future("reduce_nnz not supported by provider")
2141    }
2142    fn reduce_nnz_dim<'a>(
2143        &'a self,
2144        _a: &'a GpuTensorHandle,
2145        _dim: usize,
2146    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2147        unsupported_future("reduce_nnz_dim not supported by provider")
2148    }
2149    fn reduce_prod<'a>(
2150        &'a self,
2151        _a: &'a GpuTensorHandle,
2152    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2153        unsupported_future("reduce_prod not supported by provider")
2154    }
2155    fn reduce_prod_dim<'a>(
2156        &'a self,
2157        _a: &'a GpuTensorHandle,
2158        _dim: usize,
2159    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2160        unsupported_future("reduce_prod_dim not supported by provider")
2161    }
2162    fn reduce_mean<'a>(
2163        &'a self,
2164        _a: &'a GpuTensorHandle,
2165    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2166        unsupported_future("reduce_mean not supported by provider")
2167    }
2168    /// Reduce mean across multiple zero-based dimensions in one device pass.
2169    fn reduce_mean_nd<'a>(
2170        &'a self,
2171        _a: &'a GpuTensorHandle,
2172        _dims_zero_based: &'a [usize],
2173    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2174        unsupported_future("reduce_mean_nd not supported by provider")
2175    }
2176    /// Reduce moments across multiple zero-based dimensions in one device pass.
2177    /// Returns mean (E[x]) and mean of squares (E[x^2]).
2178    fn reduce_moments_nd<'a>(
2179        &'a self,
2180        _a: &'a GpuTensorHandle,
2181        _dims_zero_based: &'a [usize],
2182    ) -> AccelProviderFuture<'a, ProviderMoments2> {
2183        unsupported_future("reduce_moments_nd not supported by provider")
2184    }
2185    fn reduce_mean_dim<'a>(
2186        &'a self,
2187        _a: &'a GpuTensorHandle,
2188        _dim: usize,
2189    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2190        unsupported_future("reduce_mean_dim not supported by provider")
2191    }
2192    fn reduce_std<'a>(
2193        &'a self,
2194        _a: &'a GpuTensorHandle,
2195        _normalization: ProviderStdNormalization,
2196        _nan_mode: ProviderNanMode,
2197    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2198        unsupported_future("reduce_std not supported by provider")
2199    }
2200    fn reduce_std_dim<'a>(
2201        &'a self,
2202        _a: &'a GpuTensorHandle,
2203        _dim: usize,
2204        _normalization: ProviderStdNormalization,
2205        _nan_mode: ProviderNanMode,
2206    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2207        unsupported_future("reduce_std_dim not supported by provider")
2208    }
2209    fn reduce_any<'a>(
2210        &'a self,
2211        _a: &'a GpuTensorHandle,
2212        _omit_nan: bool,
2213    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2214        unsupported_future("reduce_any not supported by provider")
2215    }
2216    fn reduce_any_dim<'a>(
2217        &'a self,
2218        _a: &'a GpuTensorHandle,
2219        _dim: usize,
2220        _omit_nan: bool,
2221    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2222        unsupported_future("reduce_any_dim not supported by provider")
2223    }
2224    fn reduce_all<'a>(
2225        &'a self,
2226        _a: &'a GpuTensorHandle,
2227        _omit_nan: bool,
2228    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2229        unsupported_future("reduce_all not supported by provider")
2230    }
2231    fn reduce_all_dim<'a>(
2232        &'a self,
2233        _a: &'a GpuTensorHandle,
2234        _dim: usize,
2235        _omit_nan: bool,
2236    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2237        unsupported_future("reduce_all_dim not supported by provider")
2238    }
2239    fn reduce_median<'a>(
2240        &'a self,
2241        _a: &'a GpuTensorHandle,
2242    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2243        unsupported_future("reduce_median not supported by provider")
2244    }
2245    fn reduce_median_dim<'a>(
2246        &'a self,
2247        _a: &'a GpuTensorHandle,
2248        _dim: usize,
2249    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2250        unsupported_future("reduce_median_dim not supported by provider")
2251    }
2252    fn reduce_min<'a>(
2253        &'a self,
2254        _a: &'a GpuTensorHandle,
2255    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2256        unsupported_future("reduce_min not supported by provider")
2257    }
2258    fn reduce_min_dim<'a>(
2259        &'a self,
2260        _a: &'a GpuTensorHandle,
2261        _dim: usize,
2262    ) -> AccelProviderFuture<'a, ReduceDimResult> {
2263        unsupported_future("reduce_min_dim not supported by provider")
2264    }
2265    fn reduce_max<'a>(
2266        &'a self,
2267        _a: &'a GpuTensorHandle,
2268    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2269        unsupported_future("reduce_max not supported by provider")
2270    }
2271    fn reduce_max_dim<'a>(
2272        &'a self,
2273        _a: &'a GpuTensorHandle,
2274        _dim: usize,
2275    ) -> AccelProviderFuture<'a, ReduceDimResult> {
2276        unsupported_future("reduce_max_dim not supported by provider")
2277    }
2278    fn cumsum_scan(
2279        &self,
2280        _input: &GpuTensorHandle,
2281        _dim: usize,
2282        _direction: ProviderScanDirection,
2283        _nan_mode: ProviderNanMode,
2284    ) -> anyhow::Result<GpuTensorHandle> {
2285        Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2286    }
2287    fn cumprod_scan(
2288        &self,
2289        _input: &GpuTensorHandle,
2290        _dim: usize,
2291        _direction: ProviderScanDirection,
2292        _nan_mode: ProviderNanMode,
2293    ) -> anyhow::Result<GpuTensorHandle> {
2294        Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2295    }
2296    fn cummin_scan(
2297        &self,
2298        _input: &GpuTensorHandle,
2299        _dim: usize,
2300        _direction: ProviderScanDirection,
2301        _nan_mode: ProviderNanMode,
2302    ) -> anyhow::Result<ProviderCumminResult> {
2303        Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2304    }
2305    fn cummax_scan(
2306        &self,
2307        _input: &GpuTensorHandle,
2308        _dim: usize,
2309        _direction: ProviderScanDirection,
2310        _nan_mode: ProviderNanMode,
2311    ) -> anyhow::Result<ProviderCummaxResult> {
2312        Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2313    }
2314
2315    fn find(
2316        &self,
2317        _a: &GpuTensorHandle,
2318        _limit: Option<usize>,
2319        _direction: FindDirection,
2320    ) -> anyhow::Result<ProviderFindResult> {
2321        Err(anyhow::anyhow!("find not supported by provider"))
2322    }
2323
2324    fn fused_elementwise(
2325        &self,
2326        _shader: &str,
2327        _inputs: &[GpuTensorHandle],
2328        _output_shape: &[usize],
2329        _len: usize,
2330    ) -> anyhow::Result<GpuTensorHandle> {
2331        Err(anyhow::anyhow!(
2332            "fused_elementwise not supported by provider"
2333        ))
2334    }
2335
2336    /// Execute a single fused elementwise kernel that writes `num_outputs` output buffers in one
2337    /// dispatch. The shader is expected to declare `output0`, `output1`, … `output{N-1}` storage
2338    /// bindings (at binding indices `inputs.len()` through `inputs.len() + num_outputs - 1`) and a
2339    /// uniform `params` binding at `inputs.len() + num_outputs`.
2340    ///
2341    /// Providers that do not override this method fall back to calling `fused_elementwise` once
2342    /// per output, which preserves correctness at the cost of the O(N²) dispatch overhead this
2343    /// method is designed to eliminate.
2344    fn fused_elementwise_multi(
2345        &self,
2346        _shader: &str,
2347        _inputs: &[GpuTensorHandle],
2348        _output_shape: &[usize],
2349        _len: usize,
2350        _num_outputs: usize,
2351    ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2352        Err(anyhow::anyhow!(
2353            "fused_elementwise_multi not supported by provider"
2354        ))
2355    }
2356
2357    /// Build a numeric tensor where NaNs in `a` are replaced with 0.0 (device side).
2358    fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2359        Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2360    }
2361
2362    /// Build a numeric mask tensor with 1.0 where value is not NaN and 0.0 where value is NaN.
2363    fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2364        Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2365    }
2366
2367    /// Generic fused reduction entrypoint.
2368    ///
2369    /// The shader is expected to implement a column-major reduction across `reduce_len` with
2370    /// `num_slices` independent slices (e.g., columns). Providers should create a uniform buffer
2371    /// compatible with the expected `Params/MParams` struct in the shader and dispatch
2372    /// `num_slices` workgroups with `workgroup_size` threads, or an equivalent strategy.
2373    #[allow(clippy::too_many_arguments)]
2374    fn fused_reduction(
2375        &self,
2376        _shader: &str,
2377        _inputs: &[GpuTensorHandle],
2378        _output_shape: &[usize],
2379        _reduce_len: usize,
2380        _num_slices: usize,
2381        _workgroup_size: u32,
2382        _flavor: ReductionFlavor,
2383    ) -> anyhow::Result<GpuTensorHandle> {
2384        Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2385    }
2386
2387    /// Optionally pre-compile commonly used pipelines to amortize first-dispatch costs.
2388    fn warmup(&self) {}
2389
2390    /// Returns (cache_hits, cache_misses) for fused pipeline cache, if supported.
2391    fn fused_cache_counters(&self) -> (u64, u64) {
2392        (0, 0)
2393    }
2394
2395    /// Returns the duration of the last provider warmup in milliseconds, if known.
2396    fn last_warmup_millis(&self) -> Option<u64> {
2397        None
2398    }
2399
2400    /// Returns a snapshot of provider telemetry counters if supported.
2401    fn telemetry_snapshot(&self) -> ProviderTelemetry {
2402        let (hits, misses) = self.fused_cache_counters();
2403        ProviderTelemetry {
2404            fused_elementwise: ProviderDispatchStats::default(),
2405            fused_reduction: ProviderDispatchStats::default(),
2406            matmul: ProviderDispatchStats::default(),
2407            linsolve: ProviderDispatchStats::default(),
2408            mldivide: ProviderDispatchStats::default(),
2409            mrdivide: ProviderDispatchStats::default(),
2410            upload_bytes: 0,
2411            download_bytes: 0,
2412            solve_fallbacks: Vec::new(),
2413            fusion_cache_hits: hits,
2414            fusion_cache_misses: misses,
2415            bind_group_cache_hits: 0,
2416            bind_group_cache_misses: 0,
2417            bind_group_cache_by_layout: None,
2418            kernel_launches: Vec::new(),
2419        }
2420    }
2421
2422    /// Reset all telemetry counters maintained by the provider, if supported.
2423    fn reset_telemetry(&self) {}
2424
2425    /// Default reduction workgroup size the provider prefers.
2426    fn default_reduction_workgroup_size(&self) -> u32 {
2427        256
2428    }
2429
2430    /// Threshold above which provider will prefer two-pass reduction.
2431    fn two_pass_threshold(&self) -> usize {
2432        1024
2433    }
2434
2435    /// Current two-pass mode preference (auto/forced on/off).
2436    fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2437        ReductionTwoPassMode::Auto
2438    }
2439
2440    /// Fast-path: write a GPU column in a matrix from a GPU vector, returning a new handle.
2441    /// Expected: `values.shape == [rows, 1]` (or `[rows]`) and `col_index < cols`.
2442    fn scatter_column(
2443        &self,
2444        _matrix: &GpuTensorHandle,
2445        _col_index: usize,
2446        _values: &GpuTensorHandle,
2447    ) -> anyhow::Result<GpuTensorHandle> {
2448        Err(anyhow::anyhow!("scatter_column not supported by provider"))
2449    }
2450
2451    /// Fast-path: write a GPU row in a matrix from a GPU vector, returning a new handle.
2452    /// Expected: `values.shape == [1, cols]` (or `[cols]`) and `row_index < rows`.
2453    fn scatter_row(
2454        &self,
2455        _matrix: &GpuTensorHandle,
2456        _row_index: usize,
2457        _values: &GpuTensorHandle,
2458    ) -> anyhow::Result<GpuTensorHandle> {
2459        Err(anyhow::anyhow!("scatter_row not supported by provider"))
2460    }
2461
2462    fn sub2ind(
2463        &self,
2464        _dims: &[usize],
2465        _strides: &[usize],
2466        _inputs: &[&GpuTensorHandle],
2467        _scalar_mask: &[bool],
2468        _len: usize,
2469        _output_shape: &[usize],
2470    ) -> anyhow::Result<GpuTensorHandle> {
2471        Err(anyhow::anyhow!("sub2ind not supported by provider"))
2472    }
2473
2474    /// Returns true if the provider offers a device-side `ind2sub` implementation.
2475    fn supports_ind2sub(&self) -> bool {
2476        false
2477    }
2478
2479    /// Convert linear indices into per-dimension subscripts on the device.
2480    fn ind2sub(
2481        &self,
2482        _dims: &[usize],
2483        _strides: &[usize],
2484        _indices: &GpuTensorHandle,
2485        _total: usize,
2486        _len: usize,
2487        _output_shape: &[usize],
2488    ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2489        Err(anyhow::anyhow!("ind2sub not supported by provider"))
2490    }
2491
2492    /// Determine if a matrix is symmetric (or skew-symmetric) without gathering it to the host.
2493    fn issymmetric(
2494        &self,
2495        _matrix: &GpuTensorHandle,
2496        _kind: ProviderSymmetryKind,
2497        _tolerance: f64,
2498    ) -> anyhow::Result<bool> {
2499        Err(anyhow::anyhow!(
2500            "issymmetric predicate not supported by provider"
2501        ))
2502    }
2503
2504    /// Determine if a matrix is Hermitian (or skew-Hermitian) without gathering it to the host.
2505    fn ishermitian<'a>(
2506        &'a self,
2507        _matrix: &'a GpuTensorHandle,
2508        _kind: ProviderHermitianKind,
2509        _tolerance: f64,
2510    ) -> AccelProviderFuture<'a, bool> {
2511        Box::pin(async move {
2512            Err(anyhow::anyhow!(
2513                "ishermitian predicate not supported by provider"
2514            ))
2515        })
2516    }
2517
2518    /// Inspect the bandwidth of a matrix without gathering it back to the host.
2519    fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2520        Err(anyhow::anyhow!("bandwidth not supported by provider"))
2521    }
2522
2523    /// Compute the symmetric reverse Cuthill-McKee permutation for the matrix.
2524    ///
2525    /// Implementations may execute on the device or gather to the host. The permutation should be
2526    /// returned as zero-based indices.
2527    fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2528        Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2529    }
2530}
2531
2532static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2533    Lazy::new(|| RwLock::new(None));
2534static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2535    Lazy::new(|| RwLock::new(HashMap::new()));
2536static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2537
2538#[cfg(not(target_arch = "wasm32"))]
2539thread_local! {
2540    static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2541}
2542
2543#[cfg(target_arch = "wasm32")]
2544static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2545    Lazy::new(|| Mutex::new(None));
2546
2547#[cfg(not(target_arch = "wasm32"))]
2548fn replace_thread_provider(
2549    provider: Option<&'static dyn AccelProvider>,
2550) -> Option<&'static dyn AccelProvider> {
2551    THREAD_PROVIDER.with(|cell| {
2552        let prev = cell.get();
2553        cell.set(provider);
2554        prev
2555    })
2556}
2557
2558#[cfg(target_arch = "wasm32")]
2559fn replace_thread_provider(
2560    provider: Option<&'static dyn AccelProvider>,
2561) -> Option<&'static dyn AccelProvider> {
2562    let mut slot = WASM_THREAD_PROVIDER
2563        .lock()
2564        .expect("wasm provider mutex poisoned");
2565    let prev = *slot;
2566    *slot = provider;
2567    prev
2568}
2569
2570#[cfg(not(target_arch = "wasm32"))]
2571fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2572    THREAD_PROVIDER.with(|cell| cell.get())
2573}
2574
2575#[cfg(target_arch = "wasm32")]
2576fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2577    WASM_THREAD_PROVIDER
2578        .lock()
2579        .expect("wasm provider mutex poisoned")
2580        .as_ref()
2581        .copied()
2582}
2583
2584/// Register a global acceleration provider.
2585///
2586/// # Safety
2587/// - The caller must guarantee that `p` is valid for the entire program lifetime
2588///   (e.g., a `'static` singleton), as the runtime stores a raw reference globally.
2589/// - Concurrent callers must ensure registration happens once or is properly
2590///   synchronized; this function does not enforce thread-safety for re-registration.
2591pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2592    if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2593        *guard = Some(p);
2594    }
2595    register_provider_for_device(p.device_id(), p);
2596}
2597
2598unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2599    if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2600        guard.insert(device_id, provider);
2601    }
2602}
2603
2604pub fn provider() -> Option<&'static dyn AccelProvider> {
2605    if let Some(p) = current_thread_provider() {
2606        return Some(p);
2607    }
2608    GLOBAL_PROVIDER
2609        .read()
2610        .ok()
2611        .and_then(|guard| guard.as_ref().copied())
2612}
2613
2614/// Clear the globally registered provider. Intended for tests to ensure deterministic behaviour.
2615pub fn clear_provider() {
2616    replace_thread_provider(None);
2617    if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2618        *guard = None;
2619    }
2620    if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2621        map.clear();
2622    }
2623}
2624
2625pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2626    if let Some(registered) = PROVIDER_REGISTRY
2627        .read()
2628        .ok()
2629        .and_then(|guard| guard.get(&device_id).copied())
2630    {
2631        return Some(registered);
2632    }
2633    if let Some(thread_provider) = current_thread_provider() {
2634        if thread_provider.device_id() == device_id {
2635            return Some(thread_provider);
2636        }
2637    }
2638    // Preserve legacy behavior: when no explicit per-device registration exists,
2639    // fall back to the globally active provider regardless of handle device id.
2640    GLOBAL_PROVIDER
2641        .read()
2642        .ok()
2643        .and_then(|guard| guard.as_ref().copied())
2644}
2645
2646pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2647    provider_for_device(handle.device_id)
2648}
2649
2650pub fn spawn_handle_concurrency_for(handle: &GpuTensorHandle) -> Option<SpawnHandleConcurrency> {
2651    provider_for_handle(handle).map(AccelProvider::spawn_handle_concurrency)
2652}
2653
2654pub fn next_device_id() -> u32 {
2655    DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2656}
2657
2658pub struct ThreadProviderGuard {
2659    prev: Option<&'static dyn AccelProvider>,
2660}
2661
2662impl ThreadProviderGuard {
2663    pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2664        let prev = replace_thread_provider(provider);
2665        ThreadProviderGuard { prev }
2666    }
2667}
2668
2669impl Drop for ThreadProviderGuard {
2670    fn drop(&mut self) {
2671        let prev = self.prev.take();
2672        replace_thread_provider(prev);
2673    }
2674}
2675
2676pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2677    replace_thread_provider(provider);
2678}
2679
2680/// Convenience: perform elementwise add via provider if possible; otherwise return None
2681pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2682    if let Some(p) = provider() {
2683        if let Ok(h) = p.elem_add(a, b).await {
2684            return Some(h);
2685        }
2686    }
2687    None
2688}
2689
2690/// Convenience: perform elementwise hypot via provider if possible; otherwise return None
2691pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2692    if let Some(p) = provider() {
2693        if let Ok(h) = p.elem_hypot(a, b).await {
2694            return Some(h);
2695        }
2696    }
2697    None
2698}
2699
2700/// Convenience: perform elementwise max via provider if possible; otherwise return None
2701pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2702    if let Some(p) = provider() {
2703        if let Ok(h) = p.elem_max(a, b).await {
2704            return Some(h);
2705        }
2706    }
2707    None
2708}
2709
2710/// Convenience: perform elementwise min via provider if possible; otherwise return None
2711pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2712    if let Some(p) = provider() {
2713        if let Ok(h) = p.elem_min(a, b).await {
2714            return Some(h);
2715        }
2716    }
2717    None
2718}
2719
2720/// Convenience: perform elementwise atan2 via provider if possible; otherwise return None
2721pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2722    if let Some(p) = provider() {
2723        if let Ok(h) = p.elem_atan2(y, x).await {
2724            return Some(h);
2725        }
2726    }
2727    None
2728}
2729
2730// Minimal host tensor views to avoid depending on runmat-builtins and cycles
2731#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2732pub struct HostTensorOwned {
2733    pub data: Vec<f64>,
2734    pub shape: Vec<usize>,
2735    pub storage: GpuTensorStorage,
2736}
2737
2738#[derive(Debug)]
2739pub struct HostTensorView<'a> {
2740    pub data: &'a [f64],
2741    pub shape: &'a [usize],
2742}
2743
2744/// Lightweight 1-D axis view used by provider meshgrid hooks.
2745#[derive(Debug)]
2746pub struct MeshgridAxisView<'a> {
2747    pub data: &'a [f64],
2748}
2749
2750/// Provider-side meshgrid result containing coordinate tensor handles.
2751#[derive(Debug, Clone)]
2752pub struct ProviderMeshgridResult {
2753    pub outputs: Vec<GpuTensorHandle>,
2754}
2755
2756/// Descriptor for GEMM epilogues applied to `C = A * B` before storing to `C`.
2757///
2758/// Supported operations:
2759/// - Scale by `alpha` and add scalar `beta`.
2760/// - Multiply output by per-row and/or per-column scale vectors (broadcasted).
2761#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
2762pub enum ScaleOp {
2763    Multiply,
2764    Divide,
2765}
2766
2767#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2768pub struct MatmulEpilogue {
2769    /// Scalar multiply applied to each output element.
2770    pub alpha: f64,
2771    /// Scalar add applied to each output element after scaling.
2772    pub beta: f64,
2773    /// Optional per-row scale (length m). When present, output[row, col] *= row_scale[row].
2774    pub row_scale: Option<GpuTensorHandle>,
2775    /// Optional per-column scale (length n). When present, output[row, col] *= col_scale[col].
2776    pub col_scale: Option<GpuTensorHandle>,
2777    /// Row scale operation (multiply or divide). Ignored when `row_scale` is None.
2778    pub row_op: ScaleOp,
2779    /// Column scale operation (multiply or divide). Ignored when `col_scale` is None.
2780    pub col_op: ScaleOp,
2781    /// Optional lower clamp bound applied after scale/bias.
2782    #[serde(default)]
2783    pub clamp_min: Option<f64>,
2784    /// Optional upper clamp bound applied after scale/bias.
2785    #[serde(default)]
2786    pub clamp_max: Option<f64>,
2787    /// Optional power exponent applied after clamp (final operation in the epilogue).
2788    #[serde(default)]
2789    pub pow_exponent: Option<f64>,
2790    /// Optional output buffer for the diagonal of the result (length min(m, n)).
2791    #[serde(default)]
2792    pub diag_output: Option<GpuTensorHandle>,
2793}
2794
2795impl MatmulEpilogue {
2796    pub fn noop() -> Self {
2797        Self {
2798            alpha: 1.0,
2799            beta: 0.0,
2800            row_scale: None,
2801            col_scale: None,
2802            row_op: ScaleOp::Multiply,
2803            col_op: ScaleOp::Multiply,
2804            clamp_min: None,
2805            clamp_max: None,
2806            pow_exponent: None,
2807            diag_output: None,
2808        }
2809    }
2810    pub fn is_noop(&self) -> bool {
2811        self.alpha == 1.0
2812            && self.beta == 0.0
2813            && self.row_scale.is_none()
2814            && self.col_scale.is_none()
2815            && self.clamp_min.is_none()
2816            && self.clamp_max.is_none()
2817            && self.pow_exponent.is_none()
2818            && self.diag_output.is_none()
2819    }
2820}
2821
2822#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
2823pub struct PowerStepEpilogue {
2824    pub epsilon: f64,
2825}
2826
2827impl Default for PowerStepEpilogue {
2828    fn default() -> Self {
2829        Self { epsilon: 0.0 }
2830    }
2831}
2832
2833#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2834pub struct ImageNormalizeDescriptor {
2835    pub batch: usize,
2836    pub height: usize,
2837    pub width: usize,
2838    pub epsilon: f64,
2839    #[serde(default)]
2840    pub gain: Option<f64>,
2841    #[serde(default)]
2842    pub bias: Option<f64>,
2843    #[serde(default)]
2844    pub gamma: Option<f64>,
2845}
2846
2847#[cfg(test)]
2848mod tests {
2849    use super::*;
2850
2851    struct TestProvider {
2852        device_id: u32,
2853        name: &'static str,
2854        spawn_concurrency: SpawnHandleConcurrency,
2855    }
2856
2857    impl AccelProvider for TestProvider {
2858        fn upload(&self, _host: &HostTensorView) -> anyhow::Result<GpuTensorHandle> {
2859            Err(anyhow!("test provider upload should not be called"))
2860        }
2861
2862        fn download<'a>(&'a self, _h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a> {
2863            unsupported_future("test provider download should not be called")
2864        }
2865
2866        fn free(&self, _h: &GpuTensorHandle) -> anyhow::Result<()> {
2867            Err(anyhow!("test provider free should not be called"))
2868        }
2869
2870        fn device_info(&self) -> String {
2871            self.name.to_string()
2872        }
2873
2874        fn device_id(&self) -> u32 {
2875            self.device_id
2876        }
2877
2878        fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
2879            self.spawn_concurrency
2880        }
2881    }
2882
2883    static PROVIDER_TEST_LOCK: Lazy<std::sync::Mutex<()>> = Lazy::new(|| std::sync::Mutex::new(()));
2884    static PROVIDER_A: TestProvider = TestProvider {
2885        device_id: 101,
2886        name: "provider-a",
2887        spawn_concurrency: SpawnHandleConcurrency::ImmutableShare,
2888    };
2889    static PROVIDER_B: TestProvider = TestProvider {
2890        device_id: 202,
2891        name: "provider-b",
2892        spawn_concurrency: SpawnHandleConcurrency::Reject,
2893    };
2894    static PROVIDER_C: TestProvider = TestProvider {
2895        device_id: 303,
2896        name: "provider-c",
2897        spawn_concurrency: SpawnHandleConcurrency::CopyOnWrite,
2898    };
2899
2900    fn register_test_providers() {
2901        clear_provider();
2902        unsafe {
2903            register_provider(&PROVIDER_A);
2904            register_provider(&PROVIDER_B);
2905        }
2906    }
2907
2908    fn test_handle(device_id: u32) -> GpuTensorHandle {
2909        GpuTensorHandle {
2910            shape: vec![1],
2911            device_id,
2912            buffer_id: 42,
2913        }
2914    }
2915
2916    #[test]
2917    fn provider_for_device_prefers_registered_device_over_thread_provider() {
2918        let _lock = PROVIDER_TEST_LOCK
2919            .lock()
2920            .expect("provider test lock poisoned");
2921        register_test_providers();
2922        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
2923
2924        let provider = provider_for_device(PROVIDER_A.device_id()).expect("provider for device");
2925
2926        assert_eq!(provider.device_info(), PROVIDER_A.name);
2927        clear_provider();
2928    }
2929
2930    #[test]
2931    fn provider_for_handle_uses_handle_device_owner() {
2932        let _lock = PROVIDER_TEST_LOCK
2933            .lock()
2934            .expect("provider test lock poisoned");
2935        register_test_providers();
2936        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
2937
2938        let provider =
2939            provider_for_handle(&test_handle(PROVIDER_A.device_id())).expect("provider for handle");
2940
2941        assert_eq!(provider.device_info(), PROVIDER_A.name);
2942        clear_provider();
2943    }
2944
2945    #[test]
2946    fn spawn_handle_concurrency_for_uses_registered_owner() {
2947        let _lock = PROVIDER_TEST_LOCK
2948            .lock()
2949            .expect("provider test lock poisoned");
2950        register_test_providers();
2951        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
2952
2953        let concurrency = spawn_handle_concurrency_for(&test_handle(PROVIDER_A.device_id()))
2954            .expect("spawn concurrency");
2955
2956        assert_eq!(concurrency, PROVIDER_A.spawn_concurrency);
2957        clear_provider();
2958    }
2959
2960    #[test]
2961    fn provider_keeps_thread_local_active_provider_semantics() {
2962        let _lock = PROVIDER_TEST_LOCK
2963            .lock()
2964            .expect("provider test lock poisoned");
2965        register_test_providers();
2966        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_A));
2967
2968        let active = provider().expect("active provider");
2969
2970        assert_eq!(active.device_info(), PROVIDER_A.name);
2971        clear_provider();
2972    }
2973
2974    #[test]
2975    fn unregistered_thread_provider_only_matches_own_device_before_global_fallback() {
2976        let _lock = PROVIDER_TEST_LOCK
2977            .lock()
2978            .expect("provider test lock poisoned");
2979        clear_provider();
2980        unsafe {
2981            register_provider(&PROVIDER_A);
2982        }
2983        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_C));
2984
2985        let own_device = provider_for_device(PROVIDER_C.device_id()).expect("own provider");
2986        let fallback = provider_for_device(404).expect("global fallback provider");
2987
2988        assert_eq!(own_device.device_info(), PROVIDER_C.name);
2989        assert_eq!(fallback.device_info(), PROVIDER_A.name);
2990        clear_provider();
2991    }
2992}