Skip to main content

runmat_accelerate_api/
lib.rs

1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyMarkFn = fn(&GpuTensorHandle);
17type ResidencyClearFn = fn(&GpuTensorHandle);
18type SequenceThresholdFn = fn() -> Option<usize>;
19type WorkgroupSizeHintFn = fn() -> Option<u32>;
20
21static RESIDENCY_MARK: OnceCell<ResidencyMarkFn> = OnceCell::new();
22static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
23static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
24static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
25
26static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
27static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
28    Lazy::new(|| RwLock::new(HashMap::new()));
29static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
30    Lazy::new(|| RwLock::new(HashMap::new()));
31
32static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
33    Lazy::new(|| RwLock::new(HashMap::new()));
34static HANDLE_STORAGES: Lazy<RwLock<HashMap<u64, GpuTensorStorage>>> =
35    Lazy::new(|| RwLock::new(HashMap::new()));
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct TransposeInfo {
39    pub base_rows: usize,
40    pub base_cols: usize,
41}
42
43/// Register a callback used to mark residency tracking when GPU tensors are
44/// created or returned by device-side execution paths.
45pub fn register_residency_mark(handler: ResidencyMarkFn) {
46    let _ = RESIDENCY_MARK.set(handler);
47}
48
49/// Mark residency metadata for the provided GPU tensor handle, if a backend
50/// has registered a handler via [`register_residency_mark`].
51pub fn mark_residency(handle: &GpuTensorHandle) {
52    if let Some(handler) = RESIDENCY_MARK.get() {
53        handler(handle);
54    }
55}
56
57/// Register a callback used to clear residency tracking when GPU tensors are
58/// gathered back to the host. Backends that maintain residency metadata should
59/// install this hook during initialization.
60pub fn register_residency_clear(handler: ResidencyClearFn) {
61    let _ = RESIDENCY_CLEAR.set(handler);
62}
63
64/// Clear residency metadata for the provided GPU tensor handle, if a backend
65/// has registered a handler via [`register_residency_clear`].
66pub fn clear_residency(handle: &GpuTensorHandle) {
67    if let Some(handler) = RESIDENCY_CLEAR.get() {
68        handler(handle);
69    }
70}
71
72/// Register a callback that exposes the current sequence length threshold
73/// derived from the auto-offload planner. Array constructors can use this hint
74/// to decide when to prefer GPU residency automatically.
75pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
76    let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
77}
78
79/// Query the currently registered sequence threshold hint, if any.
80pub fn sequence_threshold_hint() -> Option<usize> {
81    SEQUENCE_THRESHOLD_PROVIDER
82        .get()
83        .and_then(|provider| provider())
84}
85
86/// Register a callback that reports the calibrated workgroup size selected by
87/// the active acceleration provider (if any). Plotting kernels can reuse this
88/// hint to match backend tuning.
89pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
90    let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
91}
92
93/// Query the current workgroup size hint exposed by the provider.
94pub fn workgroup_size_hint() -> Option<u32> {
95    WORKGROUP_SIZE_HINT_PROVIDER
96        .get()
97        .and_then(|provider| provider())
98}
99
100/// Export a shared acceleration context (e.g., the active WGPU device) when the
101/// current provider exposes one.
102pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
103    provider().and_then(|p| p.export_context(kind))
104}
105
106/// Request a provider-owned WGPU buffer for zero-copy consumers. Returns `None`
107/// when the active provider does not expose buffers or does not support the
108/// supplied handle.
109#[cfg(feature = "wgpu")]
110pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
111    provider().and_then(|p| p.export_wgpu_buffer(handle))
112}
113
114/// Record the precision associated with a GPU tensor handle so host operations can
115/// reconstruct the original dtype when gathering back to the CPU.
116pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
117    if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
118        guard.insert(handle.buffer_id, precision);
119    }
120}
121
122/// Look up the recorded precision for a GPU tensor handle, if any.
123pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
124    HANDLE_PRECISIONS
125        .read()
126        .ok()
127        .and_then(|guard| guard.get(&handle.buffer_id).copied())
128}
129
130/// Clear any recorded precision metadata for a GPU tensor handle.
131pub fn clear_handle_precision(handle: &GpuTensorHandle) {
132    if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
133        guard.remove(&handle.buffer_id);
134    }
135}
136
137/// Annotate a GPU tensor handle as logically-typed (`logical` in MATLAB terms)
138/// or clear the logical flag when `logical` is `false`.
139pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
140    if let Ok(mut guard) = LOGICAL_HANDLES.write() {
141        if logical {
142            guard.insert(handle.buffer_id);
143            if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
144                *hits.entry(handle.buffer_id).or_insert(0) += 1;
145            }
146        } else {
147            guard.remove(&handle.buffer_id);
148            if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
149                hits.remove(&handle.buffer_id);
150            }
151        }
152    }
153}
154
155/// Convenience helper for clearing logical annotations explicitly.
156pub fn clear_handle_logical(handle: &GpuTensorHandle) {
157    set_handle_logical(handle, false);
158}
159
160/// Returns true when the supplied handle has been marked as logical.
161pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
162    LOGICAL_HANDLES
163        .read()
164        .map(|guard| guard.contains(&handle.buffer_id))
165        .unwrap_or(false)
166}
167
168pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
169    LOGICAL_HANDLE_HITS
170        .read()
171        .ok()
172        .and_then(|guard| guard.get(&buffer_id).copied())
173}
174
175pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
176    if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
177        guard.insert(
178            handle.buffer_id,
179            TransposeInfo {
180                base_rows,
181                base_cols,
182            },
183        );
184    }
185}
186
187pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
188    if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
189        guard.remove(&handle.buffer_id);
190    }
191}
192
193pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
194    TRANSPOSED_HANDLES
195        .read()
196        .ok()
197        .and_then(|guard| guard.get(&handle.buffer_id).copied())
198}
199
200pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
201    handle_transpose_info(handle).is_some()
202}
203
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum GpuTensorStorage {
206    Real,
207    ComplexInterleaved,
208}
209
210impl Default for GpuTensorStorage {
211    fn default() -> Self {
212        Self::Real
213    }
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct GpuTensorHandle {
218    pub shape: Vec<usize>,
219    pub device_id: u32,
220    pub buffer_id: u64,
221}
222
223#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
224pub enum ProviderSpectralRange {
225    Onesided,
226    Twosided,
227    Centered,
228}
229
230#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
231pub enum ProviderSpectralFrameMode {
232    Sliding {
233        hop: usize,
234    },
235    ColumnSliding {
236        hop: usize,
237        input_rows: usize,
238        frames_per_column: usize,
239    },
240    FoldedColumns {
241        input_rows: usize,
242    },
243}
244
245#[derive(Clone, Debug)]
246pub struct ProviderSpectralRequest<'a> {
247    pub input: &'a GpuTensorHandle,
248    pub input_len: usize,
249    pub input_complex: bool,
250    pub window: &'a [f64],
251    pub nfft: usize,
252    pub frame_count: usize,
253    pub frame_mode: ProviderSpectralFrameMode,
254    pub range: ProviderSpectralRange,
255    pub denominator: f64,
256}
257
258#[derive(Clone, Debug)]
259pub struct ProviderSpectralResult {
260    pub s: GpuTensorHandle,
261    pub ps: GpuTensorHandle,
262    pub rows: usize,
263    pub cols: usize,
264}
265
266#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
267pub enum ProviderEnvelopeMethod {
268    Analytic,
269    AnalyticFir { filter_len: usize },
270    Rms { window_len: usize },
271}
272
273#[derive(Clone, Debug)]
274pub struct ProviderEnvelopeRequest<'a> {
275    pub input: &'a GpuTensorHandle,
276    pub channel_len: usize,
277    pub channel_count: usize,
278    pub output_shape: &'a [usize],
279    pub method: ProviderEnvelopeMethod,
280}
281
282#[derive(Clone, Debug)]
283pub struct ProviderEnvelopeResult {
284    pub upper: GpuTensorHandle,
285    pub lower: GpuTensorHandle,
286}
287
288#[derive(Clone, Debug)]
289pub struct ProviderHilbertRequest<'a> {
290    pub input: &'a GpuTensorHandle,
291    /// Optional FFT length along `dim`.
292    pub length: Option<usize>,
293    /// Zero-based transform dimension.
294    pub dim: usize,
295}
296
297#[derive(Clone, Debug)]
298pub struct ProviderModulationRequest<'a> {
299    pub input: &'a GpuTensorHandle,
300    /// `(real, imag)` pairs interleaved by symbol index.
301    pub constellation: &'a [f64],
302}
303
304#[derive(Clone, Debug)]
305pub struct ProviderBitModulationRequest<'a> {
306    pub input: &'a GpuTensorHandle,
307    /// Number of bit rows in the input before grouping.
308    pub input_rows: usize,
309    /// Number of input bits that form one output symbol.
310    pub bits_per_symbol: usize,
311    /// `(real, imag)` pairs interleaved by symbol index.
312    pub constellation: &'a [f64],
313}
314
315pub async fn uniform_spectral_estimate(
316    request: ProviderSpectralRequest<'_>,
317) -> anyhow::Result<ProviderSpectralResult> {
318    validate_uniform_spectral_request(&request)?;
319
320    let provider =
321        provider().ok_or_else(|| anyhow!("uniform_spectral_estimate: GPU provider unavailable"))?;
322    provider.uniform_spectral_estimate(&request).await
323}
324
325fn validate_uniform_spectral_request(request: &ProviderSpectralRequest<'_>) -> anyhow::Result<()> {
326    let invalid_frame_mode = matches!(
327        request.frame_mode,
328        ProviderSpectralFrameMode::Sliding { hop: 0 }
329            | ProviderSpectralFrameMode::ColumnSliding { hop: 0, .. }
330    );
331    let invalid_input_coverage = match request.frame_mode {
332        ProviderSpectralFrameMode::Sliding { hop } => {
333            let required = request
334                .frame_count
335                .checked_sub(1)
336                .and_then(|frames| frames.checked_mul(hop))
337                .and_then(|offset| offset.checked_add(request.window.len()));
338            required.is_none_or(|required| required > request.input_len)
339        }
340        ProviderSpectralFrameMode::ColumnSliding {
341            hop,
342            input_rows,
343            frames_per_column,
344        } => {
345            if input_rows == 0 || frames_per_column == 0 {
346                true
347            } else {
348                let last_frame = request.frame_count - 1;
349                let source_col = last_frame / frames_per_column;
350                let segment = last_frame % frames_per_column;
351                source_col
352                    .checked_mul(input_rows)
353                    .and_then(|base| {
354                        segment
355                            .checked_mul(hop)
356                            .and_then(|step| base.checked_add(step))
357                    })
358                    .and_then(|offset| offset.checked_add(request.window.len()))
359                    .is_none_or(|required| required > request.input_len)
360            }
361        }
362        ProviderSpectralFrameMode::FoldedColumns { input_rows } => {
363            input_rows == 0
364                || input_rows
365                    .checked_mul(request.frame_count)
366                    .is_none_or(|required| required > request.input_len)
367        }
368    };
369    if request.window.is_empty()
370        || request.nfft == 0
371        || request.frame_count == 0
372        || invalid_frame_mode
373        || invalid_input_coverage
374        || !request.denominator.is_finite()
375        || request.denominator <= 0.0
376    {
377        return Err(anyhow!("uniform_spectral_estimate: invalid request"));
378    }
379
380    Ok(())
381}
382
383pub async fn signal_envelope(
384    request: ProviderEnvelopeRequest<'_>,
385) -> anyhow::Result<ProviderEnvelopeResult> {
386    let expected_len = request
387        .channel_len
388        .checked_mul(request.channel_count)
389        .ok_or_else(|| anyhow!("signal_envelope: invalid request"))?;
390    let output_len = request
391        .output_shape
392        .iter()
393        .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
394        .ok_or_else(|| anyhow!("signal_envelope: invalid request"))?;
395    let input_len = request
396        .input
397        .shape
398        .iter()
399        .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
400        .ok_or_else(|| anyhow!("signal_envelope: invalid request"))?;
401    if request.channel_len == 0
402        || request.channel_count == 0
403        || request.output_shape.is_empty()
404        || output_len != expected_len
405        || input_len != expected_len
406        || !provider_envelope_input_shape_matches(
407            &request.input.shape,
408            request.channel_len,
409            request.channel_count,
410        )
411    {
412        return Err(anyhow!("signal_envelope: invalid request"));
413    }
414
415    match request.method {
416        ProviderEnvelopeMethod::AnalyticFir { filter_len }
417        | ProviderEnvelopeMethod::Rms {
418            window_len: filter_len,
419        } if filter_len == 0 => return Err(anyhow!("signal_envelope: invalid request")),
420        _ => {}
421    }
422
423    let provider =
424        provider().ok_or_else(|| anyhow!("signal_envelope: GPU provider unavailable"))?;
425    provider.signal_envelope(&request).await
426}
427
428fn provider_envelope_input_shape_matches(
429    shape: &[usize],
430    channel_len: usize,
431    channel_count: usize,
432) -> bool {
433    if channel_count == 1 {
434        return match shape {
435            [len] => *len == channel_len,
436            [rows, cols] => {
437                (*rows == channel_len && *cols == 1) || (*rows == 1 && *cols == channel_len)
438            }
439            _ => false,
440        };
441    }
442
443    matches!(shape, [rows, cols] if *rows == channel_len && *cols == channel_count)
444}
445
446pub async fn signal_hilbert(
447    request: ProviderHilbertRequest<'_>,
448) -> anyhow::Result<GpuTensorHandle> {
449    if request.length == Some(0) {
450        return Err(anyhow!("signal_hilbert: invalid request"));
451    }
452    if request.dim >= request.input.shape.len() {
453        return Err(anyhow!("signal_hilbert: invalid request"));
454    }
455
456    let provider = provider().ok_or_else(|| anyhow!("signal_hilbert: GPU provider unavailable"))?;
457    provider.signal_hilbert(&request).await
458}
459
460#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
461pub struct ApiDeviceInfo {
462    pub device_id: u32,
463    pub name: String,
464    pub vendor: String,
465    pub memory_bytes: Option<u64>,
466    pub backend: Option<String>,
467}
468
469#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
470pub struct ReduceDimResult {
471    pub values: GpuTensorHandle,
472    pub indices: GpuTensorHandle,
473}
474
475#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
476pub struct ProviderCumminResult {
477    pub values: GpuTensorHandle,
478    pub indices: GpuTensorHandle,
479}
480
481/// Result payload returned by provider-side `cummax` scans.
482///
483/// Alias of [`ProviderCumminResult`] because both operations return the same pair of tensors
484/// (running values and MATLAB-compatible indices).
485pub type ProviderCummaxResult = ProviderCumminResult;
486
487/// Names a shared acceleration context that callers may request (e.g. plotting).
488#[derive(Debug, Clone, Copy, PartialEq, Eq)]
489pub enum AccelContextKind {
490    Plotting,
491}
492
493/// Handle returned by [`export_context`] that describes a shared GPU context.
494#[derive(Clone)]
495pub enum AccelContextHandle {
496    #[cfg(feature = "wgpu")]
497    Wgpu(WgpuContextHandle),
498}
499
500impl AccelContextHandle {
501    /// Returns the underlying WGPU context when available.
502    #[cfg(feature = "wgpu")]
503    pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
504        match self {
505            AccelContextHandle::Wgpu(ctx) => Some(ctx),
506        }
507    }
508}
509
510/// Shared WGPU device/queue pair exported by the acceleration provider.
511#[cfg(feature = "wgpu")]
512#[derive(Clone)]
513pub struct WgpuContextHandle {
514    pub instance: Arc<wgpu::Instance>,
515    pub device: Arc<wgpu::Device>,
516    pub queue: Arc<wgpu::Queue>,
517    pub adapter: Arc<wgpu::Adapter>,
518    pub adapter_info: wgpu::AdapterInfo,
519    pub limits: wgpu::Limits,
520    pub features: wgpu::Features,
521}
522
523/// Borrowed reference to a provider-owned WGPU buffer corresponding to a `GpuTensorHandle`.
524///
525/// Providers that expose this handle must ensure the buffer was created with
526/// `wgpu::BufferUsages::COPY_SRC`, so zero-copy consumers can export or inspect
527/// tensor data without depending on provider-specific readback APIs.
528#[cfg(feature = "wgpu")]
529#[derive(Clone)]
530pub struct WgpuBufferRef {
531    pub buffer: Arc<wgpu::Buffer>,
532    pub len: usize,
533    pub shape: Vec<usize>,
534    pub element_size: usize,
535    pub precision: ProviderPrecision,
536}
537
538pub fn set_handle_storage(handle: &GpuTensorHandle, storage: GpuTensorStorage) {
539    if let Ok(mut guard) = HANDLE_STORAGES.write() {
540        guard.insert(handle.buffer_id, storage);
541    }
542}
543
544pub fn handle_storage(handle: &GpuTensorHandle) -> GpuTensorStorage {
545    HANDLE_STORAGES
546        .read()
547        .ok()
548        .and_then(|guard| guard.get(&handle.buffer_id).cloned())
549        .unwrap_or(GpuTensorStorage::Real)
550}
551
552pub fn clear_handle_storage(handle: &GpuTensorHandle) {
553    if let Ok(mut guard) = HANDLE_STORAGES.write() {
554        guard.remove(&handle.buffer_id);
555    }
556}
557
558#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
559pub enum PagefunOp {
560    Mtimes,
561}
562
563#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
564pub struct PagefunRequest {
565    pub op: PagefunOp,
566    pub inputs: Vec<GpuTensorHandle>,
567    pub output_shape: Vec<usize>,
568    pub page_dims: Vec<usize>,
569    pub input_page_dims: Vec<Vec<usize>>,
570}
571
572#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
573pub enum FindDirection {
574    First,
575    Last,
576}
577
578#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
579pub struct ProviderFindResult {
580    pub linear: GpuTensorHandle,
581    pub rows: GpuTensorHandle,
582    pub cols: GpuTensorHandle,
583    pub values: Option<GpuTensorHandle>,
584}
585
586#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
587pub struct ProviderBandwidth {
588    pub lower: u32,
589    pub upper: u32,
590}
591
592#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
593pub enum ProviderSymmetryKind {
594    Symmetric,
595    Skew,
596}
597
598#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
599pub enum ProviderHermitianKind {
600    Hermitian,
601    Skew,
602}
603
604#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
605pub struct ProviderLuResult {
606    pub combined: GpuTensorHandle,
607    pub lower: GpuTensorHandle,
608    pub upper: GpuTensorHandle,
609    pub perm_matrix: GpuTensorHandle,
610    pub perm_vector: GpuTensorHandle,
611}
612
613#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
614pub struct ProviderCholResult {
615    pub factor: GpuTensorHandle,
616    /// MATLAB-compatible failure index (0 indicates success).
617    pub info: u32,
618}
619
620#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
621pub struct ProviderQrResult {
622    pub q: GpuTensorHandle,
623    pub r: GpuTensorHandle,
624    pub perm_matrix: GpuTensorHandle,
625    pub perm_vector: GpuTensorHandle,
626}
627
628#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
629pub struct ProviderQrPowerIterResult {
630    pub q: GpuTensorHandle,
631    pub r: GpuTensorHandle,
632    pub perm_matrix: GpuTensorHandle,
633    pub perm_vector: GpuTensorHandle,
634}
635
636#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
637pub struct ProviderLinsolveOptions {
638    pub lower: bool,
639    pub upper: bool,
640    pub rectangular: bool,
641    pub transposed: bool,
642    pub conjugate: bool,
643    pub symmetric: bool,
644    pub posdef: bool,
645    pub need_rcond: bool,
646    pub rcond: Option<f64>,
647}
648
649#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
650pub struct ProviderLinsolveResult {
651    pub solution: GpuTensorHandle,
652    pub reciprocal_condition: f64,
653}
654
655#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
656pub struct ProviderPinvOptions {
657    pub tolerance: Option<f64>,
658}
659
660#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
661pub struct ProviderPolyvalMu {
662    pub mean: f64,
663    pub scale: f64,
664}
665
666#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
667pub struct ProviderPolyvalOptions {
668    pub mu: Option<ProviderPolyvalMu>,
669}
670
671#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
672pub struct ProviderInvOptions {}
673
674#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
675pub struct ProviderPolyfitResult {
676    pub coefficients: Vec<f64>,
677    pub r_matrix: Vec<f64>,
678    pub normr: f64,
679    pub df: f64,
680    pub mu: [f64; 2],
681}
682
683/// Numerator/denominator payload returned by provider-backed `polyder` quotient rule.
684#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
685pub struct ProviderPolyderQuotient {
686    pub numerator: GpuTensorHandle,
687    pub denominator: GpuTensorHandle,
688}
689
690/// Supported norm specifications for the `cond` builtin.
691#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
692pub enum ProviderCondNorm {
693    Two,
694    One,
695    Inf,
696    Fro,
697}
698
699/// Supported norm orders for the `norm` builtin.
700#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
701pub enum ProviderNormOrder {
702    Two,
703    One,
704    Inf,
705    NegInf,
706    Zero,
707    Fro,
708    Nuc,
709    P(f64),
710}
711
712#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
713pub struct ProviderEigResult {
714    pub eigenvalues: GpuTensorHandle,
715    pub diagonal: GpuTensorHandle,
716    pub right: GpuTensorHandle,
717    pub left: Option<GpuTensorHandle>,
718}
719
720#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
721pub enum ProviderQrPivot {
722    Matrix,
723    Vector,
724}
725
726#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
727pub struct ProviderQrOptions {
728    pub economy: bool,
729    pub pivot: ProviderQrPivot,
730}
731
732impl Default for ProviderQrOptions {
733    fn default() -> Self {
734        Self {
735            economy: false,
736            pivot: ProviderQrPivot::Matrix,
737        }
738    }
739}
740
741#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
742pub enum ProviderPrecision {
743    F32,
744    F64,
745}
746
747/// Declares how provider-owned GPU handles may cross async spawn boundaries.
748///
749/// This is a runtime/provider policy surface (not a semantic type fact) used by
750/// VM/runtime spawn handling to prevent unsynchronized device-handle races.
751#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
752pub enum SpawnHandleConcurrency {
753    /// Provider supports immutable sharing of handle-backed values across spawned tasks.
754    ImmutableShare,
755    /// Provider supports copy-on-write semantics when spawned and parent tasks diverge.
756    CopyOnWrite,
757    /// Provider supports synchronized mutation for shared handles.
758    SynchronizedMutation,
759    /// Provider rejects spawned sharing of raw handles.
760    Reject,
761}
762
763impl SpawnHandleConcurrency {
764    pub fn as_str(self) -> &'static str {
765        match self {
766            SpawnHandleConcurrency::ImmutableShare => "immutable_share",
767            SpawnHandleConcurrency::CopyOnWrite => "copy_on_write",
768            SpawnHandleConcurrency::SynchronizedMutation => "synchronized_mutation",
769            SpawnHandleConcurrency::Reject => "reject",
770        }
771    }
772}
773
774#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
775pub enum ReductionTwoPassMode {
776    Auto,
777    ForceOn,
778    ForceOff,
779}
780
781impl ReductionTwoPassMode {
782    pub fn as_str(self) -> &'static str {
783        match self {
784            ReductionTwoPassMode::Auto => "auto",
785            ReductionTwoPassMode::ForceOn => "force_on",
786            ReductionTwoPassMode::ForceOff => "force_off",
787        }
788    }
789}
790
791#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
792pub enum ReductionFlavor {
793    Sum,
794    Mean,
795    CustomScale(f64),
796}
797
798impl ReductionFlavor {
799    pub fn is_mean(self) -> bool {
800        matches!(self, ReductionFlavor::Mean)
801    }
802
803    pub fn scale(self, reduce_len: usize) -> f64 {
804        match self {
805            ReductionFlavor::Sum => 1.0,
806            ReductionFlavor::Mean => {
807                if reduce_len == 0 {
808                    1.0
809                } else {
810                    1.0 / reduce_len as f64
811                }
812            }
813            ReductionFlavor::CustomScale(scale) => scale,
814        }
815    }
816}
817
818/// Normalisation mode for correlation coefficients.
819#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
820pub enum CorrcoefNormalization {
821    Unbiased,
822    Biased,
823}
824
825/// Row-selection strategy for correlation coefficients.
826#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
827pub enum CorrcoefRows {
828    All,
829    Complete,
830    Pairwise,
831}
832
833/// Options controlling provider-backed correlation coefficient computation.
834#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
835pub struct CorrcoefOptions {
836    pub normalization: CorrcoefNormalization,
837    pub rows: CorrcoefRows,
838}
839
840impl Default for CorrcoefOptions {
841    fn default() -> Self {
842        Self {
843            normalization: CorrcoefNormalization::Unbiased,
844            rows: CorrcoefRows::All,
845        }
846    }
847}
848
849/// Normalisation mode used by covariance computations.
850#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
851pub enum CovNormalization {
852    Unbiased,
853    Biased,
854}
855
856/// Row handling strategy for covariance computations.
857#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
858pub enum CovRows {
859    All,
860    OmitRows,
861    PartialRows,
862}
863
864/// Options controlling provider-backed covariance computation.
865#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
866pub struct CovarianceOptions {
867    pub normalization: CovNormalization,
868    pub rows: CovRows,
869    pub has_weight_vector: bool,
870}
871
872impl Default for CovarianceOptions {
873    fn default() -> Self {
874        Self {
875            normalization: CovNormalization::Unbiased,
876            rows: CovRows::All,
877            has_weight_vector: false,
878        }
879    }
880}
881
882/// Normalization strategy used by provider-backed standard deviation reductions.
883#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
884pub enum ProviderStdNormalization {
885    Sample,
886    Population,
887}
888
889/// NaN handling mode for provider-backed reductions.
890#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
891pub enum ProviderNanMode {
892    Include,
893    Omit,
894}
895
896/// Direction used when computing prefix sums on the device.
897#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
898pub enum ProviderScanDirection {
899    Forward,
900    Reverse,
901}
902
903/// Sort direction used by acceleration providers.
904#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
905pub enum SortOrder {
906    Ascend,
907    Descend,
908}
909
910/// Comparison strategy applied during sorting.
911#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
912pub enum SortComparison {
913    Auto,
914    Real,
915    Abs,
916}
917
918/// Host-resident outputs returned by provider-backed sort operations.
919#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
920pub struct SortResult {
921    pub values: HostTensorOwned,
922    pub indices: HostTensorOwned,
923}
924
925#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
926pub struct SortRowsColumnSpec {
927    pub index: usize,
928    pub order: SortOrder,
929}
930
931/// Ordering applied by provider-backed `unique` operations.
932#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
933pub enum UniqueOrder {
934    Sorted,
935    Stable,
936}
937
938/// Occurrence selection for provider-backed `unique` operations.
939#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
940pub enum UniqueOccurrence {
941    First,
942    Last,
943}
944
945/// Options controlling provider-backed `unique` operations.
946#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
947pub struct UniqueOptions {
948    pub rows: bool,
949    pub order: UniqueOrder,
950    pub occurrence: UniqueOccurrence,
951}
952
953/// Host-resident outputs returned by provider-backed `unique` operations.
954#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
955pub struct UniqueResult {
956    pub values: HostTensorOwned,
957    pub ia: HostTensorOwned,
958    pub ic: HostTensorOwned,
959}
960
961/// Ordering applied by provider-backed `union` operations.
962#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
963pub enum UnionOrder {
964    Sorted,
965    Stable,
966}
967
968/// Options controlling provider-backed `union` operations.
969#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
970pub struct UnionOptions {
971    pub rows: bool,
972    pub order: UnionOrder,
973}
974
975/// Host-resident outputs returned by provider-backed `union` operations.
976#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
977pub struct UnionResult {
978    pub values: HostTensorOwned,
979    pub ia: HostTensorOwned,
980    pub ib: HostTensorOwned,
981}
982
983/// Parameterisation of 2-D filters generated by `fspecial`.
984#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
985pub enum FspecialFilter {
986    Average {
987        rows: u32,
988        cols: u32,
989    },
990    Disk {
991        radius: f64,
992        size: u32,
993    },
994    Gaussian {
995        rows: u32,
996        cols: u32,
997        sigma: f64,
998    },
999    Laplacian {
1000        alpha: f64,
1001    },
1002    Log {
1003        rows: u32,
1004        cols: u32,
1005        sigma: f64,
1006    },
1007    Motion {
1008        length: u32,
1009        kernel_size: u32,
1010        angle_degrees: f64,
1011        oversample: u32,
1012    },
1013    Prewitt,
1014    Sobel,
1015    Unsharp {
1016        alpha: f64,
1017    },
1018}
1019
1020/// Request dispatched to acceleration providers for `fspecial` kernels.
1021#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1022pub struct FspecialRequest {
1023    pub filter: FspecialFilter,
1024}
1025
1026/// Padding strategy used by `imfilter`.
1027#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1028pub enum ImfilterPadding {
1029    Constant,
1030    Replicate,
1031    Symmetric,
1032    Circular,
1033}
1034
1035/// Output sizing mode used by `imfilter`.
1036#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1037pub enum ImfilterShape {
1038    Same,
1039    Full,
1040    Valid,
1041}
1042
1043/// Correlation vs convolution behaviour for `imfilter`.
1044#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1045pub enum ImfilterMode {
1046    Correlation,
1047    Convolution,
1048}
1049
1050/// Options supplied to acceleration providers for `imfilter`.
1051#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1052pub struct ImfilterOptions {
1053    pub padding: ImfilterPadding,
1054    pub constant_value: f64,
1055    pub shape: ImfilterShape,
1056    pub mode: ImfilterMode,
1057}
1058
1059impl Default for ImfilterOptions {
1060    fn default() -> Self {
1061        Self {
1062            padding: ImfilterPadding::Constant,
1063            constant_value: 0.0,
1064            shape: ImfilterShape::Same,
1065            mode: ImfilterMode::Correlation,
1066        }
1067    }
1068}
1069
1070/// Ordering applied by provider-backed `setdiff` operations.
1071#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1072pub enum SetdiffOrder {
1073    Sorted,
1074    Stable,
1075}
1076
1077/// Options controlling provider-backed `setdiff` operations.
1078#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1079pub struct SetdiffOptions {
1080    pub rows: bool,
1081    pub order: SetdiffOrder,
1082}
1083
1084/// Host-resident outputs returned by provider-backed `setdiff` operations.
1085#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1086pub struct SetdiffResult {
1087    pub values: HostTensorOwned,
1088    pub ia: HostTensorOwned,
1089}
1090
1091/// Options controlling provider-backed `ismember` operations.
1092#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1093pub struct IsMemberOptions {
1094    pub rows: bool,
1095}
1096
1097/// Host-resident logical output returned by providers.
1098#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1099pub struct HostLogicalOwned {
1100    pub data: Vec<u8>,
1101    pub shape: Vec<usize>,
1102}
1103
1104/// Host-resident outputs returned by provider-backed `ismember` operations.
1105#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1106pub struct IsMemberResult {
1107    pub mask: HostLogicalOwned,
1108    pub loc: HostTensorOwned,
1109}
1110
1111#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1112pub enum ProviderConvMode {
1113    Full,
1114    Same,
1115    Valid,
1116}
1117
1118#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1119pub enum ProviderConvOrientation {
1120    Row,
1121    Column,
1122}
1123
1124#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1125pub struct ProviderConv1dOptions {
1126    pub mode: ProviderConvMode,
1127    pub orientation: ProviderConvOrientation,
1128}
1129
1130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1131pub struct ProviderIirFilterOptions {
1132    /// Zero-based dimension along which filtering should be applied.
1133    pub dim: usize,
1134    /// Optional initial conditions (state vector) residing on the device.
1135    pub zi: Option<GpuTensorHandle>,
1136}
1137
1138#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1139pub struct ProviderIirFilterResult {
1140    /// Filtered output tensor, matching the input signal shape.
1141    pub output: GpuTensorHandle,
1142    /// Final conditions for the filter state (same shape as the requested `zi` layout).
1143    pub final_state: Option<GpuTensorHandle>,
1144}
1145
1146#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1147pub struct ProviderMoments2 {
1148    pub mean: GpuTensorHandle,
1149    pub ex2: GpuTensorHandle,
1150}
1151
1152#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
1153pub struct ProviderDispatchStats {
1154    /// Number of GPU dispatches recorded for this category.
1155    pub count: u64,
1156    /// Accumulated wall-clock time of dispatches in nanoseconds (host measured).
1157    pub total_wall_time_ns: u64,
1158}
1159
1160#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
1161pub struct ProviderFallbackStat {
1162    pub reason: String,
1163    pub count: u64,
1164}
1165
1166#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
1167pub struct ProviderTelemetry {
1168    pub fused_elementwise: ProviderDispatchStats,
1169    pub fused_reduction: ProviderDispatchStats,
1170    pub matmul: ProviderDispatchStats,
1171    pub linsolve: ProviderDispatchStats,
1172    pub mldivide: ProviderDispatchStats,
1173    pub mrdivide: ProviderDispatchStats,
1174    pub upload_bytes: u64,
1175    pub download_bytes: u64,
1176    pub solve_fallbacks: Vec<ProviderFallbackStat>,
1177    pub fusion_cache_hits: u64,
1178    pub fusion_cache_misses: u64,
1179    pub bind_group_cache_hits: u64,
1180    pub bind_group_cache_misses: u64,
1181    /// Optional per-layout bind group cache counters (layout tags and their hit/miss counts)
1182    pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
1183    /// Recent kernel launch metadata (bounded log; newest last)
1184    pub kernel_launches: Vec<KernelLaunchTelemetry>,
1185}
1186
1187#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1188pub struct BindGroupLayoutTelemetry {
1189    pub tag: String,
1190    pub hits: u64,
1191    pub misses: u64,
1192}
1193
1194#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1195pub struct KernelAttrTelemetry {
1196    pub key: String,
1197    pub value: u64,
1198}
1199
1200#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1201pub struct KernelLaunchTelemetry {
1202    pub kernel: String,
1203    pub precision: Option<String>,
1204    pub shape: Vec<KernelAttrTelemetry>,
1205    pub tuning: Vec<KernelAttrTelemetry>,
1206}
1207
1208pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
1209pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
1210
1211fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
1212    Box::pin(async move { Err(anyhow::anyhow!(message)) })
1213}
1214
1215/// Device/provider interface that backends implement and register into the runtime layer
1216pub trait AccelProvider: Send + Sync {
1217    fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
1218    fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
1219    fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
1220    fn device_info(&self) -> String;
1221    fn device_id(&self) -> u32 {
1222        0
1223    }
1224
1225    /// Declares provider policy for sharing `GpuTensorHandle` values across
1226    /// spawned async boundaries.
1227    ///
1228    /// Default is conservative rejection. Providers that can safely support
1229    /// cross-task sharing should override this.
1230    fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
1231        SpawnHandleConcurrency::Reject
1232    }
1233
1234    /// Export a shared GPU context handle, allowing downstream systems (plotting, visualization)
1235    /// to reuse the same device/queue without copying tensor data back to the host.
1236    fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
1237        None
1238    }
1239
1240    /// Export a provider-owned WGPU buffer for zero-copy integrations.
1241    #[cfg(feature = "wgpu")]
1242    fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
1243        let _ = _handle;
1244        None
1245    }
1246
1247    /// Gather elements from `source` at the provided zero-based linear `indices`, materialising
1248    /// a dense tensor with the specified `output_shape`.
1249    fn gather_linear(
1250        &self,
1251        _source: &GpuTensorHandle,
1252        _indices: &[u32],
1253        _output_shape: &[usize],
1254    ) -> anyhow::Result<GpuTensorHandle> {
1255        Err(anyhow::anyhow!("gather_linear not supported by provider"))
1256    }
1257
1258    /// Scatter the contents of `values` into `target` at the provided zero-based linear `indices`.
1259    ///
1260    /// The provider must ensure `values.len() == indices.len()` and update `target` in place.
1261    fn scatter_linear(
1262        &self,
1263        _target: &GpuTensorHandle,
1264        _indices: &[u32],
1265        _values: &GpuTensorHandle,
1266    ) -> anyhow::Result<()> {
1267        Err(anyhow::anyhow!("scatter_linear not supported by provider"))
1268    }
1269
1270    /// Structured device information (optional to override). Default adapts from `device_info()`.
1271    fn device_info_struct(&self) -> ApiDeviceInfo {
1272        ApiDeviceInfo {
1273            device_id: 0,
1274            name: self.device_info(),
1275            vendor: String::new(),
1276            memory_bytes: None,
1277            backend: None,
1278        }
1279    }
1280
1281    fn precision(&self) -> ProviderPrecision {
1282        ProviderPrecision::F64
1283    }
1284
1285    /// Read a single scalar at linear index from a device tensor, returning it as f64.
1286    fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
1287        Err(anyhow::anyhow!("read_scalar not supported by provider"))
1288    }
1289
1290    /// Allocate a zero-initialised tensor with the provided shape on the device.
1291    fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1292        Err(anyhow::anyhow!("zeros not supported by provider"))
1293    }
1294
1295    /// Allocate a one-initialised tensor with the provided shape on the device.
1296    fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1297        Err(anyhow::anyhow!("ones not supported by provider"))
1298    }
1299
1300    /// Allocate a zero-initialised tensor matching the prototype tensor.
1301    fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1302        self.zeros(&prototype.shape)
1303    }
1304
1305    /// Allocate a tensor filled with a constant value on the device.
1306    fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
1307        if value == 0.0 {
1308            return self.zeros(shape);
1309        }
1310        if let Ok(base) = self.zeros(shape) {
1311            match self.scalar_add(&base, value) {
1312                Ok(out) => {
1313                    let _ = self.free(&base);
1314                    return Ok(out);
1315                }
1316                Err(_) => {
1317                    let _ = self.free(&base);
1318                }
1319            }
1320        }
1321        let len: usize = shape.iter().copied().product();
1322        let data = vec![value; len];
1323        let view = HostTensorView { data: &data, shape };
1324        self.upload(&view)
1325    }
1326
1327    /// Allocate a tensor filled with a constant value, matching a prototype's residency.
1328    fn fill_like(
1329        &self,
1330        prototype: &GpuTensorHandle,
1331        value: f64,
1332    ) -> anyhow::Result<GpuTensorHandle> {
1333        if value == 0.0 {
1334            return self.zeros_like(prototype);
1335        }
1336        if let Ok(base) = self.zeros_like(prototype) {
1337            match self.scalar_add(&base, value) {
1338                Ok(out) => {
1339                    let _ = self.free(&base);
1340                    return Ok(out);
1341                }
1342                Err(_) => {
1343                    let _ = self.free(&base);
1344                }
1345            }
1346        }
1347        self.fill(&prototype.shape, value)
1348    }
1349
1350    /// Allocate a one-initialised tensor matching the prototype tensor.
1351    fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1352        self.ones(&prototype.shape)
1353    }
1354
1355    /// Allocate an identity tensor with ones along the leading diagonal of the first two axes.
1356    fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1357        Err(anyhow::anyhow!("eye not supported by provider"))
1358    }
1359
1360    /// Allocate an identity tensor matching the prototype tensor's shape.
1361    fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1362        self.eye(&prototype.shape)
1363    }
1364
1365    /// Construct MATLAB-style coordinate grids from axis vectors.
1366    fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1367        Err(anyhow::anyhow!("meshgrid not supported by provider"))
1368    }
1369
1370    /// Construct a diagonal matrix from a vector-like tensor. `offset` matches MATLAB semantics.
1371    fn diag_from_vector(
1372        &self,
1373        _vector: &GpuTensorHandle,
1374        _offset: isize,
1375    ) -> anyhow::Result<GpuTensorHandle> {
1376        Err(anyhow::anyhow!(
1377            "diag_from_vector not supported by provider"
1378        ))
1379    }
1380
1381    /// Extract a diagonal from a matrix-like tensor. The result is always a column vector.
1382    fn diag_extract(
1383        &self,
1384        _matrix: &GpuTensorHandle,
1385        _offset: isize,
1386    ) -> anyhow::Result<GpuTensorHandle> {
1387        Err(anyhow::anyhow!("diag_extract not supported by provider"))
1388    }
1389
1390    /// Apply a lower-triangular mask to the first two dimensions of a tensor.
1391    fn tril<'a>(
1392        &'a self,
1393        _matrix: &'a GpuTensorHandle,
1394        _offset: isize,
1395    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1396        Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1397    }
1398
1399    /// Apply an upper-triangular mask to the first two dimensions of a tensor.
1400    fn triu<'a>(
1401        &'a self,
1402        _matrix: &'a GpuTensorHandle,
1403        _offset: isize,
1404    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1405        Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1406    }
1407
1408    /// Evaluate a polynomial expressed by `coefficients` at each element in `points`.
1409    fn polyval(
1410        &self,
1411        _coefficients: &GpuTensorHandle,
1412        _points: &GpuTensorHandle,
1413        _options: &ProviderPolyvalOptions,
1414    ) -> anyhow::Result<GpuTensorHandle> {
1415        Err(anyhow::anyhow!("polyval not supported by provider"))
1416    }
1417
1418    /// Fit a polynomial of degree `degree` to `(x, y)` samples. Optional weights must match `x`.
1419    fn polyfit<'a>(
1420        &'a self,
1421        _x: &'a GpuTensorHandle,
1422        _y: &'a GpuTensorHandle,
1423        _degree: usize,
1424        _weights: Option<&'a GpuTensorHandle>,
1425    ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1426        Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1427    }
1428
1429    /// Differentiate a polynomial represented as a vector of coefficients.
1430    fn polyder_single<'a>(
1431        &'a self,
1432        _polynomial: &'a GpuTensorHandle,
1433    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1434        Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1435    }
1436
1437    /// Apply the product rule to polynomials `p` and `q`.
1438    fn polyder_product<'a>(
1439        &'a self,
1440        _p: &'a GpuTensorHandle,
1441        _q: &'a GpuTensorHandle,
1442    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1443        Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1444    }
1445
1446    /// Apply the quotient rule to polynomials `u` and `v`.
1447    fn polyder_quotient<'a>(
1448        &'a self,
1449        _u: &'a GpuTensorHandle,
1450        _v: &'a GpuTensorHandle,
1451    ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1452        Box::pin(async move {
1453            Err(anyhow::anyhow!(
1454                "polyder_quotient not supported by provider"
1455            ))
1456        })
1457    }
1458
1459    /// Integrate a polynomial represented as a vector of coefficients and append a constant term.
1460    fn polyint(
1461        &self,
1462        _polynomial: &GpuTensorHandle,
1463        _constant: f64,
1464    ) -> anyhow::Result<GpuTensorHandle> {
1465        Err(anyhow::anyhow!("polyint not supported by provider"))
1466    }
1467
1468    /// Allocate a tensor filled with random values drawn from U(0, 1).
1469    fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1470        Err(anyhow::anyhow!("random_uniform not supported by provider"))
1471    }
1472
1473    /// Allocate a tensor filled with random values matching the prototype shape.
1474    fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1475        self.random_uniform(&prototype.shape)
1476    }
1477
1478    /// Allocate a tensor filled with standard normal (mean 0, stddev 1) random values.
1479    fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1480        Err(anyhow::anyhow!("random_normal not supported by provider"))
1481    }
1482
1483    /// Allocate a tensor of standard normal values matching a prototype's shape.
1484    fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1485        self.random_normal(&prototype.shape)
1486    }
1487
1488    /// Exponentially-distributed random values with mean `mu`.
1489    fn random_exponential(&self, _mu: f64, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1490        Err(anyhow::anyhow!(
1491            "random_exponential not supported by provider"
1492        ))
1493    }
1494
1495    /// Normal random values with mean `mu` and standard deviation `sigma`.
1496    fn random_normrnd(
1497        &self,
1498        _mu: f64,
1499        _sigma: f64,
1500        _shape: &[usize],
1501    ) -> anyhow::Result<GpuTensorHandle> {
1502        Err(anyhow::anyhow!("random_normrnd not supported by provider"))
1503    }
1504
1505    /// Uniform random values on the interval `[a, b)`.
1506    fn random_unifrnd(
1507        &self,
1508        _a: f64,
1509        _b: f64,
1510        _shape: &[usize],
1511    ) -> anyhow::Result<GpuTensorHandle> {
1512        Err(anyhow::anyhow!("random_unifrnd not supported by provider"))
1513    }
1514
1515    fn stochastic_evolution(
1516        &self,
1517        _state: &GpuTensorHandle,
1518        _drift: f64,
1519        _scale: f64,
1520        _steps: u32,
1521    ) -> anyhow::Result<GpuTensorHandle> {
1522        Err(anyhow::anyhow!(
1523            "stochastic_evolution not supported by provider"
1524        ))
1525    }
1526
1527    /// Set the provider RNG state to align with the host RNG.
1528    fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1529        Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1530    }
1531
1532    /// Generate a 2-D correlation kernel matching MATLAB's `fspecial` builtin.
1533    fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1534        Err(anyhow::anyhow!("fspecial not supported by provider"))
1535    }
1536
1537    /// Evaluate the `peaks` test surface on an n×n grid spanning [-3,3]×[-3,3].
1538    /// Returns the Z matrix (n×n) as a GPU tensor.
1539    fn peaks(&self, _n: usize) -> anyhow::Result<GpuTensorHandle> {
1540        Err(anyhow::anyhow!("peaks not supported by provider"))
1541    }
1542
1543    /// Evaluate the `peaks` formula element-wise on caller-supplied GPU coordinate tensors.
1544    /// X and Y must have the same shape. Returns a Z tensor of the same shape.
1545    fn peaks_xy(
1546        &self,
1547        _x: &GpuTensorHandle,
1548        _y: &GpuTensorHandle,
1549    ) -> anyhow::Result<GpuTensorHandle> {
1550        Err(anyhow::anyhow!("peaks_xy not supported by provider"))
1551    }
1552
1553    fn hann_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1554        Err(anyhow::anyhow!("hann_window not supported by provider"))
1555    }
1556
1557    fn hamming_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1558        Err(anyhow::anyhow!("hamming_window not supported by provider"))
1559    }
1560
1561    fn blackman_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1562        Err(anyhow::anyhow!("blackman_window not supported by provider"))
1563    }
1564
1565    /// Apply an N-D correlation/convolution with padding semantics matching MATLAB's `imfilter`.
1566    fn imfilter<'a>(
1567        &'a self,
1568        _image: &'a GpuTensorHandle,
1569        _kernel: &'a GpuTensorHandle,
1570        _options: &'a ImfilterOptions,
1571    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1572        unsupported_future("imfilter not supported by provider")
1573    }
1574
1575    /// Allocate a tensor filled with random integers over an inclusive range.
1576    fn random_integer_range(
1577        &self,
1578        _lower: i64,
1579        _upper: i64,
1580        _shape: &[usize],
1581    ) -> anyhow::Result<GpuTensorHandle> {
1582        Err(anyhow::anyhow!(
1583            "random_integer_range not supported by provider"
1584        ))
1585    }
1586
1587    /// Allocate a random integer tensor matching the prototype shape.
1588    fn random_integer_like(
1589        &self,
1590        prototype: &GpuTensorHandle,
1591        lower: i64,
1592        upper: i64,
1593    ) -> anyhow::Result<GpuTensorHandle> {
1594        self.random_integer_range(lower, upper, &prototype.shape)
1595    }
1596
1597    /// Allocate a random permutation of 1..=n, returning the first k elements.
1598    fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1599        Err(anyhow!("random_permutation not supported by provider"))
1600    }
1601
1602    /// Allocate a random permutation matching the prototype residency.
1603    fn random_permutation_like(
1604        &self,
1605        _prototype: &GpuTensorHandle,
1606        n: usize,
1607        k: usize,
1608    ) -> anyhow::Result<GpuTensorHandle> {
1609        self.random_permutation(n, k)
1610    }
1611
1612    /// Compute a covariance matrix across the columns of `matrix`.
1613    fn covariance<'a>(
1614        &'a self,
1615        _matrix: &'a GpuTensorHandle,
1616        _second: Option<&'a GpuTensorHandle>,
1617        _weights: Option<&'a GpuTensorHandle>,
1618        _options: &'a CovarianceOptions,
1619    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1620        unsupported_future("covariance not supported by provider")
1621    }
1622
1623    /// Compute a correlation coefficient matrix across the columns of `matrix`.
1624    fn corrcoef<'a>(
1625        &'a self,
1626        _matrix: &'a GpuTensorHandle,
1627        _options: &'a CorrcoefOptions,
1628    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1629        unsupported_future("corrcoef not supported by provider")
1630    }
1631
1632    // Optional operator hooks (default to unsupported)
1633    fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1634        Err(anyhow::anyhow!("linspace not supported by provider"))
1635    }
1636    fn elem_add<'a>(
1637        &'a self,
1638        _a: &'a GpuTensorHandle,
1639        _b: &'a GpuTensorHandle,
1640    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1641        unsupported_future("elem_add not supported by provider")
1642    }
1643    fn elem_mul<'a>(
1644        &'a self,
1645        _a: &'a GpuTensorHandle,
1646        _b: &'a GpuTensorHandle,
1647    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1648        unsupported_future("elem_mul not supported by provider")
1649    }
1650    fn elem_max<'a>(
1651        &'a self,
1652        _a: &'a GpuTensorHandle,
1653        _b: &'a GpuTensorHandle,
1654    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1655        unsupported_future("elem_max not supported by provider")
1656    }
1657    fn elem_min<'a>(
1658        &'a self,
1659        _a: &'a GpuTensorHandle,
1660        _b: &'a GpuTensorHandle,
1661    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1662        unsupported_future("elem_min not supported by provider")
1663    }
1664    fn elem_sub<'a>(
1665        &'a self,
1666        _a: &'a GpuTensorHandle,
1667        _b: &'a GpuTensorHandle,
1668    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1669        unsupported_future("elem_sub not supported by provider")
1670    }
1671    fn elem_div<'a>(
1672        &'a self,
1673        _a: &'a GpuTensorHandle,
1674        _b: &'a GpuTensorHandle,
1675    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1676        unsupported_future("elem_div not supported by provider")
1677    }
1678    fn elem_pow<'a>(
1679        &'a self,
1680        _a: &'a GpuTensorHandle,
1681        _b: &'a GpuTensorHandle,
1682    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1683        unsupported_future("elem_pow not supported by provider")
1684    }
1685
1686    /// Construct complex-interleaved GPU storage from a real-valued tensor,
1687    /// using zero for the imaginary lane.
1688    fn complex_from_real<'a>(
1689        &'a self,
1690        _real: &'a GpuTensorHandle,
1691    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1692        unsupported_future("complex_from_real not supported by provider")
1693    }
1694
1695    /// Construct complex-interleaved GPU storage from real and imaginary tensors.
1696    ///
1697    /// Implementations should support equal shapes and scalar expansion for either
1698    /// operand, matching MATLAB's `complex(real, imag)` size rules.
1699    fn complex_from_real_imag<'a>(
1700        &'a self,
1701        _real: &'a GpuTensorHandle,
1702        _imag: &'a GpuTensorHandle,
1703    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1704        unsupported_future("complex_from_real_imag not supported by provider")
1705    }
1706
1707    /// Map a resident real-valued symbol tensor through a complex constellation
1708    /// table and return complex-interleaved GPU storage.
1709    fn modulate_constellation<'a>(
1710        &'a self,
1711        _request: ProviderModulationRequest<'a>,
1712    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1713        unsupported_future("modulate_constellation not supported by provider")
1714    }
1715
1716    /// Group a resident real/logical bit tensor into symbols, map through a complex
1717    /// constellation table, and return complex-interleaved GPU storage.
1718    fn modulate_bits_constellation<'a>(
1719        &'a self,
1720        _request: ProviderBitModulationRequest<'a>,
1721    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1722        unsupported_future("modulate_bits_constellation not supported by provider")
1723    }
1724
1725    fn elem_hypot<'a>(
1726        &'a self,
1727        _a: &'a GpuTensorHandle,
1728        _b: &'a GpuTensorHandle,
1729    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1730        unsupported_future("elem_hypot not supported by provider")
1731    }
1732    fn elem_ge<'a>(
1733        &'a self,
1734        _a: &'a GpuTensorHandle,
1735        _b: &'a GpuTensorHandle,
1736    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1737        unsupported_future("elem_ge not supported by provider")
1738    }
1739    fn elem_le<'a>(
1740        &'a self,
1741        _a: &'a GpuTensorHandle,
1742        _b: &'a GpuTensorHandle,
1743    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1744        unsupported_future("elem_le not supported by provider")
1745    }
1746    fn elem_lt<'a>(
1747        &'a self,
1748        _a: &'a GpuTensorHandle,
1749        _b: &'a GpuTensorHandle,
1750    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1751        unsupported_future("elem_lt not supported by provider")
1752    }
1753    fn elem_gt<'a>(
1754        &'a self,
1755        _a: &'a GpuTensorHandle,
1756        _b: &'a GpuTensorHandle,
1757    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1758        unsupported_future("elem_gt not supported by provider")
1759    }
1760    fn elem_eq<'a>(
1761        &'a self,
1762        _a: &'a GpuTensorHandle,
1763        _b: &'a GpuTensorHandle,
1764    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1765        unsupported_future("elem_eq not supported by provider")
1766    }
1767    fn elem_ne<'a>(
1768        &'a self,
1769        _a: &'a GpuTensorHandle,
1770        _b: &'a GpuTensorHandle,
1771    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1772        unsupported_future("elem_ne not supported by provider")
1773    }
1774    fn logical_and(
1775        &self,
1776        _a: &GpuTensorHandle,
1777        _b: &GpuTensorHandle,
1778    ) -> anyhow::Result<GpuTensorHandle> {
1779        Err(anyhow::anyhow!("logical_and not supported by provider"))
1780    }
1781    fn logical_or(
1782        &self,
1783        _a: &GpuTensorHandle,
1784        _b: &GpuTensorHandle,
1785    ) -> anyhow::Result<GpuTensorHandle> {
1786        Err(anyhow::anyhow!("logical_or not supported by provider"))
1787    }
1788    fn logical_xor(
1789        &self,
1790        _a: &GpuTensorHandle,
1791        _b: &GpuTensorHandle,
1792    ) -> anyhow::Result<GpuTensorHandle> {
1793        Err(anyhow::anyhow!("logical_xor not supported by provider"))
1794    }
1795    fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1796        Err(anyhow::anyhow!("logical_not not supported by provider"))
1797    }
1798    fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1799        Ok(handle_is_logical(a))
1800    }
1801    fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1802        Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1803    }
1804    fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1805        Err(anyhow::anyhow!(
1806            "logical_isfinite not supported by provider"
1807        ))
1808    }
1809    fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1810        Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1811    }
1812    fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1813        Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1814    }
1815    fn elem_atan2<'a>(
1816        &'a self,
1817        _y: &'a GpuTensorHandle,
1818        _x: &'a GpuTensorHandle,
1819    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1820        unsupported_future("elem_atan2 not supported by provider")
1821    }
1822    // Unary elementwise operations (optional)
1823    fn unary_sin<'a>(
1824        &'a self,
1825        _a: &'a GpuTensorHandle,
1826    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1827        unsupported_future("unary_sin not supported by provider")
1828    }
1829    fn unary_sinc<'a>(
1830        &'a self,
1831        _a: &'a GpuTensorHandle,
1832    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1833        unsupported_future("unary_sinc not supported by provider")
1834    }
1835    fn unary_gamma<'a>(
1836        &'a self,
1837        _a: &'a GpuTensorHandle,
1838    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1839        unsupported_future("unary_gamma not supported by provider")
1840    }
1841    fn unary_factorial<'a>(
1842        &'a self,
1843        _a: &'a GpuTensorHandle,
1844    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1845        unsupported_future("unary_factorial not supported by provider")
1846    }
1847    fn unary_asinh<'a>(
1848        &'a self,
1849        _a: &'a GpuTensorHandle,
1850    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1851        unsupported_future("unary_asinh not supported by provider")
1852    }
1853    fn unary_sinh<'a>(
1854        &'a self,
1855        _a: &'a GpuTensorHandle,
1856    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1857        unsupported_future("unary_sinh not supported by provider")
1858    }
1859    fn unary_cosh<'a>(
1860        &'a self,
1861        _a: &'a GpuTensorHandle,
1862    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1863        unsupported_future("unary_cosh not supported by provider")
1864    }
1865    fn unary_asin<'a>(
1866        &'a self,
1867        _a: &'a GpuTensorHandle,
1868    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1869        unsupported_future("unary_asin not supported by provider")
1870    }
1871    fn unary_acos<'a>(
1872        &'a self,
1873        _a: &'a GpuTensorHandle,
1874    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1875        unsupported_future("unary_acos not supported by provider")
1876    }
1877    fn unary_acosh<'a>(
1878        &'a self,
1879        _a: &'a GpuTensorHandle,
1880    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1881        unsupported_future("unary_acosh not supported by provider")
1882    }
1883    fn unary_tan<'a>(
1884        &'a self,
1885        _a: &'a GpuTensorHandle,
1886    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1887        unsupported_future("unary_tan not supported by provider")
1888    }
1889    fn unary_tanh<'a>(
1890        &'a self,
1891        _a: &'a GpuTensorHandle,
1892    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1893        unsupported_future("unary_tanh not supported by provider")
1894    }
1895    fn unary_atan<'a>(
1896        &'a self,
1897        _a: &'a GpuTensorHandle,
1898    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1899        unsupported_future("unary_atan not supported by provider")
1900    }
1901    fn unary_atanh<'a>(
1902        &'a self,
1903        _a: &'a GpuTensorHandle,
1904    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1905        unsupported_future("unary_atanh not supported by provider")
1906    }
1907    fn unary_ceil<'a>(
1908        &'a self,
1909        _a: &'a GpuTensorHandle,
1910    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1911        unsupported_future("unary_ceil not supported by provider")
1912    }
1913    fn unary_floor<'a>(
1914        &'a self,
1915        _a: &'a GpuTensorHandle,
1916    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1917        unsupported_future("unary_floor not supported by provider")
1918    }
1919    fn unary_round<'a>(
1920        &'a self,
1921        _a: &'a GpuTensorHandle,
1922    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1923        unsupported_future("unary_round not supported by provider")
1924    }
1925    fn unary_fix<'a>(
1926        &'a self,
1927        _a: &'a GpuTensorHandle,
1928    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1929        unsupported_future("unary_fix not supported by provider")
1930    }
1931    fn unary_cos<'a>(
1932        &'a self,
1933        _a: &'a GpuTensorHandle,
1934    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1935        unsupported_future("unary_cos not supported by provider")
1936    }
1937    fn unary_angle<'a>(
1938        &'a self,
1939        _a: &'a GpuTensorHandle,
1940    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1941        unsupported_future("unary_angle not supported by provider")
1942    }
1943    fn unary_imag<'a>(
1944        &'a self,
1945        _a: &'a GpuTensorHandle,
1946    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1947        unsupported_future("unary_imag not supported by provider")
1948    }
1949    fn unary_real<'a>(
1950        &'a self,
1951        _a: &'a GpuTensorHandle,
1952    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1953        unsupported_future("unary_real not supported by provider")
1954    }
1955    fn unary_conj<'a>(
1956        &'a self,
1957        _a: &'a GpuTensorHandle,
1958    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1959        unsupported_future("unary_conj not supported by provider")
1960    }
1961    fn unary_abs<'a>(
1962        &'a self,
1963        _a: &'a GpuTensorHandle,
1964    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1965        unsupported_future("unary_abs not supported by provider")
1966    }
1967    fn unary_sign<'a>(
1968        &'a self,
1969        _a: &'a GpuTensorHandle,
1970    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1971        unsupported_future("unary_sign not supported by provider")
1972    }
1973    fn unary_heaviside<'a>(
1974        &'a self,
1975        _a: &'a GpuTensorHandle,
1976    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1977        unsupported_future("unary_heaviside not supported by provider")
1978    }
1979    fn unary_exp<'a>(
1980        &'a self,
1981        _a: &'a GpuTensorHandle,
1982    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1983        unsupported_future("unary_exp not supported by provider")
1984    }
1985    fn unary_expm1<'a>(
1986        &'a self,
1987        _a: &'a GpuTensorHandle,
1988    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1989        unsupported_future("unary_expm1 not supported by provider")
1990    }
1991    fn unary_log<'a>(
1992        &'a self,
1993        _a: &'a GpuTensorHandle,
1994    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1995        unsupported_future("unary_log not supported by provider")
1996    }
1997    fn unary_log2<'a>(
1998        &'a self,
1999        _a: &'a GpuTensorHandle,
2000    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2001        unsupported_future("unary_log2 not supported by provider")
2002    }
2003    fn unary_log10<'a>(
2004        &'a self,
2005        _a: &'a GpuTensorHandle,
2006    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2007        unsupported_future("unary_log10 not supported by provider")
2008    }
2009    fn unary_log1p<'a>(
2010        &'a self,
2011        _a: &'a GpuTensorHandle,
2012    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2013        unsupported_future("unary_log1p not supported by provider")
2014    }
2015    fn unary_sqrt<'a>(
2016        &'a self,
2017        _a: &'a GpuTensorHandle,
2018    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2019        unsupported_future("unary_sqrt not supported by provider")
2020    }
2021    fn unary_double<'a>(
2022        &'a self,
2023        _a: &'a GpuTensorHandle,
2024    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2025        unsupported_future("unary_double not supported by provider")
2026    }
2027    fn unary_single<'a>(
2028        &'a self,
2029        _a: &'a GpuTensorHandle,
2030    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2031        unsupported_future("unary_single not supported by provider")
2032    }
2033    fn unary_pow2<'a>(
2034        &'a self,
2035        _a: &'a GpuTensorHandle,
2036    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2037        unsupported_future("unary_pow2 not supported by provider")
2038    }
2039    fn unary_nextpow2<'a>(
2040        &'a self,
2041        _a: &'a GpuTensorHandle,
2042    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2043        unsupported_future("unary_nextpow2 not supported by provider")
2044    }
2045    fn pow2_scale(
2046        &self,
2047        _mantissa: &GpuTensorHandle,
2048        _exponent: &GpuTensorHandle,
2049    ) -> anyhow::Result<GpuTensorHandle> {
2050        Err(anyhow::anyhow!("pow2_scale not supported by provider"))
2051    }
2052    // Left-scalar operations (broadcast with scalar on the left)
2053    fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2054        Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
2055    }
2056    fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2057        Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
2058    }
2059    // Scalar operations: apply op with scalar right-hand side (broadcast over a)
2060    fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2061        Err(anyhow::anyhow!("scalar_add not supported by provider"))
2062    }
2063    fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2064        Err(anyhow::anyhow!("scalar_sub not supported by provider"))
2065    }
2066    fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2067        Err(anyhow::anyhow!("scalar_mul not supported by provider"))
2068    }
2069    fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2070        Err(anyhow::anyhow!("scalar_max not supported by provider"))
2071    }
2072    fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2073        Err(anyhow::anyhow!("scalar_min not supported by provider"))
2074    }
2075    fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2076        Err(anyhow::anyhow!("scalar_div not supported by provider"))
2077    }
2078    fn sort_dim<'a>(
2079        &'a self,
2080        _a: &'a GpuTensorHandle,
2081        _dim: usize,
2082        _order: SortOrder,
2083        _comparison: SortComparison,
2084    ) -> AccelProviderFuture<'a, SortResult> {
2085        unsupported_future("sort_dim not supported by provider")
2086    }
2087    fn sort_rows<'a>(
2088        &'a self,
2089        _a: &'a GpuTensorHandle,
2090        _columns: &'a [SortRowsColumnSpec],
2091        _comparison: SortComparison,
2092    ) -> AccelProviderFuture<'a, SortResult> {
2093        unsupported_future("sort_rows not supported by provider")
2094    }
2095    fn matmul<'a>(
2096        &'a self,
2097        _a: &'a GpuTensorHandle,
2098        _b: &'a GpuTensorHandle,
2099    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2100        unsupported_future("matmul not supported by provider")
2101    }
2102
2103    fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2104        Err(anyhow::anyhow!("syrk not supported by provider"))
2105    }
2106    fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
2107        Err(anyhow::anyhow!("pagefun not supported by provider"))
2108    }
2109
2110    /// Optional: matrix multiplication with an epilogue applied before store.
2111    ///
2112    /// The default implementation falls back to `matmul` when the epilogue is effectively a no-op
2113    /// (alpha=1, beta=0, no row/col scales), and otherwise returns `Err`.
2114    fn matmul_epilogue<'a>(
2115        &'a self,
2116        a: &'a GpuTensorHandle,
2117        b: &'a GpuTensorHandle,
2118        epilogue: &'a MatmulEpilogue,
2119    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2120        Box::pin(async move {
2121            if epilogue.is_noop() {
2122                return self.matmul(a, b).await;
2123            }
2124            Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
2125        })
2126    }
2127    fn image_normalize<'a>(
2128        &'a self,
2129        _input: &'a GpuTensorHandle,
2130        _desc: &'a ImageNormalizeDescriptor,
2131    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2132        unsupported_future("image_normalize fusion not supported by provider")
2133    }
2134    fn matmul_power_step<'a>(
2135        &'a self,
2136        _lhs: &'a GpuTensorHandle,
2137        _rhs: &'a GpuTensorHandle,
2138        _epilogue: &'a PowerStepEpilogue,
2139    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2140        unsupported_future("matmul_power_step normalization not supported by provider")
2141    }
2142    fn linsolve<'a>(
2143        &'a self,
2144        _lhs: &'a GpuTensorHandle,
2145        _rhs: &'a GpuTensorHandle,
2146        _options: &'a ProviderLinsolveOptions,
2147    ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
2148        unsupported_future("linsolve not supported by provider")
2149    }
2150    fn inv<'a>(
2151        &'a self,
2152        _matrix: &'a GpuTensorHandle,
2153        _options: ProviderInvOptions,
2154    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2155        unsupported_future("inv not supported by provider")
2156    }
2157    fn pinv<'a>(
2158        &'a self,
2159        _matrix: &'a GpuTensorHandle,
2160        _options: ProviderPinvOptions,
2161    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2162        unsupported_future("pinv not supported by provider")
2163    }
2164    fn cond<'a>(
2165        &'a self,
2166        _matrix: &'a GpuTensorHandle,
2167        _norm: ProviderCondNorm,
2168    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2169        Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
2170    }
2171    fn norm<'a>(
2172        &'a self,
2173        _tensor: &'a GpuTensorHandle,
2174        _order: ProviderNormOrder,
2175    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2176        Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
2177    }
2178    fn rank<'a>(
2179        &'a self,
2180        _matrix: &'a GpuTensorHandle,
2181        _tolerance: Option<f64>,
2182    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2183        Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
2184    }
2185    fn rcond<'a>(
2186        &'a self,
2187        _matrix: &'a GpuTensorHandle,
2188    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2189        Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
2190    }
2191    fn mldivide<'a>(
2192        &'a self,
2193        _lhs: &'a GpuTensorHandle,
2194        _rhs: &'a GpuTensorHandle,
2195    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2196        Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
2197    }
2198    fn mrdivide<'a>(
2199        &'a self,
2200        _lhs: &'a GpuTensorHandle,
2201        _rhs: &'a GpuTensorHandle,
2202    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2203        Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
2204    }
2205    fn eig<'a>(
2206        &'a self,
2207        _a: &'a GpuTensorHandle,
2208        _compute_left: bool,
2209    ) -> AccelProviderFuture<'a, ProviderEigResult> {
2210        Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
2211    }
2212    fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
2213        Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
2214    }
2215
2216    fn chol<'a>(
2217        &'a self,
2218        _a: &'a GpuTensorHandle,
2219        _lower: bool,
2220    ) -> AccelProviderFuture<'a, ProviderCholResult> {
2221        Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
2222    }
2223    fn qr<'a>(
2224        &'a self,
2225        _a: &'a GpuTensorHandle,
2226        _options: ProviderQrOptions,
2227    ) -> AccelProviderFuture<'a, ProviderQrResult> {
2228        Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
2229    }
2230    fn take_matmul_sources(
2231        &self,
2232        _product: &GpuTensorHandle,
2233    ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
2234        None
2235    }
2236    fn qr_power_iter<'a>(
2237        &'a self,
2238        product: &'a GpuTensorHandle,
2239        _product_lhs: Option<&'a GpuTensorHandle>,
2240        q_handle: &'a GpuTensorHandle,
2241        options: &'a ProviderQrOptions,
2242    ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
2243        let _ = (product, q_handle, options);
2244        Box::pin(async move { Ok(None) })
2245    }
2246    fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2247        Err(anyhow::anyhow!("transpose not supported by provider"))
2248    }
2249    fn conv1d(
2250        &self,
2251        _signal: &GpuTensorHandle,
2252        _kernel: &GpuTensorHandle,
2253        _options: ProviderConv1dOptions,
2254    ) -> anyhow::Result<GpuTensorHandle> {
2255        Err(anyhow::anyhow!("conv1d not supported by provider"))
2256    }
2257    fn conv2d(
2258        &self,
2259        _signal: &GpuTensorHandle,
2260        _kernel: &GpuTensorHandle,
2261        _mode: ProviderConvMode,
2262    ) -> anyhow::Result<GpuTensorHandle> {
2263        Err(anyhow::anyhow!("conv2d not supported by provider"))
2264    }
2265    fn iir_filter<'a>(
2266        &'a self,
2267        _b: &'a GpuTensorHandle,
2268        _a: &'a GpuTensorHandle,
2269        _x: &'a GpuTensorHandle,
2270        _options: ProviderIirFilterOptions,
2271    ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
2272        Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
2273    }
2274    fn uniform_spectral_estimate<'a>(
2275        &'a self,
2276        _request: &'a ProviderSpectralRequest<'a>,
2277    ) -> AccelProviderFuture<'a, ProviderSpectralResult> {
2278        unsupported_future("uniform_spectral_estimate not supported by provider")
2279    }
2280    fn signal_envelope<'a>(
2281        &'a self,
2282        _request: &'a ProviderEnvelopeRequest<'a>,
2283    ) -> AccelProviderFuture<'a, ProviderEnvelopeResult> {
2284        unsupported_future("signal_envelope not supported by provider")
2285    }
2286    fn signal_hilbert<'a>(
2287        &'a self,
2288        _request: &'a ProviderHilbertRequest<'a>,
2289    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2290        unsupported_future("signal_hilbert not supported by provider")
2291    }
2292    /// Reorder tensor dimensions according to `order`, expressed as zero-based indices.
2293    fn permute(
2294        &self,
2295        _handle: &GpuTensorHandle,
2296        _order: &[usize],
2297    ) -> anyhow::Result<GpuTensorHandle> {
2298        Err(anyhow::anyhow!("permute not supported by provider"))
2299    }
2300    fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
2301        Err(anyhow::anyhow!("flip not supported by provider"))
2302    }
2303    fn circshift(
2304        &self,
2305        _handle: &GpuTensorHandle,
2306        _shifts: &[isize],
2307    ) -> anyhow::Result<GpuTensorHandle> {
2308        Err(anyhow::anyhow!("circshift not supported by provider"))
2309    }
2310    fn diff_dim(
2311        &self,
2312        _handle: &GpuTensorHandle,
2313        _order: usize,
2314        _dim: usize,
2315    ) -> anyhow::Result<GpuTensorHandle> {
2316        Err(anyhow::anyhow!("diff_dim not supported by provider"))
2317    }
2318    fn gradient_dim(
2319        &self,
2320        _handle: &GpuTensorHandle,
2321        _dim: usize,
2322        _spacing: f64,
2323    ) -> anyhow::Result<GpuTensorHandle> {
2324        Err(anyhow::anyhow!("gradient_dim not supported by provider"))
2325    }
2326    /// Perform an in-place FFT along a zero-based dimension, optionally padding/truncating to `len`.
2327    fn fft_dim<'a>(
2328        &'a self,
2329        _handle: &'a GpuTensorHandle,
2330        _len: Option<usize>,
2331        _dim: usize,
2332    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2333        unsupported_future("fft_dim not supported by provider")
2334    }
2335    fn ifft_dim<'a>(
2336        &'a self,
2337        _handle: &'a GpuTensorHandle,
2338        _len: Option<usize>,
2339        _dim: usize,
2340    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2341        unsupported_future("ifft_dim not supported by provider")
2342    }
2343    fn fft_extract_real<'a>(
2344        &'a self,
2345        _handle: &'a GpuTensorHandle,
2346    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2347        unsupported_future("fft_extract_real not supported by provider")
2348    }
2349    fn unique<'a>(
2350        &'a self,
2351        _handle: &'a GpuTensorHandle,
2352        _options: &'a UniqueOptions,
2353    ) -> AccelProviderFuture<'a, UniqueResult> {
2354        Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
2355    }
2356    fn union<'a>(
2357        &'a self,
2358        _a: &'a GpuTensorHandle,
2359        _b: &'a GpuTensorHandle,
2360        _options: &'a UnionOptions,
2361    ) -> AccelProviderFuture<'a, UnionResult> {
2362        Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
2363    }
2364    fn setdiff<'a>(
2365        &'a self,
2366        _a: &'a GpuTensorHandle,
2367        _b: &'a GpuTensorHandle,
2368        _options: &'a SetdiffOptions,
2369    ) -> AccelProviderFuture<'a, SetdiffResult> {
2370        Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
2371    }
2372    fn ismember<'a>(
2373        &'a self,
2374        _a: &'a GpuTensorHandle,
2375        _b: &'a GpuTensorHandle,
2376        _options: &'a IsMemberOptions,
2377    ) -> AccelProviderFuture<'a, IsMemberResult> {
2378        Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
2379    }
2380    fn reshape(
2381        &self,
2382        handle: &GpuTensorHandle,
2383        new_shape: &[usize],
2384    ) -> anyhow::Result<GpuTensorHandle> {
2385        let mut updated = handle.clone();
2386        updated.shape = new_shape.to_vec();
2387        Ok(updated)
2388    }
2389    /// Concatenate the provided tensors along the 1-based dimension `dim`.
2390    fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
2391        Err(anyhow::anyhow!("cat not supported by provider"))
2392    }
2393    fn repmat(
2394        &self,
2395        _handle: &GpuTensorHandle,
2396        _reps: &[usize],
2397    ) -> anyhow::Result<GpuTensorHandle> {
2398        Err(anyhow::anyhow!("repmat not supported by provider"))
2399    }
2400    /// Compute the Kronecker product of two tensors, matching MATLAB semantics.
2401    fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2402        Err(anyhow::anyhow!("kron not supported by provider"))
2403    }
2404    /// Compute the cross product of 3-element vectors along a matching dimension.
2405    fn cross(
2406        &self,
2407        _lhs: &GpuTensorHandle,
2408        _rhs: &GpuTensorHandle,
2409        _dim: Option<usize>,
2410    ) -> anyhow::Result<GpuTensorHandle> {
2411        Err(anyhow::anyhow!("cross not supported by provider"))
2412    }
2413    fn reduce_sum<'a>(
2414        &'a self,
2415        _a: &'a GpuTensorHandle,
2416    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2417        unsupported_future("reduce_sum not supported by provider")
2418    }
2419    fn reduce_sum_dim<'a>(
2420        &'a self,
2421        _a: &'a GpuTensorHandle,
2422        _dim: usize,
2423    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2424        unsupported_future("reduce_sum_dim not supported by provider")
2425    }
2426    fn dot<'a>(
2427        &'a self,
2428        _lhs: &'a GpuTensorHandle,
2429        _rhs: &'a GpuTensorHandle,
2430        _dim: Option<usize>,
2431    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2432        unsupported_future("dot not supported by provider")
2433    }
2434    fn reduce_nnz<'a>(
2435        &'a self,
2436        _a: &'a GpuTensorHandle,
2437    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2438        unsupported_future("reduce_nnz not supported by provider")
2439    }
2440    fn reduce_nnz_dim<'a>(
2441        &'a self,
2442        _a: &'a GpuTensorHandle,
2443        _dim: usize,
2444    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2445        unsupported_future("reduce_nnz_dim not supported by provider")
2446    }
2447    fn reduce_prod<'a>(
2448        &'a self,
2449        _a: &'a GpuTensorHandle,
2450    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2451        unsupported_future("reduce_prod not supported by provider")
2452    }
2453    fn reduce_prod_dim<'a>(
2454        &'a self,
2455        _a: &'a GpuTensorHandle,
2456        _dim: usize,
2457    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2458        unsupported_future("reduce_prod_dim not supported by provider")
2459    }
2460    fn reduce_mean<'a>(
2461        &'a self,
2462        _a: &'a GpuTensorHandle,
2463    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2464        unsupported_future("reduce_mean not supported by provider")
2465    }
2466    /// Reduce mean across multiple zero-based dimensions in one device pass.
2467    fn reduce_mean_nd<'a>(
2468        &'a self,
2469        _a: &'a GpuTensorHandle,
2470        _dims_zero_based: &'a [usize],
2471    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2472        unsupported_future("reduce_mean_nd not supported by provider")
2473    }
2474    /// Reduce moments across multiple zero-based dimensions in one device pass.
2475    /// Returns mean (E[x]) and mean of squares (E[x^2]).
2476    fn reduce_moments_nd<'a>(
2477        &'a self,
2478        _a: &'a GpuTensorHandle,
2479        _dims_zero_based: &'a [usize],
2480    ) -> AccelProviderFuture<'a, ProviderMoments2> {
2481        unsupported_future("reduce_moments_nd not supported by provider")
2482    }
2483    fn reduce_mean_dim<'a>(
2484        &'a self,
2485        _a: &'a GpuTensorHandle,
2486        _dim: usize,
2487    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2488        unsupported_future("reduce_mean_dim not supported by provider")
2489    }
2490    fn reduce_std<'a>(
2491        &'a self,
2492        _a: &'a GpuTensorHandle,
2493        _normalization: ProviderStdNormalization,
2494        _nan_mode: ProviderNanMode,
2495    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2496        unsupported_future("reduce_std not supported by provider")
2497    }
2498    fn reduce_std_dim<'a>(
2499        &'a self,
2500        _a: &'a GpuTensorHandle,
2501        _dim: usize,
2502        _normalization: ProviderStdNormalization,
2503        _nan_mode: ProviderNanMode,
2504    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2505        unsupported_future("reduce_std_dim not supported by provider")
2506    }
2507    fn reduce_any<'a>(
2508        &'a self,
2509        _a: &'a GpuTensorHandle,
2510        _omit_nan: bool,
2511    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2512        unsupported_future("reduce_any not supported by provider")
2513    }
2514    fn reduce_any_dim<'a>(
2515        &'a self,
2516        _a: &'a GpuTensorHandle,
2517        _dim: usize,
2518        _omit_nan: bool,
2519    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2520        unsupported_future("reduce_any_dim not supported by provider")
2521    }
2522    fn reduce_all<'a>(
2523        &'a self,
2524        _a: &'a GpuTensorHandle,
2525        _omit_nan: bool,
2526    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2527        unsupported_future("reduce_all not supported by provider")
2528    }
2529    fn reduce_all_dim<'a>(
2530        &'a self,
2531        _a: &'a GpuTensorHandle,
2532        _dim: usize,
2533        _omit_nan: bool,
2534    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2535        unsupported_future("reduce_all_dim not supported by provider")
2536    }
2537    fn reduce_median<'a>(
2538        &'a self,
2539        _a: &'a GpuTensorHandle,
2540    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2541        unsupported_future("reduce_median not supported by provider")
2542    }
2543    fn reduce_median_dim<'a>(
2544        &'a self,
2545        _a: &'a GpuTensorHandle,
2546        _dim: usize,
2547    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2548        unsupported_future("reduce_median_dim not supported by provider")
2549    }
2550    fn reduce_min<'a>(
2551        &'a self,
2552        _a: &'a GpuTensorHandle,
2553    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2554        unsupported_future("reduce_min not supported by provider")
2555    }
2556    fn reduce_min_dim<'a>(
2557        &'a self,
2558        _a: &'a GpuTensorHandle,
2559        _dim: usize,
2560    ) -> AccelProviderFuture<'a, ReduceDimResult> {
2561        unsupported_future("reduce_min_dim not supported by provider")
2562    }
2563    fn reduce_max<'a>(
2564        &'a self,
2565        _a: &'a GpuTensorHandle,
2566    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2567        unsupported_future("reduce_max not supported by provider")
2568    }
2569    fn reduce_max_dim<'a>(
2570        &'a self,
2571        _a: &'a GpuTensorHandle,
2572        _dim: usize,
2573    ) -> AccelProviderFuture<'a, ReduceDimResult> {
2574        unsupported_future("reduce_max_dim not supported by provider")
2575    }
2576    fn cumsum_scan(
2577        &self,
2578        _input: &GpuTensorHandle,
2579        _dim: usize,
2580        _direction: ProviderScanDirection,
2581        _nan_mode: ProviderNanMode,
2582    ) -> anyhow::Result<GpuTensorHandle> {
2583        Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2584    }
2585    fn cumprod_scan(
2586        &self,
2587        _input: &GpuTensorHandle,
2588        _dim: usize,
2589        _direction: ProviderScanDirection,
2590        _nan_mode: ProviderNanMode,
2591    ) -> anyhow::Result<GpuTensorHandle> {
2592        Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2593    }
2594    fn cummin_scan(
2595        &self,
2596        _input: &GpuTensorHandle,
2597        _dim: usize,
2598        _direction: ProviderScanDirection,
2599        _nan_mode: ProviderNanMode,
2600    ) -> anyhow::Result<ProviderCumminResult> {
2601        Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2602    }
2603    fn cummax_scan(
2604        &self,
2605        _input: &GpuTensorHandle,
2606        _dim: usize,
2607        _direction: ProviderScanDirection,
2608        _nan_mode: ProviderNanMode,
2609    ) -> anyhow::Result<ProviderCummaxResult> {
2610        Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2611    }
2612
2613    fn find(
2614        &self,
2615        _a: &GpuTensorHandle,
2616        _limit: Option<usize>,
2617        _direction: FindDirection,
2618    ) -> anyhow::Result<ProviderFindResult> {
2619        Err(anyhow::anyhow!("find not supported by provider"))
2620    }
2621
2622    fn fused_elementwise(
2623        &self,
2624        _shader: &str,
2625        _inputs: &[GpuTensorHandle],
2626        _output_shape: &[usize],
2627        _len: usize,
2628    ) -> anyhow::Result<GpuTensorHandle> {
2629        Err(anyhow::anyhow!(
2630            "fused_elementwise not supported by provider"
2631        ))
2632    }
2633
2634    /// Execute a single fused elementwise kernel that writes `num_outputs` output buffers in one
2635    /// dispatch. The shader is expected to declare `output0`, `output1`, … `output{N-1}` storage
2636    /// bindings (at binding indices `inputs.len()` through `inputs.len() + num_outputs - 1`) and a
2637    /// uniform `params` binding at `inputs.len() + num_outputs`.
2638    ///
2639    /// Providers that do not override this method fall back to calling `fused_elementwise` once
2640    /// per output, which preserves correctness at the cost of the O(N²) dispatch overhead this
2641    /// method is designed to eliminate.
2642    fn fused_elementwise_multi(
2643        &self,
2644        _shader: &str,
2645        _inputs: &[GpuTensorHandle],
2646        _output_shape: &[usize],
2647        _len: usize,
2648        _num_outputs: usize,
2649    ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2650        Err(anyhow::anyhow!(
2651            "fused_elementwise_multi not supported by provider"
2652        ))
2653    }
2654
2655    /// Build a numeric tensor where NaNs in `a` are replaced with 0.0 (device side).
2656    fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2657        Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2658    }
2659
2660    /// Build a numeric mask tensor with 1.0 where value is not NaN and 0.0 where value is NaN.
2661    fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2662        Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2663    }
2664
2665    /// Generic fused reduction entrypoint.
2666    ///
2667    /// The shader is expected to implement a column-major reduction across `reduce_len` with
2668    /// `num_slices` independent slices (e.g., columns). Providers should create a uniform buffer
2669    /// compatible with the expected `Params/MParams` struct in the shader and dispatch
2670    /// `num_slices` workgroups with `workgroup_size` threads, or an equivalent strategy.
2671    #[allow(clippy::too_many_arguments)]
2672    fn fused_reduction(
2673        &self,
2674        _shader: &str,
2675        _inputs: &[GpuTensorHandle],
2676        _output_shape: &[usize],
2677        _reduce_len: usize,
2678        _num_slices: usize,
2679        _workgroup_size: u32,
2680        _flavor: ReductionFlavor,
2681    ) -> anyhow::Result<GpuTensorHandle> {
2682        Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2683    }
2684
2685    /// Optionally pre-compile commonly used pipelines to amortize first-dispatch costs.
2686    fn warmup(&self) {}
2687
2688    /// Returns (cache_hits, cache_misses) for fused pipeline cache, if supported.
2689    fn fused_cache_counters(&self) -> (u64, u64) {
2690        (0, 0)
2691    }
2692
2693    /// Returns the duration of the last provider warmup in milliseconds, if known.
2694    fn last_warmup_millis(&self) -> Option<u64> {
2695        None
2696    }
2697
2698    /// Returns a snapshot of provider telemetry counters if supported.
2699    fn telemetry_snapshot(&self) -> ProviderTelemetry {
2700        let (hits, misses) = self.fused_cache_counters();
2701        ProviderTelemetry {
2702            fused_elementwise: ProviderDispatchStats::default(),
2703            fused_reduction: ProviderDispatchStats::default(),
2704            matmul: ProviderDispatchStats::default(),
2705            linsolve: ProviderDispatchStats::default(),
2706            mldivide: ProviderDispatchStats::default(),
2707            mrdivide: ProviderDispatchStats::default(),
2708            upload_bytes: 0,
2709            download_bytes: 0,
2710            solve_fallbacks: Vec::new(),
2711            fusion_cache_hits: hits,
2712            fusion_cache_misses: misses,
2713            bind_group_cache_hits: 0,
2714            bind_group_cache_misses: 0,
2715            bind_group_cache_by_layout: None,
2716            kernel_launches: Vec::new(),
2717        }
2718    }
2719
2720    /// Reset all telemetry counters maintained by the provider, if supported.
2721    fn reset_telemetry(&self) {}
2722
2723    /// Default reduction workgroup size the provider prefers.
2724    fn default_reduction_workgroup_size(&self) -> u32 {
2725        256
2726    }
2727
2728    /// Threshold above which provider will prefer two-pass reduction.
2729    fn two_pass_threshold(&self) -> usize {
2730        1024
2731    }
2732
2733    /// Current two-pass mode preference (auto/forced on/off).
2734    fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2735        ReductionTwoPassMode::Auto
2736    }
2737
2738    /// Fast-path: write a GPU column in a matrix from a GPU vector, returning a new handle.
2739    /// Expected: `values.shape == [rows, 1]` (or `[rows]`) and `col_index < cols`.
2740    fn scatter_column(
2741        &self,
2742        _matrix: &GpuTensorHandle,
2743        _col_index: usize,
2744        _values: &GpuTensorHandle,
2745    ) -> anyhow::Result<GpuTensorHandle> {
2746        Err(anyhow::anyhow!("scatter_column not supported by provider"))
2747    }
2748
2749    /// Fast-path: write a GPU row in a matrix from a GPU vector, returning a new handle.
2750    /// Expected: `values.shape == [1, cols]` (or `[cols]`) and `row_index < rows`.
2751    fn scatter_row(
2752        &self,
2753        _matrix: &GpuTensorHandle,
2754        _row_index: usize,
2755        _values: &GpuTensorHandle,
2756    ) -> anyhow::Result<GpuTensorHandle> {
2757        Err(anyhow::anyhow!("scatter_row not supported by provider"))
2758    }
2759
2760    fn sub2ind(
2761        &self,
2762        _dims: &[usize],
2763        _strides: &[usize],
2764        _inputs: &[&GpuTensorHandle],
2765        _scalar_mask: &[bool],
2766        _len: usize,
2767        _output_shape: &[usize],
2768    ) -> anyhow::Result<GpuTensorHandle> {
2769        Err(anyhow::anyhow!("sub2ind not supported by provider"))
2770    }
2771
2772    /// Returns true if the provider offers a device-side `ind2sub` implementation.
2773    fn supports_ind2sub(&self) -> bool {
2774        false
2775    }
2776
2777    /// Convert linear indices into per-dimension subscripts on the device.
2778    fn ind2sub(
2779        &self,
2780        _dims: &[usize],
2781        _strides: &[usize],
2782        _indices: &GpuTensorHandle,
2783        _total: usize,
2784        _len: usize,
2785        _output_shape: &[usize],
2786    ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2787        Err(anyhow::anyhow!("ind2sub not supported by provider"))
2788    }
2789
2790    /// Determine if a matrix is symmetric (or skew-symmetric) without gathering it to the host.
2791    fn issymmetric(
2792        &self,
2793        _matrix: &GpuTensorHandle,
2794        _kind: ProviderSymmetryKind,
2795        _tolerance: f64,
2796    ) -> anyhow::Result<bool> {
2797        Err(anyhow::anyhow!(
2798            "issymmetric predicate not supported by provider"
2799        ))
2800    }
2801
2802    /// Determine if a matrix is Hermitian (or skew-Hermitian) without gathering it to the host.
2803    fn ishermitian<'a>(
2804        &'a self,
2805        _matrix: &'a GpuTensorHandle,
2806        _kind: ProviderHermitianKind,
2807        _tolerance: f64,
2808    ) -> AccelProviderFuture<'a, bool> {
2809        Box::pin(async move {
2810            Err(anyhow::anyhow!(
2811                "ishermitian predicate not supported by provider"
2812            ))
2813        })
2814    }
2815
2816    /// Inspect the bandwidth of a matrix without gathering it back to the host.
2817    fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2818        Err(anyhow::anyhow!("bandwidth not supported by provider"))
2819    }
2820
2821    /// Compute the symmetric reverse Cuthill-McKee permutation for the matrix.
2822    ///
2823    /// Implementations may execute on the device or gather to the host. The permutation should be
2824    /// returned as zero-based indices.
2825    fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2826        Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2827    }
2828}
2829
2830static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2831    Lazy::new(|| RwLock::new(None));
2832static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2833    Lazy::new(|| RwLock::new(HashMap::new()));
2834static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2835
2836#[cfg(not(target_arch = "wasm32"))]
2837thread_local! {
2838    static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2839}
2840
2841#[cfg(target_arch = "wasm32")]
2842static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2843    Lazy::new(|| Mutex::new(None));
2844
2845#[cfg(not(target_arch = "wasm32"))]
2846fn replace_thread_provider(
2847    provider: Option<&'static dyn AccelProvider>,
2848) -> Option<&'static dyn AccelProvider> {
2849    THREAD_PROVIDER.with(|cell| {
2850        let prev = cell.get();
2851        cell.set(provider);
2852        prev
2853    })
2854}
2855
2856#[cfg(target_arch = "wasm32")]
2857fn replace_thread_provider(
2858    provider: Option<&'static dyn AccelProvider>,
2859) -> Option<&'static dyn AccelProvider> {
2860    let mut slot = WASM_THREAD_PROVIDER
2861        .lock()
2862        .expect("wasm provider mutex poisoned");
2863    let prev = *slot;
2864    *slot = provider;
2865    prev
2866}
2867
2868#[cfg(not(target_arch = "wasm32"))]
2869fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2870    THREAD_PROVIDER.with(|cell| cell.get())
2871}
2872
2873#[cfg(target_arch = "wasm32")]
2874fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2875    WASM_THREAD_PROVIDER
2876        .lock()
2877        .expect("wasm provider mutex poisoned")
2878        .as_ref()
2879        .copied()
2880}
2881
2882/// Register a global acceleration provider.
2883///
2884/// # Safety
2885/// - The caller must guarantee that `p` is valid for the entire program lifetime
2886///   (e.g., a `'static` singleton), as the runtime stores a raw reference globally.
2887/// - Concurrent callers must ensure registration happens once or is properly
2888///   synchronized; this function does not enforce thread-safety for re-registration.
2889pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2890    if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2891        *guard = Some(p);
2892    }
2893    register_provider_for_device(p.device_id(), p);
2894}
2895
2896unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2897    if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2898        guard.insert(device_id, provider);
2899    }
2900}
2901
2902pub fn provider() -> Option<&'static dyn AccelProvider> {
2903    if let Some(p) = current_thread_provider() {
2904        return Some(p);
2905    }
2906    GLOBAL_PROVIDER
2907        .read()
2908        .ok()
2909        .and_then(|guard| guard.as_ref().copied())
2910}
2911
2912/// Clear the globally registered provider. Intended for tests to ensure deterministic behaviour.
2913pub fn clear_provider() {
2914    replace_thread_provider(None);
2915    if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2916        *guard = None;
2917    }
2918    if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2919        map.clear();
2920    }
2921}
2922
2923pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2924    if let Some(registered) = PROVIDER_REGISTRY
2925        .read()
2926        .ok()
2927        .and_then(|guard| guard.get(&device_id).copied())
2928    {
2929        return Some(registered);
2930    }
2931    if let Some(thread_provider) = current_thread_provider() {
2932        if thread_provider.device_id() == device_id {
2933            return Some(thread_provider);
2934        }
2935    }
2936    // Preserve legacy behavior: when no explicit per-device registration exists,
2937    // fall back to the globally active provider regardless of handle device id.
2938    GLOBAL_PROVIDER
2939        .read()
2940        .ok()
2941        .and_then(|guard| guard.as_ref().copied())
2942}
2943
2944pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2945    provider_for_device(handle.device_id)
2946}
2947
2948pub fn spawn_handle_concurrency_for(handle: &GpuTensorHandle) -> Option<SpawnHandleConcurrency> {
2949    provider_for_handle(handle).map(AccelProvider::spawn_handle_concurrency)
2950}
2951
2952pub fn next_device_id() -> u32 {
2953    DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2954}
2955
2956pub struct ThreadProviderGuard {
2957    prev: Option<&'static dyn AccelProvider>,
2958}
2959
2960impl ThreadProviderGuard {
2961    pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2962        let prev = replace_thread_provider(provider);
2963        ThreadProviderGuard { prev }
2964    }
2965}
2966
2967impl Drop for ThreadProviderGuard {
2968    fn drop(&mut self) {
2969        let prev = self.prev.take();
2970        replace_thread_provider(prev);
2971    }
2972}
2973
2974pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2975    replace_thread_provider(provider);
2976}
2977
2978/// Convenience: perform elementwise add via provider if possible; otherwise return None
2979pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2980    if let Some(p) = provider() {
2981        if let Ok(h) = p.elem_add(a, b).await {
2982            return Some(h);
2983        }
2984    }
2985    None
2986}
2987
2988/// Convenience: perform elementwise hypot via provider if possible; otherwise return None
2989pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2990    if let Some(p) = provider() {
2991        if let Ok(h) = p.elem_hypot(a, b).await {
2992            return Some(h);
2993        }
2994    }
2995    None
2996}
2997
2998/// Convenience: perform elementwise max via provider if possible; otherwise return None
2999pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
3000    if let Some(p) = provider() {
3001        if let Ok(h) = p.elem_max(a, b).await {
3002            return Some(h);
3003        }
3004    }
3005    None
3006}
3007
3008/// Convenience: perform elementwise min via provider if possible; otherwise return None
3009pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
3010    if let Some(p) = provider() {
3011        if let Ok(h) = p.elem_min(a, b).await {
3012            return Some(h);
3013        }
3014    }
3015    None
3016}
3017
3018/// Convenience: perform elementwise atan2 via provider if possible; otherwise return None
3019pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
3020    if let Some(p) = provider() {
3021        if let Ok(h) = p.elem_atan2(y, x).await {
3022            return Some(h);
3023        }
3024    }
3025    None
3026}
3027
3028// Minimal host tensor views to avoid depending on runmat-builtins and cycles
3029#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3030pub struct HostTensorOwned {
3031    pub data: Vec<f64>,
3032    pub shape: Vec<usize>,
3033    pub storage: GpuTensorStorage,
3034}
3035
3036#[derive(Debug)]
3037pub struct HostTensorView<'a> {
3038    pub data: &'a [f64],
3039    pub shape: &'a [usize],
3040}
3041
3042/// Lightweight 1-D axis view used by provider meshgrid hooks.
3043#[derive(Debug)]
3044pub struct MeshgridAxisView<'a> {
3045    pub data: &'a [f64],
3046}
3047
3048/// Provider-side meshgrid result containing coordinate tensor handles.
3049#[derive(Debug, Clone)]
3050pub struct ProviderMeshgridResult {
3051    pub outputs: Vec<GpuTensorHandle>,
3052}
3053
3054/// Descriptor for GEMM epilogues applied to `C = A * B` before storing to `C`.
3055///
3056/// Supported operations:
3057/// - Scale by `alpha` and add scalar `beta`.
3058/// - Multiply output by per-row and/or per-column scale vectors (broadcasted).
3059#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
3060pub enum ScaleOp {
3061    Multiply,
3062    Divide,
3063}
3064
3065#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3066pub struct MatmulEpilogue {
3067    /// Scalar multiply applied to each output element.
3068    pub alpha: f64,
3069    /// Scalar add applied to each output element after scaling.
3070    pub beta: f64,
3071    /// Optional per-row scale (length m). When present, output[row, col] *= row_scale[row].
3072    pub row_scale: Option<GpuTensorHandle>,
3073    /// Optional per-column scale (length n). When present, output[row, col] *= col_scale[col].
3074    pub col_scale: Option<GpuTensorHandle>,
3075    /// Row scale operation (multiply or divide). Ignored when `row_scale` is None.
3076    pub row_op: ScaleOp,
3077    /// Column scale operation (multiply or divide). Ignored when `col_scale` is None.
3078    pub col_op: ScaleOp,
3079    /// Optional lower clamp bound applied after scale/bias.
3080    #[serde(default)]
3081    pub clamp_min: Option<f64>,
3082    /// Optional upper clamp bound applied after scale/bias.
3083    #[serde(default)]
3084    pub clamp_max: Option<f64>,
3085    /// Optional power exponent applied after clamp (final operation in the epilogue).
3086    #[serde(default)]
3087    pub pow_exponent: Option<f64>,
3088    /// Optional output buffer for the diagonal of the result (length min(m, n)).
3089    #[serde(default)]
3090    pub diag_output: Option<GpuTensorHandle>,
3091}
3092
3093impl MatmulEpilogue {
3094    pub fn noop() -> Self {
3095        Self {
3096            alpha: 1.0,
3097            beta: 0.0,
3098            row_scale: None,
3099            col_scale: None,
3100            row_op: ScaleOp::Multiply,
3101            col_op: ScaleOp::Multiply,
3102            clamp_min: None,
3103            clamp_max: None,
3104            pow_exponent: None,
3105            diag_output: None,
3106        }
3107    }
3108    pub fn is_noop(&self) -> bool {
3109        self.alpha == 1.0
3110            && self.beta == 0.0
3111            && self.row_scale.is_none()
3112            && self.col_scale.is_none()
3113            && self.clamp_min.is_none()
3114            && self.clamp_max.is_none()
3115            && self.pow_exponent.is_none()
3116            && self.diag_output.is_none()
3117    }
3118}
3119
3120#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
3121pub struct PowerStepEpilogue {
3122    pub epsilon: f64,
3123}
3124
3125impl Default for PowerStepEpilogue {
3126    fn default() -> Self {
3127        Self { epsilon: 0.0 }
3128    }
3129}
3130
3131#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3132pub struct ImageNormalizeDescriptor {
3133    pub batch: usize,
3134    pub height: usize,
3135    pub width: usize,
3136    pub epsilon: f64,
3137    #[serde(default)]
3138    pub gain: Option<f64>,
3139    #[serde(default)]
3140    pub bias: Option<f64>,
3141    #[serde(default)]
3142    pub gamma: Option<f64>,
3143    #[serde(default = "default_image_normalize_clamp_zero")]
3144    pub clamp_zero: bool,
3145}
3146
3147fn default_image_normalize_clamp_zero() -> bool {
3148    true
3149}
3150
3151#[cfg(test)]
3152mod tests {
3153    use super::*;
3154
3155    struct TestProvider {
3156        device_id: u32,
3157        name: &'static str,
3158        spawn_concurrency: SpawnHandleConcurrency,
3159    }
3160
3161    impl AccelProvider for TestProvider {
3162        fn upload(&self, _host: &HostTensorView) -> anyhow::Result<GpuTensorHandle> {
3163            Err(anyhow!("test provider upload should not be called"))
3164        }
3165
3166        fn download<'a>(&'a self, _h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a> {
3167            unsupported_future("test provider download should not be called")
3168        }
3169
3170        fn free(&self, _h: &GpuTensorHandle) -> anyhow::Result<()> {
3171            Err(anyhow!("test provider free should not be called"))
3172        }
3173
3174        fn device_info(&self) -> String {
3175            self.name.to_string()
3176        }
3177
3178        fn device_id(&self) -> u32 {
3179            self.device_id
3180        }
3181
3182        fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
3183            self.spawn_concurrency
3184        }
3185    }
3186
3187    static PROVIDER_TEST_LOCK: Lazy<std::sync::Mutex<()>> = Lazy::new(|| std::sync::Mutex::new(()));
3188    static PROVIDER_A: TestProvider = TestProvider {
3189        device_id: 101,
3190        name: "provider-a",
3191        spawn_concurrency: SpawnHandleConcurrency::ImmutableShare,
3192    };
3193    static PROVIDER_B: TestProvider = TestProvider {
3194        device_id: 202,
3195        name: "provider-b",
3196        spawn_concurrency: SpawnHandleConcurrency::Reject,
3197    };
3198    static PROVIDER_C: TestProvider = TestProvider {
3199        device_id: 303,
3200        name: "provider-c",
3201        spawn_concurrency: SpawnHandleConcurrency::CopyOnWrite,
3202    };
3203
3204    fn register_test_providers() {
3205        clear_provider();
3206        unsafe {
3207            register_provider(&PROVIDER_A);
3208            register_provider(&PROVIDER_B);
3209        }
3210    }
3211
3212    fn test_handle(device_id: u32) -> GpuTensorHandle {
3213        GpuTensorHandle {
3214            shape: vec![1],
3215            device_id,
3216            buffer_id: 42,
3217        }
3218    }
3219
3220    fn spectral_request<'a>(
3221        input: &'a GpuTensorHandle,
3222        frame_mode: ProviderSpectralFrameMode,
3223    ) -> ProviderSpectralRequest<'a> {
3224        static WINDOW: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
3225        ProviderSpectralRequest {
3226            input,
3227            input_len: 16,
3228            input_complex: false,
3229            window: &WINDOW,
3230            nfft: 8,
3231            frame_count: 3,
3232            frame_mode,
3233            range: ProviderSpectralRange::Onesided,
3234            denominator: 1.0,
3235        }
3236    }
3237
3238    #[test]
3239    fn provider_envelope_shape_guard_rejects_equal_len_layout_spoofing() {
3240        assert!(provider_envelope_input_shape_matches(&[2, 3], 2, 3));
3241        assert!(provider_envelope_input_shape_matches(&[6, 1], 6, 1));
3242        assert!(provider_envelope_input_shape_matches(&[1, 6], 6, 1));
3243        assert!(provider_envelope_input_shape_matches(&[6], 6, 1));
3244
3245        assert!(!provider_envelope_input_shape_matches(&[3, 2], 2, 3));
3246        assert!(!provider_envelope_input_shape_matches(&[6], 2, 3));
3247        assert!(!provider_envelope_input_shape_matches(&[2, 1, 3], 2, 3));
3248    }
3249
3250    #[test]
3251    fn provider_for_device_prefers_registered_device_over_thread_provider() {
3252        let _lock = PROVIDER_TEST_LOCK
3253            .lock()
3254            .expect("provider test lock poisoned");
3255        register_test_providers();
3256        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
3257
3258        let provider = provider_for_device(PROVIDER_A.device_id()).expect("provider for device");
3259
3260        assert_eq!(provider.device_info(), PROVIDER_A.name);
3261        clear_provider();
3262    }
3263
3264    #[test]
3265    fn provider_for_handle_uses_handle_device_owner() {
3266        let _lock = PROVIDER_TEST_LOCK
3267            .lock()
3268            .expect("provider test lock poisoned");
3269        register_test_providers();
3270        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
3271
3272        let provider =
3273            provider_for_handle(&test_handle(PROVIDER_A.device_id())).expect("provider for handle");
3274
3275        assert_eq!(provider.device_info(), PROVIDER_A.name);
3276        clear_provider();
3277    }
3278
3279    #[test]
3280    fn spawn_handle_concurrency_for_uses_registered_owner() {
3281        let _lock = PROVIDER_TEST_LOCK
3282            .lock()
3283            .expect("provider test lock poisoned");
3284        register_test_providers();
3285        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
3286
3287        let concurrency = spawn_handle_concurrency_for(&test_handle(PROVIDER_A.device_id()))
3288            .expect("spawn concurrency");
3289
3290        assert_eq!(concurrency, PROVIDER_A.spawn_concurrency);
3291        clear_provider();
3292    }
3293
3294    #[test]
3295    fn provider_keeps_thread_local_active_provider_semantics() {
3296        let _lock = PROVIDER_TEST_LOCK
3297            .lock()
3298            .expect("provider test lock poisoned");
3299        register_test_providers();
3300        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_A));
3301
3302        let active = provider().expect("active provider");
3303
3304        assert_eq!(active.device_info(), PROVIDER_A.name);
3305        clear_provider();
3306    }
3307
3308    #[test]
3309    fn unregistered_thread_provider_only_matches_own_device_before_global_fallback() {
3310        let _lock = PROVIDER_TEST_LOCK
3311            .lock()
3312            .expect("provider test lock poisoned");
3313        clear_provider();
3314        unsafe {
3315            register_provider(&PROVIDER_A);
3316        }
3317        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_C));
3318
3319        let own_device = provider_for_device(PROVIDER_C.device_id()).expect("own provider");
3320        let fallback = provider_for_device(404).expect("global fallback provider");
3321
3322        assert_eq!(own_device.device_info(), PROVIDER_C.name);
3323        assert_eq!(fallback.device_info(), PROVIDER_A.name);
3324        clear_provider();
3325    }
3326
3327    #[test]
3328    fn uniform_spectral_request_validates_sliding_input_coverage() {
3329        let input = test_handle(PROVIDER_A.device_id());
3330        let mut request = spectral_request(&input, ProviderSpectralFrameMode::Sliding { hop: 6 });
3331        assert!(validate_uniform_spectral_request(&request).is_ok());
3332
3333        request.input_len = 15;
3334        assert!(validate_uniform_spectral_request(&request).is_err());
3335    }
3336
3337    #[test]
3338    fn uniform_spectral_request_rejects_sliding_coverage_overflow() {
3339        let input = test_handle(PROVIDER_A.device_id());
3340        let mut request = spectral_request(&input, ProviderSpectralFrameMode::Sliding { hop: 2 });
3341        request.frame_count = usize::MAX;
3342
3343        assert!(validate_uniform_spectral_request(&request).is_err());
3344    }
3345
3346    #[test]
3347    fn uniform_spectral_request_validates_folded_input_coverage() {
3348        let input = test_handle(PROVIDER_A.device_id());
3349        let mut request = spectral_request(
3350            &input,
3351            ProviderSpectralFrameMode::FoldedColumns { input_rows: 5 },
3352        );
3353        assert!(validate_uniform_spectral_request(&request).is_ok());
3354
3355        request.frame_mode = ProviderSpectralFrameMode::FoldedColumns { input_rows: 0 };
3356        assert!(validate_uniform_spectral_request(&request).is_err());
3357
3358        request.frame_mode = ProviderSpectralFrameMode::FoldedColumns { input_rows: 6 };
3359        assert!(validate_uniform_spectral_request(&request).is_err());
3360    }
3361
3362    #[test]
3363    fn uniform_spectral_request_rejects_folded_coverage_overflow() {
3364        let input = test_handle(PROVIDER_A.device_id());
3365        let request = spectral_request(
3366            &input,
3367            ProviderSpectralFrameMode::FoldedColumns {
3368                input_rows: usize::MAX,
3369            },
3370        );
3371
3372        assert!(validate_uniform_spectral_request(&request).is_err());
3373    }
3374
3375    #[test]
3376    fn image_normalize_descriptor_omitted_clamp_zero_defaults_true() {
3377        let payload = r#"{
3378            "batch": 2,
3379            "height": 4,
3380            "width": 5,
3381            "epsilon": 0.000001
3382        }"#;
3383
3384        let desc: ImageNormalizeDescriptor =
3385            serde_json::from_str(payload).expect("deserialize descriptor");
3386
3387        assert!(
3388            desc.clamp_zero,
3389            "legacy serialized descriptors should default to clamped image normalize"
3390        );
3391    }
3392
3393    #[test]
3394    fn image_normalize_descriptor_explicit_false_preserves_unclamped() {
3395        let payload = r#"{
3396            "batch": 2,
3397            "height": 4,
3398            "width": 5,
3399            "epsilon": 0.000001,
3400            "clamp_zero": false
3401        }"#;
3402
3403        let desc: ImageNormalizeDescriptor =
3404            serde_json::from_str(payload).expect("deserialize descriptor");
3405
3406        assert!(
3407            !desc.clamp_zero,
3408            "explicit clamp_zero=false should preserve unclamped semantics"
3409        );
3410    }
3411}