Skip to main content

runmat_accelerate_api/
lib.rs

1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyMarkFn = fn(&GpuTensorHandle);
17type ResidencyClearFn = fn(&GpuTensorHandle);
18type SequenceThresholdFn = fn() -> Option<usize>;
19type WorkgroupSizeHintFn = fn() -> Option<u32>;
20
21static RESIDENCY_MARK: OnceCell<ResidencyMarkFn> = OnceCell::new();
22static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
23static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
24static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
25
26static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
27static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
28    Lazy::new(|| RwLock::new(HashMap::new()));
29static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
30    Lazy::new(|| RwLock::new(HashMap::new()));
31
32static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
33    Lazy::new(|| RwLock::new(HashMap::new()));
34static HANDLE_STORAGES: Lazy<RwLock<HashMap<u64, GpuTensorStorage>>> =
35    Lazy::new(|| RwLock::new(HashMap::new()));
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct TransposeInfo {
39    pub base_rows: usize,
40    pub base_cols: usize,
41}
42
43/// Register a callback used to mark residency tracking when GPU tensors are
44/// created or returned by device-side execution paths.
45pub fn register_residency_mark(handler: ResidencyMarkFn) {
46    let _ = RESIDENCY_MARK.set(handler);
47}
48
49/// Mark residency metadata for the provided GPU tensor handle, if a backend
50/// has registered a handler via [`register_residency_mark`].
51pub fn mark_residency(handle: &GpuTensorHandle) {
52    if let Some(handler) = RESIDENCY_MARK.get() {
53        handler(handle);
54    }
55}
56
57/// Register a callback used to clear residency tracking when GPU tensors are
58/// gathered back to the host. Backends that maintain residency metadata should
59/// install this hook during initialization.
60pub fn register_residency_clear(handler: ResidencyClearFn) {
61    let _ = RESIDENCY_CLEAR.set(handler);
62}
63
64/// Clear residency metadata for the provided GPU tensor handle, if a backend
65/// has registered a handler via [`register_residency_clear`].
66pub fn clear_residency(handle: &GpuTensorHandle) {
67    if let Some(handler) = RESIDENCY_CLEAR.get() {
68        handler(handle);
69    }
70}
71
72/// Register a callback that exposes the current sequence length threshold
73/// derived from the auto-offload planner. Array constructors can use this hint
74/// to decide when to prefer GPU residency automatically.
75pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
76    let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
77}
78
79/// Query the currently registered sequence threshold hint, if any.
80pub fn sequence_threshold_hint() -> Option<usize> {
81    SEQUENCE_THRESHOLD_PROVIDER
82        .get()
83        .and_then(|provider| provider())
84}
85
86/// Register a callback that reports the calibrated workgroup size selected by
87/// the active acceleration provider (if any). Plotting kernels can reuse this
88/// hint to match backend tuning.
89pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
90    let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
91}
92
93/// Query the current workgroup size hint exposed by the provider.
94pub fn workgroup_size_hint() -> Option<u32> {
95    WORKGROUP_SIZE_HINT_PROVIDER
96        .get()
97        .and_then(|provider| provider())
98}
99
100/// Export a shared acceleration context (e.g., the active WGPU device) when the
101/// current provider exposes one.
102pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
103    provider().and_then(|p| p.export_context(kind))
104}
105
106/// Request a provider-owned WGPU buffer for zero-copy consumers. Returns `None`
107/// when the active provider does not expose buffers or does not support the
108/// supplied handle.
109#[cfg(feature = "wgpu")]
110pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
111    provider().and_then(|p| p.export_wgpu_buffer(handle))
112}
113
114/// Record the precision associated with a GPU tensor handle so host operations can
115/// reconstruct the original dtype when gathering back to the CPU.
116pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
117    if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
118        guard.insert(handle.buffer_id, precision);
119    }
120}
121
122/// Look up the recorded precision for a GPU tensor handle, if any.
123pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
124    HANDLE_PRECISIONS
125        .read()
126        .ok()
127        .and_then(|guard| guard.get(&handle.buffer_id).copied())
128}
129
130/// Clear any recorded precision metadata for a GPU tensor handle.
131pub fn clear_handle_precision(handle: &GpuTensorHandle) {
132    if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
133        guard.remove(&handle.buffer_id);
134    }
135}
136
137/// Annotate a GPU tensor handle as logically-typed (`logical` in MATLAB terms)
138/// or clear the logical flag when `logical` is `false`.
139pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
140    if let Ok(mut guard) = LOGICAL_HANDLES.write() {
141        if logical {
142            guard.insert(handle.buffer_id);
143            if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
144                *hits.entry(handle.buffer_id).or_insert(0) += 1;
145            }
146        } else {
147            guard.remove(&handle.buffer_id);
148            if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
149                hits.remove(&handle.buffer_id);
150            }
151        }
152    }
153}
154
155/// Convenience helper for clearing logical annotations explicitly.
156pub fn clear_handle_logical(handle: &GpuTensorHandle) {
157    set_handle_logical(handle, false);
158}
159
160/// Returns true when the supplied handle has been marked as logical.
161pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
162    LOGICAL_HANDLES
163        .read()
164        .map(|guard| guard.contains(&handle.buffer_id))
165        .unwrap_or(false)
166}
167
168pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
169    LOGICAL_HANDLE_HITS
170        .read()
171        .ok()
172        .and_then(|guard| guard.get(&buffer_id).copied())
173}
174
175pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
176    if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
177        guard.insert(
178            handle.buffer_id,
179            TransposeInfo {
180                base_rows,
181                base_cols,
182            },
183        );
184    }
185}
186
187pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
188    if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
189        guard.remove(&handle.buffer_id);
190    }
191}
192
193pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
194    TRANSPOSED_HANDLES
195        .read()
196        .ok()
197        .and_then(|guard| guard.get(&handle.buffer_id).copied())
198}
199
200pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
201    handle_transpose_info(handle).is_some()
202}
203
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum GpuTensorStorage {
206    Real,
207    ComplexInterleaved,
208}
209
210impl Default for GpuTensorStorage {
211    fn default() -> Self {
212        Self::Real
213    }
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct GpuTensorHandle {
218    pub shape: Vec<usize>,
219    pub device_id: u32,
220    pub buffer_id: u64,
221}
222
223#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
224pub enum ProviderSpectralRange {
225    Onesided,
226    Twosided,
227    Centered,
228}
229
230#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
231pub enum ProviderSpectralFrameMode {
232    Sliding {
233        hop: usize,
234    },
235    ColumnSliding {
236        hop: usize,
237        input_rows: usize,
238        frames_per_column: usize,
239    },
240    FoldedColumns {
241        input_rows: usize,
242    },
243}
244
245#[derive(Clone, Debug)]
246pub struct ProviderSpectralRequest<'a> {
247    pub input: &'a GpuTensorHandle,
248    pub input_len: usize,
249    pub input_complex: bool,
250    pub window: &'a [f64],
251    pub nfft: usize,
252    pub frame_count: usize,
253    pub frame_mode: ProviderSpectralFrameMode,
254    pub range: ProviderSpectralRange,
255    pub denominator: f64,
256}
257
258#[derive(Clone, Debug)]
259pub struct ProviderSpectralResult {
260    pub s: GpuTensorHandle,
261    pub ps: GpuTensorHandle,
262    pub rows: usize,
263    pub cols: usize,
264}
265
266#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
267pub enum ProviderEnvelopeMethod {
268    Analytic,
269    AnalyticFir { filter_len: usize },
270    Rms { window_len: usize },
271}
272
273#[derive(Clone, Debug)]
274pub struct ProviderEnvelopeRequest<'a> {
275    pub input: &'a GpuTensorHandle,
276    pub channel_len: usize,
277    pub channel_count: usize,
278    pub output_shape: &'a [usize],
279    pub method: ProviderEnvelopeMethod,
280}
281
282#[derive(Clone, Debug)]
283pub struct ProviderEnvelopeResult {
284    pub upper: GpuTensorHandle,
285    pub lower: GpuTensorHandle,
286}
287
288#[derive(Clone, Debug)]
289pub struct ProviderHilbertRequest<'a> {
290    pub input: &'a GpuTensorHandle,
291    /// Optional FFT length along `dim`.
292    pub length: Option<usize>,
293    /// Zero-based transform dimension.
294    pub dim: usize,
295}
296
297#[derive(Clone, Debug)]
298pub struct ProviderModulationRequest<'a> {
299    pub input: &'a GpuTensorHandle,
300    /// `(real, imag)` pairs interleaved by symbol index.
301    pub constellation: &'a [f64],
302}
303
304#[derive(Clone, Debug)]
305pub struct ProviderBitModulationRequest<'a> {
306    pub input: &'a GpuTensorHandle,
307    /// Number of bit rows in the input before grouping.
308    pub input_rows: usize,
309    /// Number of input bits that form one output symbol.
310    pub bits_per_symbol: usize,
311    /// `(real, imag)` pairs interleaved by symbol index.
312    pub constellation: &'a [f64],
313}
314
315pub async fn uniform_spectral_estimate(
316    request: ProviderSpectralRequest<'_>,
317) -> anyhow::Result<ProviderSpectralResult> {
318    validate_uniform_spectral_request(&request)?;
319
320    let provider =
321        provider().ok_or_else(|| anyhow!("uniform_spectral_estimate: GPU provider unavailable"))?;
322    provider.uniform_spectral_estimate(&request).await
323}
324
325fn validate_uniform_spectral_request(request: &ProviderSpectralRequest<'_>) -> anyhow::Result<()> {
326    let invalid_frame_mode = matches!(
327        request.frame_mode,
328        ProviderSpectralFrameMode::Sliding { hop: 0 }
329            | ProviderSpectralFrameMode::ColumnSliding { hop: 0, .. }
330    );
331    let invalid_input_coverage = match request.frame_mode {
332        ProviderSpectralFrameMode::Sliding { hop } => {
333            let required = request
334                .frame_count
335                .checked_sub(1)
336                .and_then(|frames| frames.checked_mul(hop))
337                .and_then(|offset| offset.checked_add(request.window.len()));
338            required.is_none_or(|required| required > request.input_len)
339        }
340        ProviderSpectralFrameMode::ColumnSliding {
341            hop,
342            input_rows,
343            frames_per_column,
344        } => {
345            if input_rows == 0 || frames_per_column == 0 {
346                true
347            } else {
348                let last_frame = request.frame_count - 1;
349                let source_col = last_frame / frames_per_column;
350                let segment = last_frame % frames_per_column;
351                source_col
352                    .checked_mul(input_rows)
353                    .and_then(|base| {
354                        segment
355                            .checked_mul(hop)
356                            .and_then(|step| base.checked_add(step))
357                    })
358                    .and_then(|offset| offset.checked_add(request.window.len()))
359                    .is_none_or(|required| required > request.input_len)
360            }
361        }
362        ProviderSpectralFrameMode::FoldedColumns { input_rows } => {
363            input_rows == 0
364                || input_rows
365                    .checked_mul(request.frame_count)
366                    .is_none_or(|required| required > request.input_len)
367        }
368    };
369    if request.window.is_empty()
370        || request.nfft == 0
371        || request.frame_count == 0
372        || invalid_frame_mode
373        || invalid_input_coverage
374        || !request.denominator.is_finite()
375        || request.denominator <= 0.0
376    {
377        return Err(anyhow!("uniform_spectral_estimate: invalid request"));
378    }
379
380    Ok(())
381}
382
383pub async fn signal_envelope(
384    request: ProviderEnvelopeRequest<'_>,
385) -> anyhow::Result<ProviderEnvelopeResult> {
386    let expected_len = request
387        .channel_len
388        .checked_mul(request.channel_count)
389        .ok_or_else(|| anyhow!("signal_envelope: invalid request"))?;
390    let output_len = request
391        .output_shape
392        .iter()
393        .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
394        .ok_or_else(|| anyhow!("signal_envelope: invalid request"))?;
395    let input_len = request
396        .input
397        .shape
398        .iter()
399        .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
400        .ok_or_else(|| anyhow!("signal_envelope: invalid request"))?;
401    if request.channel_len == 0
402        || request.channel_count == 0
403        || request.output_shape.is_empty()
404        || output_len != expected_len
405        || input_len != expected_len
406        || !provider_envelope_input_shape_matches(
407            &request.input.shape,
408            request.channel_len,
409            request.channel_count,
410        )
411    {
412        return Err(anyhow!("signal_envelope: invalid request"));
413    }
414
415    match request.method {
416        ProviderEnvelopeMethod::AnalyticFir { filter_len }
417        | ProviderEnvelopeMethod::Rms {
418            window_len: filter_len,
419        } if filter_len == 0 => return Err(anyhow!("signal_envelope: invalid request")),
420        _ => {}
421    }
422
423    let provider =
424        provider().ok_or_else(|| anyhow!("signal_envelope: GPU provider unavailable"))?;
425    provider.signal_envelope(&request).await
426}
427
428fn provider_envelope_input_shape_matches(
429    shape: &[usize],
430    channel_len: usize,
431    channel_count: usize,
432) -> bool {
433    if channel_count == 1 {
434        return match shape {
435            [len] => *len == channel_len,
436            [rows, cols] => {
437                (*rows == channel_len && *cols == 1) || (*rows == 1 && *cols == channel_len)
438            }
439            _ => false,
440        };
441    }
442
443    matches!(shape, [rows, cols] if *rows == channel_len && *cols == channel_count)
444}
445
446pub async fn signal_hilbert(
447    request: ProviderHilbertRequest<'_>,
448) -> anyhow::Result<GpuTensorHandle> {
449    if request.length == Some(0) {
450        return Err(anyhow!("signal_hilbert: invalid request"));
451    }
452    if request.dim >= request.input.shape.len() {
453        return Err(anyhow!("signal_hilbert: invalid request"));
454    }
455
456    let provider = provider().ok_or_else(|| anyhow!("signal_hilbert: GPU provider unavailable"))?;
457    provider.signal_hilbert(&request).await
458}
459
460#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
461pub struct ApiDeviceInfo {
462    pub device_id: u32,
463    pub name: String,
464    pub vendor: String,
465    pub memory_bytes: Option<u64>,
466    pub backend: Option<String>,
467}
468
469#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
470pub struct ReduceDimResult {
471    pub values: GpuTensorHandle,
472    pub indices: GpuTensorHandle,
473}
474
475#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
476pub struct ProviderCumminResult {
477    pub values: GpuTensorHandle,
478    pub indices: GpuTensorHandle,
479}
480
481/// Result payload returned by provider-side `cummax` scans.
482///
483/// Alias of [`ProviderCumminResult`] because both operations return the same pair of tensors
484/// (running values and MATLAB-compatible indices).
485pub type ProviderCummaxResult = ProviderCumminResult;
486
487/// Names a shared acceleration context that callers may request (e.g. plotting).
488#[derive(Debug, Clone, Copy, PartialEq, Eq)]
489pub enum AccelContextKind {
490    Plotting,
491}
492
493/// Handle returned by [`export_context`] that describes a shared GPU context.
494#[derive(Clone)]
495pub enum AccelContextHandle {
496    #[cfg(feature = "wgpu")]
497    Wgpu(WgpuContextHandle),
498}
499
500impl AccelContextHandle {
501    /// Returns the underlying WGPU context when available.
502    #[cfg(feature = "wgpu")]
503    pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
504        match self {
505            AccelContextHandle::Wgpu(ctx) => Some(ctx),
506        }
507    }
508}
509
510/// Shared WGPU device/queue pair exported by the acceleration provider.
511#[cfg(feature = "wgpu")]
512#[derive(Clone)]
513pub struct WgpuContextHandle {
514    pub instance: Arc<wgpu::Instance>,
515    pub device: Arc<wgpu::Device>,
516    pub queue: Arc<wgpu::Queue>,
517    pub adapter: Arc<wgpu::Adapter>,
518    pub adapter_info: wgpu::AdapterInfo,
519    pub limits: wgpu::Limits,
520    pub features: wgpu::Features,
521}
522
523/// Borrowed reference to a provider-owned WGPU buffer corresponding to a `GpuTensorHandle`.
524#[cfg(feature = "wgpu")]
525#[derive(Clone)]
526pub struct WgpuBufferRef {
527    pub buffer: Arc<wgpu::Buffer>,
528    pub len: usize,
529    pub shape: Vec<usize>,
530    pub element_size: usize,
531    pub precision: ProviderPrecision,
532}
533
534pub fn set_handle_storage(handle: &GpuTensorHandle, storage: GpuTensorStorage) {
535    if let Ok(mut guard) = HANDLE_STORAGES.write() {
536        guard.insert(handle.buffer_id, storage);
537    }
538}
539
540pub fn handle_storage(handle: &GpuTensorHandle) -> GpuTensorStorage {
541    HANDLE_STORAGES
542        .read()
543        .ok()
544        .and_then(|guard| guard.get(&handle.buffer_id).cloned())
545        .unwrap_or(GpuTensorStorage::Real)
546}
547
548pub fn clear_handle_storage(handle: &GpuTensorHandle) {
549    if let Ok(mut guard) = HANDLE_STORAGES.write() {
550        guard.remove(&handle.buffer_id);
551    }
552}
553
554#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
555pub enum PagefunOp {
556    Mtimes,
557}
558
559#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
560pub struct PagefunRequest {
561    pub op: PagefunOp,
562    pub inputs: Vec<GpuTensorHandle>,
563    pub output_shape: Vec<usize>,
564    pub page_dims: Vec<usize>,
565    pub input_page_dims: Vec<Vec<usize>>,
566}
567
568#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
569pub enum FindDirection {
570    First,
571    Last,
572}
573
574#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
575pub struct ProviderFindResult {
576    pub linear: GpuTensorHandle,
577    pub rows: GpuTensorHandle,
578    pub cols: GpuTensorHandle,
579    pub values: Option<GpuTensorHandle>,
580}
581
582#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
583pub struct ProviderBandwidth {
584    pub lower: u32,
585    pub upper: u32,
586}
587
588#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
589pub enum ProviderSymmetryKind {
590    Symmetric,
591    Skew,
592}
593
594#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
595pub enum ProviderHermitianKind {
596    Hermitian,
597    Skew,
598}
599
600#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
601pub struct ProviderLuResult {
602    pub combined: GpuTensorHandle,
603    pub lower: GpuTensorHandle,
604    pub upper: GpuTensorHandle,
605    pub perm_matrix: GpuTensorHandle,
606    pub perm_vector: GpuTensorHandle,
607}
608
609#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
610pub struct ProviderCholResult {
611    pub factor: GpuTensorHandle,
612    /// MATLAB-compatible failure index (0 indicates success).
613    pub info: u32,
614}
615
616#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
617pub struct ProviderQrResult {
618    pub q: GpuTensorHandle,
619    pub r: GpuTensorHandle,
620    pub perm_matrix: GpuTensorHandle,
621    pub perm_vector: GpuTensorHandle,
622}
623
624#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
625pub struct ProviderQrPowerIterResult {
626    pub q: GpuTensorHandle,
627    pub r: GpuTensorHandle,
628    pub perm_matrix: GpuTensorHandle,
629    pub perm_vector: GpuTensorHandle,
630}
631
632#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
633pub struct ProviderLinsolveOptions {
634    pub lower: bool,
635    pub upper: bool,
636    pub rectangular: bool,
637    pub transposed: bool,
638    pub conjugate: bool,
639    pub symmetric: bool,
640    pub posdef: bool,
641    pub need_rcond: bool,
642    pub rcond: Option<f64>,
643}
644
645#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
646pub struct ProviderLinsolveResult {
647    pub solution: GpuTensorHandle,
648    pub reciprocal_condition: f64,
649}
650
651#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
652pub struct ProviderPinvOptions {
653    pub tolerance: Option<f64>,
654}
655
656#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
657pub struct ProviderPolyvalMu {
658    pub mean: f64,
659    pub scale: f64,
660}
661
662#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
663pub struct ProviderPolyvalOptions {
664    pub mu: Option<ProviderPolyvalMu>,
665}
666
667#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
668pub struct ProviderInvOptions {}
669
670#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
671pub struct ProviderPolyfitResult {
672    pub coefficients: Vec<f64>,
673    pub r_matrix: Vec<f64>,
674    pub normr: f64,
675    pub df: f64,
676    pub mu: [f64; 2],
677}
678
679/// Numerator/denominator payload returned by provider-backed `polyder` quotient rule.
680#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
681pub struct ProviderPolyderQuotient {
682    pub numerator: GpuTensorHandle,
683    pub denominator: GpuTensorHandle,
684}
685
686/// Supported norm specifications for the `cond` builtin.
687#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
688pub enum ProviderCondNorm {
689    Two,
690    One,
691    Inf,
692    Fro,
693}
694
695/// Supported norm orders for the `norm` builtin.
696#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
697pub enum ProviderNormOrder {
698    Two,
699    One,
700    Inf,
701    NegInf,
702    Zero,
703    Fro,
704    Nuc,
705    P(f64),
706}
707
708#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
709pub struct ProviderEigResult {
710    pub eigenvalues: GpuTensorHandle,
711    pub diagonal: GpuTensorHandle,
712    pub right: GpuTensorHandle,
713    pub left: Option<GpuTensorHandle>,
714}
715
716#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
717pub enum ProviderQrPivot {
718    Matrix,
719    Vector,
720}
721
722#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
723pub struct ProviderQrOptions {
724    pub economy: bool,
725    pub pivot: ProviderQrPivot,
726}
727
728impl Default for ProviderQrOptions {
729    fn default() -> Self {
730        Self {
731            economy: false,
732            pivot: ProviderQrPivot::Matrix,
733        }
734    }
735}
736
737#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
738pub enum ProviderPrecision {
739    F32,
740    F64,
741}
742
743/// Declares how provider-owned GPU handles may cross async spawn boundaries.
744///
745/// This is a runtime/provider policy surface (not a semantic type fact) used by
746/// VM/runtime spawn handling to prevent unsynchronized device-handle races.
747#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
748pub enum SpawnHandleConcurrency {
749    /// Provider supports immutable sharing of handle-backed values across spawned tasks.
750    ImmutableShare,
751    /// Provider supports copy-on-write semantics when spawned and parent tasks diverge.
752    CopyOnWrite,
753    /// Provider supports synchronized mutation for shared handles.
754    SynchronizedMutation,
755    /// Provider rejects spawned sharing of raw handles.
756    Reject,
757}
758
759impl SpawnHandleConcurrency {
760    pub fn as_str(self) -> &'static str {
761        match self {
762            SpawnHandleConcurrency::ImmutableShare => "immutable_share",
763            SpawnHandleConcurrency::CopyOnWrite => "copy_on_write",
764            SpawnHandleConcurrency::SynchronizedMutation => "synchronized_mutation",
765            SpawnHandleConcurrency::Reject => "reject",
766        }
767    }
768}
769
770#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
771pub enum ReductionTwoPassMode {
772    Auto,
773    ForceOn,
774    ForceOff,
775}
776
777impl ReductionTwoPassMode {
778    pub fn as_str(self) -> &'static str {
779        match self {
780            ReductionTwoPassMode::Auto => "auto",
781            ReductionTwoPassMode::ForceOn => "force_on",
782            ReductionTwoPassMode::ForceOff => "force_off",
783        }
784    }
785}
786
787#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
788pub enum ReductionFlavor {
789    Sum,
790    Mean,
791    CustomScale(f64),
792}
793
794impl ReductionFlavor {
795    pub fn is_mean(self) -> bool {
796        matches!(self, ReductionFlavor::Mean)
797    }
798
799    pub fn scale(self, reduce_len: usize) -> f64 {
800        match self {
801            ReductionFlavor::Sum => 1.0,
802            ReductionFlavor::Mean => {
803                if reduce_len == 0 {
804                    1.0
805                } else {
806                    1.0 / reduce_len as f64
807                }
808            }
809            ReductionFlavor::CustomScale(scale) => scale,
810        }
811    }
812}
813
814/// Normalisation mode for correlation coefficients.
815#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
816pub enum CorrcoefNormalization {
817    Unbiased,
818    Biased,
819}
820
821/// Row-selection strategy for correlation coefficients.
822#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
823pub enum CorrcoefRows {
824    All,
825    Complete,
826    Pairwise,
827}
828
829/// Options controlling provider-backed correlation coefficient computation.
830#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
831pub struct CorrcoefOptions {
832    pub normalization: CorrcoefNormalization,
833    pub rows: CorrcoefRows,
834}
835
836impl Default for CorrcoefOptions {
837    fn default() -> Self {
838        Self {
839            normalization: CorrcoefNormalization::Unbiased,
840            rows: CorrcoefRows::All,
841        }
842    }
843}
844
845/// Normalisation mode used by covariance computations.
846#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
847pub enum CovNormalization {
848    Unbiased,
849    Biased,
850}
851
852/// Row handling strategy for covariance computations.
853#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
854pub enum CovRows {
855    All,
856    OmitRows,
857    PartialRows,
858}
859
860/// Options controlling provider-backed covariance computation.
861#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
862pub struct CovarianceOptions {
863    pub normalization: CovNormalization,
864    pub rows: CovRows,
865    pub has_weight_vector: bool,
866}
867
868impl Default for CovarianceOptions {
869    fn default() -> Self {
870        Self {
871            normalization: CovNormalization::Unbiased,
872            rows: CovRows::All,
873            has_weight_vector: false,
874        }
875    }
876}
877
878/// Normalization strategy used by provider-backed standard deviation reductions.
879#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
880pub enum ProviderStdNormalization {
881    Sample,
882    Population,
883}
884
885/// NaN handling mode for provider-backed reductions.
886#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
887pub enum ProviderNanMode {
888    Include,
889    Omit,
890}
891
892/// Direction used when computing prefix sums on the device.
893#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
894pub enum ProviderScanDirection {
895    Forward,
896    Reverse,
897}
898
899/// Sort direction used by acceleration providers.
900#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
901pub enum SortOrder {
902    Ascend,
903    Descend,
904}
905
906/// Comparison strategy applied during sorting.
907#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
908pub enum SortComparison {
909    Auto,
910    Real,
911    Abs,
912}
913
914/// Host-resident outputs returned by provider-backed sort operations.
915#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
916pub struct SortResult {
917    pub values: HostTensorOwned,
918    pub indices: HostTensorOwned,
919}
920
921#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
922pub struct SortRowsColumnSpec {
923    pub index: usize,
924    pub order: SortOrder,
925}
926
927/// Ordering applied by provider-backed `unique` operations.
928#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
929pub enum UniqueOrder {
930    Sorted,
931    Stable,
932}
933
934/// Occurrence selection for provider-backed `unique` operations.
935#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
936pub enum UniqueOccurrence {
937    First,
938    Last,
939}
940
941/// Options controlling provider-backed `unique` operations.
942#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
943pub struct UniqueOptions {
944    pub rows: bool,
945    pub order: UniqueOrder,
946    pub occurrence: UniqueOccurrence,
947}
948
949/// Host-resident outputs returned by provider-backed `unique` operations.
950#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
951pub struct UniqueResult {
952    pub values: HostTensorOwned,
953    pub ia: HostTensorOwned,
954    pub ic: HostTensorOwned,
955}
956
957/// Ordering applied by provider-backed `union` operations.
958#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
959pub enum UnionOrder {
960    Sorted,
961    Stable,
962}
963
964/// Options controlling provider-backed `union` operations.
965#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
966pub struct UnionOptions {
967    pub rows: bool,
968    pub order: UnionOrder,
969}
970
971/// Host-resident outputs returned by provider-backed `union` operations.
972#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
973pub struct UnionResult {
974    pub values: HostTensorOwned,
975    pub ia: HostTensorOwned,
976    pub ib: HostTensorOwned,
977}
978
979/// Parameterisation of 2-D filters generated by `fspecial`.
980#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
981pub enum FspecialFilter {
982    Average {
983        rows: u32,
984        cols: u32,
985    },
986    Disk {
987        radius: f64,
988        size: u32,
989    },
990    Gaussian {
991        rows: u32,
992        cols: u32,
993        sigma: f64,
994    },
995    Laplacian {
996        alpha: f64,
997    },
998    Log {
999        rows: u32,
1000        cols: u32,
1001        sigma: f64,
1002    },
1003    Motion {
1004        length: u32,
1005        kernel_size: u32,
1006        angle_degrees: f64,
1007        oversample: u32,
1008    },
1009    Prewitt,
1010    Sobel,
1011    Unsharp {
1012        alpha: f64,
1013    },
1014}
1015
1016/// Request dispatched to acceleration providers for `fspecial` kernels.
1017#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1018pub struct FspecialRequest {
1019    pub filter: FspecialFilter,
1020}
1021
1022/// Padding strategy used by `imfilter`.
1023#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1024pub enum ImfilterPadding {
1025    Constant,
1026    Replicate,
1027    Symmetric,
1028    Circular,
1029}
1030
1031/// Output sizing mode used by `imfilter`.
1032#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1033pub enum ImfilterShape {
1034    Same,
1035    Full,
1036    Valid,
1037}
1038
1039/// Correlation vs convolution behaviour for `imfilter`.
1040#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1041pub enum ImfilterMode {
1042    Correlation,
1043    Convolution,
1044}
1045
1046/// Options supplied to acceleration providers for `imfilter`.
1047#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1048pub struct ImfilterOptions {
1049    pub padding: ImfilterPadding,
1050    pub constant_value: f64,
1051    pub shape: ImfilterShape,
1052    pub mode: ImfilterMode,
1053}
1054
1055impl Default for ImfilterOptions {
1056    fn default() -> Self {
1057        Self {
1058            padding: ImfilterPadding::Constant,
1059            constant_value: 0.0,
1060            shape: ImfilterShape::Same,
1061            mode: ImfilterMode::Correlation,
1062        }
1063    }
1064}
1065
1066/// Ordering applied by provider-backed `setdiff` operations.
1067#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1068pub enum SetdiffOrder {
1069    Sorted,
1070    Stable,
1071}
1072
1073/// Options controlling provider-backed `setdiff` operations.
1074#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1075pub struct SetdiffOptions {
1076    pub rows: bool,
1077    pub order: SetdiffOrder,
1078}
1079
1080/// Host-resident outputs returned by provider-backed `setdiff` operations.
1081#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1082pub struct SetdiffResult {
1083    pub values: HostTensorOwned,
1084    pub ia: HostTensorOwned,
1085}
1086
1087/// Options controlling provider-backed `ismember` operations.
1088#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1089pub struct IsMemberOptions {
1090    pub rows: bool,
1091}
1092
1093/// Host-resident logical output returned by providers.
1094#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1095pub struct HostLogicalOwned {
1096    pub data: Vec<u8>,
1097    pub shape: Vec<usize>,
1098}
1099
1100/// Host-resident outputs returned by provider-backed `ismember` operations.
1101#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1102pub struct IsMemberResult {
1103    pub mask: HostLogicalOwned,
1104    pub loc: HostTensorOwned,
1105}
1106
1107#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1108pub enum ProviderConvMode {
1109    Full,
1110    Same,
1111    Valid,
1112}
1113
1114#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1115pub enum ProviderConvOrientation {
1116    Row,
1117    Column,
1118}
1119
1120#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1121pub struct ProviderConv1dOptions {
1122    pub mode: ProviderConvMode,
1123    pub orientation: ProviderConvOrientation,
1124}
1125
1126#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1127pub struct ProviderIirFilterOptions {
1128    /// Zero-based dimension along which filtering should be applied.
1129    pub dim: usize,
1130    /// Optional initial conditions (state vector) residing on the device.
1131    pub zi: Option<GpuTensorHandle>,
1132}
1133
1134#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1135pub struct ProviderIirFilterResult {
1136    /// Filtered output tensor, matching the input signal shape.
1137    pub output: GpuTensorHandle,
1138    /// Final conditions for the filter state (same shape as the requested `zi` layout).
1139    pub final_state: Option<GpuTensorHandle>,
1140}
1141
1142#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1143pub struct ProviderMoments2 {
1144    pub mean: GpuTensorHandle,
1145    pub ex2: GpuTensorHandle,
1146}
1147
1148#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
1149pub struct ProviderDispatchStats {
1150    /// Number of GPU dispatches recorded for this category.
1151    pub count: u64,
1152    /// Accumulated wall-clock time of dispatches in nanoseconds (host measured).
1153    pub total_wall_time_ns: u64,
1154}
1155
1156#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
1157pub struct ProviderFallbackStat {
1158    pub reason: String,
1159    pub count: u64,
1160}
1161
1162#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
1163pub struct ProviderTelemetry {
1164    pub fused_elementwise: ProviderDispatchStats,
1165    pub fused_reduction: ProviderDispatchStats,
1166    pub matmul: ProviderDispatchStats,
1167    pub linsolve: ProviderDispatchStats,
1168    pub mldivide: ProviderDispatchStats,
1169    pub mrdivide: ProviderDispatchStats,
1170    pub upload_bytes: u64,
1171    pub download_bytes: u64,
1172    pub solve_fallbacks: Vec<ProviderFallbackStat>,
1173    pub fusion_cache_hits: u64,
1174    pub fusion_cache_misses: u64,
1175    pub bind_group_cache_hits: u64,
1176    pub bind_group_cache_misses: u64,
1177    /// Optional per-layout bind group cache counters (layout tags and their hit/miss counts)
1178    pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
1179    /// Recent kernel launch metadata (bounded log; newest last)
1180    pub kernel_launches: Vec<KernelLaunchTelemetry>,
1181}
1182
1183#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1184pub struct BindGroupLayoutTelemetry {
1185    pub tag: String,
1186    pub hits: u64,
1187    pub misses: u64,
1188}
1189
1190#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1191pub struct KernelAttrTelemetry {
1192    pub key: String,
1193    pub value: u64,
1194}
1195
1196#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1197pub struct KernelLaunchTelemetry {
1198    pub kernel: String,
1199    pub precision: Option<String>,
1200    pub shape: Vec<KernelAttrTelemetry>,
1201    pub tuning: Vec<KernelAttrTelemetry>,
1202}
1203
1204pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
1205pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
1206
1207fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
1208    Box::pin(async move { Err(anyhow::anyhow!(message)) })
1209}
1210
1211/// Device/provider interface that backends implement and register into the runtime layer
1212pub trait AccelProvider: Send + Sync {
1213    fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
1214    fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
1215    fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
1216    fn device_info(&self) -> String;
1217    fn device_id(&self) -> u32 {
1218        0
1219    }
1220
1221    /// Declares provider policy for sharing `GpuTensorHandle` values across
1222    /// spawned async boundaries.
1223    ///
1224    /// Default is conservative rejection. Providers that can safely support
1225    /// cross-task sharing should override this.
1226    fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
1227        SpawnHandleConcurrency::Reject
1228    }
1229
1230    /// Export a shared GPU context handle, allowing downstream systems (plotting, visualization)
1231    /// to reuse the same device/queue without copying tensor data back to the host.
1232    fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
1233        None
1234    }
1235
1236    /// Export a provider-owned WGPU buffer for zero-copy integrations.
1237    #[cfg(feature = "wgpu")]
1238    fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
1239        let _ = _handle;
1240        None
1241    }
1242
1243    /// Gather elements from `source` at the provided zero-based linear `indices`, materialising
1244    /// a dense tensor with the specified `output_shape`.
1245    fn gather_linear(
1246        &self,
1247        _source: &GpuTensorHandle,
1248        _indices: &[u32],
1249        _output_shape: &[usize],
1250    ) -> anyhow::Result<GpuTensorHandle> {
1251        Err(anyhow::anyhow!("gather_linear not supported by provider"))
1252    }
1253
1254    /// Scatter the contents of `values` into `target` at the provided zero-based linear `indices`.
1255    ///
1256    /// The provider must ensure `values.len() == indices.len()` and update `target` in place.
1257    fn scatter_linear(
1258        &self,
1259        _target: &GpuTensorHandle,
1260        _indices: &[u32],
1261        _values: &GpuTensorHandle,
1262    ) -> anyhow::Result<()> {
1263        Err(anyhow::anyhow!("scatter_linear not supported by provider"))
1264    }
1265
1266    /// Structured device information (optional to override). Default adapts from `device_info()`.
1267    fn device_info_struct(&self) -> ApiDeviceInfo {
1268        ApiDeviceInfo {
1269            device_id: 0,
1270            name: self.device_info(),
1271            vendor: String::new(),
1272            memory_bytes: None,
1273            backend: None,
1274        }
1275    }
1276
1277    fn precision(&self) -> ProviderPrecision {
1278        ProviderPrecision::F64
1279    }
1280
1281    /// Read a single scalar at linear index from a device tensor, returning it as f64.
1282    fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
1283        Err(anyhow::anyhow!("read_scalar not supported by provider"))
1284    }
1285
1286    /// Allocate a zero-initialised tensor with the provided shape on the device.
1287    fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1288        Err(anyhow::anyhow!("zeros not supported by provider"))
1289    }
1290
1291    /// Allocate a one-initialised tensor with the provided shape on the device.
1292    fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1293        Err(anyhow::anyhow!("ones not supported by provider"))
1294    }
1295
1296    /// Allocate a zero-initialised tensor matching the prototype tensor.
1297    fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1298        self.zeros(&prototype.shape)
1299    }
1300
1301    /// Allocate a tensor filled with a constant value on the device.
1302    fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
1303        if value == 0.0 {
1304            return self.zeros(shape);
1305        }
1306        if let Ok(base) = self.zeros(shape) {
1307            match self.scalar_add(&base, value) {
1308                Ok(out) => {
1309                    let _ = self.free(&base);
1310                    return Ok(out);
1311                }
1312                Err(_) => {
1313                    let _ = self.free(&base);
1314                }
1315            }
1316        }
1317        let len: usize = shape.iter().copied().product();
1318        let data = vec![value; len];
1319        let view = HostTensorView { data: &data, shape };
1320        self.upload(&view)
1321    }
1322
1323    /// Allocate a tensor filled with a constant value, matching a prototype's residency.
1324    fn fill_like(
1325        &self,
1326        prototype: &GpuTensorHandle,
1327        value: f64,
1328    ) -> anyhow::Result<GpuTensorHandle> {
1329        if value == 0.0 {
1330            return self.zeros_like(prototype);
1331        }
1332        if let Ok(base) = self.zeros_like(prototype) {
1333            match self.scalar_add(&base, value) {
1334                Ok(out) => {
1335                    let _ = self.free(&base);
1336                    return Ok(out);
1337                }
1338                Err(_) => {
1339                    let _ = self.free(&base);
1340                }
1341            }
1342        }
1343        self.fill(&prototype.shape, value)
1344    }
1345
1346    /// Allocate a one-initialised tensor matching the prototype tensor.
1347    fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1348        self.ones(&prototype.shape)
1349    }
1350
1351    /// Allocate an identity tensor with ones along the leading diagonal of the first two axes.
1352    fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1353        Err(anyhow::anyhow!("eye not supported by provider"))
1354    }
1355
1356    /// Allocate an identity tensor matching the prototype tensor's shape.
1357    fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1358        self.eye(&prototype.shape)
1359    }
1360
1361    /// Construct MATLAB-style coordinate grids from axis vectors.
1362    fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1363        Err(anyhow::anyhow!("meshgrid not supported by provider"))
1364    }
1365
1366    /// Construct a diagonal matrix from a vector-like tensor. `offset` matches MATLAB semantics.
1367    fn diag_from_vector(
1368        &self,
1369        _vector: &GpuTensorHandle,
1370        _offset: isize,
1371    ) -> anyhow::Result<GpuTensorHandle> {
1372        Err(anyhow::anyhow!(
1373            "diag_from_vector not supported by provider"
1374        ))
1375    }
1376
1377    /// Extract a diagonal from a matrix-like tensor. The result is always a column vector.
1378    fn diag_extract(
1379        &self,
1380        _matrix: &GpuTensorHandle,
1381        _offset: isize,
1382    ) -> anyhow::Result<GpuTensorHandle> {
1383        Err(anyhow::anyhow!("diag_extract not supported by provider"))
1384    }
1385
1386    /// Apply a lower-triangular mask to the first two dimensions of a tensor.
1387    fn tril<'a>(
1388        &'a self,
1389        _matrix: &'a GpuTensorHandle,
1390        _offset: isize,
1391    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1392        Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1393    }
1394
1395    /// Apply an upper-triangular mask to the first two dimensions of a tensor.
1396    fn triu<'a>(
1397        &'a self,
1398        _matrix: &'a GpuTensorHandle,
1399        _offset: isize,
1400    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1401        Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1402    }
1403
1404    /// Evaluate a polynomial expressed by `coefficients` at each element in `points`.
1405    fn polyval(
1406        &self,
1407        _coefficients: &GpuTensorHandle,
1408        _points: &GpuTensorHandle,
1409        _options: &ProviderPolyvalOptions,
1410    ) -> anyhow::Result<GpuTensorHandle> {
1411        Err(anyhow::anyhow!("polyval not supported by provider"))
1412    }
1413
1414    /// Fit a polynomial of degree `degree` to `(x, y)` samples. Optional weights must match `x`.
1415    fn polyfit<'a>(
1416        &'a self,
1417        _x: &'a GpuTensorHandle,
1418        _y: &'a GpuTensorHandle,
1419        _degree: usize,
1420        _weights: Option<&'a GpuTensorHandle>,
1421    ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1422        Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1423    }
1424
1425    /// Differentiate a polynomial represented as a vector of coefficients.
1426    fn polyder_single<'a>(
1427        &'a self,
1428        _polynomial: &'a GpuTensorHandle,
1429    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1430        Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1431    }
1432
1433    /// Apply the product rule to polynomials `p` and `q`.
1434    fn polyder_product<'a>(
1435        &'a self,
1436        _p: &'a GpuTensorHandle,
1437        _q: &'a GpuTensorHandle,
1438    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1439        Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1440    }
1441
1442    /// Apply the quotient rule to polynomials `u` and `v`.
1443    fn polyder_quotient<'a>(
1444        &'a self,
1445        _u: &'a GpuTensorHandle,
1446        _v: &'a GpuTensorHandle,
1447    ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1448        Box::pin(async move {
1449            Err(anyhow::anyhow!(
1450                "polyder_quotient not supported by provider"
1451            ))
1452        })
1453    }
1454
1455    /// Integrate a polynomial represented as a vector of coefficients and append a constant term.
1456    fn polyint(
1457        &self,
1458        _polynomial: &GpuTensorHandle,
1459        _constant: f64,
1460    ) -> anyhow::Result<GpuTensorHandle> {
1461        Err(anyhow::anyhow!("polyint not supported by provider"))
1462    }
1463
1464    /// Allocate a tensor filled with random values drawn from U(0, 1).
1465    fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1466        Err(anyhow::anyhow!("random_uniform not supported by provider"))
1467    }
1468
1469    /// Allocate a tensor filled with random values matching the prototype shape.
1470    fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1471        self.random_uniform(&prototype.shape)
1472    }
1473
1474    /// Allocate a tensor filled with standard normal (mean 0, stddev 1) random values.
1475    fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1476        Err(anyhow::anyhow!("random_normal not supported by provider"))
1477    }
1478
1479    /// Allocate a tensor of standard normal values matching a prototype's shape.
1480    fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1481        self.random_normal(&prototype.shape)
1482    }
1483
1484    /// Exponentially-distributed random values with mean `mu`.
1485    fn random_exponential(&self, _mu: f64, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1486        Err(anyhow::anyhow!(
1487            "random_exponential not supported by provider"
1488        ))
1489    }
1490
1491    /// Normal random values with mean `mu` and standard deviation `sigma`.
1492    fn random_normrnd(
1493        &self,
1494        _mu: f64,
1495        _sigma: f64,
1496        _shape: &[usize],
1497    ) -> anyhow::Result<GpuTensorHandle> {
1498        Err(anyhow::anyhow!("random_normrnd not supported by provider"))
1499    }
1500
1501    /// Uniform random values on the interval `[a, b)`.
1502    fn random_unifrnd(
1503        &self,
1504        _a: f64,
1505        _b: f64,
1506        _shape: &[usize],
1507    ) -> anyhow::Result<GpuTensorHandle> {
1508        Err(anyhow::anyhow!("random_unifrnd not supported by provider"))
1509    }
1510
1511    fn stochastic_evolution(
1512        &self,
1513        _state: &GpuTensorHandle,
1514        _drift: f64,
1515        _scale: f64,
1516        _steps: u32,
1517    ) -> anyhow::Result<GpuTensorHandle> {
1518        Err(anyhow::anyhow!(
1519            "stochastic_evolution not supported by provider"
1520        ))
1521    }
1522
1523    /// Set the provider RNG state to align with the host RNG.
1524    fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1525        Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1526    }
1527
1528    /// Generate a 2-D correlation kernel matching MATLAB's `fspecial` builtin.
1529    fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1530        Err(anyhow::anyhow!("fspecial not supported by provider"))
1531    }
1532
1533    /// Evaluate the `peaks` test surface on an n×n grid spanning [-3,3]×[-3,3].
1534    /// Returns the Z matrix (n×n) as a GPU tensor.
1535    fn peaks(&self, _n: usize) -> anyhow::Result<GpuTensorHandle> {
1536        Err(anyhow::anyhow!("peaks not supported by provider"))
1537    }
1538
1539    /// Evaluate the `peaks` formula element-wise on caller-supplied GPU coordinate tensors.
1540    /// X and Y must have the same shape. Returns a Z tensor of the same shape.
1541    fn peaks_xy(
1542        &self,
1543        _x: &GpuTensorHandle,
1544        _y: &GpuTensorHandle,
1545    ) -> anyhow::Result<GpuTensorHandle> {
1546        Err(anyhow::anyhow!("peaks_xy not supported by provider"))
1547    }
1548
1549    fn hann_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1550        Err(anyhow::anyhow!("hann_window not supported by provider"))
1551    }
1552
1553    fn hamming_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1554        Err(anyhow::anyhow!("hamming_window not supported by provider"))
1555    }
1556
1557    fn blackman_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1558        Err(anyhow::anyhow!("blackman_window not supported by provider"))
1559    }
1560
1561    /// Apply an N-D correlation/convolution with padding semantics matching MATLAB's `imfilter`.
1562    fn imfilter<'a>(
1563        &'a self,
1564        _image: &'a GpuTensorHandle,
1565        _kernel: &'a GpuTensorHandle,
1566        _options: &'a ImfilterOptions,
1567    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1568        unsupported_future("imfilter not supported by provider")
1569    }
1570
1571    /// Allocate a tensor filled with random integers over an inclusive range.
1572    fn random_integer_range(
1573        &self,
1574        _lower: i64,
1575        _upper: i64,
1576        _shape: &[usize],
1577    ) -> anyhow::Result<GpuTensorHandle> {
1578        Err(anyhow::anyhow!(
1579            "random_integer_range not supported by provider"
1580        ))
1581    }
1582
1583    /// Allocate a random integer tensor matching the prototype shape.
1584    fn random_integer_like(
1585        &self,
1586        prototype: &GpuTensorHandle,
1587        lower: i64,
1588        upper: i64,
1589    ) -> anyhow::Result<GpuTensorHandle> {
1590        self.random_integer_range(lower, upper, &prototype.shape)
1591    }
1592
1593    /// Allocate a random permutation of 1..=n, returning the first k elements.
1594    fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1595        Err(anyhow!("random_permutation not supported by provider"))
1596    }
1597
1598    /// Allocate a random permutation matching the prototype residency.
1599    fn random_permutation_like(
1600        &self,
1601        _prototype: &GpuTensorHandle,
1602        n: usize,
1603        k: usize,
1604    ) -> anyhow::Result<GpuTensorHandle> {
1605        self.random_permutation(n, k)
1606    }
1607
1608    /// Compute a covariance matrix across the columns of `matrix`.
1609    fn covariance<'a>(
1610        &'a self,
1611        _matrix: &'a GpuTensorHandle,
1612        _second: Option<&'a GpuTensorHandle>,
1613        _weights: Option<&'a GpuTensorHandle>,
1614        _options: &'a CovarianceOptions,
1615    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1616        unsupported_future("covariance not supported by provider")
1617    }
1618
1619    /// Compute a correlation coefficient matrix across the columns of `matrix`.
1620    fn corrcoef<'a>(
1621        &'a self,
1622        _matrix: &'a GpuTensorHandle,
1623        _options: &'a CorrcoefOptions,
1624    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1625        unsupported_future("corrcoef not supported by provider")
1626    }
1627
1628    // Optional operator hooks (default to unsupported)
1629    fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1630        Err(anyhow::anyhow!("linspace not supported by provider"))
1631    }
1632    fn elem_add<'a>(
1633        &'a self,
1634        _a: &'a GpuTensorHandle,
1635        _b: &'a GpuTensorHandle,
1636    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1637        unsupported_future("elem_add not supported by provider")
1638    }
1639    fn elem_mul<'a>(
1640        &'a self,
1641        _a: &'a GpuTensorHandle,
1642        _b: &'a GpuTensorHandle,
1643    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1644        unsupported_future("elem_mul not supported by provider")
1645    }
1646    fn elem_max<'a>(
1647        &'a self,
1648        _a: &'a GpuTensorHandle,
1649        _b: &'a GpuTensorHandle,
1650    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1651        unsupported_future("elem_max not supported by provider")
1652    }
1653    fn elem_min<'a>(
1654        &'a self,
1655        _a: &'a GpuTensorHandle,
1656        _b: &'a GpuTensorHandle,
1657    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1658        unsupported_future("elem_min not supported by provider")
1659    }
1660    fn elem_sub<'a>(
1661        &'a self,
1662        _a: &'a GpuTensorHandle,
1663        _b: &'a GpuTensorHandle,
1664    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1665        unsupported_future("elem_sub not supported by provider")
1666    }
1667    fn elem_div<'a>(
1668        &'a self,
1669        _a: &'a GpuTensorHandle,
1670        _b: &'a GpuTensorHandle,
1671    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1672        unsupported_future("elem_div not supported by provider")
1673    }
1674    fn elem_pow<'a>(
1675        &'a self,
1676        _a: &'a GpuTensorHandle,
1677        _b: &'a GpuTensorHandle,
1678    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1679        unsupported_future("elem_pow not supported by provider")
1680    }
1681
1682    /// Construct complex-interleaved GPU storage from a real-valued tensor,
1683    /// using zero for the imaginary lane.
1684    fn complex_from_real<'a>(
1685        &'a self,
1686        _real: &'a GpuTensorHandle,
1687    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1688        unsupported_future("complex_from_real not supported by provider")
1689    }
1690
1691    /// Construct complex-interleaved GPU storage from real and imaginary tensors.
1692    ///
1693    /// Implementations should support equal shapes and scalar expansion for either
1694    /// operand, matching MATLAB's `complex(real, imag)` size rules.
1695    fn complex_from_real_imag<'a>(
1696        &'a self,
1697        _real: &'a GpuTensorHandle,
1698        _imag: &'a GpuTensorHandle,
1699    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1700        unsupported_future("complex_from_real_imag not supported by provider")
1701    }
1702
1703    /// Map a resident real-valued symbol tensor through a complex constellation
1704    /// table and return complex-interleaved GPU storage.
1705    fn modulate_constellation<'a>(
1706        &'a self,
1707        _request: ProviderModulationRequest<'a>,
1708    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1709        unsupported_future("modulate_constellation not supported by provider")
1710    }
1711
1712    /// Group a resident real/logical bit tensor into symbols, map through a complex
1713    /// constellation table, and return complex-interleaved GPU storage.
1714    fn modulate_bits_constellation<'a>(
1715        &'a self,
1716        _request: ProviderBitModulationRequest<'a>,
1717    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1718        unsupported_future("modulate_bits_constellation not supported by provider")
1719    }
1720
1721    fn elem_hypot<'a>(
1722        &'a self,
1723        _a: &'a GpuTensorHandle,
1724        _b: &'a GpuTensorHandle,
1725    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1726        unsupported_future("elem_hypot not supported by provider")
1727    }
1728    fn elem_ge<'a>(
1729        &'a self,
1730        _a: &'a GpuTensorHandle,
1731        _b: &'a GpuTensorHandle,
1732    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1733        unsupported_future("elem_ge not supported by provider")
1734    }
1735    fn elem_le<'a>(
1736        &'a self,
1737        _a: &'a GpuTensorHandle,
1738        _b: &'a GpuTensorHandle,
1739    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1740        unsupported_future("elem_le not supported by provider")
1741    }
1742    fn elem_lt<'a>(
1743        &'a self,
1744        _a: &'a GpuTensorHandle,
1745        _b: &'a GpuTensorHandle,
1746    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1747        unsupported_future("elem_lt not supported by provider")
1748    }
1749    fn elem_gt<'a>(
1750        &'a self,
1751        _a: &'a GpuTensorHandle,
1752        _b: &'a GpuTensorHandle,
1753    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1754        unsupported_future("elem_gt not supported by provider")
1755    }
1756    fn elem_eq<'a>(
1757        &'a self,
1758        _a: &'a GpuTensorHandle,
1759        _b: &'a GpuTensorHandle,
1760    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1761        unsupported_future("elem_eq not supported by provider")
1762    }
1763    fn elem_ne<'a>(
1764        &'a self,
1765        _a: &'a GpuTensorHandle,
1766        _b: &'a GpuTensorHandle,
1767    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1768        unsupported_future("elem_ne not supported by provider")
1769    }
1770    fn logical_and(
1771        &self,
1772        _a: &GpuTensorHandle,
1773        _b: &GpuTensorHandle,
1774    ) -> anyhow::Result<GpuTensorHandle> {
1775        Err(anyhow::anyhow!("logical_and not supported by provider"))
1776    }
1777    fn logical_or(
1778        &self,
1779        _a: &GpuTensorHandle,
1780        _b: &GpuTensorHandle,
1781    ) -> anyhow::Result<GpuTensorHandle> {
1782        Err(anyhow::anyhow!("logical_or not supported by provider"))
1783    }
1784    fn logical_xor(
1785        &self,
1786        _a: &GpuTensorHandle,
1787        _b: &GpuTensorHandle,
1788    ) -> anyhow::Result<GpuTensorHandle> {
1789        Err(anyhow::anyhow!("logical_xor not supported by provider"))
1790    }
1791    fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1792        Err(anyhow::anyhow!("logical_not not supported by provider"))
1793    }
1794    fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1795        Ok(handle_is_logical(a))
1796    }
1797    fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1798        Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1799    }
1800    fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1801        Err(anyhow::anyhow!(
1802            "logical_isfinite not supported by provider"
1803        ))
1804    }
1805    fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1806        Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1807    }
1808    fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1809        Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1810    }
1811    fn elem_atan2<'a>(
1812        &'a self,
1813        _y: &'a GpuTensorHandle,
1814        _x: &'a GpuTensorHandle,
1815    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1816        unsupported_future("elem_atan2 not supported by provider")
1817    }
1818    // Unary elementwise operations (optional)
1819    fn unary_sin<'a>(
1820        &'a self,
1821        _a: &'a GpuTensorHandle,
1822    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1823        unsupported_future("unary_sin not supported by provider")
1824    }
1825    fn unary_sinc<'a>(
1826        &'a self,
1827        _a: &'a GpuTensorHandle,
1828    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1829        unsupported_future("unary_sinc not supported by provider")
1830    }
1831    fn unary_gamma<'a>(
1832        &'a self,
1833        _a: &'a GpuTensorHandle,
1834    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1835        unsupported_future("unary_gamma not supported by provider")
1836    }
1837    fn unary_factorial<'a>(
1838        &'a self,
1839        _a: &'a GpuTensorHandle,
1840    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1841        unsupported_future("unary_factorial not supported by provider")
1842    }
1843    fn unary_asinh<'a>(
1844        &'a self,
1845        _a: &'a GpuTensorHandle,
1846    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1847        unsupported_future("unary_asinh not supported by provider")
1848    }
1849    fn unary_sinh<'a>(
1850        &'a self,
1851        _a: &'a GpuTensorHandle,
1852    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1853        unsupported_future("unary_sinh not supported by provider")
1854    }
1855    fn unary_cosh<'a>(
1856        &'a self,
1857        _a: &'a GpuTensorHandle,
1858    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1859        unsupported_future("unary_cosh not supported by provider")
1860    }
1861    fn unary_asin<'a>(
1862        &'a self,
1863        _a: &'a GpuTensorHandle,
1864    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1865        unsupported_future("unary_asin not supported by provider")
1866    }
1867    fn unary_acos<'a>(
1868        &'a self,
1869        _a: &'a GpuTensorHandle,
1870    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1871        unsupported_future("unary_acos not supported by provider")
1872    }
1873    fn unary_acosh<'a>(
1874        &'a self,
1875        _a: &'a GpuTensorHandle,
1876    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1877        unsupported_future("unary_acosh not supported by provider")
1878    }
1879    fn unary_tan<'a>(
1880        &'a self,
1881        _a: &'a GpuTensorHandle,
1882    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1883        unsupported_future("unary_tan not supported by provider")
1884    }
1885    fn unary_tanh<'a>(
1886        &'a self,
1887        _a: &'a GpuTensorHandle,
1888    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1889        unsupported_future("unary_tanh not supported by provider")
1890    }
1891    fn unary_atan<'a>(
1892        &'a self,
1893        _a: &'a GpuTensorHandle,
1894    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1895        unsupported_future("unary_atan not supported by provider")
1896    }
1897    fn unary_atanh<'a>(
1898        &'a self,
1899        _a: &'a GpuTensorHandle,
1900    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1901        unsupported_future("unary_atanh not supported by provider")
1902    }
1903    fn unary_ceil<'a>(
1904        &'a self,
1905        _a: &'a GpuTensorHandle,
1906    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1907        unsupported_future("unary_ceil not supported by provider")
1908    }
1909    fn unary_floor<'a>(
1910        &'a self,
1911        _a: &'a GpuTensorHandle,
1912    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1913        unsupported_future("unary_floor not supported by provider")
1914    }
1915    fn unary_round<'a>(
1916        &'a self,
1917        _a: &'a GpuTensorHandle,
1918    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1919        unsupported_future("unary_round not supported by provider")
1920    }
1921    fn unary_fix<'a>(
1922        &'a self,
1923        _a: &'a GpuTensorHandle,
1924    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1925        unsupported_future("unary_fix not supported by provider")
1926    }
1927    fn unary_cos<'a>(
1928        &'a self,
1929        _a: &'a GpuTensorHandle,
1930    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1931        unsupported_future("unary_cos not supported by provider")
1932    }
1933    fn unary_angle<'a>(
1934        &'a self,
1935        _a: &'a GpuTensorHandle,
1936    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1937        unsupported_future("unary_angle not supported by provider")
1938    }
1939    fn unary_imag<'a>(
1940        &'a self,
1941        _a: &'a GpuTensorHandle,
1942    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1943        unsupported_future("unary_imag not supported by provider")
1944    }
1945    fn unary_real<'a>(
1946        &'a self,
1947        _a: &'a GpuTensorHandle,
1948    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1949        unsupported_future("unary_real not supported by provider")
1950    }
1951    fn unary_conj<'a>(
1952        &'a self,
1953        _a: &'a GpuTensorHandle,
1954    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1955        unsupported_future("unary_conj not supported by provider")
1956    }
1957    fn unary_abs<'a>(
1958        &'a self,
1959        _a: &'a GpuTensorHandle,
1960    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1961        unsupported_future("unary_abs not supported by provider")
1962    }
1963    fn unary_sign<'a>(
1964        &'a self,
1965        _a: &'a GpuTensorHandle,
1966    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1967        unsupported_future("unary_sign not supported by provider")
1968    }
1969    fn unary_heaviside<'a>(
1970        &'a self,
1971        _a: &'a GpuTensorHandle,
1972    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1973        unsupported_future("unary_heaviside not supported by provider")
1974    }
1975    fn unary_exp<'a>(
1976        &'a self,
1977        _a: &'a GpuTensorHandle,
1978    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1979        unsupported_future("unary_exp not supported by provider")
1980    }
1981    fn unary_expm1<'a>(
1982        &'a self,
1983        _a: &'a GpuTensorHandle,
1984    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1985        unsupported_future("unary_expm1 not supported by provider")
1986    }
1987    fn unary_log<'a>(
1988        &'a self,
1989        _a: &'a GpuTensorHandle,
1990    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1991        unsupported_future("unary_log not supported by provider")
1992    }
1993    fn unary_log2<'a>(
1994        &'a self,
1995        _a: &'a GpuTensorHandle,
1996    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1997        unsupported_future("unary_log2 not supported by provider")
1998    }
1999    fn unary_log10<'a>(
2000        &'a self,
2001        _a: &'a GpuTensorHandle,
2002    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2003        unsupported_future("unary_log10 not supported by provider")
2004    }
2005    fn unary_log1p<'a>(
2006        &'a self,
2007        _a: &'a GpuTensorHandle,
2008    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2009        unsupported_future("unary_log1p not supported by provider")
2010    }
2011    fn unary_sqrt<'a>(
2012        &'a self,
2013        _a: &'a GpuTensorHandle,
2014    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2015        unsupported_future("unary_sqrt not supported by provider")
2016    }
2017    fn unary_double<'a>(
2018        &'a self,
2019        _a: &'a GpuTensorHandle,
2020    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2021        unsupported_future("unary_double not supported by provider")
2022    }
2023    fn unary_single<'a>(
2024        &'a self,
2025        _a: &'a GpuTensorHandle,
2026    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2027        unsupported_future("unary_single not supported by provider")
2028    }
2029    fn unary_pow2<'a>(
2030        &'a self,
2031        _a: &'a GpuTensorHandle,
2032    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2033        unsupported_future("unary_pow2 not supported by provider")
2034    }
2035    fn unary_nextpow2<'a>(
2036        &'a self,
2037        _a: &'a GpuTensorHandle,
2038    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2039        unsupported_future("unary_nextpow2 not supported by provider")
2040    }
2041    fn pow2_scale(
2042        &self,
2043        _mantissa: &GpuTensorHandle,
2044        _exponent: &GpuTensorHandle,
2045    ) -> anyhow::Result<GpuTensorHandle> {
2046        Err(anyhow::anyhow!("pow2_scale not supported by provider"))
2047    }
2048    // Left-scalar operations (broadcast with scalar on the left)
2049    fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2050        Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
2051    }
2052    fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2053        Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
2054    }
2055    // Scalar operations: apply op with scalar right-hand side (broadcast over a)
2056    fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2057        Err(anyhow::anyhow!("scalar_add not supported by provider"))
2058    }
2059    fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2060        Err(anyhow::anyhow!("scalar_sub not supported by provider"))
2061    }
2062    fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2063        Err(anyhow::anyhow!("scalar_mul not supported by provider"))
2064    }
2065    fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2066        Err(anyhow::anyhow!("scalar_max not supported by provider"))
2067    }
2068    fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2069        Err(anyhow::anyhow!("scalar_min not supported by provider"))
2070    }
2071    fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2072        Err(anyhow::anyhow!("scalar_div not supported by provider"))
2073    }
2074    fn sort_dim<'a>(
2075        &'a self,
2076        _a: &'a GpuTensorHandle,
2077        _dim: usize,
2078        _order: SortOrder,
2079        _comparison: SortComparison,
2080    ) -> AccelProviderFuture<'a, SortResult> {
2081        unsupported_future("sort_dim not supported by provider")
2082    }
2083    fn sort_rows<'a>(
2084        &'a self,
2085        _a: &'a GpuTensorHandle,
2086        _columns: &'a [SortRowsColumnSpec],
2087        _comparison: SortComparison,
2088    ) -> AccelProviderFuture<'a, SortResult> {
2089        unsupported_future("sort_rows not supported by provider")
2090    }
2091    fn matmul<'a>(
2092        &'a self,
2093        _a: &'a GpuTensorHandle,
2094        _b: &'a GpuTensorHandle,
2095    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2096        unsupported_future("matmul not supported by provider")
2097    }
2098
2099    fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2100        Err(anyhow::anyhow!("syrk not supported by provider"))
2101    }
2102    fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
2103        Err(anyhow::anyhow!("pagefun not supported by provider"))
2104    }
2105
2106    /// Optional: matrix multiplication with an epilogue applied before store.
2107    ///
2108    /// The default implementation falls back to `matmul` when the epilogue is effectively a no-op
2109    /// (alpha=1, beta=0, no row/col scales), and otherwise returns `Err`.
2110    fn matmul_epilogue<'a>(
2111        &'a self,
2112        a: &'a GpuTensorHandle,
2113        b: &'a GpuTensorHandle,
2114        epilogue: &'a MatmulEpilogue,
2115    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2116        Box::pin(async move {
2117            if epilogue.is_noop() {
2118                return self.matmul(a, b).await;
2119            }
2120            Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
2121        })
2122    }
2123    fn image_normalize<'a>(
2124        &'a self,
2125        _input: &'a GpuTensorHandle,
2126        _desc: &'a ImageNormalizeDescriptor,
2127    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2128        unsupported_future("image_normalize fusion not supported by provider")
2129    }
2130    fn matmul_power_step<'a>(
2131        &'a self,
2132        _lhs: &'a GpuTensorHandle,
2133        _rhs: &'a GpuTensorHandle,
2134        _epilogue: &'a PowerStepEpilogue,
2135    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2136        unsupported_future("matmul_power_step normalization not supported by provider")
2137    }
2138    fn linsolve<'a>(
2139        &'a self,
2140        _lhs: &'a GpuTensorHandle,
2141        _rhs: &'a GpuTensorHandle,
2142        _options: &'a ProviderLinsolveOptions,
2143    ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
2144        unsupported_future("linsolve not supported by provider")
2145    }
2146    fn inv<'a>(
2147        &'a self,
2148        _matrix: &'a GpuTensorHandle,
2149        _options: ProviderInvOptions,
2150    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2151        unsupported_future("inv not supported by provider")
2152    }
2153    fn pinv<'a>(
2154        &'a self,
2155        _matrix: &'a GpuTensorHandle,
2156        _options: ProviderPinvOptions,
2157    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2158        unsupported_future("pinv not supported by provider")
2159    }
2160    fn cond<'a>(
2161        &'a self,
2162        _matrix: &'a GpuTensorHandle,
2163        _norm: ProviderCondNorm,
2164    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2165        Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
2166    }
2167    fn norm<'a>(
2168        &'a self,
2169        _tensor: &'a GpuTensorHandle,
2170        _order: ProviderNormOrder,
2171    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2172        Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
2173    }
2174    fn rank<'a>(
2175        &'a self,
2176        _matrix: &'a GpuTensorHandle,
2177        _tolerance: Option<f64>,
2178    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2179        Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
2180    }
2181    fn rcond<'a>(
2182        &'a self,
2183        _matrix: &'a GpuTensorHandle,
2184    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2185        Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
2186    }
2187    fn mldivide<'a>(
2188        &'a self,
2189        _lhs: &'a GpuTensorHandle,
2190        _rhs: &'a GpuTensorHandle,
2191    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2192        Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
2193    }
2194    fn mrdivide<'a>(
2195        &'a self,
2196        _lhs: &'a GpuTensorHandle,
2197        _rhs: &'a GpuTensorHandle,
2198    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2199        Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
2200    }
2201    fn eig<'a>(
2202        &'a self,
2203        _a: &'a GpuTensorHandle,
2204        _compute_left: bool,
2205    ) -> AccelProviderFuture<'a, ProviderEigResult> {
2206        Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
2207    }
2208    fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
2209        Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
2210    }
2211
2212    fn chol<'a>(
2213        &'a self,
2214        _a: &'a GpuTensorHandle,
2215        _lower: bool,
2216    ) -> AccelProviderFuture<'a, ProviderCholResult> {
2217        Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
2218    }
2219    fn qr<'a>(
2220        &'a self,
2221        _a: &'a GpuTensorHandle,
2222        _options: ProviderQrOptions,
2223    ) -> AccelProviderFuture<'a, ProviderQrResult> {
2224        Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
2225    }
2226    fn take_matmul_sources(
2227        &self,
2228        _product: &GpuTensorHandle,
2229    ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
2230        None
2231    }
2232    fn qr_power_iter<'a>(
2233        &'a self,
2234        product: &'a GpuTensorHandle,
2235        _product_lhs: Option<&'a GpuTensorHandle>,
2236        q_handle: &'a GpuTensorHandle,
2237        options: &'a ProviderQrOptions,
2238    ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
2239        let _ = (product, q_handle, options);
2240        Box::pin(async move { Ok(None) })
2241    }
2242    fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2243        Err(anyhow::anyhow!("transpose not supported by provider"))
2244    }
2245    fn conv1d(
2246        &self,
2247        _signal: &GpuTensorHandle,
2248        _kernel: &GpuTensorHandle,
2249        _options: ProviderConv1dOptions,
2250    ) -> anyhow::Result<GpuTensorHandle> {
2251        Err(anyhow::anyhow!("conv1d not supported by provider"))
2252    }
2253    fn conv2d(
2254        &self,
2255        _signal: &GpuTensorHandle,
2256        _kernel: &GpuTensorHandle,
2257        _mode: ProviderConvMode,
2258    ) -> anyhow::Result<GpuTensorHandle> {
2259        Err(anyhow::anyhow!("conv2d not supported by provider"))
2260    }
2261    fn iir_filter<'a>(
2262        &'a self,
2263        _b: &'a GpuTensorHandle,
2264        _a: &'a GpuTensorHandle,
2265        _x: &'a GpuTensorHandle,
2266        _options: ProviderIirFilterOptions,
2267    ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
2268        Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
2269    }
2270    fn uniform_spectral_estimate<'a>(
2271        &'a self,
2272        _request: &'a ProviderSpectralRequest<'a>,
2273    ) -> AccelProviderFuture<'a, ProviderSpectralResult> {
2274        unsupported_future("uniform_spectral_estimate not supported by provider")
2275    }
2276    fn signal_envelope<'a>(
2277        &'a self,
2278        _request: &'a ProviderEnvelopeRequest<'a>,
2279    ) -> AccelProviderFuture<'a, ProviderEnvelopeResult> {
2280        unsupported_future("signal_envelope not supported by provider")
2281    }
2282    fn signal_hilbert<'a>(
2283        &'a self,
2284        _request: &'a ProviderHilbertRequest<'a>,
2285    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2286        unsupported_future("signal_hilbert not supported by provider")
2287    }
2288    /// Reorder tensor dimensions according to `order`, expressed as zero-based indices.
2289    fn permute(
2290        &self,
2291        _handle: &GpuTensorHandle,
2292        _order: &[usize],
2293    ) -> anyhow::Result<GpuTensorHandle> {
2294        Err(anyhow::anyhow!("permute not supported by provider"))
2295    }
2296    fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
2297        Err(anyhow::anyhow!("flip not supported by provider"))
2298    }
2299    fn circshift(
2300        &self,
2301        _handle: &GpuTensorHandle,
2302        _shifts: &[isize],
2303    ) -> anyhow::Result<GpuTensorHandle> {
2304        Err(anyhow::anyhow!("circshift not supported by provider"))
2305    }
2306    fn diff_dim(
2307        &self,
2308        _handle: &GpuTensorHandle,
2309        _order: usize,
2310        _dim: usize,
2311    ) -> anyhow::Result<GpuTensorHandle> {
2312        Err(anyhow::anyhow!("diff_dim not supported by provider"))
2313    }
2314    fn gradient_dim(
2315        &self,
2316        _handle: &GpuTensorHandle,
2317        _dim: usize,
2318        _spacing: f64,
2319    ) -> anyhow::Result<GpuTensorHandle> {
2320        Err(anyhow::anyhow!("gradient_dim not supported by provider"))
2321    }
2322    /// Perform an in-place FFT along a zero-based dimension, optionally padding/truncating to `len`.
2323    fn fft_dim<'a>(
2324        &'a self,
2325        _handle: &'a GpuTensorHandle,
2326        _len: Option<usize>,
2327        _dim: usize,
2328    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2329        unsupported_future("fft_dim not supported by provider")
2330    }
2331    fn ifft_dim<'a>(
2332        &'a self,
2333        _handle: &'a GpuTensorHandle,
2334        _len: Option<usize>,
2335        _dim: usize,
2336    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2337        unsupported_future("ifft_dim not supported by provider")
2338    }
2339    fn fft_extract_real<'a>(
2340        &'a self,
2341        _handle: &'a GpuTensorHandle,
2342    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2343        unsupported_future("fft_extract_real not supported by provider")
2344    }
2345    fn unique<'a>(
2346        &'a self,
2347        _handle: &'a GpuTensorHandle,
2348        _options: &'a UniqueOptions,
2349    ) -> AccelProviderFuture<'a, UniqueResult> {
2350        Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
2351    }
2352    fn union<'a>(
2353        &'a self,
2354        _a: &'a GpuTensorHandle,
2355        _b: &'a GpuTensorHandle,
2356        _options: &'a UnionOptions,
2357    ) -> AccelProviderFuture<'a, UnionResult> {
2358        Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
2359    }
2360    fn setdiff<'a>(
2361        &'a self,
2362        _a: &'a GpuTensorHandle,
2363        _b: &'a GpuTensorHandle,
2364        _options: &'a SetdiffOptions,
2365    ) -> AccelProviderFuture<'a, SetdiffResult> {
2366        Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
2367    }
2368    fn ismember<'a>(
2369        &'a self,
2370        _a: &'a GpuTensorHandle,
2371        _b: &'a GpuTensorHandle,
2372        _options: &'a IsMemberOptions,
2373    ) -> AccelProviderFuture<'a, IsMemberResult> {
2374        Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
2375    }
2376    fn reshape(
2377        &self,
2378        handle: &GpuTensorHandle,
2379        new_shape: &[usize],
2380    ) -> anyhow::Result<GpuTensorHandle> {
2381        let mut updated = handle.clone();
2382        updated.shape = new_shape.to_vec();
2383        Ok(updated)
2384    }
2385    /// Concatenate the provided tensors along the 1-based dimension `dim`.
2386    fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
2387        Err(anyhow::anyhow!("cat not supported by provider"))
2388    }
2389    fn repmat(
2390        &self,
2391        _handle: &GpuTensorHandle,
2392        _reps: &[usize],
2393    ) -> anyhow::Result<GpuTensorHandle> {
2394        Err(anyhow::anyhow!("repmat not supported by provider"))
2395    }
2396    /// Compute the Kronecker product of two tensors, matching MATLAB semantics.
2397    fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2398        Err(anyhow::anyhow!("kron not supported by provider"))
2399    }
2400    /// Compute the cross product of 3-element vectors along a matching dimension.
2401    fn cross(
2402        &self,
2403        _lhs: &GpuTensorHandle,
2404        _rhs: &GpuTensorHandle,
2405        _dim: Option<usize>,
2406    ) -> anyhow::Result<GpuTensorHandle> {
2407        Err(anyhow::anyhow!("cross not supported by provider"))
2408    }
2409    fn reduce_sum<'a>(
2410        &'a self,
2411        _a: &'a GpuTensorHandle,
2412    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2413        unsupported_future("reduce_sum not supported by provider")
2414    }
2415    fn reduce_sum_dim<'a>(
2416        &'a self,
2417        _a: &'a GpuTensorHandle,
2418        _dim: usize,
2419    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2420        unsupported_future("reduce_sum_dim not supported by provider")
2421    }
2422    fn dot<'a>(
2423        &'a self,
2424        _lhs: &'a GpuTensorHandle,
2425        _rhs: &'a GpuTensorHandle,
2426        _dim: Option<usize>,
2427    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2428        unsupported_future("dot not supported by provider")
2429    }
2430    fn reduce_nnz<'a>(
2431        &'a self,
2432        _a: &'a GpuTensorHandle,
2433    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2434        unsupported_future("reduce_nnz not supported by provider")
2435    }
2436    fn reduce_nnz_dim<'a>(
2437        &'a self,
2438        _a: &'a GpuTensorHandle,
2439        _dim: usize,
2440    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2441        unsupported_future("reduce_nnz_dim not supported by provider")
2442    }
2443    fn reduce_prod<'a>(
2444        &'a self,
2445        _a: &'a GpuTensorHandle,
2446    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2447        unsupported_future("reduce_prod not supported by provider")
2448    }
2449    fn reduce_prod_dim<'a>(
2450        &'a self,
2451        _a: &'a GpuTensorHandle,
2452        _dim: usize,
2453    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2454        unsupported_future("reduce_prod_dim not supported by provider")
2455    }
2456    fn reduce_mean<'a>(
2457        &'a self,
2458        _a: &'a GpuTensorHandle,
2459    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2460        unsupported_future("reduce_mean not supported by provider")
2461    }
2462    /// Reduce mean across multiple zero-based dimensions in one device pass.
2463    fn reduce_mean_nd<'a>(
2464        &'a self,
2465        _a: &'a GpuTensorHandle,
2466        _dims_zero_based: &'a [usize],
2467    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2468        unsupported_future("reduce_mean_nd not supported by provider")
2469    }
2470    /// Reduce moments across multiple zero-based dimensions in one device pass.
2471    /// Returns mean (E[x]) and mean of squares (E[x^2]).
2472    fn reduce_moments_nd<'a>(
2473        &'a self,
2474        _a: &'a GpuTensorHandle,
2475        _dims_zero_based: &'a [usize],
2476    ) -> AccelProviderFuture<'a, ProviderMoments2> {
2477        unsupported_future("reduce_moments_nd not supported by provider")
2478    }
2479    fn reduce_mean_dim<'a>(
2480        &'a self,
2481        _a: &'a GpuTensorHandle,
2482        _dim: usize,
2483    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2484        unsupported_future("reduce_mean_dim not supported by provider")
2485    }
2486    fn reduce_std<'a>(
2487        &'a self,
2488        _a: &'a GpuTensorHandle,
2489        _normalization: ProviderStdNormalization,
2490        _nan_mode: ProviderNanMode,
2491    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2492        unsupported_future("reduce_std not supported by provider")
2493    }
2494    fn reduce_std_dim<'a>(
2495        &'a self,
2496        _a: &'a GpuTensorHandle,
2497        _dim: usize,
2498        _normalization: ProviderStdNormalization,
2499        _nan_mode: ProviderNanMode,
2500    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2501        unsupported_future("reduce_std_dim not supported by provider")
2502    }
2503    fn reduce_any<'a>(
2504        &'a self,
2505        _a: &'a GpuTensorHandle,
2506        _omit_nan: bool,
2507    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2508        unsupported_future("reduce_any not supported by provider")
2509    }
2510    fn reduce_any_dim<'a>(
2511        &'a self,
2512        _a: &'a GpuTensorHandle,
2513        _dim: usize,
2514        _omit_nan: bool,
2515    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2516        unsupported_future("reduce_any_dim not supported by provider")
2517    }
2518    fn reduce_all<'a>(
2519        &'a self,
2520        _a: &'a GpuTensorHandle,
2521        _omit_nan: bool,
2522    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2523        unsupported_future("reduce_all not supported by provider")
2524    }
2525    fn reduce_all_dim<'a>(
2526        &'a self,
2527        _a: &'a GpuTensorHandle,
2528        _dim: usize,
2529        _omit_nan: bool,
2530    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2531        unsupported_future("reduce_all_dim not supported by provider")
2532    }
2533    fn reduce_median<'a>(
2534        &'a self,
2535        _a: &'a GpuTensorHandle,
2536    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2537        unsupported_future("reduce_median not supported by provider")
2538    }
2539    fn reduce_median_dim<'a>(
2540        &'a self,
2541        _a: &'a GpuTensorHandle,
2542        _dim: usize,
2543    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2544        unsupported_future("reduce_median_dim not supported by provider")
2545    }
2546    fn reduce_min<'a>(
2547        &'a self,
2548        _a: &'a GpuTensorHandle,
2549    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2550        unsupported_future("reduce_min not supported by provider")
2551    }
2552    fn reduce_min_dim<'a>(
2553        &'a self,
2554        _a: &'a GpuTensorHandle,
2555        _dim: usize,
2556    ) -> AccelProviderFuture<'a, ReduceDimResult> {
2557        unsupported_future("reduce_min_dim not supported by provider")
2558    }
2559    fn reduce_max<'a>(
2560        &'a self,
2561        _a: &'a GpuTensorHandle,
2562    ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2563        unsupported_future("reduce_max not supported by provider")
2564    }
2565    fn reduce_max_dim<'a>(
2566        &'a self,
2567        _a: &'a GpuTensorHandle,
2568        _dim: usize,
2569    ) -> AccelProviderFuture<'a, ReduceDimResult> {
2570        unsupported_future("reduce_max_dim not supported by provider")
2571    }
2572    fn cumsum_scan(
2573        &self,
2574        _input: &GpuTensorHandle,
2575        _dim: usize,
2576        _direction: ProviderScanDirection,
2577        _nan_mode: ProviderNanMode,
2578    ) -> anyhow::Result<GpuTensorHandle> {
2579        Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2580    }
2581    fn cumprod_scan(
2582        &self,
2583        _input: &GpuTensorHandle,
2584        _dim: usize,
2585        _direction: ProviderScanDirection,
2586        _nan_mode: ProviderNanMode,
2587    ) -> anyhow::Result<GpuTensorHandle> {
2588        Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2589    }
2590    fn cummin_scan(
2591        &self,
2592        _input: &GpuTensorHandle,
2593        _dim: usize,
2594        _direction: ProviderScanDirection,
2595        _nan_mode: ProviderNanMode,
2596    ) -> anyhow::Result<ProviderCumminResult> {
2597        Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2598    }
2599    fn cummax_scan(
2600        &self,
2601        _input: &GpuTensorHandle,
2602        _dim: usize,
2603        _direction: ProviderScanDirection,
2604        _nan_mode: ProviderNanMode,
2605    ) -> anyhow::Result<ProviderCummaxResult> {
2606        Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2607    }
2608
2609    fn find(
2610        &self,
2611        _a: &GpuTensorHandle,
2612        _limit: Option<usize>,
2613        _direction: FindDirection,
2614    ) -> anyhow::Result<ProviderFindResult> {
2615        Err(anyhow::anyhow!("find not supported by provider"))
2616    }
2617
2618    fn fused_elementwise(
2619        &self,
2620        _shader: &str,
2621        _inputs: &[GpuTensorHandle],
2622        _output_shape: &[usize],
2623        _len: usize,
2624    ) -> anyhow::Result<GpuTensorHandle> {
2625        Err(anyhow::anyhow!(
2626            "fused_elementwise not supported by provider"
2627        ))
2628    }
2629
2630    /// Execute a single fused elementwise kernel that writes `num_outputs` output buffers in one
2631    /// dispatch. The shader is expected to declare `output0`, `output1`, … `output{N-1}` storage
2632    /// bindings (at binding indices `inputs.len()` through `inputs.len() + num_outputs - 1`) and a
2633    /// uniform `params` binding at `inputs.len() + num_outputs`.
2634    ///
2635    /// Providers that do not override this method fall back to calling `fused_elementwise` once
2636    /// per output, which preserves correctness at the cost of the O(N²) dispatch overhead this
2637    /// method is designed to eliminate.
2638    fn fused_elementwise_multi(
2639        &self,
2640        _shader: &str,
2641        _inputs: &[GpuTensorHandle],
2642        _output_shape: &[usize],
2643        _len: usize,
2644        _num_outputs: usize,
2645    ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2646        Err(anyhow::anyhow!(
2647            "fused_elementwise_multi not supported by provider"
2648        ))
2649    }
2650
2651    /// Build a numeric tensor where NaNs in `a` are replaced with 0.0 (device side).
2652    fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2653        Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2654    }
2655
2656    /// Build a numeric mask tensor with 1.0 where value is not NaN and 0.0 where value is NaN.
2657    fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2658        Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2659    }
2660
2661    /// Generic fused reduction entrypoint.
2662    ///
2663    /// The shader is expected to implement a column-major reduction across `reduce_len` with
2664    /// `num_slices` independent slices (e.g., columns). Providers should create a uniform buffer
2665    /// compatible with the expected `Params/MParams` struct in the shader and dispatch
2666    /// `num_slices` workgroups with `workgroup_size` threads, or an equivalent strategy.
2667    #[allow(clippy::too_many_arguments)]
2668    fn fused_reduction(
2669        &self,
2670        _shader: &str,
2671        _inputs: &[GpuTensorHandle],
2672        _output_shape: &[usize],
2673        _reduce_len: usize,
2674        _num_slices: usize,
2675        _workgroup_size: u32,
2676        _flavor: ReductionFlavor,
2677    ) -> anyhow::Result<GpuTensorHandle> {
2678        Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2679    }
2680
2681    /// Optionally pre-compile commonly used pipelines to amortize first-dispatch costs.
2682    fn warmup(&self) {}
2683
2684    /// Returns (cache_hits, cache_misses) for fused pipeline cache, if supported.
2685    fn fused_cache_counters(&self) -> (u64, u64) {
2686        (0, 0)
2687    }
2688
2689    /// Returns the duration of the last provider warmup in milliseconds, if known.
2690    fn last_warmup_millis(&self) -> Option<u64> {
2691        None
2692    }
2693
2694    /// Returns a snapshot of provider telemetry counters if supported.
2695    fn telemetry_snapshot(&self) -> ProviderTelemetry {
2696        let (hits, misses) = self.fused_cache_counters();
2697        ProviderTelemetry {
2698            fused_elementwise: ProviderDispatchStats::default(),
2699            fused_reduction: ProviderDispatchStats::default(),
2700            matmul: ProviderDispatchStats::default(),
2701            linsolve: ProviderDispatchStats::default(),
2702            mldivide: ProviderDispatchStats::default(),
2703            mrdivide: ProviderDispatchStats::default(),
2704            upload_bytes: 0,
2705            download_bytes: 0,
2706            solve_fallbacks: Vec::new(),
2707            fusion_cache_hits: hits,
2708            fusion_cache_misses: misses,
2709            bind_group_cache_hits: 0,
2710            bind_group_cache_misses: 0,
2711            bind_group_cache_by_layout: None,
2712            kernel_launches: Vec::new(),
2713        }
2714    }
2715
2716    /// Reset all telemetry counters maintained by the provider, if supported.
2717    fn reset_telemetry(&self) {}
2718
2719    /// Default reduction workgroup size the provider prefers.
2720    fn default_reduction_workgroup_size(&self) -> u32 {
2721        256
2722    }
2723
2724    /// Threshold above which provider will prefer two-pass reduction.
2725    fn two_pass_threshold(&self) -> usize {
2726        1024
2727    }
2728
2729    /// Current two-pass mode preference (auto/forced on/off).
2730    fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2731        ReductionTwoPassMode::Auto
2732    }
2733
2734    /// Fast-path: write a GPU column in a matrix from a GPU vector, returning a new handle.
2735    /// Expected: `values.shape == [rows, 1]` (or `[rows]`) and `col_index < cols`.
2736    fn scatter_column(
2737        &self,
2738        _matrix: &GpuTensorHandle,
2739        _col_index: usize,
2740        _values: &GpuTensorHandle,
2741    ) -> anyhow::Result<GpuTensorHandle> {
2742        Err(anyhow::anyhow!("scatter_column not supported by provider"))
2743    }
2744
2745    /// Fast-path: write a GPU row in a matrix from a GPU vector, returning a new handle.
2746    /// Expected: `values.shape == [1, cols]` (or `[cols]`) and `row_index < rows`.
2747    fn scatter_row(
2748        &self,
2749        _matrix: &GpuTensorHandle,
2750        _row_index: usize,
2751        _values: &GpuTensorHandle,
2752    ) -> anyhow::Result<GpuTensorHandle> {
2753        Err(anyhow::anyhow!("scatter_row not supported by provider"))
2754    }
2755
2756    fn sub2ind(
2757        &self,
2758        _dims: &[usize],
2759        _strides: &[usize],
2760        _inputs: &[&GpuTensorHandle],
2761        _scalar_mask: &[bool],
2762        _len: usize,
2763        _output_shape: &[usize],
2764    ) -> anyhow::Result<GpuTensorHandle> {
2765        Err(anyhow::anyhow!("sub2ind not supported by provider"))
2766    }
2767
2768    /// Returns true if the provider offers a device-side `ind2sub` implementation.
2769    fn supports_ind2sub(&self) -> bool {
2770        false
2771    }
2772
2773    /// Convert linear indices into per-dimension subscripts on the device.
2774    fn ind2sub(
2775        &self,
2776        _dims: &[usize],
2777        _strides: &[usize],
2778        _indices: &GpuTensorHandle,
2779        _total: usize,
2780        _len: usize,
2781        _output_shape: &[usize],
2782    ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2783        Err(anyhow::anyhow!("ind2sub not supported by provider"))
2784    }
2785
2786    /// Determine if a matrix is symmetric (or skew-symmetric) without gathering it to the host.
2787    fn issymmetric(
2788        &self,
2789        _matrix: &GpuTensorHandle,
2790        _kind: ProviderSymmetryKind,
2791        _tolerance: f64,
2792    ) -> anyhow::Result<bool> {
2793        Err(anyhow::anyhow!(
2794            "issymmetric predicate not supported by provider"
2795        ))
2796    }
2797
2798    /// Determine if a matrix is Hermitian (or skew-Hermitian) without gathering it to the host.
2799    fn ishermitian<'a>(
2800        &'a self,
2801        _matrix: &'a GpuTensorHandle,
2802        _kind: ProviderHermitianKind,
2803        _tolerance: f64,
2804    ) -> AccelProviderFuture<'a, bool> {
2805        Box::pin(async move {
2806            Err(anyhow::anyhow!(
2807                "ishermitian predicate not supported by provider"
2808            ))
2809        })
2810    }
2811
2812    /// Inspect the bandwidth of a matrix without gathering it back to the host.
2813    fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2814        Err(anyhow::anyhow!("bandwidth not supported by provider"))
2815    }
2816
2817    /// Compute the symmetric reverse Cuthill-McKee permutation for the matrix.
2818    ///
2819    /// Implementations may execute on the device or gather to the host. The permutation should be
2820    /// returned as zero-based indices.
2821    fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2822        Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2823    }
2824}
2825
2826static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2827    Lazy::new(|| RwLock::new(None));
2828static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2829    Lazy::new(|| RwLock::new(HashMap::new()));
2830static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2831
2832#[cfg(not(target_arch = "wasm32"))]
2833thread_local! {
2834    static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2835}
2836
2837#[cfg(target_arch = "wasm32")]
2838static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2839    Lazy::new(|| Mutex::new(None));
2840
2841#[cfg(not(target_arch = "wasm32"))]
2842fn replace_thread_provider(
2843    provider: Option<&'static dyn AccelProvider>,
2844) -> Option<&'static dyn AccelProvider> {
2845    THREAD_PROVIDER.with(|cell| {
2846        let prev = cell.get();
2847        cell.set(provider);
2848        prev
2849    })
2850}
2851
2852#[cfg(target_arch = "wasm32")]
2853fn replace_thread_provider(
2854    provider: Option<&'static dyn AccelProvider>,
2855) -> Option<&'static dyn AccelProvider> {
2856    let mut slot = WASM_THREAD_PROVIDER
2857        .lock()
2858        .expect("wasm provider mutex poisoned");
2859    let prev = *slot;
2860    *slot = provider;
2861    prev
2862}
2863
2864#[cfg(not(target_arch = "wasm32"))]
2865fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2866    THREAD_PROVIDER.with(|cell| cell.get())
2867}
2868
2869#[cfg(target_arch = "wasm32")]
2870fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2871    WASM_THREAD_PROVIDER
2872        .lock()
2873        .expect("wasm provider mutex poisoned")
2874        .as_ref()
2875        .copied()
2876}
2877
2878/// Register a global acceleration provider.
2879///
2880/// # Safety
2881/// - The caller must guarantee that `p` is valid for the entire program lifetime
2882///   (e.g., a `'static` singleton), as the runtime stores a raw reference globally.
2883/// - Concurrent callers must ensure registration happens once or is properly
2884///   synchronized; this function does not enforce thread-safety for re-registration.
2885pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2886    if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2887        *guard = Some(p);
2888    }
2889    register_provider_for_device(p.device_id(), p);
2890}
2891
2892unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2893    if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2894        guard.insert(device_id, provider);
2895    }
2896}
2897
2898pub fn provider() -> Option<&'static dyn AccelProvider> {
2899    if let Some(p) = current_thread_provider() {
2900        return Some(p);
2901    }
2902    GLOBAL_PROVIDER
2903        .read()
2904        .ok()
2905        .and_then(|guard| guard.as_ref().copied())
2906}
2907
2908/// Clear the globally registered provider. Intended for tests to ensure deterministic behaviour.
2909pub fn clear_provider() {
2910    replace_thread_provider(None);
2911    if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2912        *guard = None;
2913    }
2914    if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2915        map.clear();
2916    }
2917}
2918
2919pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2920    if let Some(registered) = PROVIDER_REGISTRY
2921        .read()
2922        .ok()
2923        .and_then(|guard| guard.get(&device_id).copied())
2924    {
2925        return Some(registered);
2926    }
2927    if let Some(thread_provider) = current_thread_provider() {
2928        if thread_provider.device_id() == device_id {
2929            return Some(thread_provider);
2930        }
2931    }
2932    // Preserve legacy behavior: when no explicit per-device registration exists,
2933    // fall back to the globally active provider regardless of handle device id.
2934    GLOBAL_PROVIDER
2935        .read()
2936        .ok()
2937        .and_then(|guard| guard.as_ref().copied())
2938}
2939
2940pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2941    provider_for_device(handle.device_id)
2942}
2943
2944pub fn spawn_handle_concurrency_for(handle: &GpuTensorHandle) -> Option<SpawnHandleConcurrency> {
2945    provider_for_handle(handle).map(AccelProvider::spawn_handle_concurrency)
2946}
2947
2948pub fn next_device_id() -> u32 {
2949    DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2950}
2951
2952pub struct ThreadProviderGuard {
2953    prev: Option<&'static dyn AccelProvider>,
2954}
2955
2956impl ThreadProviderGuard {
2957    pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2958        let prev = replace_thread_provider(provider);
2959        ThreadProviderGuard { prev }
2960    }
2961}
2962
2963impl Drop for ThreadProviderGuard {
2964    fn drop(&mut self) {
2965        let prev = self.prev.take();
2966        replace_thread_provider(prev);
2967    }
2968}
2969
2970pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2971    replace_thread_provider(provider);
2972}
2973
2974/// Convenience: perform elementwise add via provider if possible; otherwise return None
2975pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2976    if let Some(p) = provider() {
2977        if let Ok(h) = p.elem_add(a, b).await {
2978            return Some(h);
2979        }
2980    }
2981    None
2982}
2983
2984/// Convenience: perform elementwise hypot via provider if possible; otherwise return None
2985pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2986    if let Some(p) = provider() {
2987        if let Ok(h) = p.elem_hypot(a, b).await {
2988            return Some(h);
2989        }
2990    }
2991    None
2992}
2993
2994/// Convenience: perform elementwise max via provider if possible; otherwise return None
2995pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2996    if let Some(p) = provider() {
2997        if let Ok(h) = p.elem_max(a, b).await {
2998            return Some(h);
2999        }
3000    }
3001    None
3002}
3003
3004/// Convenience: perform elementwise min via provider if possible; otherwise return None
3005pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
3006    if let Some(p) = provider() {
3007        if let Ok(h) = p.elem_min(a, b).await {
3008            return Some(h);
3009        }
3010    }
3011    None
3012}
3013
3014/// Convenience: perform elementwise atan2 via provider if possible; otherwise return None
3015pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
3016    if let Some(p) = provider() {
3017        if let Ok(h) = p.elem_atan2(y, x).await {
3018            return Some(h);
3019        }
3020    }
3021    None
3022}
3023
3024// Minimal host tensor views to avoid depending on runmat-builtins and cycles
3025#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3026pub struct HostTensorOwned {
3027    pub data: Vec<f64>,
3028    pub shape: Vec<usize>,
3029    pub storage: GpuTensorStorage,
3030}
3031
3032#[derive(Debug)]
3033pub struct HostTensorView<'a> {
3034    pub data: &'a [f64],
3035    pub shape: &'a [usize],
3036}
3037
3038/// Lightweight 1-D axis view used by provider meshgrid hooks.
3039#[derive(Debug)]
3040pub struct MeshgridAxisView<'a> {
3041    pub data: &'a [f64],
3042}
3043
3044/// Provider-side meshgrid result containing coordinate tensor handles.
3045#[derive(Debug, Clone)]
3046pub struct ProviderMeshgridResult {
3047    pub outputs: Vec<GpuTensorHandle>,
3048}
3049
3050/// Descriptor for GEMM epilogues applied to `C = A * B` before storing to `C`.
3051///
3052/// Supported operations:
3053/// - Scale by `alpha` and add scalar `beta`.
3054/// - Multiply output by per-row and/or per-column scale vectors (broadcasted).
3055#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
3056pub enum ScaleOp {
3057    Multiply,
3058    Divide,
3059}
3060
3061#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3062pub struct MatmulEpilogue {
3063    /// Scalar multiply applied to each output element.
3064    pub alpha: f64,
3065    /// Scalar add applied to each output element after scaling.
3066    pub beta: f64,
3067    /// Optional per-row scale (length m). When present, output[row, col] *= row_scale[row].
3068    pub row_scale: Option<GpuTensorHandle>,
3069    /// Optional per-column scale (length n). When present, output[row, col] *= col_scale[col].
3070    pub col_scale: Option<GpuTensorHandle>,
3071    /// Row scale operation (multiply or divide). Ignored when `row_scale` is None.
3072    pub row_op: ScaleOp,
3073    /// Column scale operation (multiply or divide). Ignored when `col_scale` is None.
3074    pub col_op: ScaleOp,
3075    /// Optional lower clamp bound applied after scale/bias.
3076    #[serde(default)]
3077    pub clamp_min: Option<f64>,
3078    /// Optional upper clamp bound applied after scale/bias.
3079    #[serde(default)]
3080    pub clamp_max: Option<f64>,
3081    /// Optional power exponent applied after clamp (final operation in the epilogue).
3082    #[serde(default)]
3083    pub pow_exponent: Option<f64>,
3084    /// Optional output buffer for the diagonal of the result (length min(m, n)).
3085    #[serde(default)]
3086    pub diag_output: Option<GpuTensorHandle>,
3087}
3088
3089impl MatmulEpilogue {
3090    pub fn noop() -> Self {
3091        Self {
3092            alpha: 1.0,
3093            beta: 0.0,
3094            row_scale: None,
3095            col_scale: None,
3096            row_op: ScaleOp::Multiply,
3097            col_op: ScaleOp::Multiply,
3098            clamp_min: None,
3099            clamp_max: None,
3100            pow_exponent: None,
3101            diag_output: None,
3102        }
3103    }
3104    pub fn is_noop(&self) -> bool {
3105        self.alpha == 1.0
3106            && self.beta == 0.0
3107            && self.row_scale.is_none()
3108            && self.col_scale.is_none()
3109            && self.clamp_min.is_none()
3110            && self.clamp_max.is_none()
3111            && self.pow_exponent.is_none()
3112            && self.diag_output.is_none()
3113    }
3114}
3115
3116#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
3117pub struct PowerStepEpilogue {
3118    pub epsilon: f64,
3119}
3120
3121impl Default for PowerStepEpilogue {
3122    fn default() -> Self {
3123        Self { epsilon: 0.0 }
3124    }
3125}
3126
3127#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3128pub struct ImageNormalizeDescriptor {
3129    pub batch: usize,
3130    pub height: usize,
3131    pub width: usize,
3132    pub epsilon: f64,
3133    #[serde(default)]
3134    pub gain: Option<f64>,
3135    #[serde(default)]
3136    pub bias: Option<f64>,
3137    #[serde(default)]
3138    pub gamma: Option<f64>,
3139    #[serde(default = "default_image_normalize_clamp_zero")]
3140    pub clamp_zero: bool,
3141}
3142
3143fn default_image_normalize_clamp_zero() -> bool {
3144    true
3145}
3146
3147#[cfg(test)]
3148mod tests {
3149    use super::*;
3150
3151    struct TestProvider {
3152        device_id: u32,
3153        name: &'static str,
3154        spawn_concurrency: SpawnHandleConcurrency,
3155    }
3156
3157    impl AccelProvider for TestProvider {
3158        fn upload(&self, _host: &HostTensorView) -> anyhow::Result<GpuTensorHandle> {
3159            Err(anyhow!("test provider upload should not be called"))
3160        }
3161
3162        fn download<'a>(&'a self, _h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a> {
3163            unsupported_future("test provider download should not be called")
3164        }
3165
3166        fn free(&self, _h: &GpuTensorHandle) -> anyhow::Result<()> {
3167            Err(anyhow!("test provider free should not be called"))
3168        }
3169
3170        fn device_info(&self) -> String {
3171            self.name.to_string()
3172        }
3173
3174        fn device_id(&self) -> u32 {
3175            self.device_id
3176        }
3177
3178        fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
3179            self.spawn_concurrency
3180        }
3181    }
3182
3183    static PROVIDER_TEST_LOCK: Lazy<std::sync::Mutex<()>> = Lazy::new(|| std::sync::Mutex::new(()));
3184    static PROVIDER_A: TestProvider = TestProvider {
3185        device_id: 101,
3186        name: "provider-a",
3187        spawn_concurrency: SpawnHandleConcurrency::ImmutableShare,
3188    };
3189    static PROVIDER_B: TestProvider = TestProvider {
3190        device_id: 202,
3191        name: "provider-b",
3192        spawn_concurrency: SpawnHandleConcurrency::Reject,
3193    };
3194    static PROVIDER_C: TestProvider = TestProvider {
3195        device_id: 303,
3196        name: "provider-c",
3197        spawn_concurrency: SpawnHandleConcurrency::CopyOnWrite,
3198    };
3199
3200    fn register_test_providers() {
3201        clear_provider();
3202        unsafe {
3203            register_provider(&PROVIDER_A);
3204            register_provider(&PROVIDER_B);
3205        }
3206    }
3207
3208    fn test_handle(device_id: u32) -> GpuTensorHandle {
3209        GpuTensorHandle {
3210            shape: vec![1],
3211            device_id,
3212            buffer_id: 42,
3213        }
3214    }
3215
3216    fn spectral_request<'a>(
3217        input: &'a GpuTensorHandle,
3218        frame_mode: ProviderSpectralFrameMode,
3219    ) -> ProviderSpectralRequest<'a> {
3220        static WINDOW: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
3221        ProviderSpectralRequest {
3222            input,
3223            input_len: 16,
3224            input_complex: false,
3225            window: &WINDOW,
3226            nfft: 8,
3227            frame_count: 3,
3228            frame_mode,
3229            range: ProviderSpectralRange::Onesided,
3230            denominator: 1.0,
3231        }
3232    }
3233
3234    #[test]
3235    fn provider_envelope_shape_guard_rejects_equal_len_layout_spoofing() {
3236        assert!(provider_envelope_input_shape_matches(&[2, 3], 2, 3));
3237        assert!(provider_envelope_input_shape_matches(&[6, 1], 6, 1));
3238        assert!(provider_envelope_input_shape_matches(&[1, 6], 6, 1));
3239        assert!(provider_envelope_input_shape_matches(&[6], 6, 1));
3240
3241        assert!(!provider_envelope_input_shape_matches(&[3, 2], 2, 3));
3242        assert!(!provider_envelope_input_shape_matches(&[6], 2, 3));
3243        assert!(!provider_envelope_input_shape_matches(&[2, 1, 3], 2, 3));
3244    }
3245
3246    #[test]
3247    fn provider_for_device_prefers_registered_device_over_thread_provider() {
3248        let _lock = PROVIDER_TEST_LOCK
3249            .lock()
3250            .expect("provider test lock poisoned");
3251        register_test_providers();
3252        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
3253
3254        let provider = provider_for_device(PROVIDER_A.device_id()).expect("provider for device");
3255
3256        assert_eq!(provider.device_info(), PROVIDER_A.name);
3257        clear_provider();
3258    }
3259
3260    #[test]
3261    fn provider_for_handle_uses_handle_device_owner() {
3262        let _lock = PROVIDER_TEST_LOCK
3263            .lock()
3264            .expect("provider test lock poisoned");
3265        register_test_providers();
3266        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
3267
3268        let provider =
3269            provider_for_handle(&test_handle(PROVIDER_A.device_id())).expect("provider for handle");
3270
3271        assert_eq!(provider.device_info(), PROVIDER_A.name);
3272        clear_provider();
3273    }
3274
3275    #[test]
3276    fn spawn_handle_concurrency_for_uses_registered_owner() {
3277        let _lock = PROVIDER_TEST_LOCK
3278            .lock()
3279            .expect("provider test lock poisoned");
3280        register_test_providers();
3281        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
3282
3283        let concurrency = spawn_handle_concurrency_for(&test_handle(PROVIDER_A.device_id()))
3284            .expect("spawn concurrency");
3285
3286        assert_eq!(concurrency, PROVIDER_A.spawn_concurrency);
3287        clear_provider();
3288    }
3289
3290    #[test]
3291    fn provider_keeps_thread_local_active_provider_semantics() {
3292        let _lock = PROVIDER_TEST_LOCK
3293            .lock()
3294            .expect("provider test lock poisoned");
3295        register_test_providers();
3296        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_A));
3297
3298        let active = provider().expect("active provider");
3299
3300        assert_eq!(active.device_info(), PROVIDER_A.name);
3301        clear_provider();
3302    }
3303
3304    #[test]
3305    fn unregistered_thread_provider_only_matches_own_device_before_global_fallback() {
3306        let _lock = PROVIDER_TEST_LOCK
3307            .lock()
3308            .expect("provider test lock poisoned");
3309        clear_provider();
3310        unsafe {
3311            register_provider(&PROVIDER_A);
3312        }
3313        let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_C));
3314
3315        let own_device = provider_for_device(PROVIDER_C.device_id()).expect("own provider");
3316        let fallback = provider_for_device(404).expect("global fallback provider");
3317
3318        assert_eq!(own_device.device_info(), PROVIDER_C.name);
3319        assert_eq!(fallback.device_info(), PROVIDER_A.name);
3320        clear_provider();
3321    }
3322
3323    #[test]
3324    fn uniform_spectral_request_validates_sliding_input_coverage() {
3325        let input = test_handle(PROVIDER_A.device_id());
3326        let mut request = spectral_request(&input, ProviderSpectralFrameMode::Sliding { hop: 6 });
3327        assert!(validate_uniform_spectral_request(&request).is_ok());
3328
3329        request.input_len = 15;
3330        assert!(validate_uniform_spectral_request(&request).is_err());
3331    }
3332
3333    #[test]
3334    fn uniform_spectral_request_rejects_sliding_coverage_overflow() {
3335        let input = test_handle(PROVIDER_A.device_id());
3336        let mut request = spectral_request(&input, ProviderSpectralFrameMode::Sliding { hop: 2 });
3337        request.frame_count = usize::MAX;
3338
3339        assert!(validate_uniform_spectral_request(&request).is_err());
3340    }
3341
3342    #[test]
3343    fn uniform_spectral_request_validates_folded_input_coverage() {
3344        let input = test_handle(PROVIDER_A.device_id());
3345        let mut request = spectral_request(
3346            &input,
3347            ProviderSpectralFrameMode::FoldedColumns { input_rows: 5 },
3348        );
3349        assert!(validate_uniform_spectral_request(&request).is_ok());
3350
3351        request.frame_mode = ProviderSpectralFrameMode::FoldedColumns { input_rows: 0 };
3352        assert!(validate_uniform_spectral_request(&request).is_err());
3353
3354        request.frame_mode = ProviderSpectralFrameMode::FoldedColumns { input_rows: 6 };
3355        assert!(validate_uniform_spectral_request(&request).is_err());
3356    }
3357
3358    #[test]
3359    fn uniform_spectral_request_rejects_folded_coverage_overflow() {
3360        let input = test_handle(PROVIDER_A.device_id());
3361        let request = spectral_request(
3362            &input,
3363            ProviderSpectralFrameMode::FoldedColumns {
3364                input_rows: usize::MAX,
3365            },
3366        );
3367
3368        assert!(validate_uniform_spectral_request(&request).is_err());
3369    }
3370
3371    #[test]
3372    fn image_normalize_descriptor_omitted_clamp_zero_defaults_true() {
3373        let payload = r#"{
3374            "batch": 2,
3375            "height": 4,
3376            "width": 5,
3377            "epsilon": 0.000001
3378        }"#;
3379
3380        let desc: ImageNormalizeDescriptor =
3381            serde_json::from_str(payload).expect("deserialize descriptor");
3382
3383        assert!(
3384            desc.clamp_zero,
3385            "legacy serialized descriptors should default to clamped image normalize"
3386        );
3387    }
3388
3389    #[test]
3390    fn image_normalize_descriptor_explicit_false_preserves_unclamped() {
3391        let payload = r#"{
3392            "batch": 2,
3393            "height": 4,
3394            "width": 5,
3395            "epsilon": 0.000001,
3396            "clamp_zero": false
3397        }"#;
3398
3399        let desc: ImageNormalizeDescriptor =
3400            serde_json::from_str(payload).expect("deserialize descriptor");
3401
3402        assert!(
3403            !desc.clamp_zero,
3404            "explicit clamp_zero=false should preserve unclamped semantics"
3405        );
3406    }
3407}