1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyMarkFn = fn(&GpuTensorHandle);
17type ResidencyClearFn = fn(&GpuTensorHandle);
18type SequenceThresholdFn = fn() -> Option<usize>;
19type WorkgroupSizeHintFn = fn() -> Option<u32>;
20
21static RESIDENCY_MARK: OnceCell<ResidencyMarkFn> = OnceCell::new();
22static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
23static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
24static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
25
26static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
27static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
28 Lazy::new(|| RwLock::new(HashMap::new()));
29static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
30 Lazy::new(|| RwLock::new(HashMap::new()));
31
32static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
33 Lazy::new(|| RwLock::new(HashMap::new()));
34static HANDLE_STORAGES: Lazy<RwLock<HashMap<u64, GpuTensorStorage>>> =
35 Lazy::new(|| RwLock::new(HashMap::new()));
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct TransposeInfo {
39 pub base_rows: usize,
40 pub base_cols: usize,
41}
42
43pub fn register_residency_mark(handler: ResidencyMarkFn) {
46 let _ = RESIDENCY_MARK.set(handler);
47}
48
49pub fn mark_residency(handle: &GpuTensorHandle) {
52 if let Some(handler) = RESIDENCY_MARK.get() {
53 handler(handle);
54 }
55}
56
57pub fn register_residency_clear(handler: ResidencyClearFn) {
61 let _ = RESIDENCY_CLEAR.set(handler);
62}
63
64pub fn clear_residency(handle: &GpuTensorHandle) {
67 if let Some(handler) = RESIDENCY_CLEAR.get() {
68 handler(handle);
69 }
70}
71
72pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
76 let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
77}
78
79pub fn sequence_threshold_hint() -> Option<usize> {
81 SEQUENCE_THRESHOLD_PROVIDER
82 .get()
83 .and_then(|provider| provider())
84}
85
86pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
90 let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
91}
92
93pub fn workgroup_size_hint() -> Option<u32> {
95 WORKGROUP_SIZE_HINT_PROVIDER
96 .get()
97 .and_then(|provider| provider())
98}
99
100pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
103 provider().and_then(|p| p.export_context(kind))
104}
105
106#[cfg(feature = "wgpu")]
110pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
111 provider().and_then(|p| p.export_wgpu_buffer(handle))
112}
113
114pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
117 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
118 guard.insert(handle.buffer_id, precision);
119 }
120}
121
122pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
124 HANDLE_PRECISIONS
125 .read()
126 .ok()
127 .and_then(|guard| guard.get(&handle.buffer_id).copied())
128}
129
130pub fn clear_handle_precision(handle: &GpuTensorHandle) {
132 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
133 guard.remove(&handle.buffer_id);
134 }
135}
136
137pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
140 if let Ok(mut guard) = LOGICAL_HANDLES.write() {
141 if logical {
142 guard.insert(handle.buffer_id);
143 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
144 *hits.entry(handle.buffer_id).or_insert(0) += 1;
145 }
146 } else {
147 guard.remove(&handle.buffer_id);
148 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
149 hits.remove(&handle.buffer_id);
150 }
151 }
152 }
153}
154
155pub fn clear_handle_logical(handle: &GpuTensorHandle) {
157 set_handle_logical(handle, false);
158}
159
160pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
162 LOGICAL_HANDLES
163 .read()
164 .map(|guard| guard.contains(&handle.buffer_id))
165 .unwrap_or(false)
166}
167
168pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
169 LOGICAL_HANDLE_HITS
170 .read()
171 .ok()
172 .and_then(|guard| guard.get(&buffer_id).copied())
173}
174
175pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
176 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
177 guard.insert(
178 handle.buffer_id,
179 TransposeInfo {
180 base_rows,
181 base_cols,
182 },
183 );
184 }
185}
186
187pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
188 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
189 guard.remove(&handle.buffer_id);
190 }
191}
192
193pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
194 TRANSPOSED_HANDLES
195 .read()
196 .ok()
197 .and_then(|guard| guard.get(&handle.buffer_id).copied())
198}
199
200pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
201 handle_transpose_info(handle).is_some()
202}
203
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum GpuTensorStorage {
206 Real,
207 ComplexInterleaved,
208}
209
210impl Default for GpuTensorStorage {
211 fn default() -> Self {
212 Self::Real
213 }
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct GpuTensorHandle {
218 pub shape: Vec<usize>,
219 pub device_id: u32,
220 pub buffer_id: u64,
221}
222
223#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
224pub enum ProviderSpectralRange {
225 Onesided,
226 Twosided,
227 Centered,
228}
229
230#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
231pub enum ProviderSpectralFrameMode {
232 Sliding {
233 hop: usize,
234 },
235 ColumnSliding {
236 hop: usize,
237 input_rows: usize,
238 frames_per_column: usize,
239 },
240 FoldedColumns {
241 input_rows: usize,
242 },
243}
244
245#[derive(Clone, Debug)]
246pub struct ProviderSpectralRequest<'a> {
247 pub input: &'a GpuTensorHandle,
248 pub input_len: usize,
249 pub input_complex: bool,
250 pub window: &'a [f64],
251 pub nfft: usize,
252 pub frame_count: usize,
253 pub frame_mode: ProviderSpectralFrameMode,
254 pub range: ProviderSpectralRange,
255 pub denominator: f64,
256}
257
258#[derive(Clone, Debug)]
259pub struct ProviderSpectralResult {
260 pub s: GpuTensorHandle,
261 pub ps: GpuTensorHandle,
262 pub rows: usize,
263 pub cols: usize,
264}
265
266#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
267pub enum ProviderEnvelopeMethod {
268 Analytic,
269 AnalyticFir { filter_len: usize },
270 Rms { window_len: usize },
271}
272
273#[derive(Clone, Debug)]
274pub struct ProviderEnvelopeRequest<'a> {
275 pub input: &'a GpuTensorHandle,
276 pub channel_len: usize,
277 pub channel_count: usize,
278 pub output_shape: &'a [usize],
279 pub method: ProviderEnvelopeMethod,
280}
281
282#[derive(Clone, Debug)]
283pub struct ProviderEnvelopeResult {
284 pub upper: GpuTensorHandle,
285 pub lower: GpuTensorHandle,
286}
287
288#[derive(Clone, Debug)]
289pub struct ProviderHilbertRequest<'a> {
290 pub input: &'a GpuTensorHandle,
291 pub length: Option<usize>,
293 pub dim: usize,
295}
296
297#[derive(Clone, Debug)]
298pub struct ProviderModulationRequest<'a> {
299 pub input: &'a GpuTensorHandle,
300 pub constellation: &'a [f64],
302}
303
304#[derive(Clone, Debug)]
305pub struct ProviderBitModulationRequest<'a> {
306 pub input: &'a GpuTensorHandle,
307 pub input_rows: usize,
309 pub bits_per_symbol: usize,
311 pub constellation: &'a [f64],
313}
314
315pub async fn uniform_spectral_estimate(
316 request: ProviderSpectralRequest<'_>,
317) -> anyhow::Result<ProviderSpectralResult> {
318 validate_uniform_spectral_request(&request)?;
319
320 let provider =
321 provider().ok_or_else(|| anyhow!("uniform_spectral_estimate: GPU provider unavailable"))?;
322 provider.uniform_spectral_estimate(&request).await
323}
324
325fn validate_uniform_spectral_request(request: &ProviderSpectralRequest<'_>) -> anyhow::Result<()> {
326 let invalid_frame_mode = matches!(
327 request.frame_mode,
328 ProviderSpectralFrameMode::Sliding { hop: 0 }
329 | ProviderSpectralFrameMode::ColumnSliding { hop: 0, .. }
330 );
331 let invalid_input_coverage = match request.frame_mode {
332 ProviderSpectralFrameMode::Sliding { hop } => {
333 let required = request
334 .frame_count
335 .checked_sub(1)
336 .and_then(|frames| frames.checked_mul(hop))
337 .and_then(|offset| offset.checked_add(request.window.len()));
338 required.is_none_or(|required| required > request.input_len)
339 }
340 ProviderSpectralFrameMode::ColumnSliding {
341 hop,
342 input_rows,
343 frames_per_column,
344 } => {
345 if input_rows == 0 || frames_per_column == 0 {
346 true
347 } else {
348 let last_frame = request.frame_count - 1;
349 let source_col = last_frame / frames_per_column;
350 let segment = last_frame % frames_per_column;
351 source_col
352 .checked_mul(input_rows)
353 .and_then(|base| {
354 segment
355 .checked_mul(hop)
356 .and_then(|step| base.checked_add(step))
357 })
358 .and_then(|offset| offset.checked_add(request.window.len()))
359 .is_none_or(|required| required > request.input_len)
360 }
361 }
362 ProviderSpectralFrameMode::FoldedColumns { input_rows } => {
363 input_rows == 0
364 || input_rows
365 .checked_mul(request.frame_count)
366 .is_none_or(|required| required > request.input_len)
367 }
368 };
369 if request.window.is_empty()
370 || request.nfft == 0
371 || request.frame_count == 0
372 || invalid_frame_mode
373 || invalid_input_coverage
374 || !request.denominator.is_finite()
375 || request.denominator <= 0.0
376 {
377 return Err(anyhow!("uniform_spectral_estimate: invalid request"));
378 }
379
380 Ok(())
381}
382
383pub async fn signal_envelope(
384 request: ProviderEnvelopeRequest<'_>,
385) -> anyhow::Result<ProviderEnvelopeResult> {
386 let expected_len = request
387 .channel_len
388 .checked_mul(request.channel_count)
389 .ok_or_else(|| anyhow!("signal_envelope: invalid request"))?;
390 let output_len = request
391 .output_shape
392 .iter()
393 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
394 .ok_or_else(|| anyhow!("signal_envelope: invalid request"))?;
395 let input_len = request
396 .input
397 .shape
398 .iter()
399 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
400 .ok_or_else(|| anyhow!("signal_envelope: invalid request"))?;
401 if request.channel_len == 0
402 || request.channel_count == 0
403 || request.output_shape.is_empty()
404 || output_len != expected_len
405 || input_len != expected_len
406 || !provider_envelope_input_shape_matches(
407 &request.input.shape,
408 request.channel_len,
409 request.channel_count,
410 )
411 {
412 return Err(anyhow!("signal_envelope: invalid request"));
413 }
414
415 match request.method {
416 ProviderEnvelopeMethod::AnalyticFir { filter_len }
417 | ProviderEnvelopeMethod::Rms {
418 window_len: filter_len,
419 } if filter_len == 0 => return Err(anyhow!("signal_envelope: invalid request")),
420 _ => {}
421 }
422
423 let provider =
424 provider().ok_or_else(|| anyhow!("signal_envelope: GPU provider unavailable"))?;
425 provider.signal_envelope(&request).await
426}
427
428fn provider_envelope_input_shape_matches(
429 shape: &[usize],
430 channel_len: usize,
431 channel_count: usize,
432) -> bool {
433 if channel_count == 1 {
434 return match shape {
435 [len] => *len == channel_len,
436 [rows, cols] => {
437 (*rows == channel_len && *cols == 1) || (*rows == 1 && *cols == channel_len)
438 }
439 _ => false,
440 };
441 }
442
443 matches!(shape, [rows, cols] if *rows == channel_len && *cols == channel_count)
444}
445
446pub async fn signal_hilbert(
447 request: ProviderHilbertRequest<'_>,
448) -> anyhow::Result<GpuTensorHandle> {
449 if request.length == Some(0) {
450 return Err(anyhow!("signal_hilbert: invalid request"));
451 }
452 if request.dim >= request.input.shape.len() {
453 return Err(anyhow!("signal_hilbert: invalid request"));
454 }
455
456 let provider = provider().ok_or_else(|| anyhow!("signal_hilbert: GPU provider unavailable"))?;
457 provider.signal_hilbert(&request).await
458}
459
460#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
461pub struct ApiDeviceInfo {
462 pub device_id: u32,
463 pub name: String,
464 pub vendor: String,
465 pub memory_bytes: Option<u64>,
466 pub backend: Option<String>,
467}
468
469#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
470pub struct ReduceDimResult {
471 pub values: GpuTensorHandle,
472 pub indices: GpuTensorHandle,
473}
474
475#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
476pub struct ProviderCumminResult {
477 pub values: GpuTensorHandle,
478 pub indices: GpuTensorHandle,
479}
480
481pub type ProviderCummaxResult = ProviderCumminResult;
486
487#[derive(Debug, Clone, Copy, PartialEq, Eq)]
489pub enum AccelContextKind {
490 Plotting,
491}
492
493#[derive(Clone)]
495pub enum AccelContextHandle {
496 #[cfg(feature = "wgpu")]
497 Wgpu(WgpuContextHandle),
498}
499
500impl AccelContextHandle {
501 #[cfg(feature = "wgpu")]
503 pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
504 match self {
505 AccelContextHandle::Wgpu(ctx) => Some(ctx),
506 }
507 }
508}
509
510#[cfg(feature = "wgpu")]
512#[derive(Clone)]
513pub struct WgpuContextHandle {
514 pub instance: Arc<wgpu::Instance>,
515 pub device: Arc<wgpu::Device>,
516 pub queue: Arc<wgpu::Queue>,
517 pub adapter: Arc<wgpu::Adapter>,
518 pub adapter_info: wgpu::AdapterInfo,
519 pub limits: wgpu::Limits,
520 pub features: wgpu::Features,
521}
522
523#[cfg(feature = "wgpu")]
525#[derive(Clone)]
526pub struct WgpuBufferRef {
527 pub buffer: Arc<wgpu::Buffer>,
528 pub len: usize,
529 pub shape: Vec<usize>,
530 pub element_size: usize,
531 pub precision: ProviderPrecision,
532}
533
534pub fn set_handle_storage(handle: &GpuTensorHandle, storage: GpuTensorStorage) {
535 if let Ok(mut guard) = HANDLE_STORAGES.write() {
536 guard.insert(handle.buffer_id, storage);
537 }
538}
539
540pub fn handle_storage(handle: &GpuTensorHandle) -> GpuTensorStorage {
541 HANDLE_STORAGES
542 .read()
543 .ok()
544 .and_then(|guard| guard.get(&handle.buffer_id).cloned())
545 .unwrap_or(GpuTensorStorage::Real)
546}
547
548pub fn clear_handle_storage(handle: &GpuTensorHandle) {
549 if let Ok(mut guard) = HANDLE_STORAGES.write() {
550 guard.remove(&handle.buffer_id);
551 }
552}
553
554#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
555pub enum PagefunOp {
556 Mtimes,
557}
558
559#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
560pub struct PagefunRequest {
561 pub op: PagefunOp,
562 pub inputs: Vec<GpuTensorHandle>,
563 pub output_shape: Vec<usize>,
564 pub page_dims: Vec<usize>,
565 pub input_page_dims: Vec<Vec<usize>>,
566}
567
568#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
569pub enum FindDirection {
570 First,
571 Last,
572}
573
574#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
575pub struct ProviderFindResult {
576 pub linear: GpuTensorHandle,
577 pub rows: GpuTensorHandle,
578 pub cols: GpuTensorHandle,
579 pub values: Option<GpuTensorHandle>,
580}
581
582#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
583pub struct ProviderBandwidth {
584 pub lower: u32,
585 pub upper: u32,
586}
587
588#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
589pub enum ProviderSymmetryKind {
590 Symmetric,
591 Skew,
592}
593
594#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
595pub enum ProviderHermitianKind {
596 Hermitian,
597 Skew,
598}
599
600#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
601pub struct ProviderLuResult {
602 pub combined: GpuTensorHandle,
603 pub lower: GpuTensorHandle,
604 pub upper: GpuTensorHandle,
605 pub perm_matrix: GpuTensorHandle,
606 pub perm_vector: GpuTensorHandle,
607}
608
609#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
610pub struct ProviderCholResult {
611 pub factor: GpuTensorHandle,
612 pub info: u32,
614}
615
616#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
617pub struct ProviderQrResult {
618 pub q: GpuTensorHandle,
619 pub r: GpuTensorHandle,
620 pub perm_matrix: GpuTensorHandle,
621 pub perm_vector: GpuTensorHandle,
622}
623
624#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
625pub struct ProviderQrPowerIterResult {
626 pub q: GpuTensorHandle,
627 pub r: GpuTensorHandle,
628 pub perm_matrix: GpuTensorHandle,
629 pub perm_vector: GpuTensorHandle,
630}
631
632#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
633pub struct ProviderLinsolveOptions {
634 pub lower: bool,
635 pub upper: bool,
636 pub rectangular: bool,
637 pub transposed: bool,
638 pub conjugate: bool,
639 pub symmetric: bool,
640 pub posdef: bool,
641 pub need_rcond: bool,
642 pub rcond: Option<f64>,
643}
644
645#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
646pub struct ProviderLinsolveResult {
647 pub solution: GpuTensorHandle,
648 pub reciprocal_condition: f64,
649}
650
651#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
652pub struct ProviderPinvOptions {
653 pub tolerance: Option<f64>,
654}
655
656#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
657pub struct ProviderPolyvalMu {
658 pub mean: f64,
659 pub scale: f64,
660}
661
662#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
663pub struct ProviderPolyvalOptions {
664 pub mu: Option<ProviderPolyvalMu>,
665}
666
667#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
668pub struct ProviderInvOptions {}
669
670#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
671pub struct ProviderPolyfitResult {
672 pub coefficients: Vec<f64>,
673 pub r_matrix: Vec<f64>,
674 pub normr: f64,
675 pub df: f64,
676 pub mu: [f64; 2],
677}
678
679#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
681pub struct ProviderPolyderQuotient {
682 pub numerator: GpuTensorHandle,
683 pub denominator: GpuTensorHandle,
684}
685
686#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
688pub enum ProviderCondNorm {
689 Two,
690 One,
691 Inf,
692 Fro,
693}
694
695#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
697pub enum ProviderNormOrder {
698 Two,
699 One,
700 Inf,
701 NegInf,
702 Zero,
703 Fro,
704 Nuc,
705 P(f64),
706}
707
708#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
709pub struct ProviderEigResult {
710 pub eigenvalues: GpuTensorHandle,
711 pub diagonal: GpuTensorHandle,
712 pub right: GpuTensorHandle,
713 pub left: Option<GpuTensorHandle>,
714}
715
716#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
717pub enum ProviderQrPivot {
718 Matrix,
719 Vector,
720}
721
722#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
723pub struct ProviderQrOptions {
724 pub economy: bool,
725 pub pivot: ProviderQrPivot,
726}
727
728impl Default for ProviderQrOptions {
729 fn default() -> Self {
730 Self {
731 economy: false,
732 pivot: ProviderQrPivot::Matrix,
733 }
734 }
735}
736
737#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
738pub enum ProviderPrecision {
739 F32,
740 F64,
741}
742
743#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
748pub enum SpawnHandleConcurrency {
749 ImmutableShare,
751 CopyOnWrite,
753 SynchronizedMutation,
755 Reject,
757}
758
759impl SpawnHandleConcurrency {
760 pub fn as_str(self) -> &'static str {
761 match self {
762 SpawnHandleConcurrency::ImmutableShare => "immutable_share",
763 SpawnHandleConcurrency::CopyOnWrite => "copy_on_write",
764 SpawnHandleConcurrency::SynchronizedMutation => "synchronized_mutation",
765 SpawnHandleConcurrency::Reject => "reject",
766 }
767 }
768}
769
770#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
771pub enum ReductionTwoPassMode {
772 Auto,
773 ForceOn,
774 ForceOff,
775}
776
777impl ReductionTwoPassMode {
778 pub fn as_str(self) -> &'static str {
779 match self {
780 ReductionTwoPassMode::Auto => "auto",
781 ReductionTwoPassMode::ForceOn => "force_on",
782 ReductionTwoPassMode::ForceOff => "force_off",
783 }
784 }
785}
786
787#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
788pub enum ReductionFlavor {
789 Sum,
790 Mean,
791 CustomScale(f64),
792}
793
794impl ReductionFlavor {
795 pub fn is_mean(self) -> bool {
796 matches!(self, ReductionFlavor::Mean)
797 }
798
799 pub fn scale(self, reduce_len: usize) -> f64 {
800 match self {
801 ReductionFlavor::Sum => 1.0,
802 ReductionFlavor::Mean => {
803 if reduce_len == 0 {
804 1.0
805 } else {
806 1.0 / reduce_len as f64
807 }
808 }
809 ReductionFlavor::CustomScale(scale) => scale,
810 }
811 }
812}
813
814#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
816pub enum CorrcoefNormalization {
817 Unbiased,
818 Biased,
819}
820
821#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
823pub enum CorrcoefRows {
824 All,
825 Complete,
826 Pairwise,
827}
828
829#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
831pub struct CorrcoefOptions {
832 pub normalization: CorrcoefNormalization,
833 pub rows: CorrcoefRows,
834}
835
836impl Default for CorrcoefOptions {
837 fn default() -> Self {
838 Self {
839 normalization: CorrcoefNormalization::Unbiased,
840 rows: CorrcoefRows::All,
841 }
842 }
843}
844
845#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
847pub enum CovNormalization {
848 Unbiased,
849 Biased,
850}
851
852#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
854pub enum CovRows {
855 All,
856 OmitRows,
857 PartialRows,
858}
859
860#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
862pub struct CovarianceOptions {
863 pub normalization: CovNormalization,
864 pub rows: CovRows,
865 pub has_weight_vector: bool,
866}
867
868impl Default for CovarianceOptions {
869 fn default() -> Self {
870 Self {
871 normalization: CovNormalization::Unbiased,
872 rows: CovRows::All,
873 has_weight_vector: false,
874 }
875 }
876}
877
878#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
880pub enum ProviderStdNormalization {
881 Sample,
882 Population,
883}
884
885#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
887pub enum ProviderNanMode {
888 Include,
889 Omit,
890}
891
892#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
894pub enum ProviderScanDirection {
895 Forward,
896 Reverse,
897}
898
899#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
901pub enum SortOrder {
902 Ascend,
903 Descend,
904}
905
906#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
908pub enum SortComparison {
909 Auto,
910 Real,
911 Abs,
912}
913
914#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
916pub struct SortResult {
917 pub values: HostTensorOwned,
918 pub indices: HostTensorOwned,
919}
920
921#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
922pub struct SortRowsColumnSpec {
923 pub index: usize,
924 pub order: SortOrder,
925}
926
927#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
929pub enum UniqueOrder {
930 Sorted,
931 Stable,
932}
933
934#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
936pub enum UniqueOccurrence {
937 First,
938 Last,
939}
940
941#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
943pub struct UniqueOptions {
944 pub rows: bool,
945 pub order: UniqueOrder,
946 pub occurrence: UniqueOccurrence,
947}
948
949#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
951pub struct UniqueResult {
952 pub values: HostTensorOwned,
953 pub ia: HostTensorOwned,
954 pub ic: HostTensorOwned,
955}
956
957#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
959pub enum UnionOrder {
960 Sorted,
961 Stable,
962}
963
964#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
966pub struct UnionOptions {
967 pub rows: bool,
968 pub order: UnionOrder,
969}
970
971#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
973pub struct UnionResult {
974 pub values: HostTensorOwned,
975 pub ia: HostTensorOwned,
976 pub ib: HostTensorOwned,
977}
978
979#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
981pub enum FspecialFilter {
982 Average {
983 rows: u32,
984 cols: u32,
985 },
986 Disk {
987 radius: f64,
988 size: u32,
989 },
990 Gaussian {
991 rows: u32,
992 cols: u32,
993 sigma: f64,
994 },
995 Laplacian {
996 alpha: f64,
997 },
998 Log {
999 rows: u32,
1000 cols: u32,
1001 sigma: f64,
1002 },
1003 Motion {
1004 length: u32,
1005 kernel_size: u32,
1006 angle_degrees: f64,
1007 oversample: u32,
1008 },
1009 Prewitt,
1010 Sobel,
1011 Unsharp {
1012 alpha: f64,
1013 },
1014}
1015
1016#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1018pub struct FspecialRequest {
1019 pub filter: FspecialFilter,
1020}
1021
1022#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1024pub enum ImfilterPadding {
1025 Constant,
1026 Replicate,
1027 Symmetric,
1028 Circular,
1029}
1030
1031#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1033pub enum ImfilterShape {
1034 Same,
1035 Full,
1036 Valid,
1037}
1038
1039#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1041pub enum ImfilterMode {
1042 Correlation,
1043 Convolution,
1044}
1045
1046#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1048pub struct ImfilterOptions {
1049 pub padding: ImfilterPadding,
1050 pub constant_value: f64,
1051 pub shape: ImfilterShape,
1052 pub mode: ImfilterMode,
1053}
1054
1055impl Default for ImfilterOptions {
1056 fn default() -> Self {
1057 Self {
1058 padding: ImfilterPadding::Constant,
1059 constant_value: 0.0,
1060 shape: ImfilterShape::Same,
1061 mode: ImfilterMode::Correlation,
1062 }
1063 }
1064}
1065
1066#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1068pub enum SetdiffOrder {
1069 Sorted,
1070 Stable,
1071}
1072
1073#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1075pub struct SetdiffOptions {
1076 pub rows: bool,
1077 pub order: SetdiffOrder,
1078}
1079
1080#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1082pub struct SetdiffResult {
1083 pub values: HostTensorOwned,
1084 pub ia: HostTensorOwned,
1085}
1086
1087#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1089pub struct IsMemberOptions {
1090 pub rows: bool,
1091}
1092
1093#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1095pub struct HostLogicalOwned {
1096 pub data: Vec<u8>,
1097 pub shape: Vec<usize>,
1098}
1099
1100#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1102pub struct IsMemberResult {
1103 pub mask: HostLogicalOwned,
1104 pub loc: HostTensorOwned,
1105}
1106
1107#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1108pub enum ProviderConvMode {
1109 Full,
1110 Same,
1111 Valid,
1112}
1113
1114#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1115pub enum ProviderConvOrientation {
1116 Row,
1117 Column,
1118}
1119
1120#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1121pub struct ProviderConv1dOptions {
1122 pub mode: ProviderConvMode,
1123 pub orientation: ProviderConvOrientation,
1124}
1125
1126#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1127pub struct ProviderIirFilterOptions {
1128 pub dim: usize,
1130 pub zi: Option<GpuTensorHandle>,
1132}
1133
1134#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1135pub struct ProviderIirFilterResult {
1136 pub output: GpuTensorHandle,
1138 pub final_state: Option<GpuTensorHandle>,
1140}
1141
1142#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1143pub struct ProviderMoments2 {
1144 pub mean: GpuTensorHandle,
1145 pub ex2: GpuTensorHandle,
1146}
1147
1148#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
1149pub struct ProviderDispatchStats {
1150 pub count: u64,
1152 pub total_wall_time_ns: u64,
1154}
1155
1156#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
1157pub struct ProviderFallbackStat {
1158 pub reason: String,
1159 pub count: u64,
1160}
1161
1162#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
1163pub struct ProviderTelemetry {
1164 pub fused_elementwise: ProviderDispatchStats,
1165 pub fused_reduction: ProviderDispatchStats,
1166 pub matmul: ProviderDispatchStats,
1167 pub linsolve: ProviderDispatchStats,
1168 pub mldivide: ProviderDispatchStats,
1169 pub mrdivide: ProviderDispatchStats,
1170 pub upload_bytes: u64,
1171 pub download_bytes: u64,
1172 pub solve_fallbacks: Vec<ProviderFallbackStat>,
1173 pub fusion_cache_hits: u64,
1174 pub fusion_cache_misses: u64,
1175 pub bind_group_cache_hits: u64,
1176 pub bind_group_cache_misses: u64,
1177 pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
1179 pub kernel_launches: Vec<KernelLaunchTelemetry>,
1181}
1182
1183#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1184pub struct BindGroupLayoutTelemetry {
1185 pub tag: String,
1186 pub hits: u64,
1187 pub misses: u64,
1188}
1189
1190#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1191pub struct KernelAttrTelemetry {
1192 pub key: String,
1193 pub value: u64,
1194}
1195
1196#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1197pub struct KernelLaunchTelemetry {
1198 pub kernel: String,
1199 pub precision: Option<String>,
1200 pub shape: Vec<KernelAttrTelemetry>,
1201 pub tuning: Vec<KernelAttrTelemetry>,
1202}
1203
1204pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
1205pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
1206
1207fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
1208 Box::pin(async move { Err(anyhow::anyhow!(message)) })
1209}
1210
1211pub trait AccelProvider: Send + Sync {
1213 fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
1214 fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
1215 fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
1216 fn device_info(&self) -> String;
1217 fn device_id(&self) -> u32 {
1218 0
1219 }
1220
1221 fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
1227 SpawnHandleConcurrency::Reject
1228 }
1229
1230 fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
1233 None
1234 }
1235
1236 #[cfg(feature = "wgpu")]
1238 fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
1239 let _ = _handle;
1240 None
1241 }
1242
1243 fn gather_linear(
1246 &self,
1247 _source: &GpuTensorHandle,
1248 _indices: &[u32],
1249 _output_shape: &[usize],
1250 ) -> anyhow::Result<GpuTensorHandle> {
1251 Err(anyhow::anyhow!("gather_linear not supported by provider"))
1252 }
1253
1254 fn scatter_linear(
1258 &self,
1259 _target: &GpuTensorHandle,
1260 _indices: &[u32],
1261 _values: &GpuTensorHandle,
1262 ) -> anyhow::Result<()> {
1263 Err(anyhow::anyhow!("scatter_linear not supported by provider"))
1264 }
1265
1266 fn device_info_struct(&self) -> ApiDeviceInfo {
1268 ApiDeviceInfo {
1269 device_id: 0,
1270 name: self.device_info(),
1271 vendor: String::new(),
1272 memory_bytes: None,
1273 backend: None,
1274 }
1275 }
1276
1277 fn precision(&self) -> ProviderPrecision {
1278 ProviderPrecision::F64
1279 }
1280
1281 fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
1283 Err(anyhow::anyhow!("read_scalar not supported by provider"))
1284 }
1285
1286 fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1288 Err(anyhow::anyhow!("zeros not supported by provider"))
1289 }
1290
1291 fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1293 Err(anyhow::anyhow!("ones not supported by provider"))
1294 }
1295
1296 fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1298 self.zeros(&prototype.shape)
1299 }
1300
1301 fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
1303 if value == 0.0 {
1304 return self.zeros(shape);
1305 }
1306 if let Ok(base) = self.zeros(shape) {
1307 match self.scalar_add(&base, value) {
1308 Ok(out) => {
1309 let _ = self.free(&base);
1310 return Ok(out);
1311 }
1312 Err(_) => {
1313 let _ = self.free(&base);
1314 }
1315 }
1316 }
1317 let len: usize = shape.iter().copied().product();
1318 let data = vec![value; len];
1319 let view = HostTensorView { data: &data, shape };
1320 self.upload(&view)
1321 }
1322
1323 fn fill_like(
1325 &self,
1326 prototype: &GpuTensorHandle,
1327 value: f64,
1328 ) -> anyhow::Result<GpuTensorHandle> {
1329 if value == 0.0 {
1330 return self.zeros_like(prototype);
1331 }
1332 if let Ok(base) = self.zeros_like(prototype) {
1333 match self.scalar_add(&base, value) {
1334 Ok(out) => {
1335 let _ = self.free(&base);
1336 return Ok(out);
1337 }
1338 Err(_) => {
1339 let _ = self.free(&base);
1340 }
1341 }
1342 }
1343 self.fill(&prototype.shape, value)
1344 }
1345
1346 fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1348 self.ones(&prototype.shape)
1349 }
1350
1351 fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1353 Err(anyhow::anyhow!("eye not supported by provider"))
1354 }
1355
1356 fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1358 self.eye(&prototype.shape)
1359 }
1360
1361 fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1363 Err(anyhow::anyhow!("meshgrid not supported by provider"))
1364 }
1365
1366 fn diag_from_vector(
1368 &self,
1369 _vector: &GpuTensorHandle,
1370 _offset: isize,
1371 ) -> anyhow::Result<GpuTensorHandle> {
1372 Err(anyhow::anyhow!(
1373 "diag_from_vector not supported by provider"
1374 ))
1375 }
1376
1377 fn diag_extract(
1379 &self,
1380 _matrix: &GpuTensorHandle,
1381 _offset: isize,
1382 ) -> anyhow::Result<GpuTensorHandle> {
1383 Err(anyhow::anyhow!("diag_extract not supported by provider"))
1384 }
1385
1386 fn tril<'a>(
1388 &'a self,
1389 _matrix: &'a GpuTensorHandle,
1390 _offset: isize,
1391 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1392 Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1393 }
1394
1395 fn triu<'a>(
1397 &'a self,
1398 _matrix: &'a GpuTensorHandle,
1399 _offset: isize,
1400 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1401 Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1402 }
1403
1404 fn polyval(
1406 &self,
1407 _coefficients: &GpuTensorHandle,
1408 _points: &GpuTensorHandle,
1409 _options: &ProviderPolyvalOptions,
1410 ) -> anyhow::Result<GpuTensorHandle> {
1411 Err(anyhow::anyhow!("polyval not supported by provider"))
1412 }
1413
1414 fn polyfit<'a>(
1416 &'a self,
1417 _x: &'a GpuTensorHandle,
1418 _y: &'a GpuTensorHandle,
1419 _degree: usize,
1420 _weights: Option<&'a GpuTensorHandle>,
1421 ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1422 Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1423 }
1424
1425 fn polyder_single<'a>(
1427 &'a self,
1428 _polynomial: &'a GpuTensorHandle,
1429 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1430 Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1431 }
1432
1433 fn polyder_product<'a>(
1435 &'a self,
1436 _p: &'a GpuTensorHandle,
1437 _q: &'a GpuTensorHandle,
1438 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1439 Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1440 }
1441
1442 fn polyder_quotient<'a>(
1444 &'a self,
1445 _u: &'a GpuTensorHandle,
1446 _v: &'a GpuTensorHandle,
1447 ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1448 Box::pin(async move {
1449 Err(anyhow::anyhow!(
1450 "polyder_quotient not supported by provider"
1451 ))
1452 })
1453 }
1454
1455 fn polyint(
1457 &self,
1458 _polynomial: &GpuTensorHandle,
1459 _constant: f64,
1460 ) -> anyhow::Result<GpuTensorHandle> {
1461 Err(anyhow::anyhow!("polyint not supported by provider"))
1462 }
1463
1464 fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1466 Err(anyhow::anyhow!("random_uniform not supported by provider"))
1467 }
1468
1469 fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1471 self.random_uniform(&prototype.shape)
1472 }
1473
1474 fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1476 Err(anyhow::anyhow!("random_normal not supported by provider"))
1477 }
1478
1479 fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1481 self.random_normal(&prototype.shape)
1482 }
1483
1484 fn random_exponential(&self, _mu: f64, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1486 Err(anyhow::anyhow!(
1487 "random_exponential not supported by provider"
1488 ))
1489 }
1490
1491 fn random_normrnd(
1493 &self,
1494 _mu: f64,
1495 _sigma: f64,
1496 _shape: &[usize],
1497 ) -> anyhow::Result<GpuTensorHandle> {
1498 Err(anyhow::anyhow!("random_normrnd not supported by provider"))
1499 }
1500
1501 fn random_unifrnd(
1503 &self,
1504 _a: f64,
1505 _b: f64,
1506 _shape: &[usize],
1507 ) -> anyhow::Result<GpuTensorHandle> {
1508 Err(anyhow::anyhow!("random_unifrnd not supported by provider"))
1509 }
1510
1511 fn stochastic_evolution(
1512 &self,
1513 _state: &GpuTensorHandle,
1514 _drift: f64,
1515 _scale: f64,
1516 _steps: u32,
1517 ) -> anyhow::Result<GpuTensorHandle> {
1518 Err(anyhow::anyhow!(
1519 "stochastic_evolution not supported by provider"
1520 ))
1521 }
1522
1523 fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1525 Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1526 }
1527
1528 fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1530 Err(anyhow::anyhow!("fspecial not supported by provider"))
1531 }
1532
1533 fn peaks(&self, _n: usize) -> anyhow::Result<GpuTensorHandle> {
1536 Err(anyhow::anyhow!("peaks not supported by provider"))
1537 }
1538
1539 fn peaks_xy(
1542 &self,
1543 _x: &GpuTensorHandle,
1544 _y: &GpuTensorHandle,
1545 ) -> anyhow::Result<GpuTensorHandle> {
1546 Err(anyhow::anyhow!("peaks_xy not supported by provider"))
1547 }
1548
1549 fn hann_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1550 Err(anyhow::anyhow!("hann_window not supported by provider"))
1551 }
1552
1553 fn hamming_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1554 Err(anyhow::anyhow!("hamming_window not supported by provider"))
1555 }
1556
1557 fn blackman_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1558 Err(anyhow::anyhow!("blackman_window not supported by provider"))
1559 }
1560
1561 fn imfilter<'a>(
1563 &'a self,
1564 _image: &'a GpuTensorHandle,
1565 _kernel: &'a GpuTensorHandle,
1566 _options: &'a ImfilterOptions,
1567 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1568 unsupported_future("imfilter not supported by provider")
1569 }
1570
1571 fn random_integer_range(
1573 &self,
1574 _lower: i64,
1575 _upper: i64,
1576 _shape: &[usize],
1577 ) -> anyhow::Result<GpuTensorHandle> {
1578 Err(anyhow::anyhow!(
1579 "random_integer_range not supported by provider"
1580 ))
1581 }
1582
1583 fn random_integer_like(
1585 &self,
1586 prototype: &GpuTensorHandle,
1587 lower: i64,
1588 upper: i64,
1589 ) -> anyhow::Result<GpuTensorHandle> {
1590 self.random_integer_range(lower, upper, &prototype.shape)
1591 }
1592
1593 fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1595 Err(anyhow!("random_permutation not supported by provider"))
1596 }
1597
1598 fn random_permutation_like(
1600 &self,
1601 _prototype: &GpuTensorHandle,
1602 n: usize,
1603 k: usize,
1604 ) -> anyhow::Result<GpuTensorHandle> {
1605 self.random_permutation(n, k)
1606 }
1607
1608 fn covariance<'a>(
1610 &'a self,
1611 _matrix: &'a GpuTensorHandle,
1612 _second: Option<&'a GpuTensorHandle>,
1613 _weights: Option<&'a GpuTensorHandle>,
1614 _options: &'a CovarianceOptions,
1615 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1616 unsupported_future("covariance not supported by provider")
1617 }
1618
1619 fn corrcoef<'a>(
1621 &'a self,
1622 _matrix: &'a GpuTensorHandle,
1623 _options: &'a CorrcoefOptions,
1624 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1625 unsupported_future("corrcoef not supported by provider")
1626 }
1627
1628 fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1630 Err(anyhow::anyhow!("linspace not supported by provider"))
1631 }
1632 fn elem_add<'a>(
1633 &'a self,
1634 _a: &'a GpuTensorHandle,
1635 _b: &'a GpuTensorHandle,
1636 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1637 unsupported_future("elem_add not supported by provider")
1638 }
1639 fn elem_mul<'a>(
1640 &'a self,
1641 _a: &'a GpuTensorHandle,
1642 _b: &'a GpuTensorHandle,
1643 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1644 unsupported_future("elem_mul not supported by provider")
1645 }
1646 fn elem_max<'a>(
1647 &'a self,
1648 _a: &'a GpuTensorHandle,
1649 _b: &'a GpuTensorHandle,
1650 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1651 unsupported_future("elem_max not supported by provider")
1652 }
1653 fn elem_min<'a>(
1654 &'a self,
1655 _a: &'a GpuTensorHandle,
1656 _b: &'a GpuTensorHandle,
1657 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1658 unsupported_future("elem_min not supported by provider")
1659 }
1660 fn elem_sub<'a>(
1661 &'a self,
1662 _a: &'a GpuTensorHandle,
1663 _b: &'a GpuTensorHandle,
1664 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1665 unsupported_future("elem_sub not supported by provider")
1666 }
1667 fn elem_div<'a>(
1668 &'a self,
1669 _a: &'a GpuTensorHandle,
1670 _b: &'a GpuTensorHandle,
1671 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1672 unsupported_future("elem_div not supported by provider")
1673 }
1674 fn elem_pow<'a>(
1675 &'a self,
1676 _a: &'a GpuTensorHandle,
1677 _b: &'a GpuTensorHandle,
1678 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1679 unsupported_future("elem_pow not supported by provider")
1680 }
1681
1682 fn complex_from_real<'a>(
1685 &'a self,
1686 _real: &'a GpuTensorHandle,
1687 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1688 unsupported_future("complex_from_real not supported by provider")
1689 }
1690
1691 fn complex_from_real_imag<'a>(
1696 &'a self,
1697 _real: &'a GpuTensorHandle,
1698 _imag: &'a GpuTensorHandle,
1699 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1700 unsupported_future("complex_from_real_imag not supported by provider")
1701 }
1702
1703 fn modulate_constellation<'a>(
1706 &'a self,
1707 _request: ProviderModulationRequest<'a>,
1708 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1709 unsupported_future("modulate_constellation not supported by provider")
1710 }
1711
1712 fn modulate_bits_constellation<'a>(
1715 &'a self,
1716 _request: ProviderBitModulationRequest<'a>,
1717 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1718 unsupported_future("modulate_bits_constellation not supported by provider")
1719 }
1720
1721 fn elem_hypot<'a>(
1722 &'a self,
1723 _a: &'a GpuTensorHandle,
1724 _b: &'a GpuTensorHandle,
1725 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1726 unsupported_future("elem_hypot not supported by provider")
1727 }
1728 fn elem_ge<'a>(
1729 &'a self,
1730 _a: &'a GpuTensorHandle,
1731 _b: &'a GpuTensorHandle,
1732 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1733 unsupported_future("elem_ge not supported by provider")
1734 }
1735 fn elem_le<'a>(
1736 &'a self,
1737 _a: &'a GpuTensorHandle,
1738 _b: &'a GpuTensorHandle,
1739 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1740 unsupported_future("elem_le not supported by provider")
1741 }
1742 fn elem_lt<'a>(
1743 &'a self,
1744 _a: &'a GpuTensorHandle,
1745 _b: &'a GpuTensorHandle,
1746 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1747 unsupported_future("elem_lt not supported by provider")
1748 }
1749 fn elem_gt<'a>(
1750 &'a self,
1751 _a: &'a GpuTensorHandle,
1752 _b: &'a GpuTensorHandle,
1753 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1754 unsupported_future("elem_gt not supported by provider")
1755 }
1756 fn elem_eq<'a>(
1757 &'a self,
1758 _a: &'a GpuTensorHandle,
1759 _b: &'a GpuTensorHandle,
1760 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1761 unsupported_future("elem_eq not supported by provider")
1762 }
1763 fn elem_ne<'a>(
1764 &'a self,
1765 _a: &'a GpuTensorHandle,
1766 _b: &'a GpuTensorHandle,
1767 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1768 unsupported_future("elem_ne not supported by provider")
1769 }
1770 fn logical_and(
1771 &self,
1772 _a: &GpuTensorHandle,
1773 _b: &GpuTensorHandle,
1774 ) -> anyhow::Result<GpuTensorHandle> {
1775 Err(anyhow::anyhow!("logical_and not supported by provider"))
1776 }
1777 fn logical_or(
1778 &self,
1779 _a: &GpuTensorHandle,
1780 _b: &GpuTensorHandle,
1781 ) -> anyhow::Result<GpuTensorHandle> {
1782 Err(anyhow::anyhow!("logical_or not supported by provider"))
1783 }
1784 fn logical_xor(
1785 &self,
1786 _a: &GpuTensorHandle,
1787 _b: &GpuTensorHandle,
1788 ) -> anyhow::Result<GpuTensorHandle> {
1789 Err(anyhow::anyhow!("logical_xor not supported by provider"))
1790 }
1791 fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1792 Err(anyhow::anyhow!("logical_not not supported by provider"))
1793 }
1794 fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1795 Ok(handle_is_logical(a))
1796 }
1797 fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1798 Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1799 }
1800 fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1801 Err(anyhow::anyhow!(
1802 "logical_isfinite not supported by provider"
1803 ))
1804 }
1805 fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1806 Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1807 }
1808 fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1809 Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1810 }
1811 fn elem_atan2<'a>(
1812 &'a self,
1813 _y: &'a GpuTensorHandle,
1814 _x: &'a GpuTensorHandle,
1815 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1816 unsupported_future("elem_atan2 not supported by provider")
1817 }
1818 fn unary_sin<'a>(
1820 &'a self,
1821 _a: &'a GpuTensorHandle,
1822 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1823 unsupported_future("unary_sin not supported by provider")
1824 }
1825 fn unary_sinc<'a>(
1826 &'a self,
1827 _a: &'a GpuTensorHandle,
1828 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1829 unsupported_future("unary_sinc not supported by provider")
1830 }
1831 fn unary_gamma<'a>(
1832 &'a self,
1833 _a: &'a GpuTensorHandle,
1834 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1835 unsupported_future("unary_gamma not supported by provider")
1836 }
1837 fn unary_factorial<'a>(
1838 &'a self,
1839 _a: &'a GpuTensorHandle,
1840 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1841 unsupported_future("unary_factorial not supported by provider")
1842 }
1843 fn unary_asinh<'a>(
1844 &'a self,
1845 _a: &'a GpuTensorHandle,
1846 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1847 unsupported_future("unary_asinh not supported by provider")
1848 }
1849 fn unary_sinh<'a>(
1850 &'a self,
1851 _a: &'a GpuTensorHandle,
1852 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1853 unsupported_future("unary_sinh not supported by provider")
1854 }
1855 fn unary_cosh<'a>(
1856 &'a self,
1857 _a: &'a GpuTensorHandle,
1858 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1859 unsupported_future("unary_cosh not supported by provider")
1860 }
1861 fn unary_asin<'a>(
1862 &'a self,
1863 _a: &'a GpuTensorHandle,
1864 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1865 unsupported_future("unary_asin not supported by provider")
1866 }
1867 fn unary_acos<'a>(
1868 &'a self,
1869 _a: &'a GpuTensorHandle,
1870 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1871 unsupported_future("unary_acos not supported by provider")
1872 }
1873 fn unary_acosh<'a>(
1874 &'a self,
1875 _a: &'a GpuTensorHandle,
1876 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1877 unsupported_future("unary_acosh not supported by provider")
1878 }
1879 fn unary_tan<'a>(
1880 &'a self,
1881 _a: &'a GpuTensorHandle,
1882 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1883 unsupported_future("unary_tan not supported by provider")
1884 }
1885 fn unary_tanh<'a>(
1886 &'a self,
1887 _a: &'a GpuTensorHandle,
1888 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1889 unsupported_future("unary_tanh not supported by provider")
1890 }
1891 fn unary_atan<'a>(
1892 &'a self,
1893 _a: &'a GpuTensorHandle,
1894 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1895 unsupported_future("unary_atan not supported by provider")
1896 }
1897 fn unary_atanh<'a>(
1898 &'a self,
1899 _a: &'a GpuTensorHandle,
1900 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1901 unsupported_future("unary_atanh not supported by provider")
1902 }
1903 fn unary_ceil<'a>(
1904 &'a self,
1905 _a: &'a GpuTensorHandle,
1906 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1907 unsupported_future("unary_ceil not supported by provider")
1908 }
1909 fn unary_floor<'a>(
1910 &'a self,
1911 _a: &'a GpuTensorHandle,
1912 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1913 unsupported_future("unary_floor not supported by provider")
1914 }
1915 fn unary_round<'a>(
1916 &'a self,
1917 _a: &'a GpuTensorHandle,
1918 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1919 unsupported_future("unary_round not supported by provider")
1920 }
1921 fn unary_fix<'a>(
1922 &'a self,
1923 _a: &'a GpuTensorHandle,
1924 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1925 unsupported_future("unary_fix not supported by provider")
1926 }
1927 fn unary_cos<'a>(
1928 &'a self,
1929 _a: &'a GpuTensorHandle,
1930 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1931 unsupported_future("unary_cos not supported by provider")
1932 }
1933 fn unary_angle<'a>(
1934 &'a self,
1935 _a: &'a GpuTensorHandle,
1936 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1937 unsupported_future("unary_angle not supported by provider")
1938 }
1939 fn unary_imag<'a>(
1940 &'a self,
1941 _a: &'a GpuTensorHandle,
1942 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1943 unsupported_future("unary_imag not supported by provider")
1944 }
1945 fn unary_real<'a>(
1946 &'a self,
1947 _a: &'a GpuTensorHandle,
1948 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1949 unsupported_future("unary_real not supported by provider")
1950 }
1951 fn unary_conj<'a>(
1952 &'a self,
1953 _a: &'a GpuTensorHandle,
1954 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1955 unsupported_future("unary_conj not supported by provider")
1956 }
1957 fn unary_abs<'a>(
1958 &'a self,
1959 _a: &'a GpuTensorHandle,
1960 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1961 unsupported_future("unary_abs not supported by provider")
1962 }
1963 fn unary_sign<'a>(
1964 &'a self,
1965 _a: &'a GpuTensorHandle,
1966 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1967 unsupported_future("unary_sign not supported by provider")
1968 }
1969 fn unary_heaviside<'a>(
1970 &'a self,
1971 _a: &'a GpuTensorHandle,
1972 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1973 unsupported_future("unary_heaviside not supported by provider")
1974 }
1975 fn unary_exp<'a>(
1976 &'a self,
1977 _a: &'a GpuTensorHandle,
1978 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1979 unsupported_future("unary_exp not supported by provider")
1980 }
1981 fn unary_expm1<'a>(
1982 &'a self,
1983 _a: &'a GpuTensorHandle,
1984 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1985 unsupported_future("unary_expm1 not supported by provider")
1986 }
1987 fn unary_log<'a>(
1988 &'a self,
1989 _a: &'a GpuTensorHandle,
1990 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1991 unsupported_future("unary_log not supported by provider")
1992 }
1993 fn unary_log2<'a>(
1994 &'a self,
1995 _a: &'a GpuTensorHandle,
1996 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1997 unsupported_future("unary_log2 not supported by provider")
1998 }
1999 fn unary_log10<'a>(
2000 &'a self,
2001 _a: &'a GpuTensorHandle,
2002 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2003 unsupported_future("unary_log10 not supported by provider")
2004 }
2005 fn unary_log1p<'a>(
2006 &'a self,
2007 _a: &'a GpuTensorHandle,
2008 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2009 unsupported_future("unary_log1p not supported by provider")
2010 }
2011 fn unary_sqrt<'a>(
2012 &'a self,
2013 _a: &'a GpuTensorHandle,
2014 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2015 unsupported_future("unary_sqrt not supported by provider")
2016 }
2017 fn unary_double<'a>(
2018 &'a self,
2019 _a: &'a GpuTensorHandle,
2020 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2021 unsupported_future("unary_double not supported by provider")
2022 }
2023 fn unary_single<'a>(
2024 &'a self,
2025 _a: &'a GpuTensorHandle,
2026 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2027 unsupported_future("unary_single not supported by provider")
2028 }
2029 fn unary_pow2<'a>(
2030 &'a self,
2031 _a: &'a GpuTensorHandle,
2032 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2033 unsupported_future("unary_pow2 not supported by provider")
2034 }
2035 fn unary_nextpow2<'a>(
2036 &'a self,
2037 _a: &'a GpuTensorHandle,
2038 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2039 unsupported_future("unary_nextpow2 not supported by provider")
2040 }
2041 fn pow2_scale(
2042 &self,
2043 _mantissa: &GpuTensorHandle,
2044 _exponent: &GpuTensorHandle,
2045 ) -> anyhow::Result<GpuTensorHandle> {
2046 Err(anyhow::anyhow!("pow2_scale not supported by provider"))
2047 }
2048 fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2050 Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
2051 }
2052 fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2053 Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
2054 }
2055 fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2057 Err(anyhow::anyhow!("scalar_add not supported by provider"))
2058 }
2059 fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2060 Err(anyhow::anyhow!("scalar_sub not supported by provider"))
2061 }
2062 fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2063 Err(anyhow::anyhow!("scalar_mul not supported by provider"))
2064 }
2065 fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2066 Err(anyhow::anyhow!("scalar_max not supported by provider"))
2067 }
2068 fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2069 Err(anyhow::anyhow!("scalar_min not supported by provider"))
2070 }
2071 fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2072 Err(anyhow::anyhow!("scalar_div not supported by provider"))
2073 }
2074 fn sort_dim<'a>(
2075 &'a self,
2076 _a: &'a GpuTensorHandle,
2077 _dim: usize,
2078 _order: SortOrder,
2079 _comparison: SortComparison,
2080 ) -> AccelProviderFuture<'a, SortResult> {
2081 unsupported_future("sort_dim not supported by provider")
2082 }
2083 fn sort_rows<'a>(
2084 &'a self,
2085 _a: &'a GpuTensorHandle,
2086 _columns: &'a [SortRowsColumnSpec],
2087 _comparison: SortComparison,
2088 ) -> AccelProviderFuture<'a, SortResult> {
2089 unsupported_future("sort_rows not supported by provider")
2090 }
2091 fn matmul<'a>(
2092 &'a self,
2093 _a: &'a GpuTensorHandle,
2094 _b: &'a GpuTensorHandle,
2095 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2096 unsupported_future("matmul not supported by provider")
2097 }
2098
2099 fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2100 Err(anyhow::anyhow!("syrk not supported by provider"))
2101 }
2102 fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
2103 Err(anyhow::anyhow!("pagefun not supported by provider"))
2104 }
2105
2106 fn matmul_epilogue<'a>(
2111 &'a self,
2112 a: &'a GpuTensorHandle,
2113 b: &'a GpuTensorHandle,
2114 epilogue: &'a MatmulEpilogue,
2115 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2116 Box::pin(async move {
2117 if epilogue.is_noop() {
2118 return self.matmul(a, b).await;
2119 }
2120 Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
2121 })
2122 }
2123 fn image_normalize<'a>(
2124 &'a self,
2125 _input: &'a GpuTensorHandle,
2126 _desc: &'a ImageNormalizeDescriptor,
2127 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2128 unsupported_future("image_normalize fusion not supported by provider")
2129 }
2130 fn matmul_power_step<'a>(
2131 &'a self,
2132 _lhs: &'a GpuTensorHandle,
2133 _rhs: &'a GpuTensorHandle,
2134 _epilogue: &'a PowerStepEpilogue,
2135 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2136 unsupported_future("matmul_power_step normalization not supported by provider")
2137 }
2138 fn linsolve<'a>(
2139 &'a self,
2140 _lhs: &'a GpuTensorHandle,
2141 _rhs: &'a GpuTensorHandle,
2142 _options: &'a ProviderLinsolveOptions,
2143 ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
2144 unsupported_future("linsolve not supported by provider")
2145 }
2146 fn inv<'a>(
2147 &'a self,
2148 _matrix: &'a GpuTensorHandle,
2149 _options: ProviderInvOptions,
2150 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2151 unsupported_future("inv not supported by provider")
2152 }
2153 fn pinv<'a>(
2154 &'a self,
2155 _matrix: &'a GpuTensorHandle,
2156 _options: ProviderPinvOptions,
2157 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2158 unsupported_future("pinv not supported by provider")
2159 }
2160 fn cond<'a>(
2161 &'a self,
2162 _matrix: &'a GpuTensorHandle,
2163 _norm: ProviderCondNorm,
2164 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2165 Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
2166 }
2167 fn norm<'a>(
2168 &'a self,
2169 _tensor: &'a GpuTensorHandle,
2170 _order: ProviderNormOrder,
2171 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2172 Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
2173 }
2174 fn rank<'a>(
2175 &'a self,
2176 _matrix: &'a GpuTensorHandle,
2177 _tolerance: Option<f64>,
2178 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2179 Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
2180 }
2181 fn rcond<'a>(
2182 &'a self,
2183 _matrix: &'a GpuTensorHandle,
2184 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2185 Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
2186 }
2187 fn mldivide<'a>(
2188 &'a self,
2189 _lhs: &'a GpuTensorHandle,
2190 _rhs: &'a GpuTensorHandle,
2191 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2192 Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
2193 }
2194 fn mrdivide<'a>(
2195 &'a self,
2196 _lhs: &'a GpuTensorHandle,
2197 _rhs: &'a GpuTensorHandle,
2198 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2199 Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
2200 }
2201 fn eig<'a>(
2202 &'a self,
2203 _a: &'a GpuTensorHandle,
2204 _compute_left: bool,
2205 ) -> AccelProviderFuture<'a, ProviderEigResult> {
2206 Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
2207 }
2208 fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
2209 Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
2210 }
2211
2212 fn chol<'a>(
2213 &'a self,
2214 _a: &'a GpuTensorHandle,
2215 _lower: bool,
2216 ) -> AccelProviderFuture<'a, ProviderCholResult> {
2217 Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
2218 }
2219 fn qr<'a>(
2220 &'a self,
2221 _a: &'a GpuTensorHandle,
2222 _options: ProviderQrOptions,
2223 ) -> AccelProviderFuture<'a, ProviderQrResult> {
2224 Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
2225 }
2226 fn take_matmul_sources(
2227 &self,
2228 _product: &GpuTensorHandle,
2229 ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
2230 None
2231 }
2232 fn qr_power_iter<'a>(
2233 &'a self,
2234 product: &'a GpuTensorHandle,
2235 _product_lhs: Option<&'a GpuTensorHandle>,
2236 q_handle: &'a GpuTensorHandle,
2237 options: &'a ProviderQrOptions,
2238 ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
2239 let _ = (product, q_handle, options);
2240 Box::pin(async move { Ok(None) })
2241 }
2242 fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2243 Err(anyhow::anyhow!("transpose not supported by provider"))
2244 }
2245 fn conv1d(
2246 &self,
2247 _signal: &GpuTensorHandle,
2248 _kernel: &GpuTensorHandle,
2249 _options: ProviderConv1dOptions,
2250 ) -> anyhow::Result<GpuTensorHandle> {
2251 Err(anyhow::anyhow!("conv1d not supported by provider"))
2252 }
2253 fn conv2d(
2254 &self,
2255 _signal: &GpuTensorHandle,
2256 _kernel: &GpuTensorHandle,
2257 _mode: ProviderConvMode,
2258 ) -> anyhow::Result<GpuTensorHandle> {
2259 Err(anyhow::anyhow!("conv2d not supported by provider"))
2260 }
2261 fn iir_filter<'a>(
2262 &'a self,
2263 _b: &'a GpuTensorHandle,
2264 _a: &'a GpuTensorHandle,
2265 _x: &'a GpuTensorHandle,
2266 _options: ProviderIirFilterOptions,
2267 ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
2268 Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
2269 }
2270 fn uniform_spectral_estimate<'a>(
2271 &'a self,
2272 _request: &'a ProviderSpectralRequest<'a>,
2273 ) -> AccelProviderFuture<'a, ProviderSpectralResult> {
2274 unsupported_future("uniform_spectral_estimate not supported by provider")
2275 }
2276 fn signal_envelope<'a>(
2277 &'a self,
2278 _request: &'a ProviderEnvelopeRequest<'a>,
2279 ) -> AccelProviderFuture<'a, ProviderEnvelopeResult> {
2280 unsupported_future("signal_envelope not supported by provider")
2281 }
2282 fn signal_hilbert<'a>(
2283 &'a self,
2284 _request: &'a ProviderHilbertRequest<'a>,
2285 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2286 unsupported_future("signal_hilbert not supported by provider")
2287 }
2288 fn permute(
2290 &self,
2291 _handle: &GpuTensorHandle,
2292 _order: &[usize],
2293 ) -> anyhow::Result<GpuTensorHandle> {
2294 Err(anyhow::anyhow!("permute not supported by provider"))
2295 }
2296 fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
2297 Err(anyhow::anyhow!("flip not supported by provider"))
2298 }
2299 fn circshift(
2300 &self,
2301 _handle: &GpuTensorHandle,
2302 _shifts: &[isize],
2303 ) -> anyhow::Result<GpuTensorHandle> {
2304 Err(anyhow::anyhow!("circshift not supported by provider"))
2305 }
2306 fn diff_dim(
2307 &self,
2308 _handle: &GpuTensorHandle,
2309 _order: usize,
2310 _dim: usize,
2311 ) -> anyhow::Result<GpuTensorHandle> {
2312 Err(anyhow::anyhow!("diff_dim not supported by provider"))
2313 }
2314 fn gradient_dim(
2315 &self,
2316 _handle: &GpuTensorHandle,
2317 _dim: usize,
2318 _spacing: f64,
2319 ) -> anyhow::Result<GpuTensorHandle> {
2320 Err(anyhow::anyhow!("gradient_dim not supported by provider"))
2321 }
2322 fn fft_dim<'a>(
2324 &'a self,
2325 _handle: &'a GpuTensorHandle,
2326 _len: Option<usize>,
2327 _dim: usize,
2328 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2329 unsupported_future("fft_dim not supported by provider")
2330 }
2331 fn ifft_dim<'a>(
2332 &'a self,
2333 _handle: &'a GpuTensorHandle,
2334 _len: Option<usize>,
2335 _dim: usize,
2336 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2337 unsupported_future("ifft_dim not supported by provider")
2338 }
2339 fn fft_extract_real<'a>(
2340 &'a self,
2341 _handle: &'a GpuTensorHandle,
2342 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2343 unsupported_future("fft_extract_real not supported by provider")
2344 }
2345 fn unique<'a>(
2346 &'a self,
2347 _handle: &'a GpuTensorHandle,
2348 _options: &'a UniqueOptions,
2349 ) -> AccelProviderFuture<'a, UniqueResult> {
2350 Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
2351 }
2352 fn union<'a>(
2353 &'a self,
2354 _a: &'a GpuTensorHandle,
2355 _b: &'a GpuTensorHandle,
2356 _options: &'a UnionOptions,
2357 ) -> AccelProviderFuture<'a, UnionResult> {
2358 Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
2359 }
2360 fn setdiff<'a>(
2361 &'a self,
2362 _a: &'a GpuTensorHandle,
2363 _b: &'a GpuTensorHandle,
2364 _options: &'a SetdiffOptions,
2365 ) -> AccelProviderFuture<'a, SetdiffResult> {
2366 Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
2367 }
2368 fn ismember<'a>(
2369 &'a self,
2370 _a: &'a GpuTensorHandle,
2371 _b: &'a GpuTensorHandle,
2372 _options: &'a IsMemberOptions,
2373 ) -> AccelProviderFuture<'a, IsMemberResult> {
2374 Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
2375 }
2376 fn reshape(
2377 &self,
2378 handle: &GpuTensorHandle,
2379 new_shape: &[usize],
2380 ) -> anyhow::Result<GpuTensorHandle> {
2381 let mut updated = handle.clone();
2382 updated.shape = new_shape.to_vec();
2383 Ok(updated)
2384 }
2385 fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
2387 Err(anyhow::anyhow!("cat not supported by provider"))
2388 }
2389 fn repmat(
2390 &self,
2391 _handle: &GpuTensorHandle,
2392 _reps: &[usize],
2393 ) -> anyhow::Result<GpuTensorHandle> {
2394 Err(anyhow::anyhow!("repmat not supported by provider"))
2395 }
2396 fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2398 Err(anyhow::anyhow!("kron not supported by provider"))
2399 }
2400 fn cross(
2402 &self,
2403 _lhs: &GpuTensorHandle,
2404 _rhs: &GpuTensorHandle,
2405 _dim: Option<usize>,
2406 ) -> anyhow::Result<GpuTensorHandle> {
2407 Err(anyhow::anyhow!("cross not supported by provider"))
2408 }
2409 fn reduce_sum<'a>(
2410 &'a self,
2411 _a: &'a GpuTensorHandle,
2412 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2413 unsupported_future("reduce_sum not supported by provider")
2414 }
2415 fn reduce_sum_dim<'a>(
2416 &'a self,
2417 _a: &'a GpuTensorHandle,
2418 _dim: usize,
2419 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2420 unsupported_future("reduce_sum_dim not supported by provider")
2421 }
2422 fn dot<'a>(
2423 &'a self,
2424 _lhs: &'a GpuTensorHandle,
2425 _rhs: &'a GpuTensorHandle,
2426 _dim: Option<usize>,
2427 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2428 unsupported_future("dot not supported by provider")
2429 }
2430 fn reduce_nnz<'a>(
2431 &'a self,
2432 _a: &'a GpuTensorHandle,
2433 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2434 unsupported_future("reduce_nnz not supported by provider")
2435 }
2436 fn reduce_nnz_dim<'a>(
2437 &'a self,
2438 _a: &'a GpuTensorHandle,
2439 _dim: usize,
2440 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2441 unsupported_future("reduce_nnz_dim not supported by provider")
2442 }
2443 fn reduce_prod<'a>(
2444 &'a self,
2445 _a: &'a GpuTensorHandle,
2446 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2447 unsupported_future("reduce_prod not supported by provider")
2448 }
2449 fn reduce_prod_dim<'a>(
2450 &'a self,
2451 _a: &'a GpuTensorHandle,
2452 _dim: usize,
2453 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2454 unsupported_future("reduce_prod_dim not supported by provider")
2455 }
2456 fn reduce_mean<'a>(
2457 &'a self,
2458 _a: &'a GpuTensorHandle,
2459 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2460 unsupported_future("reduce_mean not supported by provider")
2461 }
2462 fn reduce_mean_nd<'a>(
2464 &'a self,
2465 _a: &'a GpuTensorHandle,
2466 _dims_zero_based: &'a [usize],
2467 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2468 unsupported_future("reduce_mean_nd not supported by provider")
2469 }
2470 fn reduce_moments_nd<'a>(
2473 &'a self,
2474 _a: &'a GpuTensorHandle,
2475 _dims_zero_based: &'a [usize],
2476 ) -> AccelProviderFuture<'a, ProviderMoments2> {
2477 unsupported_future("reduce_moments_nd not supported by provider")
2478 }
2479 fn reduce_mean_dim<'a>(
2480 &'a self,
2481 _a: &'a GpuTensorHandle,
2482 _dim: usize,
2483 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2484 unsupported_future("reduce_mean_dim not supported by provider")
2485 }
2486 fn reduce_std<'a>(
2487 &'a self,
2488 _a: &'a GpuTensorHandle,
2489 _normalization: ProviderStdNormalization,
2490 _nan_mode: ProviderNanMode,
2491 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2492 unsupported_future("reduce_std not supported by provider")
2493 }
2494 fn reduce_std_dim<'a>(
2495 &'a self,
2496 _a: &'a GpuTensorHandle,
2497 _dim: usize,
2498 _normalization: ProviderStdNormalization,
2499 _nan_mode: ProviderNanMode,
2500 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2501 unsupported_future("reduce_std_dim not supported by provider")
2502 }
2503 fn reduce_any<'a>(
2504 &'a self,
2505 _a: &'a GpuTensorHandle,
2506 _omit_nan: bool,
2507 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2508 unsupported_future("reduce_any not supported by provider")
2509 }
2510 fn reduce_any_dim<'a>(
2511 &'a self,
2512 _a: &'a GpuTensorHandle,
2513 _dim: usize,
2514 _omit_nan: bool,
2515 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2516 unsupported_future("reduce_any_dim not supported by provider")
2517 }
2518 fn reduce_all<'a>(
2519 &'a self,
2520 _a: &'a GpuTensorHandle,
2521 _omit_nan: bool,
2522 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2523 unsupported_future("reduce_all not supported by provider")
2524 }
2525 fn reduce_all_dim<'a>(
2526 &'a self,
2527 _a: &'a GpuTensorHandle,
2528 _dim: usize,
2529 _omit_nan: bool,
2530 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2531 unsupported_future("reduce_all_dim not supported by provider")
2532 }
2533 fn reduce_median<'a>(
2534 &'a self,
2535 _a: &'a GpuTensorHandle,
2536 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2537 unsupported_future("reduce_median not supported by provider")
2538 }
2539 fn reduce_median_dim<'a>(
2540 &'a self,
2541 _a: &'a GpuTensorHandle,
2542 _dim: usize,
2543 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2544 unsupported_future("reduce_median_dim not supported by provider")
2545 }
2546 fn reduce_min<'a>(
2547 &'a self,
2548 _a: &'a GpuTensorHandle,
2549 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2550 unsupported_future("reduce_min not supported by provider")
2551 }
2552 fn reduce_min_dim<'a>(
2553 &'a self,
2554 _a: &'a GpuTensorHandle,
2555 _dim: usize,
2556 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2557 unsupported_future("reduce_min_dim not supported by provider")
2558 }
2559 fn reduce_max<'a>(
2560 &'a self,
2561 _a: &'a GpuTensorHandle,
2562 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2563 unsupported_future("reduce_max not supported by provider")
2564 }
2565 fn reduce_max_dim<'a>(
2566 &'a self,
2567 _a: &'a GpuTensorHandle,
2568 _dim: usize,
2569 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2570 unsupported_future("reduce_max_dim not supported by provider")
2571 }
2572 fn cumsum_scan(
2573 &self,
2574 _input: &GpuTensorHandle,
2575 _dim: usize,
2576 _direction: ProviderScanDirection,
2577 _nan_mode: ProviderNanMode,
2578 ) -> anyhow::Result<GpuTensorHandle> {
2579 Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2580 }
2581 fn cumprod_scan(
2582 &self,
2583 _input: &GpuTensorHandle,
2584 _dim: usize,
2585 _direction: ProviderScanDirection,
2586 _nan_mode: ProviderNanMode,
2587 ) -> anyhow::Result<GpuTensorHandle> {
2588 Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2589 }
2590 fn cummin_scan(
2591 &self,
2592 _input: &GpuTensorHandle,
2593 _dim: usize,
2594 _direction: ProviderScanDirection,
2595 _nan_mode: ProviderNanMode,
2596 ) -> anyhow::Result<ProviderCumminResult> {
2597 Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2598 }
2599 fn cummax_scan(
2600 &self,
2601 _input: &GpuTensorHandle,
2602 _dim: usize,
2603 _direction: ProviderScanDirection,
2604 _nan_mode: ProviderNanMode,
2605 ) -> anyhow::Result<ProviderCummaxResult> {
2606 Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2607 }
2608
2609 fn find(
2610 &self,
2611 _a: &GpuTensorHandle,
2612 _limit: Option<usize>,
2613 _direction: FindDirection,
2614 ) -> anyhow::Result<ProviderFindResult> {
2615 Err(anyhow::anyhow!("find not supported by provider"))
2616 }
2617
2618 fn fused_elementwise(
2619 &self,
2620 _shader: &str,
2621 _inputs: &[GpuTensorHandle],
2622 _output_shape: &[usize],
2623 _len: usize,
2624 ) -> anyhow::Result<GpuTensorHandle> {
2625 Err(anyhow::anyhow!(
2626 "fused_elementwise not supported by provider"
2627 ))
2628 }
2629
2630 fn fused_elementwise_multi(
2639 &self,
2640 _shader: &str,
2641 _inputs: &[GpuTensorHandle],
2642 _output_shape: &[usize],
2643 _len: usize,
2644 _num_outputs: usize,
2645 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2646 Err(anyhow::anyhow!(
2647 "fused_elementwise_multi not supported by provider"
2648 ))
2649 }
2650
2651 fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2653 Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2654 }
2655
2656 fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2658 Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2659 }
2660
2661 #[allow(clippy::too_many_arguments)]
2668 fn fused_reduction(
2669 &self,
2670 _shader: &str,
2671 _inputs: &[GpuTensorHandle],
2672 _output_shape: &[usize],
2673 _reduce_len: usize,
2674 _num_slices: usize,
2675 _workgroup_size: u32,
2676 _flavor: ReductionFlavor,
2677 ) -> anyhow::Result<GpuTensorHandle> {
2678 Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2679 }
2680
2681 fn warmup(&self) {}
2683
2684 fn fused_cache_counters(&self) -> (u64, u64) {
2686 (0, 0)
2687 }
2688
2689 fn last_warmup_millis(&self) -> Option<u64> {
2691 None
2692 }
2693
2694 fn telemetry_snapshot(&self) -> ProviderTelemetry {
2696 let (hits, misses) = self.fused_cache_counters();
2697 ProviderTelemetry {
2698 fused_elementwise: ProviderDispatchStats::default(),
2699 fused_reduction: ProviderDispatchStats::default(),
2700 matmul: ProviderDispatchStats::default(),
2701 linsolve: ProviderDispatchStats::default(),
2702 mldivide: ProviderDispatchStats::default(),
2703 mrdivide: ProviderDispatchStats::default(),
2704 upload_bytes: 0,
2705 download_bytes: 0,
2706 solve_fallbacks: Vec::new(),
2707 fusion_cache_hits: hits,
2708 fusion_cache_misses: misses,
2709 bind_group_cache_hits: 0,
2710 bind_group_cache_misses: 0,
2711 bind_group_cache_by_layout: None,
2712 kernel_launches: Vec::new(),
2713 }
2714 }
2715
2716 fn reset_telemetry(&self) {}
2718
2719 fn default_reduction_workgroup_size(&self) -> u32 {
2721 256
2722 }
2723
2724 fn two_pass_threshold(&self) -> usize {
2726 1024
2727 }
2728
2729 fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2731 ReductionTwoPassMode::Auto
2732 }
2733
2734 fn scatter_column(
2737 &self,
2738 _matrix: &GpuTensorHandle,
2739 _col_index: usize,
2740 _values: &GpuTensorHandle,
2741 ) -> anyhow::Result<GpuTensorHandle> {
2742 Err(anyhow::anyhow!("scatter_column not supported by provider"))
2743 }
2744
2745 fn scatter_row(
2748 &self,
2749 _matrix: &GpuTensorHandle,
2750 _row_index: usize,
2751 _values: &GpuTensorHandle,
2752 ) -> anyhow::Result<GpuTensorHandle> {
2753 Err(anyhow::anyhow!("scatter_row not supported by provider"))
2754 }
2755
2756 fn sub2ind(
2757 &self,
2758 _dims: &[usize],
2759 _strides: &[usize],
2760 _inputs: &[&GpuTensorHandle],
2761 _scalar_mask: &[bool],
2762 _len: usize,
2763 _output_shape: &[usize],
2764 ) -> anyhow::Result<GpuTensorHandle> {
2765 Err(anyhow::anyhow!("sub2ind not supported by provider"))
2766 }
2767
2768 fn supports_ind2sub(&self) -> bool {
2770 false
2771 }
2772
2773 fn ind2sub(
2775 &self,
2776 _dims: &[usize],
2777 _strides: &[usize],
2778 _indices: &GpuTensorHandle,
2779 _total: usize,
2780 _len: usize,
2781 _output_shape: &[usize],
2782 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2783 Err(anyhow::anyhow!("ind2sub not supported by provider"))
2784 }
2785
2786 fn issymmetric(
2788 &self,
2789 _matrix: &GpuTensorHandle,
2790 _kind: ProviderSymmetryKind,
2791 _tolerance: f64,
2792 ) -> anyhow::Result<bool> {
2793 Err(anyhow::anyhow!(
2794 "issymmetric predicate not supported by provider"
2795 ))
2796 }
2797
2798 fn ishermitian<'a>(
2800 &'a self,
2801 _matrix: &'a GpuTensorHandle,
2802 _kind: ProviderHermitianKind,
2803 _tolerance: f64,
2804 ) -> AccelProviderFuture<'a, bool> {
2805 Box::pin(async move {
2806 Err(anyhow::anyhow!(
2807 "ishermitian predicate not supported by provider"
2808 ))
2809 })
2810 }
2811
2812 fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2814 Err(anyhow::anyhow!("bandwidth not supported by provider"))
2815 }
2816
2817 fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2822 Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2823 }
2824}
2825
2826static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2827 Lazy::new(|| RwLock::new(None));
2828static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2829 Lazy::new(|| RwLock::new(HashMap::new()));
2830static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2831
2832#[cfg(not(target_arch = "wasm32"))]
2833thread_local! {
2834 static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2835}
2836
2837#[cfg(target_arch = "wasm32")]
2838static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2839 Lazy::new(|| Mutex::new(None));
2840
2841#[cfg(not(target_arch = "wasm32"))]
2842fn replace_thread_provider(
2843 provider: Option<&'static dyn AccelProvider>,
2844) -> Option<&'static dyn AccelProvider> {
2845 THREAD_PROVIDER.with(|cell| {
2846 let prev = cell.get();
2847 cell.set(provider);
2848 prev
2849 })
2850}
2851
2852#[cfg(target_arch = "wasm32")]
2853fn replace_thread_provider(
2854 provider: Option<&'static dyn AccelProvider>,
2855) -> Option<&'static dyn AccelProvider> {
2856 let mut slot = WASM_THREAD_PROVIDER
2857 .lock()
2858 .expect("wasm provider mutex poisoned");
2859 let prev = *slot;
2860 *slot = provider;
2861 prev
2862}
2863
2864#[cfg(not(target_arch = "wasm32"))]
2865fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2866 THREAD_PROVIDER.with(|cell| cell.get())
2867}
2868
2869#[cfg(target_arch = "wasm32")]
2870fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2871 WASM_THREAD_PROVIDER
2872 .lock()
2873 .expect("wasm provider mutex poisoned")
2874 .as_ref()
2875 .copied()
2876}
2877
2878pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2886 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2887 *guard = Some(p);
2888 }
2889 register_provider_for_device(p.device_id(), p);
2890}
2891
2892unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2893 if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2894 guard.insert(device_id, provider);
2895 }
2896}
2897
2898pub fn provider() -> Option<&'static dyn AccelProvider> {
2899 if let Some(p) = current_thread_provider() {
2900 return Some(p);
2901 }
2902 GLOBAL_PROVIDER
2903 .read()
2904 .ok()
2905 .and_then(|guard| guard.as_ref().copied())
2906}
2907
2908pub fn clear_provider() {
2910 replace_thread_provider(None);
2911 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2912 *guard = None;
2913 }
2914 if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2915 map.clear();
2916 }
2917}
2918
2919pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2920 if let Some(registered) = PROVIDER_REGISTRY
2921 .read()
2922 .ok()
2923 .and_then(|guard| guard.get(&device_id).copied())
2924 {
2925 return Some(registered);
2926 }
2927 if let Some(thread_provider) = current_thread_provider() {
2928 if thread_provider.device_id() == device_id {
2929 return Some(thread_provider);
2930 }
2931 }
2932 GLOBAL_PROVIDER
2935 .read()
2936 .ok()
2937 .and_then(|guard| guard.as_ref().copied())
2938}
2939
2940pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2941 provider_for_device(handle.device_id)
2942}
2943
2944pub fn spawn_handle_concurrency_for(handle: &GpuTensorHandle) -> Option<SpawnHandleConcurrency> {
2945 provider_for_handle(handle).map(AccelProvider::spawn_handle_concurrency)
2946}
2947
2948pub fn next_device_id() -> u32 {
2949 DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2950}
2951
2952pub struct ThreadProviderGuard {
2953 prev: Option<&'static dyn AccelProvider>,
2954}
2955
2956impl ThreadProviderGuard {
2957 pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2958 let prev = replace_thread_provider(provider);
2959 ThreadProviderGuard { prev }
2960 }
2961}
2962
2963impl Drop for ThreadProviderGuard {
2964 fn drop(&mut self) {
2965 let prev = self.prev.take();
2966 replace_thread_provider(prev);
2967 }
2968}
2969
2970pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2971 replace_thread_provider(provider);
2972}
2973
2974pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2976 if let Some(p) = provider() {
2977 if let Ok(h) = p.elem_add(a, b).await {
2978 return Some(h);
2979 }
2980 }
2981 None
2982}
2983
2984pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2986 if let Some(p) = provider() {
2987 if let Ok(h) = p.elem_hypot(a, b).await {
2988 return Some(h);
2989 }
2990 }
2991 None
2992}
2993
2994pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2996 if let Some(p) = provider() {
2997 if let Ok(h) = p.elem_max(a, b).await {
2998 return Some(h);
2999 }
3000 }
3001 None
3002}
3003
3004pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
3006 if let Some(p) = provider() {
3007 if let Ok(h) = p.elem_min(a, b).await {
3008 return Some(h);
3009 }
3010 }
3011 None
3012}
3013
3014pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
3016 if let Some(p) = provider() {
3017 if let Ok(h) = p.elem_atan2(y, x).await {
3018 return Some(h);
3019 }
3020 }
3021 None
3022}
3023
3024#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3026pub struct HostTensorOwned {
3027 pub data: Vec<f64>,
3028 pub shape: Vec<usize>,
3029 pub storage: GpuTensorStorage,
3030}
3031
3032#[derive(Debug)]
3033pub struct HostTensorView<'a> {
3034 pub data: &'a [f64],
3035 pub shape: &'a [usize],
3036}
3037
3038#[derive(Debug)]
3040pub struct MeshgridAxisView<'a> {
3041 pub data: &'a [f64],
3042}
3043
3044#[derive(Debug, Clone)]
3046pub struct ProviderMeshgridResult {
3047 pub outputs: Vec<GpuTensorHandle>,
3048}
3049
3050#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
3056pub enum ScaleOp {
3057 Multiply,
3058 Divide,
3059}
3060
3061#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3062pub struct MatmulEpilogue {
3063 pub alpha: f64,
3065 pub beta: f64,
3067 pub row_scale: Option<GpuTensorHandle>,
3069 pub col_scale: Option<GpuTensorHandle>,
3071 pub row_op: ScaleOp,
3073 pub col_op: ScaleOp,
3075 #[serde(default)]
3077 pub clamp_min: Option<f64>,
3078 #[serde(default)]
3080 pub clamp_max: Option<f64>,
3081 #[serde(default)]
3083 pub pow_exponent: Option<f64>,
3084 #[serde(default)]
3086 pub diag_output: Option<GpuTensorHandle>,
3087}
3088
3089impl MatmulEpilogue {
3090 pub fn noop() -> Self {
3091 Self {
3092 alpha: 1.0,
3093 beta: 0.0,
3094 row_scale: None,
3095 col_scale: None,
3096 row_op: ScaleOp::Multiply,
3097 col_op: ScaleOp::Multiply,
3098 clamp_min: None,
3099 clamp_max: None,
3100 pow_exponent: None,
3101 diag_output: None,
3102 }
3103 }
3104 pub fn is_noop(&self) -> bool {
3105 self.alpha == 1.0
3106 && self.beta == 0.0
3107 && self.row_scale.is_none()
3108 && self.col_scale.is_none()
3109 && self.clamp_min.is_none()
3110 && self.clamp_max.is_none()
3111 && self.pow_exponent.is_none()
3112 && self.diag_output.is_none()
3113 }
3114}
3115
3116#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
3117pub struct PowerStepEpilogue {
3118 pub epsilon: f64,
3119}
3120
3121impl Default for PowerStepEpilogue {
3122 fn default() -> Self {
3123 Self { epsilon: 0.0 }
3124 }
3125}
3126
3127#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3128pub struct ImageNormalizeDescriptor {
3129 pub batch: usize,
3130 pub height: usize,
3131 pub width: usize,
3132 pub epsilon: f64,
3133 #[serde(default)]
3134 pub gain: Option<f64>,
3135 #[serde(default)]
3136 pub bias: Option<f64>,
3137 #[serde(default)]
3138 pub gamma: Option<f64>,
3139 #[serde(default = "default_image_normalize_clamp_zero")]
3140 pub clamp_zero: bool,
3141}
3142
3143fn default_image_normalize_clamp_zero() -> bool {
3144 true
3145}
3146
3147#[cfg(test)]
3148mod tests {
3149 use super::*;
3150
3151 struct TestProvider {
3152 device_id: u32,
3153 name: &'static str,
3154 spawn_concurrency: SpawnHandleConcurrency,
3155 }
3156
3157 impl AccelProvider for TestProvider {
3158 fn upload(&self, _host: &HostTensorView) -> anyhow::Result<GpuTensorHandle> {
3159 Err(anyhow!("test provider upload should not be called"))
3160 }
3161
3162 fn download<'a>(&'a self, _h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a> {
3163 unsupported_future("test provider download should not be called")
3164 }
3165
3166 fn free(&self, _h: &GpuTensorHandle) -> anyhow::Result<()> {
3167 Err(anyhow!("test provider free should not be called"))
3168 }
3169
3170 fn device_info(&self) -> String {
3171 self.name.to_string()
3172 }
3173
3174 fn device_id(&self) -> u32 {
3175 self.device_id
3176 }
3177
3178 fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
3179 self.spawn_concurrency
3180 }
3181 }
3182
3183 static PROVIDER_TEST_LOCK: Lazy<std::sync::Mutex<()>> = Lazy::new(|| std::sync::Mutex::new(()));
3184 static PROVIDER_A: TestProvider = TestProvider {
3185 device_id: 101,
3186 name: "provider-a",
3187 spawn_concurrency: SpawnHandleConcurrency::ImmutableShare,
3188 };
3189 static PROVIDER_B: TestProvider = TestProvider {
3190 device_id: 202,
3191 name: "provider-b",
3192 spawn_concurrency: SpawnHandleConcurrency::Reject,
3193 };
3194 static PROVIDER_C: TestProvider = TestProvider {
3195 device_id: 303,
3196 name: "provider-c",
3197 spawn_concurrency: SpawnHandleConcurrency::CopyOnWrite,
3198 };
3199
3200 fn register_test_providers() {
3201 clear_provider();
3202 unsafe {
3203 register_provider(&PROVIDER_A);
3204 register_provider(&PROVIDER_B);
3205 }
3206 }
3207
3208 fn test_handle(device_id: u32) -> GpuTensorHandle {
3209 GpuTensorHandle {
3210 shape: vec![1],
3211 device_id,
3212 buffer_id: 42,
3213 }
3214 }
3215
3216 fn spectral_request<'a>(
3217 input: &'a GpuTensorHandle,
3218 frame_mode: ProviderSpectralFrameMode,
3219 ) -> ProviderSpectralRequest<'a> {
3220 static WINDOW: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
3221 ProviderSpectralRequest {
3222 input,
3223 input_len: 16,
3224 input_complex: false,
3225 window: &WINDOW,
3226 nfft: 8,
3227 frame_count: 3,
3228 frame_mode,
3229 range: ProviderSpectralRange::Onesided,
3230 denominator: 1.0,
3231 }
3232 }
3233
3234 #[test]
3235 fn provider_envelope_shape_guard_rejects_equal_len_layout_spoofing() {
3236 assert!(provider_envelope_input_shape_matches(&[2, 3], 2, 3));
3237 assert!(provider_envelope_input_shape_matches(&[6, 1], 6, 1));
3238 assert!(provider_envelope_input_shape_matches(&[1, 6], 6, 1));
3239 assert!(provider_envelope_input_shape_matches(&[6], 6, 1));
3240
3241 assert!(!provider_envelope_input_shape_matches(&[3, 2], 2, 3));
3242 assert!(!provider_envelope_input_shape_matches(&[6], 2, 3));
3243 assert!(!provider_envelope_input_shape_matches(&[2, 1, 3], 2, 3));
3244 }
3245
3246 #[test]
3247 fn provider_for_device_prefers_registered_device_over_thread_provider() {
3248 let _lock = PROVIDER_TEST_LOCK
3249 .lock()
3250 .expect("provider test lock poisoned");
3251 register_test_providers();
3252 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
3253
3254 let provider = provider_for_device(PROVIDER_A.device_id()).expect("provider for device");
3255
3256 assert_eq!(provider.device_info(), PROVIDER_A.name);
3257 clear_provider();
3258 }
3259
3260 #[test]
3261 fn provider_for_handle_uses_handle_device_owner() {
3262 let _lock = PROVIDER_TEST_LOCK
3263 .lock()
3264 .expect("provider test lock poisoned");
3265 register_test_providers();
3266 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
3267
3268 let provider =
3269 provider_for_handle(&test_handle(PROVIDER_A.device_id())).expect("provider for handle");
3270
3271 assert_eq!(provider.device_info(), PROVIDER_A.name);
3272 clear_provider();
3273 }
3274
3275 #[test]
3276 fn spawn_handle_concurrency_for_uses_registered_owner() {
3277 let _lock = PROVIDER_TEST_LOCK
3278 .lock()
3279 .expect("provider test lock poisoned");
3280 register_test_providers();
3281 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
3282
3283 let concurrency = spawn_handle_concurrency_for(&test_handle(PROVIDER_A.device_id()))
3284 .expect("spawn concurrency");
3285
3286 assert_eq!(concurrency, PROVIDER_A.spawn_concurrency);
3287 clear_provider();
3288 }
3289
3290 #[test]
3291 fn provider_keeps_thread_local_active_provider_semantics() {
3292 let _lock = PROVIDER_TEST_LOCK
3293 .lock()
3294 .expect("provider test lock poisoned");
3295 register_test_providers();
3296 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_A));
3297
3298 let active = provider().expect("active provider");
3299
3300 assert_eq!(active.device_info(), PROVIDER_A.name);
3301 clear_provider();
3302 }
3303
3304 #[test]
3305 fn unregistered_thread_provider_only_matches_own_device_before_global_fallback() {
3306 let _lock = PROVIDER_TEST_LOCK
3307 .lock()
3308 .expect("provider test lock poisoned");
3309 clear_provider();
3310 unsafe {
3311 register_provider(&PROVIDER_A);
3312 }
3313 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_C));
3314
3315 let own_device = provider_for_device(PROVIDER_C.device_id()).expect("own provider");
3316 let fallback = provider_for_device(404).expect("global fallback provider");
3317
3318 assert_eq!(own_device.device_info(), PROVIDER_C.name);
3319 assert_eq!(fallback.device_info(), PROVIDER_A.name);
3320 clear_provider();
3321 }
3322
3323 #[test]
3324 fn uniform_spectral_request_validates_sliding_input_coverage() {
3325 let input = test_handle(PROVIDER_A.device_id());
3326 let mut request = spectral_request(&input, ProviderSpectralFrameMode::Sliding { hop: 6 });
3327 assert!(validate_uniform_spectral_request(&request).is_ok());
3328
3329 request.input_len = 15;
3330 assert!(validate_uniform_spectral_request(&request).is_err());
3331 }
3332
3333 #[test]
3334 fn uniform_spectral_request_rejects_sliding_coverage_overflow() {
3335 let input = test_handle(PROVIDER_A.device_id());
3336 let mut request = spectral_request(&input, ProviderSpectralFrameMode::Sliding { hop: 2 });
3337 request.frame_count = usize::MAX;
3338
3339 assert!(validate_uniform_spectral_request(&request).is_err());
3340 }
3341
3342 #[test]
3343 fn uniform_spectral_request_validates_folded_input_coverage() {
3344 let input = test_handle(PROVIDER_A.device_id());
3345 let mut request = spectral_request(
3346 &input,
3347 ProviderSpectralFrameMode::FoldedColumns { input_rows: 5 },
3348 );
3349 assert!(validate_uniform_spectral_request(&request).is_ok());
3350
3351 request.frame_mode = ProviderSpectralFrameMode::FoldedColumns { input_rows: 0 };
3352 assert!(validate_uniform_spectral_request(&request).is_err());
3353
3354 request.frame_mode = ProviderSpectralFrameMode::FoldedColumns { input_rows: 6 };
3355 assert!(validate_uniform_spectral_request(&request).is_err());
3356 }
3357
3358 #[test]
3359 fn uniform_spectral_request_rejects_folded_coverage_overflow() {
3360 let input = test_handle(PROVIDER_A.device_id());
3361 let request = spectral_request(
3362 &input,
3363 ProviderSpectralFrameMode::FoldedColumns {
3364 input_rows: usize::MAX,
3365 },
3366 );
3367
3368 assert!(validate_uniform_spectral_request(&request).is_err());
3369 }
3370
3371 #[test]
3372 fn image_normalize_descriptor_omitted_clamp_zero_defaults_true() {
3373 let payload = r#"{
3374 "batch": 2,
3375 "height": 4,
3376 "width": 5,
3377 "epsilon": 0.000001
3378 }"#;
3379
3380 let desc: ImageNormalizeDescriptor =
3381 serde_json::from_str(payload).expect("deserialize descriptor");
3382
3383 assert!(
3384 desc.clamp_zero,
3385 "legacy serialized descriptors should default to clamped image normalize"
3386 );
3387 }
3388
3389 #[test]
3390 fn image_normalize_descriptor_explicit_false_preserves_unclamped() {
3391 let payload = r#"{
3392 "batch": 2,
3393 "height": 4,
3394 "width": 5,
3395 "epsilon": 0.000001,
3396 "clamp_zero": false
3397 }"#;
3398
3399 let desc: ImageNormalizeDescriptor =
3400 serde_json::from_str(payload).expect("deserialize descriptor");
3401
3402 assert!(
3403 !desc.clamp_zero,
3404 "explicit clamp_zero=false should preserve unclamped semantics"
3405 );
3406 }
3407}