Skip to main content

runmat_accelerate_api/
lib.rs

1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyClearFn = fn(&GpuTensorHandle);
17type SequenceThresholdFn = fn() -> Option<usize>;
18type WorkgroupSizeHintFn = fn() -> Option<u32>;
19
20static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
21static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
22static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
23
24static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
25static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
26    Lazy::new(|| RwLock::new(HashMap::new()));
27static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
28    Lazy::new(|| RwLock::new(HashMap::new()));
29
30static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
31    Lazy::new(|| RwLock::new(HashMap::new()));
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct TransposeInfo {
35    pub base_rows: usize,
36    pub base_cols: usize,
37}
38
39/// Register a callback used to clear residency tracking when GPU tensors are
40/// gathered back to the host. Backends that maintain residency metadata should
41/// install this hook during initialization.
42pub fn register_residency_clear(handler: ResidencyClearFn) {
43    let _ = RESIDENCY_CLEAR.set(handler);
44}
45
46/// Clear residency metadata for the provided GPU tensor handle, if a backend
47/// has registered a handler via [`register_residency_clear`].
48pub fn clear_residency(handle: &GpuTensorHandle) {
49    if let Some(handler) = RESIDENCY_CLEAR.get() {
50        handler(handle);
51    }
52}
53
54/// Register a callback that exposes the current sequence length threshold
55/// derived from the auto-offload planner. Array constructors can use this hint
56/// to decide when to prefer GPU residency automatically.
57pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
58    let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
59}
60
61/// Query the currently registered sequence threshold hint, if any.
62pub fn sequence_threshold_hint() -> Option<usize> {
63    SEQUENCE_THRESHOLD_PROVIDER
64        .get()
65        .and_then(|provider| provider())
66}
67
68/// Register a callback that reports the calibrated workgroup size selected by
69/// the active acceleration provider (if any). Plotting kernels can reuse this
70/// hint to match backend tuning.
71pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
72    let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
73}
74
75/// Query the current workgroup size hint exposed by the provider.
76pub fn workgroup_size_hint() -> Option<u32> {
77    WORKGROUP_SIZE_HINT_PROVIDER
78        .get()
79        .and_then(|provider| provider())
80}
81
82/// Export a shared acceleration context (e.g., the active WGPU device) when the
83/// current provider exposes one.
84pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
85    provider().and_then(|p| p.export_context(kind))
86}
87
88/// Request a provider-owned WGPU buffer for zero-copy consumers. Returns `None`
89/// when the active provider does not expose buffers or does not support the
90/// supplied handle.
91#[cfg(feature = "wgpu")]
92pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
93    provider().and_then(|p| p.export_wgpu_buffer(handle))
94}
95
96/// Record the precision associated with a GPU tensor handle so host operations can
97/// reconstruct the original dtype when gathering back to the CPU.
98pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
99    if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
100        guard.insert(handle.buffer_id, precision);
101    }
102}
103
104/// Look up the recorded precision for a GPU tensor handle, if any.
105pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
106    HANDLE_PRECISIONS
107        .read()
108        .ok()
109        .and_then(|guard| guard.get(&handle.buffer_id).copied())
110}
111
112/// Clear any recorded precision metadata for a GPU tensor handle.
113pub fn clear_handle_precision(handle: &GpuTensorHandle) {
114    if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
115        guard.remove(&handle.buffer_id);
116    }
117}
118
119/// Annotate a GPU tensor handle as logically-typed (`logical` in MATLAB terms)
120/// or clear the logical flag when `logical` is `false`.
121pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
122    if let Ok(mut guard) = LOGICAL_HANDLES.write() {
123        if logical {
124            guard.insert(handle.buffer_id);
125            if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
126                *hits.entry(handle.buffer_id).or_insert(0) += 1;
127            }
128        } else {
129            guard.remove(&handle.buffer_id);
130            if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
131                hits.remove(&handle.buffer_id);
132            }
133        }
134    }
135}
136
137/// Convenience helper for clearing logical annotations explicitly.
138pub fn clear_handle_logical(handle: &GpuTensorHandle) {
139    set_handle_logical(handle, false);
140}
141
142/// Returns true when the supplied handle has been marked as logical.
143pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
144    LOGICAL_HANDLES
145        .read()
146        .map(|guard| guard.contains(&handle.buffer_id))
147        .unwrap_or(false)
148}
149
150pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
151    LOGICAL_HANDLE_HITS
152        .read()
153        .ok()
154        .and_then(|guard| guard.get(&buffer_id).copied())
155}
156
157pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
158    if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
159        guard.insert(
160            handle.buffer_id,
161            TransposeInfo {
162                base_rows,
163                base_cols,
164            },
165        );
166    }
167}
168
169pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
170    if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
171        guard.remove(&handle.buffer_id);
172    }
173}
174
175pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
176    TRANSPOSED_HANDLES
177        .read()
178        .ok()
179        .and_then(|guard| guard.get(&handle.buffer_id).copied())
180}
181
182pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
183    handle_transpose_info(handle).is_some()
184}
185
186#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
187pub struct GpuTensorHandle {
188    pub shape: Vec<usize>,
189    pub device_id: u32,
190    pub buffer_id: u64,
191}
192
193#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
194pub struct ApiDeviceInfo {
195    pub device_id: u32,
196    pub name: String,
197    pub vendor: String,
198    pub memory_bytes: Option<u64>,
199    pub backend: Option<String>,
200}
201
202#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
203pub struct ReduceDimResult {
204    pub values: GpuTensorHandle,
205    pub indices: GpuTensorHandle,
206}
207
208#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
209pub struct ProviderCumminResult {
210    pub values: GpuTensorHandle,
211    pub indices: GpuTensorHandle,
212}
213
214/// Result payload returned by provider-side `cummax` scans.
215///
216/// Alias of [`ProviderCumminResult`] because both operations return the same pair of tensors
217/// (running values and MATLAB-compatible indices).
218pub type ProviderCummaxResult = ProviderCumminResult;
219
220/// Names a shared acceleration context that callers may request (e.g. plotting).
221#[derive(Debug, Clone, Copy, PartialEq, Eq)]
222pub enum AccelContextKind {
223    Plotting,
224}
225
226/// Handle returned by [`export_context`] that describes a shared GPU context.
227#[derive(Clone)]
228pub enum AccelContextHandle {
229    #[cfg(feature = "wgpu")]
230    Wgpu(WgpuContextHandle),
231}
232
233impl AccelContextHandle {
234    /// Returns the underlying WGPU context when available.
235    #[cfg(feature = "wgpu")]
236    pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
237        match self {
238            AccelContextHandle::Wgpu(ctx) => Some(ctx),
239        }
240    }
241}
242
243/// Shared WGPU device/queue pair exported by the acceleration provider.
244#[cfg(feature = "wgpu")]
245#[derive(Clone)]
246pub struct WgpuContextHandle {
247    pub instance: Arc<wgpu::Instance>,
248    pub device: Arc<wgpu::Device>,
249    pub queue: Arc<wgpu::Queue>,
250    pub adapter: Arc<wgpu::Adapter>,
251    pub adapter_info: wgpu::AdapterInfo,
252    pub limits: wgpu::Limits,
253    pub features: wgpu::Features,
254}
255
256/// Borrowed reference to a provider-owned WGPU buffer corresponding to a `GpuTensorHandle`.
257#[cfg(feature = "wgpu")]
258#[derive(Clone)]
259pub struct WgpuBufferRef {
260    pub buffer: Arc<wgpu::Buffer>,
261    pub len: usize,
262    pub shape: Vec<usize>,
263    pub element_size: usize,
264    pub precision: ProviderPrecision,
265}
266
267#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
268pub enum PagefunOp {
269    Mtimes,
270}
271
272#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
273pub struct PagefunRequest {
274    pub op: PagefunOp,
275    pub inputs: Vec<GpuTensorHandle>,
276    pub output_shape: Vec<usize>,
277    pub page_dims: Vec<usize>,
278    pub input_page_dims: Vec<Vec<usize>>,
279}
280
281#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
282pub enum FindDirection {
283    First,
284    Last,
285}
286
287#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
288pub struct ProviderFindResult {
289    pub linear: GpuTensorHandle,
290    pub rows: GpuTensorHandle,
291    pub cols: GpuTensorHandle,
292    pub values: Option<GpuTensorHandle>,
293}
294
295#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
296pub struct ProviderBandwidth {
297    pub lower: u32,
298    pub upper: u32,
299}
300
301#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
302pub enum ProviderSymmetryKind {
303    Symmetric,
304    Skew,
305}
306
307#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
308pub enum ProviderHermitianKind {
309    Hermitian,
310    Skew,
311}
312
313#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
314pub struct ProviderLuResult {
315    pub combined: GpuTensorHandle,
316    pub lower: GpuTensorHandle,
317    pub upper: GpuTensorHandle,
318    pub perm_matrix: GpuTensorHandle,
319    pub perm_vector: GpuTensorHandle,
320}
321
322#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
323pub struct ProviderCholResult {
324    pub factor: GpuTensorHandle,
325    /// MATLAB-compatible failure index (0 indicates success).
326    pub info: u32,
327}
328
329#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
330pub struct ProviderQrResult {
331    pub q: GpuTensorHandle,
332    pub r: GpuTensorHandle,
333    pub perm_matrix: GpuTensorHandle,
334    pub perm_vector: GpuTensorHandle,
335}
336
337#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
338pub struct ProviderQrPowerIterResult {
339    pub q: GpuTensorHandle,
340    pub r: GpuTensorHandle,
341    pub perm_matrix: GpuTensorHandle,
342    pub perm_vector: GpuTensorHandle,
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
346pub struct ProviderLinsolveOptions {
347    pub lower: bool,
348    pub upper: bool,
349    pub rectangular: bool,
350    pub transposed: bool,
351    pub conjugate: bool,
352    pub symmetric: bool,
353    pub posdef: bool,
354    pub rcond: Option<f64>,
355}
356
357#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
358pub struct ProviderLinsolveResult {
359    pub solution: GpuTensorHandle,
360    pub reciprocal_condition: f64,
361}
362
363#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
364pub struct ProviderPinvOptions {
365    pub tolerance: Option<f64>,
366}
367
368#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
369pub struct ProviderPolyvalMu {
370    pub mean: f64,
371    pub scale: f64,
372}
373
374#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
375pub struct ProviderPolyvalOptions {
376    pub mu: Option<ProviderPolyvalMu>,
377}
378
379#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
380pub struct ProviderInvOptions {}
381
382#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
383pub struct ProviderPolyfitResult {
384    pub coefficients: Vec<f64>,
385    pub r_matrix: Vec<f64>,
386    pub normr: f64,
387    pub df: f64,
388    pub mu: [f64; 2],
389}
390
391/// Numerator/denominator payload returned by provider-backed `polyder` quotient rule.
392#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
393pub struct ProviderPolyderQuotient {
394    pub numerator: GpuTensorHandle,
395    pub denominator: GpuTensorHandle,
396}
397
398/// Supported norm specifications for the `cond` builtin.
399#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
400pub enum ProviderCondNorm {
401    Two,
402    One,
403    Inf,
404    Fro,
405}
406
407/// Supported norm orders for the `norm` builtin.
408#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
409pub enum ProviderNormOrder {
410    Two,
411    One,
412    Inf,
413    NegInf,
414    Zero,
415    Fro,
416    Nuc,
417    P(f64),
418}
419
420#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
421pub struct ProviderEigResult {
422    pub eigenvalues: GpuTensorHandle,
423    pub diagonal: GpuTensorHandle,
424    pub right: GpuTensorHandle,
425    pub left: Option<GpuTensorHandle>,
426}
427
428#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
429pub enum ProviderQrPivot {
430    Matrix,
431    Vector,
432}
433
434#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
435pub struct ProviderQrOptions {
436    pub economy: bool,
437    pub pivot: ProviderQrPivot,
438}
439
440impl Default for ProviderQrOptions {
441    fn default() -> Self {
442        Self {
443            economy: false,
444            pivot: ProviderQrPivot::Matrix,
445        }
446    }
447}
448
449#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
450pub enum ProviderPrecision {
451    F32,
452    F64,
453}
454
455#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
456pub enum ReductionTwoPassMode {
457    Auto,
458    ForceOn,
459    ForceOff,
460}
461
462impl ReductionTwoPassMode {
463    pub fn as_str(self) -> &'static str {
464        match self {
465            ReductionTwoPassMode::Auto => "auto",
466            ReductionTwoPassMode::ForceOn => "force_on",
467            ReductionTwoPassMode::ForceOff => "force_off",
468        }
469    }
470}
471
472#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
473pub enum ReductionFlavor {
474    Sum,
475    Mean,
476    CustomScale(f64),
477}
478
479impl ReductionFlavor {
480    pub fn is_mean(self) -> bool {
481        matches!(self, ReductionFlavor::Mean)
482    }
483
484    pub fn scale(self, reduce_len: usize) -> f64 {
485        match self {
486            ReductionFlavor::Sum => 1.0,
487            ReductionFlavor::Mean => {
488                if reduce_len == 0 {
489                    1.0
490                } else {
491                    1.0 / reduce_len as f64
492                }
493            }
494            ReductionFlavor::CustomScale(scale) => scale,
495        }
496    }
497}
498
499/// Normalisation mode for correlation coefficients.
500#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
501pub enum CorrcoefNormalization {
502    Unbiased,
503    Biased,
504}
505
506/// Row-selection strategy for correlation coefficients.
507#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
508pub enum CorrcoefRows {
509    All,
510    Complete,
511    Pairwise,
512}
513
514/// Options controlling provider-backed correlation coefficient computation.
515#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
516pub struct CorrcoefOptions {
517    pub normalization: CorrcoefNormalization,
518    pub rows: CorrcoefRows,
519}
520
521impl Default for CorrcoefOptions {
522    fn default() -> Self {
523        Self {
524            normalization: CorrcoefNormalization::Unbiased,
525            rows: CorrcoefRows::All,
526        }
527    }
528}
529
530/// Normalisation mode used by covariance computations.
531#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
532pub enum CovNormalization {
533    Unbiased,
534    Biased,
535}
536
537/// Row handling strategy for covariance computations.
538#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
539pub enum CovRows {
540    All,
541    OmitRows,
542    PartialRows,
543}
544
545/// Options controlling provider-backed covariance computation.
546#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
547pub struct CovarianceOptions {
548    pub normalization: CovNormalization,
549    pub rows: CovRows,
550    pub has_weight_vector: bool,
551}
552
553impl Default for CovarianceOptions {
554    fn default() -> Self {
555        Self {
556            normalization: CovNormalization::Unbiased,
557            rows: CovRows::All,
558            has_weight_vector: false,
559        }
560    }
561}
562
563/// Normalization strategy used by provider-backed standard deviation reductions.
564#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
565pub enum ProviderStdNormalization {
566    Sample,
567    Population,
568}
569
570/// NaN handling mode for provider-backed reductions.
571#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
572pub enum ProviderNanMode {
573    Include,
574    Omit,
575}
576
577/// Direction used when computing prefix sums on the device.
578#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
579pub enum ProviderScanDirection {
580    Forward,
581    Reverse,
582}
583
584/// Sort direction used by acceleration providers.
585#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
586pub enum SortOrder {
587    Ascend,
588    Descend,
589}
590
591/// Comparison strategy applied during sorting.
592#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
593pub enum SortComparison {
594    Auto,
595    Real,
596    Abs,
597}
598
599/// Host-resident outputs returned by provider-backed sort operations.
600#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
601pub struct SortResult {
602    pub values: HostTensorOwned,
603    pub indices: HostTensorOwned,
604}
605
606#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
607pub struct SortRowsColumnSpec {
608    pub index: usize,
609    pub order: SortOrder,
610}
611
612/// Ordering applied by provider-backed `unique` operations.
613#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
614pub enum UniqueOrder {
615    Sorted,
616    Stable,
617}
618
619/// Occurrence selection for provider-backed `unique` operations.
620#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
621pub enum UniqueOccurrence {
622    First,
623    Last,
624}
625
626/// Options controlling provider-backed `unique` operations.
627#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
628pub struct UniqueOptions {
629    pub rows: bool,
630    pub order: UniqueOrder,
631    pub occurrence: UniqueOccurrence,
632}
633
634/// Host-resident outputs returned by provider-backed `unique` operations.
635#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
636pub struct UniqueResult {
637    pub values: HostTensorOwned,
638    pub ia: HostTensorOwned,
639    pub ic: HostTensorOwned,
640}
641
642/// Ordering applied by provider-backed `union` operations.
643#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
644pub enum UnionOrder {
645    Sorted,
646    Stable,
647}
648
649/// Options controlling provider-backed `union` operations.
650#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
651pub struct UnionOptions {
652    pub rows: bool,
653    pub order: UnionOrder,
654}
655
656/// Host-resident outputs returned by provider-backed `union` operations.
657#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
658pub struct UnionResult {
659    pub values: HostTensorOwned,
660    pub ia: HostTensorOwned,
661    pub ib: HostTensorOwned,
662}
663
664/// Parameterisation of 2-D filters generated by `fspecial`.
665#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
666pub enum FspecialFilter {
667    Average {
668        rows: u32,
669        cols: u32,
670    },
671    Disk {
672        radius: f64,
673        size: u32,
674    },
675    Gaussian {
676        rows: u32,
677        cols: u32,
678        sigma: f64,
679    },
680    Laplacian {
681        alpha: f64,
682    },
683    Log {
684        rows: u32,
685        cols: u32,
686        sigma: f64,
687    },
688    Motion {
689        length: u32,
690        kernel_size: u32,
691        angle_degrees: f64,
692        oversample: u32,
693    },
694    Prewitt,
695    Sobel,
696    Unsharp {
697        alpha: f64,
698    },
699}
700
701/// Request dispatched to acceleration providers for `fspecial` kernels.
702#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
703pub struct FspecialRequest {
704    pub filter: FspecialFilter,
705}
706
707/// Padding strategy used by `imfilter`.
708#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
709pub enum ImfilterPadding {
710    Constant,
711    Replicate,
712    Symmetric,
713    Circular,
714}
715
716/// Output sizing mode used by `imfilter`.
717#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
718pub enum ImfilterShape {
719    Same,
720    Full,
721    Valid,
722}
723
724/// Correlation vs convolution behaviour for `imfilter`.
725#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
726pub enum ImfilterMode {
727    Correlation,
728    Convolution,
729}
730
731/// Options supplied to acceleration providers for `imfilter`.
732#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
733pub struct ImfilterOptions {
734    pub padding: ImfilterPadding,
735    pub constant_value: f64,
736    pub shape: ImfilterShape,
737    pub mode: ImfilterMode,
738}
739
740impl Default for ImfilterOptions {
741    fn default() -> Self {
742        Self {
743            padding: ImfilterPadding::Constant,
744            constant_value: 0.0,
745            shape: ImfilterShape::Same,
746            mode: ImfilterMode::Correlation,
747        }
748    }
749}
750
751/// Ordering applied by provider-backed `setdiff` operations.
752#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
753pub enum SetdiffOrder {
754    Sorted,
755    Stable,
756}
757
758/// Options controlling provider-backed `setdiff` operations.
759#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
760pub struct SetdiffOptions {
761    pub rows: bool,
762    pub order: SetdiffOrder,
763}
764
765/// Host-resident outputs returned by provider-backed `setdiff` operations.
766#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
767pub struct SetdiffResult {
768    pub values: HostTensorOwned,
769    pub ia: HostTensorOwned,
770}
771
772/// Options controlling provider-backed `ismember` operations.
773#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
774pub struct IsMemberOptions {
775    pub rows: bool,
776}
777
778/// Host-resident logical output returned by providers.
779#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
780pub struct HostLogicalOwned {
781    pub data: Vec<u8>,
782    pub shape: Vec<usize>,
783}
784
785/// Host-resident outputs returned by provider-backed `ismember` operations.
786#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
787pub struct IsMemberResult {
788    pub mask: HostLogicalOwned,
789    pub loc: HostTensorOwned,
790}
791
792#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
793pub enum ProviderConvMode {
794    Full,
795    Same,
796    Valid,
797}
798
799#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
800pub enum ProviderConvOrientation {
801    Row,
802    Column,
803}
804
805#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
806pub struct ProviderConv1dOptions {
807    pub mode: ProviderConvMode,
808    pub orientation: ProviderConvOrientation,
809}
810
811#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
812pub struct ProviderIirFilterOptions {
813    /// Zero-based dimension along which filtering should be applied.
814    pub dim: usize,
815    /// Optional initial conditions (state vector) residing on the device.
816    pub zi: Option<GpuTensorHandle>,
817}
818
819#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
820pub struct ProviderIirFilterResult {
821    /// Filtered output tensor, matching the input signal shape.
822    pub output: GpuTensorHandle,
823    /// Final conditions for the filter state (same shape as the requested `zi` layout).
824    pub final_state: Option<GpuTensorHandle>,
825}
826
827#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
828pub struct ProviderMoments2 {
829    pub mean: GpuTensorHandle,
830    pub ex2: GpuTensorHandle,
831}
832
833#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
834pub struct ProviderDispatchStats {
835    /// Number of GPU dispatches recorded for this category.
836    pub count: u64,
837    /// Accumulated wall-clock time of dispatches in nanoseconds (host measured).
838    pub total_wall_time_ns: u64,
839}
840
841#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
842pub struct ProviderTelemetry {
843    pub fused_elementwise: ProviderDispatchStats,
844    pub fused_reduction: ProviderDispatchStats,
845    pub matmul: ProviderDispatchStats,
846    pub upload_bytes: u64,
847    pub download_bytes: u64,
848    pub fusion_cache_hits: u64,
849    pub fusion_cache_misses: u64,
850    pub bind_group_cache_hits: u64,
851    pub bind_group_cache_misses: u64,
852    /// Optional per-layout bind group cache counters (layout tags and their hit/miss counts)
853    pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
854    /// Recent kernel launch metadata (bounded log; newest last)
855    pub kernel_launches: Vec<KernelLaunchTelemetry>,
856}
857
858#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
859pub struct BindGroupLayoutTelemetry {
860    pub tag: String,
861    pub hits: u64,
862    pub misses: u64,
863}
864
865#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
866pub struct KernelAttrTelemetry {
867    pub key: String,
868    pub value: u64,
869}
870
871#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
872pub struct KernelLaunchTelemetry {
873    pub kernel: String,
874    pub precision: Option<String>,
875    pub shape: Vec<KernelAttrTelemetry>,
876    pub tuning: Vec<KernelAttrTelemetry>,
877}
878
879pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
880pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
881
882fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
883    Box::pin(async move { Err(anyhow::anyhow!(message)) })
884}
885
886/// Device/provider interface that backends implement and register into the runtime layer
887pub trait AccelProvider: Send + Sync {
888    fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
889    fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
890    fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
891    fn device_info(&self) -> String;
892    fn device_id(&self) -> u32 {
893        0
894    }
895
896    /// Export a shared GPU context handle, allowing downstream systems (plotting, visualization)
897    /// to reuse the same device/queue without copying tensor data back to the host.
898    fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
899        None
900    }
901
902    /// Export a provider-owned WGPU buffer for zero-copy integrations.
903    #[cfg(feature = "wgpu")]
904    fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
905        let _ = _handle;
906        None
907    }
908
909    /// Gather elements from `source` at the provided zero-based linear `indices`, materialising
910    /// a dense tensor with the specified `output_shape`.
911    fn gather_linear(
912        &self,
913        _source: &GpuTensorHandle,
914        _indices: &[u32],
915        _output_shape: &[usize],
916    ) -> anyhow::Result<GpuTensorHandle> {
917        Err(anyhow::anyhow!("gather_linear not supported by provider"))
918    }
919
920    /// Scatter the contents of `values` into `target` at the provided zero-based linear `indices`.
921    ///
922    /// The provider must ensure `values.len() == indices.len()` and update `target` in place.
923    fn scatter_linear(
924        &self,
925        _target: &GpuTensorHandle,
926        _indices: &[u32],
927        _values: &GpuTensorHandle,
928    ) -> anyhow::Result<()> {
929        Err(anyhow::anyhow!("scatter_linear not supported by provider"))
930    }
931
932    /// Structured device information (optional to override). Default adapts from `device_info()`.
933    fn device_info_struct(&self) -> ApiDeviceInfo {
934        ApiDeviceInfo {
935            device_id: 0,
936            name: self.device_info(),
937            vendor: String::new(),
938            memory_bytes: None,
939            backend: None,
940        }
941    }
942
943    fn precision(&self) -> ProviderPrecision {
944        ProviderPrecision::F64
945    }
946
947    /// Read a single scalar at linear index from a device tensor, returning it as f64.
948    fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
949        Err(anyhow::anyhow!("read_scalar not supported by provider"))
950    }
951
952    /// Allocate a zero-initialised tensor with the provided shape on the device.
953    fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
954        Err(anyhow::anyhow!("zeros not supported by provider"))
955    }
956
957    /// Allocate a one-initialised tensor with the provided shape on the device.
958    fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
959        Err(anyhow::anyhow!("ones not supported by provider"))
960    }
961
962    /// Allocate a zero-initialised tensor matching the prototype tensor.
963    fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
964        self.zeros(&prototype.shape)
965    }
966
967    /// Allocate a tensor filled with a constant value on the device.
968    fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
969        if value == 0.0 {
970            return self.zeros(shape);
971        }
972        if let Ok(base) = self.zeros(shape) {
973            match self.scalar_add(&base, value) {
974                Ok(out) => {
975                    let _ = self.free(&base);
976                    return Ok(out);
977                }
978                Err(_) => {
979                    let _ = self.free(&base);
980                }
981            }
982        }
983        let len: usize = shape.iter().copied().product();
984        let data = vec![value; len];
985        let view = HostTensorView { data: &data, shape };
986        self.upload(&view)
987    }
988
989    /// Allocate a tensor filled with a constant value, matching a prototype's residency.
990    fn fill_like(
991        &self,
992        prototype: &GpuTensorHandle,
993        value: f64,
994    ) -> anyhow::Result<GpuTensorHandle> {
995        if value == 0.0 {
996            return self.zeros_like(prototype);
997        }
998        if let Ok(base) = self.zeros_like(prototype) {
999            match self.scalar_add(&base, value) {
1000                Ok(out) => {
1001                    let _ = self.free(&base);
1002                    return Ok(out);
1003                }
1004                Err(_) => {
1005                    let _ = self.free(&base);
1006                }
1007            }
1008        }
1009        self.fill(&prototype.shape, value)
1010    }
1011
1012    /// Allocate a one-initialised tensor matching the prototype tensor.
1013    fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1014        self.ones(&prototype.shape)
1015    }
1016
1017    /// Allocate an identity tensor with ones along the leading diagonal of the first two axes.
1018    fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1019        Err(anyhow::anyhow!("eye not supported by provider"))
1020    }
1021
1022    /// Allocate an identity tensor matching the prototype tensor's shape.
1023    fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1024        self.eye(&prototype.shape)
1025    }
1026
1027    /// Construct MATLAB-style coordinate grids from axis vectors.
1028    fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1029        Err(anyhow::anyhow!("meshgrid not supported by provider"))
1030    }
1031
1032    /// Construct a diagonal matrix from a vector-like tensor. `offset` matches MATLAB semantics.
1033    fn diag_from_vector(
1034        &self,
1035        _vector: &GpuTensorHandle,
1036        _offset: isize,
1037    ) -> anyhow::Result<GpuTensorHandle> {
1038        Err(anyhow::anyhow!(
1039            "diag_from_vector not supported by provider"
1040        ))
1041    }
1042
1043    /// Extract a diagonal from a matrix-like tensor. The result is always a column vector.
1044    fn diag_extract(
1045        &self,
1046        _matrix: &GpuTensorHandle,
1047        _offset: isize,
1048    ) -> anyhow::Result<GpuTensorHandle> {
1049        Err(anyhow::anyhow!("diag_extract not supported by provider"))
1050    }
1051
1052    /// Apply a lower-triangular mask to the first two dimensions of a tensor.
1053    fn tril<'a>(
1054        &'a self,
1055        _matrix: &'a GpuTensorHandle,
1056        _offset: isize,
1057    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1058        Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1059    }
1060
1061    /// Apply an upper-triangular mask to the first two dimensions of a tensor.
1062    fn triu<'a>(
1063        &'a self,
1064        _matrix: &'a GpuTensorHandle,
1065        _offset: isize,
1066    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1067        Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1068    }
1069
1070    /// Evaluate a polynomial expressed by `coefficients` at each element in `points`.
1071    fn polyval(
1072        &self,
1073        _coefficients: &GpuTensorHandle,
1074        _points: &GpuTensorHandle,
1075        _options: &ProviderPolyvalOptions,
1076    ) -> anyhow::Result<GpuTensorHandle> {
1077        Err(anyhow::anyhow!("polyval not supported by provider"))
1078    }
1079
1080    /// Fit a polynomial of degree `degree` to `(x, y)` samples. Optional weights must match `x`.
1081    fn polyfit<'a>(
1082        &'a self,
1083        _x: &'a GpuTensorHandle,
1084        _y: &'a GpuTensorHandle,
1085        _degree: usize,
1086        _weights: Option<&'a GpuTensorHandle>,
1087    ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1088        Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1089    }
1090
1091    /// Differentiate a polynomial represented as a vector of coefficients.
1092    fn polyder_single<'a>(
1093        &'a self,
1094        _polynomial: &'a GpuTensorHandle,
1095    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1096        Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1097    }
1098
1099    /// Apply the product rule to polynomials `p` and `q`.
1100    fn polyder_product<'a>(
1101        &'a self,
1102        _p: &'a GpuTensorHandle,
1103        _q: &'a GpuTensorHandle,
1104    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1105        Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1106    }
1107
1108    /// Apply the quotient rule to polynomials `u` and `v`.
1109    fn polyder_quotient<'a>(
1110        &'a self,
1111        _u: &'a GpuTensorHandle,
1112        _v: &'a GpuTensorHandle,
1113    ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1114        Box::pin(async move {
1115            Err(anyhow::anyhow!(
1116                "polyder_quotient not supported by provider"
1117            ))
1118        })
1119    }
1120
1121    /// Integrate a polynomial represented as a vector of coefficients and append a constant term.
1122    fn polyint(
1123        &self,
1124        _polynomial: &GpuTensorHandle,
1125        _constant: f64,
1126    ) -> anyhow::Result<GpuTensorHandle> {
1127        Err(anyhow::anyhow!("polyint not supported by provider"))
1128    }
1129
1130    /// Allocate a tensor filled with random values drawn from U(0, 1).
1131    fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1132        Err(anyhow::anyhow!("random_uniform not supported by provider"))
1133    }
1134
1135    /// Allocate a tensor filled with random values matching the prototype shape.
1136    fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1137        self.random_uniform(&prototype.shape)
1138    }
1139
1140    /// Allocate a tensor filled with standard normal (mean 0, stddev 1) random values.
1141    fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1142        Err(anyhow::anyhow!("random_normal not supported by provider"))
1143    }
1144
1145    /// Allocate a tensor of standard normal values matching a prototype's shape.
1146    fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1147        self.random_normal(&prototype.shape)
1148    }
1149
1150    fn stochastic_evolution(
1151        &self,
1152        _state: &GpuTensorHandle,
1153        _drift: f64,
1154        _scale: f64,
1155        _steps: u32,
1156    ) -> anyhow::Result<GpuTensorHandle> {
1157        Err(anyhow::anyhow!(
1158            "stochastic_evolution not supported by provider"
1159        ))
1160    }
1161
1162    /// Set the provider RNG state to align with the host RNG.
1163    fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1164        Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1165    }
1166
1167    /// Generate a 2-D correlation kernel matching MATLAB's `fspecial` builtin.
1168    fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1169        Err(anyhow::anyhow!("fspecial not supported by provider"))
1170    }
1171
1172    /// Apply an N-D correlation/convolution with padding semantics matching MATLAB's `imfilter`.
1173    fn imfilter<'a>(
1174        &'a self,
1175        _image: &'a GpuTensorHandle,
1176        _kernel: &'a GpuTensorHandle,
1177        _options: &'a ImfilterOptions,
1178    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1179        unsupported_future("imfilter not supported by provider")
1180    }
1181
1182    /// Allocate a tensor filled with random integers over an inclusive range.
1183    fn random_integer_range(
1184        &self,
1185        _lower: i64,
1186        _upper: i64,
1187        _shape: &[usize],
1188    ) -> anyhow::Result<GpuTensorHandle> {
1189        Err(anyhow::anyhow!(
1190            "random_integer_range not supported by provider"
1191        ))
1192    }
1193
1194    /// Allocate a random integer tensor matching the prototype shape.
1195    fn random_integer_like(
1196        &self,
1197        prototype: &GpuTensorHandle,
1198        lower: i64,
1199        upper: i64,
1200    ) -> anyhow::Result<GpuTensorHandle> {
1201        self.random_integer_range(lower, upper, &prototype.shape)
1202    }
1203
1204    /// Allocate a random permutation of 1..=n, returning the first k elements.
1205    fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1206        Err(anyhow!("random_permutation not supported by provider"))
1207    }
1208
1209    /// Allocate a random permutation matching the prototype residency.
1210    fn random_permutation_like(
1211        &self,
1212        _prototype: &GpuTensorHandle,
1213        n: usize,
1214        k: usize,
1215    ) -> anyhow::Result<GpuTensorHandle> {
1216        self.random_permutation(n, k)
1217    }
1218
1219    /// Compute a covariance matrix across the columns of `matrix`.
1220    fn covariance<'a>(
1221        &'a self,
1222        _matrix: &'a GpuTensorHandle,
1223        _second: Option<&'a GpuTensorHandle>,
1224        _weights: Option<&'a GpuTensorHandle>,
1225        _options: &'a CovarianceOptions,
1226    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1227        unsupported_future("covariance not supported by provider")
1228    }
1229
1230    /// Compute a correlation coefficient matrix across the columns of `matrix`.
1231    fn corrcoef<'a>(
1232        &'a self,
1233        _matrix: &'a GpuTensorHandle,
1234        _options: &'a CorrcoefOptions,
1235    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1236        unsupported_future("corrcoef not supported by provider")
1237    }
1238
1239    // Optional operator hooks (default to unsupported)
1240    fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1241        Err(anyhow::anyhow!("linspace not supported by provider"))
1242    }
1243    fn elem_add<'a>(
1244        &'a self,
1245        _a: &'a GpuTensorHandle,
1246        _b: &'a GpuTensorHandle,
1247    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1248        unsupported_future("elem_add not supported by provider")
1249    }
1250    fn elem_mul<'a>(
1251        &'a self,
1252        _a: &'a GpuTensorHandle,
1253        _b: &'a GpuTensorHandle,
1254    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1255        unsupported_future("elem_mul not supported by provider")
1256    }
1257    fn elem_max<'a>(
1258        &'a self,
1259        _a: &'a GpuTensorHandle,
1260        _b: &'a GpuTensorHandle,
1261    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1262        unsupported_future("elem_max not supported by provider")
1263    }
1264    fn elem_min<'a>(
1265        &'a self,
1266        _a: &'a GpuTensorHandle,
1267        _b: &'a GpuTensorHandle,
1268    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1269        unsupported_future("elem_min not supported by provider")
1270    }
1271    fn elem_sub<'a>(
1272        &'a self,
1273        _a: &'a GpuTensorHandle,
1274        _b: &'a GpuTensorHandle,
1275    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1276        unsupported_future("elem_sub not supported by provider")
1277    }
1278    fn elem_div<'a>(
1279        &'a self,
1280        _a: &'a GpuTensorHandle,
1281        _b: &'a GpuTensorHandle,
1282    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1283        unsupported_future("elem_div not supported by provider")
1284    }
1285    fn elem_pow<'a>(
1286        &'a self,
1287        _a: &'a GpuTensorHandle,
1288        _b: &'a GpuTensorHandle,
1289    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1290        unsupported_future("elem_pow not supported by provider")
1291    }
1292
1293    fn elem_hypot<'a>(
1294        &'a self,
1295        _a: &'a GpuTensorHandle,
1296        _b: &'a GpuTensorHandle,
1297    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1298        unsupported_future("elem_hypot not supported by provider")
1299    }
1300    fn elem_ge<'a>(
1301        &'a self,
1302        _a: &'a GpuTensorHandle,
1303        _b: &'a GpuTensorHandle,
1304    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1305        unsupported_future("elem_ge not supported by provider")
1306    }
1307    fn elem_le<'a>(
1308        &'a self,
1309        _a: &'a GpuTensorHandle,
1310        _b: &'a GpuTensorHandle,
1311    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1312        unsupported_future("elem_le not supported by provider")
1313    }
1314    fn elem_lt<'a>(
1315        &'a self,
1316        _a: &'a GpuTensorHandle,
1317        _b: &'a GpuTensorHandle,
1318    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1319        unsupported_future("elem_lt not supported by provider")
1320    }
1321    fn elem_gt<'a>(
1322        &'a self,
1323        _a: &'a GpuTensorHandle,
1324        _b: &'a GpuTensorHandle,
1325    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1326        unsupported_future("elem_gt not supported by provider")
1327    }
1328    fn elem_eq<'a>(
1329        &'a self,
1330        _a: &'a GpuTensorHandle,
1331        _b: &'a GpuTensorHandle,
1332    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1333        unsupported_future("elem_eq not supported by provider")
1334    }
1335    fn elem_ne<'a>(
1336        &'a self,
1337        _a: &'a GpuTensorHandle,
1338        _b: &'a GpuTensorHandle,
1339    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1340        unsupported_future("elem_ne not supported by provider")
1341    }
1342    fn logical_and(
1343        &self,
1344        _a: &GpuTensorHandle,
1345        _b: &GpuTensorHandle,
1346    ) -> anyhow::Result<GpuTensorHandle> {
1347        Err(anyhow::anyhow!("logical_and not supported by provider"))
1348    }
1349    fn logical_or(
1350        &self,
1351        _a: &GpuTensorHandle,
1352        _b: &GpuTensorHandle,
1353    ) -> anyhow::Result<GpuTensorHandle> {
1354        Err(anyhow::anyhow!("logical_or not supported by provider"))
1355    }
1356    fn logical_xor(
1357        &self,
1358        _a: &GpuTensorHandle,
1359        _b: &GpuTensorHandle,
1360    ) -> anyhow::Result<GpuTensorHandle> {
1361        Err(anyhow::anyhow!("logical_xor not supported by provider"))
1362    }
1363    fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1364        Err(anyhow::anyhow!("logical_not not supported by provider"))
1365    }
1366    fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1367        Ok(handle_is_logical(a))
1368    }
1369    fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1370        Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1371    }
1372    fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1373        Err(anyhow::anyhow!(
1374            "logical_isfinite not supported by provider"
1375        ))
1376    }
1377    fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1378        Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1379    }
1380    fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1381        Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1382    }
1383    fn elem_atan2<'a>(
1384        &'a self,
1385        _y: &'a GpuTensorHandle,
1386        _x: &'a GpuTensorHandle,
1387    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1388        unsupported_future("elem_atan2 not supported by provider")
1389    }
1390    // Unary elementwise operations (optional)
1391    fn unary_sin<'a>(
1392        &'a self,
1393        _a: &'a GpuTensorHandle,
1394    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1395        unsupported_future("unary_sin not supported by provider")
1396    }
1397    fn unary_gamma<'a>(
1398        &'a self,
1399        _a: &'a GpuTensorHandle,
1400    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1401        unsupported_future("unary_gamma not supported by provider")
1402    }
1403    fn unary_factorial<'a>(
1404        &'a self,
1405        _a: &'a GpuTensorHandle,
1406    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1407        unsupported_future("unary_factorial not supported by provider")
1408    }
1409    fn unary_asinh<'a>(
1410        &'a self,
1411        _a: &'a GpuTensorHandle,
1412    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1413        unsupported_future("unary_asinh not supported by provider")
1414    }
1415    fn unary_sinh<'a>(
1416        &'a self,
1417        _a: &'a GpuTensorHandle,
1418    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1419        unsupported_future("unary_sinh not supported by provider")
1420    }
1421    fn unary_cosh<'a>(
1422        &'a self,
1423        _a: &'a GpuTensorHandle,
1424    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1425        unsupported_future("unary_cosh not supported by provider")
1426    }
1427    fn unary_asin<'a>(
1428        &'a self,
1429        _a: &'a GpuTensorHandle,
1430    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1431        unsupported_future("unary_asin not supported by provider")
1432    }
1433    fn unary_acos<'a>(
1434        &'a self,
1435        _a: &'a GpuTensorHandle,
1436    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1437        unsupported_future("unary_acos not supported by provider")
1438    }
1439    fn unary_acosh<'a>(
1440        &'a self,
1441        _a: &'a GpuTensorHandle,
1442    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1443        unsupported_future("unary_acosh not supported by provider")
1444    }
1445    fn unary_tan<'a>(
1446        &'a self,
1447        _a: &'a GpuTensorHandle,
1448    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1449        unsupported_future("unary_tan not supported by provider")
1450    }
1451    fn unary_tanh<'a>(
1452        &'a self,
1453        _a: &'a GpuTensorHandle,
1454    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1455        unsupported_future("unary_tanh not supported by provider")
1456    }
1457    fn unary_atan<'a>(
1458        &'a self,
1459        _a: &'a GpuTensorHandle,
1460    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1461        unsupported_future("unary_atan not supported by provider")
1462    }
1463    fn unary_atanh<'a>(
1464        &'a self,
1465        _a: &'a GpuTensorHandle,
1466    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1467        unsupported_future("unary_atanh not supported by provider")
1468    }
1469    fn unary_ceil<'a>(
1470        &'a self,
1471        _a: &'a GpuTensorHandle,
1472    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1473        unsupported_future("unary_ceil not supported by provider")
1474    }
1475    fn unary_floor<'a>(
1476        &'a self,
1477        _a: &'a GpuTensorHandle,
1478    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1479        unsupported_future("unary_floor not supported by provider")
1480    }
1481    fn unary_round<'a>(
1482        &'a self,
1483        _a: &'a GpuTensorHandle,
1484    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1485        unsupported_future("unary_round not supported by provider")
1486    }
1487    fn unary_fix<'a>(
1488        &'a self,
1489        _a: &'a GpuTensorHandle,
1490    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1491        unsupported_future("unary_fix not supported by provider")
1492    }
1493    fn unary_cos<'a>(
1494        &'a self,
1495        _a: &'a GpuTensorHandle,
1496    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1497        unsupported_future("unary_cos not supported by provider")
1498    }
1499    fn unary_angle<'a>(
1500        &'a self,
1501        _a: &'a GpuTensorHandle,
1502    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1503        unsupported_future("unary_angle not supported by provider")
1504    }
1505    fn unary_imag<'a>(
1506        &'a self,
1507        _a: &'a GpuTensorHandle,
1508    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1509        unsupported_future("unary_imag not supported by provider")
1510    }
1511    fn unary_real<'a>(
1512        &'a self,
1513        _a: &'a GpuTensorHandle,
1514    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1515        unsupported_future("unary_real not supported by provider")
1516    }
1517    fn unary_conj<'a>(
1518        &'a self,
1519        _a: &'a GpuTensorHandle,
1520    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1521        unsupported_future("unary_conj not supported by provider")
1522    }
1523    fn unary_abs<'a>(
1524        &'a self,
1525        _a: &'a GpuTensorHandle,
1526    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1527        unsupported_future("unary_abs not supported by provider")
1528    }
1529    fn unary_sign<'a>(
1530        &'a self,
1531        _a: &'a GpuTensorHandle,
1532    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1533        unsupported_future("unary_sign not supported by provider")
1534    }
1535    fn unary_exp<'a>(
1536        &'a self,
1537        _a: &'a GpuTensorHandle,
1538    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1539        unsupported_future("unary_exp not supported by provider")
1540    }
1541    fn unary_expm1<'a>(
1542        &'a self,
1543        _a: &'a GpuTensorHandle,
1544    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1545        unsupported_future("unary_expm1 not supported by provider")
1546    }
1547    fn unary_log<'a>(
1548        &'a self,
1549        _a: &'a GpuTensorHandle,
1550    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1551        unsupported_future("unary_log not supported by provider")
1552    }
1553    fn unary_log2<'a>(
1554        &'a self,
1555        _a: &'a GpuTensorHandle,
1556    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1557        unsupported_future("unary_log2 not supported by provider")
1558    }
1559    fn unary_log10<'a>(
1560        &'a self,
1561        _a: &'a GpuTensorHandle,
1562    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1563        unsupported_future("unary_log10 not supported by provider")
1564    }
1565    fn unary_log1p<'a>(
1566        &'a self,
1567        _a: &'a GpuTensorHandle,
1568    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1569        unsupported_future("unary_log1p not supported by provider")
1570    }
1571    fn unary_sqrt<'a>(
1572        &'a self,
1573        _a: &'a GpuTensorHandle,
1574    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1575        unsupported_future("unary_sqrt not supported by provider")
1576    }
1577    fn unary_double<'a>(
1578        &'a self,
1579        _a: &'a GpuTensorHandle,
1580    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1581        unsupported_future("unary_double not supported by provider")
1582    }
1583    fn unary_single<'a>(
1584        &'a self,
1585        _a: &'a GpuTensorHandle,
1586    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1587        unsupported_future("unary_single not supported by provider")
1588    }
1589    fn unary_pow2<'a>(
1590        &'a self,
1591        _a: &'a GpuTensorHandle,
1592    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1593        unsupported_future("unary_pow2 not supported by provider")
1594    }
1595    fn pow2_scale(
1596        &self,
1597        _mantissa: &GpuTensorHandle,
1598        _exponent: &GpuTensorHandle,
1599    ) -> anyhow::Result<GpuTensorHandle> {
1600        Err(anyhow::anyhow!("pow2_scale not supported by provider"))
1601    }
1602    // Left-scalar operations (broadcast with scalar on the left)
1603    fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1604        Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
1605    }
1606    fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1607        Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
1608    }
1609    // Scalar operations: apply op with scalar right-hand side (broadcast over a)
1610    fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1611        Err(anyhow::anyhow!("scalar_add not supported by provider"))
1612    }
1613    fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1614        Err(anyhow::anyhow!("scalar_sub not supported by provider"))
1615    }
1616    fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1617        Err(anyhow::anyhow!("scalar_mul not supported by provider"))
1618    }
1619    fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1620        Err(anyhow::anyhow!("scalar_max not supported by provider"))
1621    }
1622    fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1623        Err(anyhow::anyhow!("scalar_min not supported by provider"))
1624    }
1625    fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1626        Err(anyhow::anyhow!("scalar_div not supported by provider"))
1627    }
1628    fn sort_dim<'a>(
1629        &'a self,
1630        _a: &'a GpuTensorHandle,
1631        _dim: usize,
1632        _order: SortOrder,
1633        _comparison: SortComparison,
1634    ) -> AccelProviderFuture<'a, SortResult> {
1635        unsupported_future("sort_dim not supported by provider")
1636    }
1637    fn sort_rows<'a>(
1638        &'a self,
1639        _a: &'a GpuTensorHandle,
1640        _columns: &'a [SortRowsColumnSpec],
1641        _comparison: SortComparison,
1642    ) -> AccelProviderFuture<'a, SortResult> {
1643        unsupported_future("sort_rows not supported by provider")
1644    }
1645    fn matmul<'a>(
1646        &'a self,
1647        _a: &'a GpuTensorHandle,
1648        _b: &'a GpuTensorHandle,
1649    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1650        unsupported_future("matmul not supported by provider")
1651    }
1652
1653    fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1654        Err(anyhow::anyhow!("syrk not supported by provider"))
1655    }
1656    fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
1657        Err(anyhow::anyhow!("pagefun not supported by provider"))
1658    }
1659
1660    /// Optional: matrix multiplication with an epilogue applied before store.
1661    ///
1662    /// The default implementation falls back to `matmul` when the epilogue is effectively a no-op
1663    /// (alpha=1, beta=0, no row/col scales), and otherwise returns `Err`.
1664    fn matmul_epilogue<'a>(
1665        &'a self,
1666        a: &'a GpuTensorHandle,
1667        b: &'a GpuTensorHandle,
1668        epilogue: &'a MatmulEpilogue,
1669    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1670        Box::pin(async move {
1671            if epilogue.is_noop() {
1672                return self.matmul(a, b).await;
1673            }
1674            Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
1675        })
1676    }
1677    fn image_normalize<'a>(
1678        &'a self,
1679        _input: &'a GpuTensorHandle,
1680        _desc: &'a ImageNormalizeDescriptor,
1681    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1682        unsupported_future("image_normalize fusion not supported by provider")
1683    }
1684    fn matmul_power_step<'a>(
1685        &'a self,
1686        _lhs: &'a GpuTensorHandle,
1687        _rhs: &'a GpuTensorHandle,
1688        _epilogue: &'a PowerStepEpilogue,
1689    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1690        unsupported_future("matmul_power_step normalization not supported by provider")
1691    }
1692    fn linsolve<'a>(
1693        &'a self,
1694        _lhs: &'a GpuTensorHandle,
1695        _rhs: &'a GpuTensorHandle,
1696        _options: &'a ProviderLinsolveOptions,
1697    ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
1698        unsupported_future("linsolve not supported by provider")
1699    }
1700    fn inv<'a>(
1701        &'a self,
1702        _matrix: &'a GpuTensorHandle,
1703        _options: ProviderInvOptions,
1704    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1705        unsupported_future("inv not supported by provider")
1706    }
1707    fn pinv<'a>(
1708        &'a self,
1709        _matrix: &'a GpuTensorHandle,
1710        _options: ProviderPinvOptions,
1711    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1712        unsupported_future("pinv not supported by provider")
1713    }
1714    fn cond<'a>(
1715        &'a self,
1716        _matrix: &'a GpuTensorHandle,
1717        _norm: ProviderCondNorm,
1718    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1719        Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
1720    }
1721    fn norm<'a>(
1722        &'a self,
1723        _tensor: &'a GpuTensorHandle,
1724        _order: ProviderNormOrder,
1725    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1726        Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
1727    }
1728    fn rank<'a>(
1729        &'a self,
1730        _matrix: &'a GpuTensorHandle,
1731        _tolerance: Option<f64>,
1732    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1733        Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
1734    }
1735    fn rcond<'a>(
1736        &'a self,
1737        _matrix: &'a GpuTensorHandle,
1738    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1739        Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
1740    }
1741    fn mldivide<'a>(
1742        &'a self,
1743        _lhs: &'a GpuTensorHandle,
1744        _rhs: &'a GpuTensorHandle,
1745    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1746        Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
1747    }
1748    fn mrdivide<'a>(
1749        &'a self,
1750        _lhs: &'a GpuTensorHandle,
1751        _rhs: &'a GpuTensorHandle,
1752    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1753        Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
1754    }
1755    fn eig<'a>(
1756        &'a self,
1757        _a: &'a GpuTensorHandle,
1758        _compute_left: bool,
1759    ) -> AccelProviderFuture<'a, ProviderEigResult> {
1760        Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
1761    }
1762    fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
1763        Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
1764    }
1765
1766    fn chol<'a>(
1767        &'a self,
1768        _a: &'a GpuTensorHandle,
1769        _lower: bool,
1770    ) -> AccelProviderFuture<'a, ProviderCholResult> {
1771        Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
1772    }
1773    fn qr<'a>(
1774        &'a self,
1775        _a: &'a GpuTensorHandle,
1776        _options: ProviderQrOptions,
1777    ) -> AccelProviderFuture<'a, ProviderQrResult> {
1778        Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
1779    }
1780    fn take_matmul_sources(
1781        &self,
1782        _product: &GpuTensorHandle,
1783    ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
1784        None
1785    }
1786    fn qr_power_iter<'a>(
1787        &'a self,
1788        product: &'a GpuTensorHandle,
1789        _product_lhs: Option<&'a GpuTensorHandle>,
1790        q_handle: &'a GpuTensorHandle,
1791        options: &'a ProviderQrOptions,
1792    ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
1793        let _ = (product, q_handle, options);
1794        Box::pin(async move { Ok(None) })
1795    }
1796    fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1797        Err(anyhow::anyhow!("transpose not supported by provider"))
1798    }
1799    fn conv1d(
1800        &self,
1801        _signal: &GpuTensorHandle,
1802        _kernel: &GpuTensorHandle,
1803        _options: ProviderConv1dOptions,
1804    ) -> anyhow::Result<GpuTensorHandle> {
1805        Err(anyhow::anyhow!("conv1d not supported by provider"))
1806    }
1807    fn conv2d(
1808        &self,
1809        _signal: &GpuTensorHandle,
1810        _kernel: &GpuTensorHandle,
1811        _mode: ProviderConvMode,
1812    ) -> anyhow::Result<GpuTensorHandle> {
1813        Err(anyhow::anyhow!("conv2d not supported by provider"))
1814    }
1815    fn iir_filter<'a>(
1816        &'a self,
1817        _b: &'a GpuTensorHandle,
1818        _a: &'a GpuTensorHandle,
1819        _x: &'a GpuTensorHandle,
1820        _options: ProviderIirFilterOptions,
1821    ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
1822        Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
1823    }
1824    /// Reorder tensor dimensions according to `order`, expressed as zero-based indices.
1825    fn permute(
1826        &self,
1827        _handle: &GpuTensorHandle,
1828        _order: &[usize],
1829    ) -> anyhow::Result<GpuTensorHandle> {
1830        Err(anyhow::anyhow!("permute not supported by provider"))
1831    }
1832    fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1833        Err(anyhow::anyhow!("flip not supported by provider"))
1834    }
1835    fn circshift(
1836        &self,
1837        _handle: &GpuTensorHandle,
1838        _shifts: &[isize],
1839    ) -> anyhow::Result<GpuTensorHandle> {
1840        Err(anyhow::anyhow!("circshift not supported by provider"))
1841    }
1842    fn diff_dim(
1843        &self,
1844        _handle: &GpuTensorHandle,
1845        _order: usize,
1846        _dim: usize,
1847    ) -> anyhow::Result<GpuTensorHandle> {
1848        Err(anyhow::anyhow!("diff_dim not supported by provider"))
1849    }
1850    /// Perform an in-place FFT along a zero-based dimension, optionally padding/truncating to `len`.
1851    fn fft_dim<'a>(
1852        &'a self,
1853        _handle: &'a GpuTensorHandle,
1854        _len: Option<usize>,
1855        _dim: usize,
1856    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1857        unsupported_future("fft_dim not supported by provider")
1858    }
1859    fn ifft_dim<'a>(
1860        &'a self,
1861        _handle: &'a GpuTensorHandle,
1862        _len: Option<usize>,
1863        _dim: usize,
1864    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1865        unsupported_future("ifft_dim not supported by provider")
1866    }
1867    fn unique<'a>(
1868        &'a self,
1869        _handle: &'a GpuTensorHandle,
1870        _options: &'a UniqueOptions,
1871    ) -> AccelProviderFuture<'a, UniqueResult> {
1872        Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
1873    }
1874    fn union<'a>(
1875        &'a self,
1876        _a: &'a GpuTensorHandle,
1877        _b: &'a GpuTensorHandle,
1878        _options: &'a UnionOptions,
1879    ) -> AccelProviderFuture<'a, UnionResult> {
1880        Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
1881    }
1882    fn setdiff<'a>(
1883        &'a self,
1884        _a: &'a GpuTensorHandle,
1885        _b: &'a GpuTensorHandle,
1886        _options: &'a SetdiffOptions,
1887    ) -> AccelProviderFuture<'a, SetdiffResult> {
1888        Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
1889    }
1890    fn ismember<'a>(
1891        &'a self,
1892        _a: &'a GpuTensorHandle,
1893        _b: &'a GpuTensorHandle,
1894        _options: &'a IsMemberOptions,
1895    ) -> AccelProviderFuture<'a, IsMemberResult> {
1896        Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
1897    }
1898    fn reshape(
1899        &self,
1900        handle: &GpuTensorHandle,
1901        new_shape: &[usize],
1902    ) -> anyhow::Result<GpuTensorHandle> {
1903        let mut updated = handle.clone();
1904        updated.shape = new_shape.to_vec();
1905        Ok(updated)
1906    }
1907    /// Concatenate the provided tensors along the 1-based dimension `dim`.
1908    fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
1909        Err(anyhow::anyhow!("cat not supported by provider"))
1910    }
1911    fn repmat(
1912        &self,
1913        _handle: &GpuTensorHandle,
1914        _reps: &[usize],
1915    ) -> anyhow::Result<GpuTensorHandle> {
1916        Err(anyhow::anyhow!("repmat not supported by provider"))
1917    }
1918    /// Compute the Kronecker product of two tensors, matching MATLAB semantics.
1919    fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1920        Err(anyhow::anyhow!("kron not supported by provider"))
1921    }
1922    fn reduce_sum<'a>(
1923        &'a self,
1924        _a: &'a GpuTensorHandle,
1925    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1926        unsupported_future("reduce_sum not supported by provider")
1927    }
1928    fn reduce_sum_dim<'a>(
1929        &'a self,
1930        _a: &'a GpuTensorHandle,
1931        _dim: usize,
1932    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1933        unsupported_future("reduce_sum_dim not supported by provider")
1934    }
1935    fn dot<'a>(
1936        &'a self,
1937        _lhs: &'a GpuTensorHandle,
1938        _rhs: &'a GpuTensorHandle,
1939        _dim: Option<usize>,
1940    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1941        unsupported_future("dot not supported by provider")
1942    }
1943    fn reduce_nnz<'a>(
1944        &'a self,
1945        _a: &'a GpuTensorHandle,
1946    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1947        unsupported_future("reduce_nnz not supported by provider")
1948    }
1949    fn reduce_nnz_dim<'a>(
1950        &'a self,
1951        _a: &'a GpuTensorHandle,
1952        _dim: usize,
1953    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1954        unsupported_future("reduce_nnz_dim not supported by provider")
1955    }
1956    fn reduce_prod<'a>(
1957        &'a self,
1958        _a: &'a GpuTensorHandle,
1959    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1960        unsupported_future("reduce_prod not supported by provider")
1961    }
1962    fn reduce_prod_dim<'a>(
1963        &'a self,
1964        _a: &'a GpuTensorHandle,
1965        _dim: usize,
1966    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1967        unsupported_future("reduce_prod_dim not supported by provider")
1968    }
1969    fn reduce_mean<'a>(
1970        &'a self,
1971        _a: &'a GpuTensorHandle,
1972    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1973        unsupported_future("reduce_mean not supported by provider")
1974    }
1975    /// Reduce mean across multiple zero-based dimensions in one device pass.
1976    fn reduce_mean_nd<'a>(
1977        &'a self,
1978        _a: &'a GpuTensorHandle,
1979        _dims_zero_based: &'a [usize],
1980    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1981        unsupported_future("reduce_mean_nd not supported by provider")
1982    }
1983    /// Reduce moments across multiple zero-based dimensions in one device pass.
1984    /// Returns mean (E[x]) and mean of squares (E[x^2]).
1985    fn reduce_moments_nd<'a>(
1986        &'a self,
1987        _a: &'a GpuTensorHandle,
1988        _dims_zero_based: &'a [usize],
1989    ) -> AccelProviderFuture<'a, ProviderMoments2> {
1990        unsupported_future("reduce_moments_nd not supported by provider")
1991    }
1992    fn reduce_mean_dim<'a>(
1993        &'a self,
1994        _a: &'a GpuTensorHandle,
1995        _dim: usize,
1996    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1997        unsupported_future("reduce_mean_dim not supported by provider")
1998    }
1999    fn reduce_std<'a>(
2000        &'a self,
2001        _a: &'a GpuTensorHandle,
2002        _normalization: ProviderStdNormalization,
2003        _nan_mode: ProviderNanMode,
2004    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2005        unsupported_future("reduce_std not supported by provider")
2006    }
2007    fn reduce_std_dim<'a>(
2008        &'a self,
2009        _a: &'a GpuTensorHandle,
2010        _dim: usize,
2011        _normalization: ProviderStdNormalization,
2012        _nan_mode: ProviderNanMode,
2013    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2014        unsupported_future("reduce_std_dim not supported by provider")
2015    }
2016    fn reduce_any<'a>(
2017        &'a self,
2018        _a: &'a GpuTensorHandle,
2019        _omit_nan: bool,
2020    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2021        unsupported_future("reduce_any not supported by provider")
2022    }
2023    fn reduce_any_dim<'a>(
2024        &'a self,
2025        _a: &'a GpuTensorHandle,
2026        _dim: usize,
2027        _omit_nan: bool,
2028    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2029        unsupported_future("reduce_any_dim not supported by provider")
2030    }
2031    fn reduce_all<'a>(
2032        &'a self,
2033        _a: &'a GpuTensorHandle,
2034        _omit_nan: bool,
2035    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2036        unsupported_future("reduce_all not supported by provider")
2037    }
2038    fn reduce_all_dim<'a>(
2039        &'a self,
2040        _a: &'a GpuTensorHandle,
2041        _dim: usize,
2042        _omit_nan: bool,
2043    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2044        unsupported_future("reduce_all_dim not supported by provider")
2045    }
2046    fn reduce_median<'a>(
2047        &'a self,
2048        _a: &'a GpuTensorHandle,
2049    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2050        unsupported_future("reduce_median not supported by provider")
2051    }
2052    fn reduce_median_dim<'a>(
2053        &'a self,
2054        _a: &'a GpuTensorHandle,
2055        _dim: usize,
2056    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2057        unsupported_future("reduce_median_dim not supported by provider")
2058    }
2059    fn reduce_min<'a>(
2060        &'a self,
2061        _a: &'a GpuTensorHandle,
2062    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2063        unsupported_future("reduce_min not supported by provider")
2064    }
2065    fn reduce_min_dim<'a>(
2066        &'a self,
2067        _a: &'a GpuTensorHandle,
2068        _dim: usize,
2069    ) -> AccelProviderFuture<'a, ReduceDimResult> {
2070        unsupported_future("reduce_min_dim not supported by provider")
2071    }
2072    fn reduce_max<'a>(
2073        &'a self,
2074        _a: &'a GpuTensorHandle,
2075    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2076        unsupported_future("reduce_max not supported by provider")
2077    }
2078    fn reduce_max_dim<'a>(
2079        &'a self,
2080        _a: &'a GpuTensorHandle,
2081        _dim: usize,
2082    ) -> AccelProviderFuture<'a, ReduceDimResult> {
2083        unsupported_future("reduce_max_dim not supported by provider")
2084    }
2085    fn cumsum_scan(
2086        &self,
2087        _input: &GpuTensorHandle,
2088        _dim: usize,
2089        _direction: ProviderScanDirection,
2090        _nan_mode: ProviderNanMode,
2091    ) -> anyhow::Result<GpuTensorHandle> {
2092        Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2093    }
2094    fn cumprod_scan(
2095        &self,
2096        _input: &GpuTensorHandle,
2097        _dim: usize,
2098        _direction: ProviderScanDirection,
2099        _nan_mode: ProviderNanMode,
2100    ) -> anyhow::Result<GpuTensorHandle> {
2101        Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2102    }
2103    fn cummin_scan(
2104        &self,
2105        _input: &GpuTensorHandle,
2106        _dim: usize,
2107        _direction: ProviderScanDirection,
2108        _nan_mode: ProviderNanMode,
2109    ) -> anyhow::Result<ProviderCumminResult> {
2110        Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2111    }
2112    fn cummax_scan(
2113        &self,
2114        _input: &GpuTensorHandle,
2115        _dim: usize,
2116        _direction: ProviderScanDirection,
2117        _nan_mode: ProviderNanMode,
2118    ) -> anyhow::Result<ProviderCummaxResult> {
2119        Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2120    }
2121
2122    fn find(
2123        &self,
2124        _a: &GpuTensorHandle,
2125        _limit: Option<usize>,
2126        _direction: FindDirection,
2127    ) -> anyhow::Result<ProviderFindResult> {
2128        Err(anyhow::anyhow!("find not supported by provider"))
2129    }
2130
2131    fn fused_elementwise(
2132        &self,
2133        _shader: &str,
2134        _inputs: &[GpuTensorHandle],
2135        _output_shape: &[usize],
2136        _len: usize,
2137    ) -> anyhow::Result<GpuTensorHandle> {
2138        Err(anyhow::anyhow!(
2139            "fused_elementwise not supported by provider"
2140        ))
2141    }
2142
2143    /// Build a numeric tensor where NaNs in `a` are replaced with 0.0 (device side).
2144    fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2145        Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2146    }
2147
2148    /// Build a numeric mask tensor with 1.0 where value is not NaN and 0.0 where value is NaN.
2149    fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2150        Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2151    }
2152
2153    /// Generic fused reduction entrypoint.
2154    ///
2155    /// The shader is expected to implement a column-major reduction across `reduce_len` with
2156    /// `num_slices` independent slices (e.g., columns). Providers should create a uniform buffer
2157    /// compatible with the expected `Params/MParams` struct in the shader and dispatch
2158    /// `num_slices` workgroups with `workgroup_size` threads, or an equivalent strategy.
2159    #[allow(clippy::too_many_arguments)]
2160    fn fused_reduction(
2161        &self,
2162        _shader: &str,
2163        _inputs: &[GpuTensorHandle],
2164        _output_shape: &[usize],
2165        _reduce_len: usize,
2166        _num_slices: usize,
2167        _workgroup_size: u32,
2168        _flavor: ReductionFlavor,
2169    ) -> anyhow::Result<GpuTensorHandle> {
2170        Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2171    }
2172
2173    /// Optionally pre-compile commonly used pipelines to amortize first-dispatch costs.
2174    fn warmup(&self) {}
2175
2176    /// Returns (cache_hits, cache_misses) for fused pipeline cache, if supported.
2177    fn fused_cache_counters(&self) -> (u64, u64) {
2178        (0, 0)
2179    }
2180
2181    /// Returns the duration of the last provider warmup in milliseconds, if known.
2182    fn last_warmup_millis(&self) -> Option<u64> {
2183        None
2184    }
2185
2186    /// Returns a snapshot of provider telemetry counters if supported.
2187    fn telemetry_snapshot(&self) -> ProviderTelemetry {
2188        let (hits, misses) = self.fused_cache_counters();
2189        ProviderTelemetry {
2190            fused_elementwise: ProviderDispatchStats::default(),
2191            fused_reduction: ProviderDispatchStats::default(),
2192            matmul: ProviderDispatchStats::default(),
2193            upload_bytes: 0,
2194            download_bytes: 0,
2195            fusion_cache_hits: hits,
2196            fusion_cache_misses: misses,
2197            bind_group_cache_hits: 0,
2198            bind_group_cache_misses: 0,
2199            bind_group_cache_by_layout: None,
2200            kernel_launches: Vec::new(),
2201        }
2202    }
2203
2204    /// Reset all telemetry counters maintained by the provider, if supported.
2205    fn reset_telemetry(&self) {}
2206
2207    /// Default reduction workgroup size the provider prefers.
2208    fn default_reduction_workgroup_size(&self) -> u32 {
2209        256
2210    }
2211
2212    /// Threshold above which provider will prefer two-pass reduction.
2213    fn two_pass_threshold(&self) -> usize {
2214        1024
2215    }
2216
2217    /// Current two-pass mode preference (auto/forced on/off).
2218    fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2219        ReductionTwoPassMode::Auto
2220    }
2221
2222    /// Fast-path: write a GPU column in a matrix from a GPU vector, returning a new handle.
2223    /// Expected: `values.shape == [rows, 1]` (or `[rows]`) and `col_index < cols`.
2224    fn scatter_column(
2225        &self,
2226        _matrix: &GpuTensorHandle,
2227        _col_index: usize,
2228        _values: &GpuTensorHandle,
2229    ) -> anyhow::Result<GpuTensorHandle> {
2230        Err(anyhow::anyhow!("scatter_column not supported by provider"))
2231    }
2232
2233    /// Fast-path: write a GPU row in a matrix from a GPU vector, returning a new handle.
2234    /// Expected: `values.shape == [1, cols]` (or `[cols]`) and `row_index < rows`.
2235    fn scatter_row(
2236        &self,
2237        _matrix: &GpuTensorHandle,
2238        _row_index: usize,
2239        _values: &GpuTensorHandle,
2240    ) -> anyhow::Result<GpuTensorHandle> {
2241        Err(anyhow::anyhow!("scatter_row not supported by provider"))
2242    }
2243
2244    fn sub2ind(
2245        &self,
2246        _dims: &[usize],
2247        _strides: &[usize],
2248        _inputs: &[&GpuTensorHandle],
2249        _scalar_mask: &[bool],
2250        _len: usize,
2251        _output_shape: &[usize],
2252    ) -> anyhow::Result<GpuTensorHandle> {
2253        Err(anyhow::anyhow!("sub2ind not supported by provider"))
2254    }
2255
2256    /// Returns true if the provider offers a device-side `ind2sub` implementation.
2257    fn supports_ind2sub(&self) -> bool {
2258        false
2259    }
2260
2261    /// Convert linear indices into per-dimension subscripts on the device.
2262    fn ind2sub(
2263        &self,
2264        _dims: &[usize],
2265        _strides: &[usize],
2266        _indices: &GpuTensorHandle,
2267        _total: usize,
2268        _len: usize,
2269        _output_shape: &[usize],
2270    ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2271        Err(anyhow::anyhow!("ind2sub not supported by provider"))
2272    }
2273
2274    /// Determine if a matrix is symmetric (or skew-symmetric) without gathering it to the host.
2275    fn issymmetric(
2276        &self,
2277        _matrix: &GpuTensorHandle,
2278        _kind: ProviderSymmetryKind,
2279        _tolerance: f64,
2280    ) -> anyhow::Result<bool> {
2281        Err(anyhow::anyhow!(
2282            "issymmetric predicate not supported by provider"
2283        ))
2284    }
2285
2286    /// Determine if a matrix is Hermitian (or skew-Hermitian) without gathering it to the host.
2287    fn ishermitian<'a>(
2288        &'a self,
2289        _matrix: &'a GpuTensorHandle,
2290        _kind: ProviderHermitianKind,
2291        _tolerance: f64,
2292    ) -> AccelProviderFuture<'a, bool> {
2293        Box::pin(async move {
2294            Err(anyhow::anyhow!(
2295                "ishermitian predicate not supported by provider"
2296            ))
2297        })
2298    }
2299
2300    /// Inspect the bandwidth of a matrix without gathering it back to the host.
2301    fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2302        Err(anyhow::anyhow!("bandwidth not supported by provider"))
2303    }
2304
2305    /// Compute the symmetric reverse Cuthill-McKee permutation for the matrix.
2306    ///
2307    /// Implementations may execute on the device or gather to the host. The permutation should be
2308    /// returned as zero-based indices.
2309    fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2310        Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2311    }
2312}
2313
2314static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2315    Lazy::new(|| RwLock::new(None));
2316static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2317    Lazy::new(|| RwLock::new(HashMap::new()));
2318static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2319
2320#[cfg(not(target_arch = "wasm32"))]
2321thread_local! {
2322    static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2323}
2324
2325#[cfg(target_arch = "wasm32")]
2326static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2327    Lazy::new(|| Mutex::new(None));
2328
2329#[cfg(not(target_arch = "wasm32"))]
2330fn replace_thread_provider(
2331    provider: Option<&'static dyn AccelProvider>,
2332) -> Option<&'static dyn AccelProvider> {
2333    THREAD_PROVIDER.with(|cell| {
2334        let prev = cell.get();
2335        cell.set(provider);
2336        prev
2337    })
2338}
2339
2340#[cfg(target_arch = "wasm32")]
2341fn replace_thread_provider(
2342    provider: Option<&'static dyn AccelProvider>,
2343) -> Option<&'static dyn AccelProvider> {
2344    let mut slot = WASM_THREAD_PROVIDER
2345        .lock()
2346        .expect("wasm provider mutex poisoned");
2347    let prev = *slot;
2348    *slot = provider;
2349    prev
2350}
2351
2352#[cfg(not(target_arch = "wasm32"))]
2353fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2354    THREAD_PROVIDER.with(|cell| cell.get())
2355}
2356
2357#[cfg(target_arch = "wasm32")]
2358fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2359    WASM_THREAD_PROVIDER
2360        .lock()
2361        .expect("wasm provider mutex poisoned")
2362        .as_ref()
2363        .copied()
2364}
2365
2366/// Register a global acceleration provider.
2367///
2368/// # Safety
2369/// - The caller must guarantee that `p` is valid for the entire program lifetime
2370///   (e.g., a `'static` singleton), as the runtime stores a raw reference globally.
2371/// - Concurrent callers must ensure registration happens once or is properly
2372///   synchronized; this function does not enforce thread-safety for re-registration.
2373pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2374    if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2375        *guard = Some(p);
2376    }
2377    register_provider_for_device(p.device_id(), p);
2378}
2379
2380unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2381    if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2382        guard.insert(device_id, provider);
2383    }
2384}
2385
2386pub fn provider() -> Option<&'static dyn AccelProvider> {
2387    if let Some(p) = current_thread_provider() {
2388        return Some(p);
2389    }
2390    GLOBAL_PROVIDER
2391        .read()
2392        .ok()
2393        .and_then(|guard| guard.as_ref().copied())
2394}
2395
2396/// Clear the globally registered provider. Intended for tests to ensure deterministic behaviour.
2397pub fn clear_provider() {
2398    if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2399        *guard = None;
2400    }
2401    if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2402        map.clear();
2403    }
2404}
2405
2406pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2407    PROVIDER_REGISTRY
2408        .read()
2409        .ok()
2410        .and_then(|guard| guard.get(&device_id).copied())
2411        .or_else(|| provider())
2412}
2413
2414pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2415    provider_for_device(handle.device_id)
2416}
2417
2418pub fn next_device_id() -> u32 {
2419    DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2420}
2421
2422pub struct ThreadProviderGuard {
2423    prev: Option<&'static dyn AccelProvider>,
2424}
2425
2426impl ThreadProviderGuard {
2427    pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2428        let prev = replace_thread_provider(provider);
2429        ThreadProviderGuard { prev }
2430    }
2431}
2432
2433impl Drop for ThreadProviderGuard {
2434    fn drop(&mut self) {
2435        let prev = self.prev.take();
2436        replace_thread_provider(prev);
2437    }
2438}
2439
2440pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2441    replace_thread_provider(provider);
2442}
2443
2444/// Convenience: perform elementwise add via provider if possible; otherwise return None
2445pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2446    if let Some(p) = provider() {
2447        if let Ok(h) = p.elem_add(a, b).await {
2448            return Some(h);
2449        }
2450    }
2451    None
2452}
2453
2454/// Convenience: perform elementwise hypot via provider if possible; otherwise return None
2455pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2456    if let Some(p) = provider() {
2457        if let Ok(h) = p.elem_hypot(a, b).await {
2458            return Some(h);
2459        }
2460    }
2461    None
2462}
2463
2464/// Convenience: perform elementwise max via provider if possible; otherwise return None
2465pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2466    if let Some(p) = provider() {
2467        if let Ok(h) = p.elem_max(a, b).await {
2468            return Some(h);
2469        }
2470    }
2471    None
2472}
2473
2474/// Convenience: perform elementwise min via provider if possible; otherwise return None
2475pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2476    if let Some(p) = provider() {
2477        if let Ok(h) = p.elem_min(a, b).await {
2478            return Some(h);
2479        }
2480    }
2481    None
2482}
2483
2484/// Convenience: perform elementwise atan2 via provider if possible; otherwise return None
2485pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2486    if let Some(p) = provider() {
2487        if let Ok(h) = p.elem_atan2(y, x).await {
2488            return Some(h);
2489        }
2490    }
2491    None
2492}
2493
2494// Minimal host tensor views to avoid depending on runmat-builtins and cycles
2495#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2496pub struct HostTensorOwned {
2497    pub data: Vec<f64>,
2498    pub shape: Vec<usize>,
2499}
2500
2501#[derive(Debug)]
2502pub struct HostTensorView<'a> {
2503    pub data: &'a [f64],
2504    pub shape: &'a [usize],
2505}
2506
2507/// Lightweight 1-D axis view used by provider meshgrid hooks.
2508#[derive(Debug)]
2509pub struct MeshgridAxisView<'a> {
2510    pub data: &'a [f64],
2511}
2512
2513/// Provider-side meshgrid result containing coordinate tensor handles.
2514#[derive(Debug, Clone)]
2515pub struct ProviderMeshgridResult {
2516    pub outputs: Vec<GpuTensorHandle>,
2517}
2518
2519/// Descriptor for GEMM epilogues applied to `C = A * B` before storing to `C`.
2520///
2521/// Supported operations:
2522/// - Scale by `alpha` and add scalar `beta`.
2523/// - Multiply output by per-row and/or per-column scale vectors (broadcasted).
2524#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
2525pub enum ScaleOp {
2526    Multiply,
2527    Divide,
2528}
2529
2530#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2531pub struct MatmulEpilogue {
2532    /// Scalar multiply applied to each output element.
2533    pub alpha: f64,
2534    /// Scalar add applied to each output element after scaling.
2535    pub beta: f64,
2536    /// Optional per-row scale (length m). When present, output[row, col] *= row_scale[row].
2537    pub row_scale: Option<GpuTensorHandle>,
2538    /// Optional per-column scale (length n). When present, output[row, col] *= col_scale[col].
2539    pub col_scale: Option<GpuTensorHandle>,
2540    /// Row scale operation (multiply or divide). Ignored when `row_scale` is None.
2541    pub row_op: ScaleOp,
2542    /// Column scale operation (multiply or divide). Ignored when `col_scale` is None.
2543    pub col_op: ScaleOp,
2544    /// Optional lower clamp bound applied after scale/bias.
2545    #[serde(default)]
2546    pub clamp_min: Option<f64>,
2547    /// Optional upper clamp bound applied after scale/bias.
2548    #[serde(default)]
2549    pub clamp_max: Option<f64>,
2550    /// Optional power exponent applied after clamp (final operation in the epilogue).
2551    #[serde(default)]
2552    pub pow_exponent: Option<f64>,
2553    /// Optional output buffer for the diagonal of the result (length min(m, n)).
2554    #[serde(default)]
2555    pub diag_output: Option<GpuTensorHandle>,
2556}
2557
2558impl MatmulEpilogue {
2559    pub fn noop() -> Self {
2560        Self {
2561            alpha: 1.0,
2562            beta: 0.0,
2563            row_scale: None,
2564            col_scale: None,
2565            row_op: ScaleOp::Multiply,
2566            col_op: ScaleOp::Multiply,
2567            clamp_min: None,
2568            clamp_max: None,
2569            pow_exponent: None,
2570            diag_output: None,
2571        }
2572    }
2573    pub fn is_noop(&self) -> bool {
2574        self.alpha == 1.0
2575            && self.beta == 0.0
2576            && self.row_scale.is_none()
2577            && self.col_scale.is_none()
2578            && self.clamp_min.is_none()
2579            && self.clamp_max.is_none()
2580            && self.pow_exponent.is_none()
2581            && self.diag_output.is_none()
2582    }
2583}
2584
2585#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
2586pub struct PowerStepEpilogue {
2587    pub epsilon: f64,
2588}
2589
2590impl Default for PowerStepEpilogue {
2591    fn default() -> Self {
2592        Self { epsilon: 0.0 }
2593    }
2594}
2595
2596#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2597pub struct ImageNormalizeDescriptor {
2598    pub batch: usize,
2599    pub height: usize,
2600    pub width: usize,
2601    pub epsilon: f64,
2602    #[serde(default)]
2603    pub gain: Option<f64>,
2604    #[serde(default)]
2605    pub bias: Option<f64>,
2606    #[serde(default)]
2607    pub gamma: Option<f64>,
2608}