1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyMarkFn = fn(&GpuTensorHandle);
17type ResidencyClearFn = fn(&GpuTensorHandle);
18type SequenceThresholdFn = fn() -> Option<usize>;
19type WorkgroupSizeHintFn = fn() -> Option<u32>;
20
21static RESIDENCY_MARK: OnceCell<ResidencyMarkFn> = OnceCell::new();
22static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
23static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
24static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
25
26static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
27static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
28 Lazy::new(|| RwLock::new(HashMap::new()));
29static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
30 Lazy::new(|| RwLock::new(HashMap::new()));
31
32static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
33 Lazy::new(|| RwLock::new(HashMap::new()));
34static HANDLE_STORAGES: Lazy<RwLock<HashMap<u64, GpuTensorStorage>>> =
35 Lazy::new(|| RwLock::new(HashMap::new()));
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct TransposeInfo {
39 pub base_rows: usize,
40 pub base_cols: usize,
41}
42
43pub fn register_residency_mark(handler: ResidencyMarkFn) {
46 let _ = RESIDENCY_MARK.set(handler);
47}
48
49pub fn mark_residency(handle: &GpuTensorHandle) {
52 if let Some(handler) = RESIDENCY_MARK.get() {
53 handler(handle);
54 }
55}
56
57pub fn register_residency_clear(handler: ResidencyClearFn) {
61 let _ = RESIDENCY_CLEAR.set(handler);
62}
63
64pub fn clear_residency(handle: &GpuTensorHandle) {
67 if let Some(handler) = RESIDENCY_CLEAR.get() {
68 handler(handle);
69 }
70}
71
72pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
76 let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
77}
78
79pub fn sequence_threshold_hint() -> Option<usize> {
81 SEQUENCE_THRESHOLD_PROVIDER
82 .get()
83 .and_then(|provider| provider())
84}
85
86pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
90 let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
91}
92
93pub fn workgroup_size_hint() -> Option<u32> {
95 WORKGROUP_SIZE_HINT_PROVIDER
96 .get()
97 .and_then(|provider| provider())
98}
99
100pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
103 provider().and_then(|p| p.export_context(kind))
104}
105
106#[cfg(feature = "wgpu")]
110pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
111 provider().and_then(|p| p.export_wgpu_buffer(handle))
112}
113
114pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
117 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
118 guard.insert(handle.buffer_id, precision);
119 }
120}
121
122pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
124 HANDLE_PRECISIONS
125 .read()
126 .ok()
127 .and_then(|guard| guard.get(&handle.buffer_id).copied())
128}
129
130pub fn clear_handle_precision(handle: &GpuTensorHandle) {
132 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
133 guard.remove(&handle.buffer_id);
134 }
135}
136
137pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
140 if let Ok(mut guard) = LOGICAL_HANDLES.write() {
141 if logical {
142 guard.insert(handle.buffer_id);
143 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
144 *hits.entry(handle.buffer_id).or_insert(0) += 1;
145 }
146 } else {
147 guard.remove(&handle.buffer_id);
148 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
149 hits.remove(&handle.buffer_id);
150 }
151 }
152 }
153}
154
155pub fn clear_handle_logical(handle: &GpuTensorHandle) {
157 set_handle_logical(handle, false);
158}
159
160pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
162 LOGICAL_HANDLES
163 .read()
164 .map(|guard| guard.contains(&handle.buffer_id))
165 .unwrap_or(false)
166}
167
168pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
169 LOGICAL_HANDLE_HITS
170 .read()
171 .ok()
172 .and_then(|guard| guard.get(&buffer_id).copied())
173}
174
175pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
176 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
177 guard.insert(
178 handle.buffer_id,
179 TransposeInfo {
180 base_rows,
181 base_cols,
182 },
183 );
184 }
185}
186
187pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
188 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
189 guard.remove(&handle.buffer_id);
190 }
191}
192
193pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
194 TRANSPOSED_HANDLES
195 .read()
196 .ok()
197 .and_then(|guard| guard.get(&handle.buffer_id).copied())
198}
199
200pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
201 handle_transpose_info(handle).is_some()
202}
203
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum GpuTensorStorage {
206 Real,
207 ComplexInterleaved,
208}
209
210impl Default for GpuTensorStorage {
211 fn default() -> Self {
212 Self::Real
213 }
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct GpuTensorHandle {
218 pub shape: Vec<usize>,
219 pub device_id: u32,
220 pub buffer_id: u64,
221}
222
223#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
224pub enum ProviderSpectralRange {
225 Onesided,
226 Twosided,
227 Centered,
228}
229
230#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
231pub enum ProviderSpectralFrameMode {
232 Sliding {
233 hop: usize,
234 },
235 ColumnSliding {
236 hop: usize,
237 input_rows: usize,
238 frames_per_column: usize,
239 },
240 FoldedColumns {
241 input_rows: usize,
242 },
243}
244
245#[derive(Clone, Debug)]
246pub struct ProviderSpectralRequest<'a> {
247 pub input: &'a GpuTensorHandle,
248 pub input_len: usize,
249 pub input_complex: bool,
250 pub window: &'a [f64],
251 pub nfft: usize,
252 pub frame_count: usize,
253 pub frame_mode: ProviderSpectralFrameMode,
254 pub range: ProviderSpectralRange,
255 pub denominator: f64,
256}
257
258#[derive(Clone, Debug)]
259pub struct ProviderSpectralResult {
260 pub s: GpuTensorHandle,
261 pub ps: GpuTensorHandle,
262 pub rows: usize,
263 pub cols: usize,
264}
265
266#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
267pub enum ProviderEnvelopeMethod {
268 Analytic,
269 AnalyticFir { filter_len: usize },
270 Rms { window_len: usize },
271}
272
273#[derive(Clone, Debug)]
274pub struct ProviderEnvelopeRequest<'a> {
275 pub input: &'a GpuTensorHandle,
276 pub channel_len: usize,
277 pub channel_count: usize,
278 pub output_shape: &'a [usize],
279 pub method: ProviderEnvelopeMethod,
280}
281
282#[derive(Clone, Debug)]
283pub struct ProviderEnvelopeResult {
284 pub upper: GpuTensorHandle,
285 pub lower: GpuTensorHandle,
286}
287
288#[derive(Clone, Debug)]
289pub struct ProviderHilbertRequest<'a> {
290 pub input: &'a GpuTensorHandle,
291 pub length: Option<usize>,
293 pub dim: usize,
295}
296
297#[derive(Clone, Debug)]
298pub struct ProviderModulationRequest<'a> {
299 pub input: &'a GpuTensorHandle,
300 pub constellation: &'a [f64],
302}
303
304#[derive(Clone, Debug)]
305pub struct ProviderBitModulationRequest<'a> {
306 pub input: &'a GpuTensorHandle,
307 pub input_rows: usize,
309 pub bits_per_symbol: usize,
311 pub constellation: &'a [f64],
313}
314
315pub async fn uniform_spectral_estimate(
316 request: ProviderSpectralRequest<'_>,
317) -> anyhow::Result<ProviderSpectralResult> {
318 validate_uniform_spectral_request(&request)?;
319
320 let provider =
321 provider().ok_or_else(|| anyhow!("uniform_spectral_estimate: GPU provider unavailable"))?;
322 provider.uniform_spectral_estimate(&request).await
323}
324
325fn validate_uniform_spectral_request(request: &ProviderSpectralRequest<'_>) -> anyhow::Result<()> {
326 let invalid_frame_mode = matches!(
327 request.frame_mode,
328 ProviderSpectralFrameMode::Sliding { hop: 0 }
329 | ProviderSpectralFrameMode::ColumnSliding { hop: 0, .. }
330 );
331 let invalid_input_coverage = match request.frame_mode {
332 ProviderSpectralFrameMode::Sliding { hop } => {
333 let required = request
334 .frame_count
335 .checked_sub(1)
336 .and_then(|frames| frames.checked_mul(hop))
337 .and_then(|offset| offset.checked_add(request.window.len()));
338 required.is_none_or(|required| required > request.input_len)
339 }
340 ProviderSpectralFrameMode::ColumnSliding {
341 hop,
342 input_rows,
343 frames_per_column,
344 } => {
345 if input_rows == 0 || frames_per_column == 0 {
346 true
347 } else {
348 let last_frame = request.frame_count - 1;
349 let source_col = last_frame / frames_per_column;
350 let segment = last_frame % frames_per_column;
351 source_col
352 .checked_mul(input_rows)
353 .and_then(|base| {
354 segment
355 .checked_mul(hop)
356 .and_then(|step| base.checked_add(step))
357 })
358 .and_then(|offset| offset.checked_add(request.window.len()))
359 .is_none_or(|required| required > request.input_len)
360 }
361 }
362 ProviderSpectralFrameMode::FoldedColumns { input_rows } => {
363 input_rows == 0
364 || input_rows
365 .checked_mul(request.frame_count)
366 .is_none_or(|required| required > request.input_len)
367 }
368 };
369 if request.window.is_empty()
370 || request.nfft == 0
371 || request.frame_count == 0
372 || invalid_frame_mode
373 || invalid_input_coverage
374 || !request.denominator.is_finite()
375 || request.denominator <= 0.0
376 {
377 return Err(anyhow!("uniform_spectral_estimate: invalid request"));
378 }
379
380 Ok(())
381}
382
383pub async fn signal_envelope(
384 request: ProviderEnvelopeRequest<'_>,
385) -> anyhow::Result<ProviderEnvelopeResult> {
386 let expected_len = request
387 .channel_len
388 .checked_mul(request.channel_count)
389 .ok_or_else(|| anyhow!("signal_envelope: invalid request"))?;
390 let output_len = request
391 .output_shape
392 .iter()
393 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
394 .ok_or_else(|| anyhow!("signal_envelope: invalid request"))?;
395 let input_len = request
396 .input
397 .shape
398 .iter()
399 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
400 .ok_or_else(|| anyhow!("signal_envelope: invalid request"))?;
401 if request.channel_len == 0
402 || request.channel_count == 0
403 || request.output_shape.is_empty()
404 || output_len != expected_len
405 || input_len != expected_len
406 || !provider_envelope_input_shape_matches(
407 &request.input.shape,
408 request.channel_len,
409 request.channel_count,
410 )
411 {
412 return Err(anyhow!("signal_envelope: invalid request"));
413 }
414
415 match request.method {
416 ProviderEnvelopeMethod::AnalyticFir { filter_len }
417 | ProviderEnvelopeMethod::Rms {
418 window_len: filter_len,
419 } if filter_len == 0 => return Err(anyhow!("signal_envelope: invalid request")),
420 _ => {}
421 }
422
423 let provider =
424 provider().ok_or_else(|| anyhow!("signal_envelope: GPU provider unavailable"))?;
425 provider.signal_envelope(&request).await
426}
427
428fn provider_envelope_input_shape_matches(
429 shape: &[usize],
430 channel_len: usize,
431 channel_count: usize,
432) -> bool {
433 if channel_count == 1 {
434 return match shape {
435 [len] => *len == channel_len,
436 [rows, cols] => {
437 (*rows == channel_len && *cols == 1) || (*rows == 1 && *cols == channel_len)
438 }
439 _ => false,
440 };
441 }
442
443 matches!(shape, [rows, cols] if *rows == channel_len && *cols == channel_count)
444}
445
446pub async fn signal_hilbert(
447 request: ProviderHilbertRequest<'_>,
448) -> anyhow::Result<GpuTensorHandle> {
449 if request.length == Some(0) {
450 return Err(anyhow!("signal_hilbert: invalid request"));
451 }
452 if request.dim >= request.input.shape.len() {
453 return Err(anyhow!("signal_hilbert: invalid request"));
454 }
455
456 let provider = provider().ok_or_else(|| anyhow!("signal_hilbert: GPU provider unavailable"))?;
457 provider.signal_hilbert(&request).await
458}
459
460#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
461pub struct ApiDeviceInfo {
462 pub device_id: u32,
463 pub name: String,
464 pub vendor: String,
465 pub memory_bytes: Option<u64>,
466 pub backend: Option<String>,
467}
468
469#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
470pub struct ReduceDimResult {
471 pub values: GpuTensorHandle,
472 pub indices: GpuTensorHandle,
473}
474
475#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
476pub struct ProviderCumminResult {
477 pub values: GpuTensorHandle,
478 pub indices: GpuTensorHandle,
479}
480
481pub type ProviderCummaxResult = ProviderCumminResult;
486
487#[derive(Debug, Clone, Copy, PartialEq, Eq)]
489pub enum AccelContextKind {
490 Plotting,
491}
492
493#[derive(Clone)]
495pub enum AccelContextHandle {
496 #[cfg(feature = "wgpu")]
497 Wgpu(WgpuContextHandle),
498}
499
500impl AccelContextHandle {
501 #[cfg(feature = "wgpu")]
503 pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
504 match self {
505 AccelContextHandle::Wgpu(ctx) => Some(ctx),
506 }
507 }
508}
509
510#[cfg(feature = "wgpu")]
512#[derive(Clone)]
513pub struct WgpuContextHandle {
514 pub instance: Arc<wgpu::Instance>,
515 pub device: Arc<wgpu::Device>,
516 pub queue: Arc<wgpu::Queue>,
517 pub adapter: Arc<wgpu::Adapter>,
518 pub adapter_info: wgpu::AdapterInfo,
519 pub limits: wgpu::Limits,
520 pub features: wgpu::Features,
521}
522
523#[cfg(feature = "wgpu")]
529#[derive(Clone)]
530pub struct WgpuBufferRef {
531 pub buffer: Arc<wgpu::Buffer>,
532 pub len: usize,
533 pub shape: Vec<usize>,
534 pub element_size: usize,
535 pub precision: ProviderPrecision,
536}
537
538pub fn set_handle_storage(handle: &GpuTensorHandle, storage: GpuTensorStorage) {
539 if let Ok(mut guard) = HANDLE_STORAGES.write() {
540 guard.insert(handle.buffer_id, storage);
541 }
542}
543
544pub fn handle_storage(handle: &GpuTensorHandle) -> GpuTensorStorage {
545 HANDLE_STORAGES
546 .read()
547 .ok()
548 .and_then(|guard| guard.get(&handle.buffer_id).cloned())
549 .unwrap_or(GpuTensorStorage::Real)
550}
551
552pub fn clear_handle_storage(handle: &GpuTensorHandle) {
553 if let Ok(mut guard) = HANDLE_STORAGES.write() {
554 guard.remove(&handle.buffer_id);
555 }
556}
557
558#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
559pub enum PagefunOp {
560 Mtimes,
561}
562
563#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
564pub struct PagefunRequest {
565 pub op: PagefunOp,
566 pub inputs: Vec<GpuTensorHandle>,
567 pub output_shape: Vec<usize>,
568 pub page_dims: Vec<usize>,
569 pub input_page_dims: Vec<Vec<usize>>,
570}
571
572#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
573pub enum FindDirection {
574 First,
575 Last,
576}
577
578#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
579pub struct ProviderFindResult {
580 pub linear: GpuTensorHandle,
581 pub rows: GpuTensorHandle,
582 pub cols: GpuTensorHandle,
583 pub values: Option<GpuTensorHandle>,
584}
585
586#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
587pub struct ProviderBandwidth {
588 pub lower: u32,
589 pub upper: u32,
590}
591
592#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
593pub enum ProviderSymmetryKind {
594 Symmetric,
595 Skew,
596}
597
598#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
599pub enum ProviderHermitianKind {
600 Hermitian,
601 Skew,
602}
603
604#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
605pub struct ProviderLuResult {
606 pub combined: GpuTensorHandle,
607 pub lower: GpuTensorHandle,
608 pub upper: GpuTensorHandle,
609 pub perm_matrix: GpuTensorHandle,
610 pub perm_vector: GpuTensorHandle,
611}
612
613#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
614pub struct ProviderCholResult {
615 pub factor: GpuTensorHandle,
616 pub info: u32,
618}
619
620#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
621pub struct ProviderQrResult {
622 pub q: GpuTensorHandle,
623 pub r: GpuTensorHandle,
624 pub perm_matrix: GpuTensorHandle,
625 pub perm_vector: GpuTensorHandle,
626}
627
628#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
629pub struct ProviderQrPowerIterResult {
630 pub q: GpuTensorHandle,
631 pub r: GpuTensorHandle,
632 pub perm_matrix: GpuTensorHandle,
633 pub perm_vector: GpuTensorHandle,
634}
635
636#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
637pub struct ProviderLinsolveOptions {
638 pub lower: bool,
639 pub upper: bool,
640 pub rectangular: bool,
641 pub transposed: bool,
642 pub conjugate: bool,
643 pub symmetric: bool,
644 pub posdef: bool,
645 pub need_rcond: bool,
646 pub rcond: Option<f64>,
647}
648
649#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
650pub struct ProviderLinsolveResult {
651 pub solution: GpuTensorHandle,
652 pub reciprocal_condition: f64,
653}
654
655#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
656pub struct ProviderPinvOptions {
657 pub tolerance: Option<f64>,
658}
659
660#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
661pub struct ProviderPolyvalMu {
662 pub mean: f64,
663 pub scale: f64,
664}
665
666#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
667pub struct ProviderPolyvalOptions {
668 pub mu: Option<ProviderPolyvalMu>,
669}
670
671#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
672pub struct ProviderInvOptions {}
673
674#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
675pub struct ProviderPolyfitResult {
676 pub coefficients: Vec<f64>,
677 pub r_matrix: Vec<f64>,
678 pub normr: f64,
679 pub df: f64,
680 pub mu: [f64; 2],
681}
682
683#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
685pub struct ProviderPolyderQuotient {
686 pub numerator: GpuTensorHandle,
687 pub denominator: GpuTensorHandle,
688}
689
690#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
692pub enum ProviderCondNorm {
693 Two,
694 One,
695 Inf,
696 Fro,
697}
698
699#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
701pub enum ProviderNormOrder {
702 Two,
703 One,
704 Inf,
705 NegInf,
706 Zero,
707 Fro,
708 Nuc,
709 P(f64),
710}
711
712#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
713pub struct ProviderEigResult {
714 pub eigenvalues: GpuTensorHandle,
715 pub diagonal: GpuTensorHandle,
716 pub right: GpuTensorHandle,
717 pub left: Option<GpuTensorHandle>,
718}
719
720#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
721pub enum ProviderQrPivot {
722 Matrix,
723 Vector,
724}
725
726#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
727pub struct ProviderQrOptions {
728 pub economy: bool,
729 pub pivot: ProviderQrPivot,
730}
731
732impl Default for ProviderQrOptions {
733 fn default() -> Self {
734 Self {
735 economy: false,
736 pivot: ProviderQrPivot::Matrix,
737 }
738 }
739}
740
741#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
742pub enum ProviderPrecision {
743 F32,
744 F64,
745}
746
747#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
752pub enum SpawnHandleConcurrency {
753 ImmutableShare,
755 CopyOnWrite,
757 SynchronizedMutation,
759 Reject,
761}
762
763impl SpawnHandleConcurrency {
764 pub fn as_str(self) -> &'static str {
765 match self {
766 SpawnHandleConcurrency::ImmutableShare => "immutable_share",
767 SpawnHandleConcurrency::CopyOnWrite => "copy_on_write",
768 SpawnHandleConcurrency::SynchronizedMutation => "synchronized_mutation",
769 SpawnHandleConcurrency::Reject => "reject",
770 }
771 }
772}
773
774#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
775pub enum ReductionTwoPassMode {
776 Auto,
777 ForceOn,
778 ForceOff,
779}
780
781impl ReductionTwoPassMode {
782 pub fn as_str(self) -> &'static str {
783 match self {
784 ReductionTwoPassMode::Auto => "auto",
785 ReductionTwoPassMode::ForceOn => "force_on",
786 ReductionTwoPassMode::ForceOff => "force_off",
787 }
788 }
789}
790
791#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
792pub enum ReductionFlavor {
793 Sum,
794 Mean,
795 CustomScale(f64),
796}
797
798impl ReductionFlavor {
799 pub fn is_mean(self) -> bool {
800 matches!(self, ReductionFlavor::Mean)
801 }
802
803 pub fn scale(self, reduce_len: usize) -> f64 {
804 match self {
805 ReductionFlavor::Sum => 1.0,
806 ReductionFlavor::Mean => {
807 if reduce_len == 0 {
808 1.0
809 } else {
810 1.0 / reduce_len as f64
811 }
812 }
813 ReductionFlavor::CustomScale(scale) => scale,
814 }
815 }
816}
817
818#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
820pub enum CorrcoefNormalization {
821 Unbiased,
822 Biased,
823}
824
825#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
827pub enum CorrcoefRows {
828 All,
829 Complete,
830 Pairwise,
831}
832
833#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
835pub struct CorrcoefOptions {
836 pub normalization: CorrcoefNormalization,
837 pub rows: CorrcoefRows,
838}
839
840impl Default for CorrcoefOptions {
841 fn default() -> Self {
842 Self {
843 normalization: CorrcoefNormalization::Unbiased,
844 rows: CorrcoefRows::All,
845 }
846 }
847}
848
849#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
851pub enum CovNormalization {
852 Unbiased,
853 Biased,
854}
855
856#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
858pub enum CovRows {
859 All,
860 OmitRows,
861 PartialRows,
862}
863
864#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
866pub struct CovarianceOptions {
867 pub normalization: CovNormalization,
868 pub rows: CovRows,
869 pub has_weight_vector: bool,
870}
871
872impl Default for CovarianceOptions {
873 fn default() -> Self {
874 Self {
875 normalization: CovNormalization::Unbiased,
876 rows: CovRows::All,
877 has_weight_vector: false,
878 }
879 }
880}
881
882#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
884pub enum ProviderStdNormalization {
885 Sample,
886 Population,
887}
888
889#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
891pub enum ProviderNanMode {
892 Include,
893 Omit,
894}
895
896#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
898pub enum ProviderScanDirection {
899 Forward,
900 Reverse,
901}
902
903#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
905pub enum SortOrder {
906 Ascend,
907 Descend,
908}
909
910#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
912pub enum SortComparison {
913 Auto,
914 Real,
915 Abs,
916}
917
918#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
920pub struct SortResult {
921 pub values: HostTensorOwned,
922 pub indices: HostTensorOwned,
923}
924
925#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
926pub struct SortRowsColumnSpec {
927 pub index: usize,
928 pub order: SortOrder,
929}
930
931#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
933pub enum UniqueOrder {
934 Sorted,
935 Stable,
936}
937
938#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
940pub enum UniqueOccurrence {
941 First,
942 Last,
943}
944
945#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
947pub struct UniqueOptions {
948 pub rows: bool,
949 pub order: UniqueOrder,
950 pub occurrence: UniqueOccurrence,
951}
952
953#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
955pub struct UniqueResult {
956 pub values: HostTensorOwned,
957 pub ia: HostTensorOwned,
958 pub ic: HostTensorOwned,
959}
960
961#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
963pub enum UnionOrder {
964 Sorted,
965 Stable,
966}
967
968#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
970pub struct UnionOptions {
971 pub rows: bool,
972 pub order: UnionOrder,
973}
974
975#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
977pub struct UnionResult {
978 pub values: HostTensorOwned,
979 pub ia: HostTensorOwned,
980 pub ib: HostTensorOwned,
981}
982
983#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
985pub enum FspecialFilter {
986 Average {
987 rows: u32,
988 cols: u32,
989 },
990 Disk {
991 radius: f64,
992 size: u32,
993 },
994 Gaussian {
995 rows: u32,
996 cols: u32,
997 sigma: f64,
998 },
999 Laplacian {
1000 alpha: f64,
1001 },
1002 Log {
1003 rows: u32,
1004 cols: u32,
1005 sigma: f64,
1006 },
1007 Motion {
1008 length: u32,
1009 kernel_size: u32,
1010 angle_degrees: f64,
1011 oversample: u32,
1012 },
1013 Prewitt,
1014 Sobel,
1015 Unsharp {
1016 alpha: f64,
1017 },
1018}
1019
1020#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1022pub struct FspecialRequest {
1023 pub filter: FspecialFilter,
1024}
1025
1026#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1028pub enum ImfilterPadding {
1029 Constant,
1030 Replicate,
1031 Symmetric,
1032 Circular,
1033}
1034
1035#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1037pub enum ImfilterShape {
1038 Same,
1039 Full,
1040 Valid,
1041}
1042
1043#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1045pub enum ImfilterMode {
1046 Correlation,
1047 Convolution,
1048}
1049
1050#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1052pub struct ImfilterOptions {
1053 pub padding: ImfilterPadding,
1054 pub constant_value: f64,
1055 pub shape: ImfilterShape,
1056 pub mode: ImfilterMode,
1057}
1058
1059impl Default for ImfilterOptions {
1060 fn default() -> Self {
1061 Self {
1062 padding: ImfilterPadding::Constant,
1063 constant_value: 0.0,
1064 shape: ImfilterShape::Same,
1065 mode: ImfilterMode::Correlation,
1066 }
1067 }
1068}
1069
1070#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1072pub enum SetdiffOrder {
1073 Sorted,
1074 Stable,
1075}
1076
1077#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1079pub struct SetdiffOptions {
1080 pub rows: bool,
1081 pub order: SetdiffOrder,
1082}
1083
1084#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1086pub struct SetdiffResult {
1087 pub values: HostTensorOwned,
1088 pub ia: HostTensorOwned,
1089}
1090
1091#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1093pub struct IsMemberOptions {
1094 pub rows: bool,
1095}
1096
1097#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1099pub struct HostLogicalOwned {
1100 pub data: Vec<u8>,
1101 pub shape: Vec<usize>,
1102}
1103
1104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1106pub struct IsMemberResult {
1107 pub mask: HostLogicalOwned,
1108 pub loc: HostTensorOwned,
1109}
1110
1111#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1112pub enum ProviderConvMode {
1113 Full,
1114 Same,
1115 Valid,
1116}
1117
1118#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1119pub enum ProviderConvOrientation {
1120 Row,
1121 Column,
1122}
1123
1124#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1125pub struct ProviderConv1dOptions {
1126 pub mode: ProviderConvMode,
1127 pub orientation: ProviderConvOrientation,
1128}
1129
1130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1131pub struct ProviderIirFilterOptions {
1132 pub dim: usize,
1134 pub zi: Option<GpuTensorHandle>,
1136}
1137
1138#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1139pub struct ProviderIirFilterResult {
1140 pub output: GpuTensorHandle,
1142 pub final_state: Option<GpuTensorHandle>,
1144}
1145
1146#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1147pub struct ProviderMoments2 {
1148 pub mean: GpuTensorHandle,
1149 pub ex2: GpuTensorHandle,
1150}
1151
1152#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
1153pub struct ProviderDispatchStats {
1154 pub count: u64,
1156 pub total_wall_time_ns: u64,
1158}
1159
1160#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
1161pub struct ProviderFallbackStat {
1162 pub reason: String,
1163 pub count: u64,
1164}
1165
1166#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
1167pub struct ProviderTelemetry {
1168 pub fused_elementwise: ProviderDispatchStats,
1169 pub fused_reduction: ProviderDispatchStats,
1170 pub matmul: ProviderDispatchStats,
1171 pub linsolve: ProviderDispatchStats,
1172 pub mldivide: ProviderDispatchStats,
1173 pub mrdivide: ProviderDispatchStats,
1174 pub upload_bytes: u64,
1175 pub download_bytes: u64,
1176 pub solve_fallbacks: Vec<ProviderFallbackStat>,
1177 pub fusion_cache_hits: u64,
1178 pub fusion_cache_misses: u64,
1179 pub bind_group_cache_hits: u64,
1180 pub bind_group_cache_misses: u64,
1181 pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
1183 pub kernel_launches: Vec<KernelLaunchTelemetry>,
1185}
1186
1187#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1188pub struct BindGroupLayoutTelemetry {
1189 pub tag: String,
1190 pub hits: u64,
1191 pub misses: u64,
1192}
1193
1194#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1195pub struct KernelAttrTelemetry {
1196 pub key: String,
1197 pub value: u64,
1198}
1199
1200#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1201pub struct KernelLaunchTelemetry {
1202 pub kernel: String,
1203 pub precision: Option<String>,
1204 pub shape: Vec<KernelAttrTelemetry>,
1205 pub tuning: Vec<KernelAttrTelemetry>,
1206}
1207
1208pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
1209pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
1210
1211fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
1212 Box::pin(async move { Err(anyhow::anyhow!(message)) })
1213}
1214
1215pub trait AccelProvider: Send + Sync {
1217 fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
1218 fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
1219 fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
1220 fn device_info(&self) -> String;
1221 fn device_id(&self) -> u32 {
1222 0
1223 }
1224
1225 fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
1231 SpawnHandleConcurrency::Reject
1232 }
1233
1234 fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
1237 None
1238 }
1239
1240 #[cfg(feature = "wgpu")]
1242 fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
1243 let _ = _handle;
1244 None
1245 }
1246
1247 fn gather_linear(
1250 &self,
1251 _source: &GpuTensorHandle,
1252 _indices: &[u32],
1253 _output_shape: &[usize],
1254 ) -> anyhow::Result<GpuTensorHandle> {
1255 Err(anyhow::anyhow!("gather_linear not supported by provider"))
1256 }
1257
1258 fn scatter_linear(
1262 &self,
1263 _target: &GpuTensorHandle,
1264 _indices: &[u32],
1265 _values: &GpuTensorHandle,
1266 ) -> anyhow::Result<()> {
1267 Err(anyhow::anyhow!("scatter_linear not supported by provider"))
1268 }
1269
1270 fn device_info_struct(&self) -> ApiDeviceInfo {
1272 ApiDeviceInfo {
1273 device_id: 0,
1274 name: self.device_info(),
1275 vendor: String::new(),
1276 memory_bytes: None,
1277 backend: None,
1278 }
1279 }
1280
1281 fn precision(&self) -> ProviderPrecision {
1282 ProviderPrecision::F64
1283 }
1284
1285 fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
1287 Err(anyhow::anyhow!("read_scalar not supported by provider"))
1288 }
1289
1290 fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1292 Err(anyhow::anyhow!("zeros not supported by provider"))
1293 }
1294
1295 fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1297 Err(anyhow::anyhow!("ones not supported by provider"))
1298 }
1299
1300 fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1302 self.zeros(&prototype.shape)
1303 }
1304
1305 fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
1307 if value == 0.0 {
1308 return self.zeros(shape);
1309 }
1310 if let Ok(base) = self.zeros(shape) {
1311 match self.scalar_add(&base, value) {
1312 Ok(out) => {
1313 let _ = self.free(&base);
1314 return Ok(out);
1315 }
1316 Err(_) => {
1317 let _ = self.free(&base);
1318 }
1319 }
1320 }
1321 let len: usize = shape.iter().copied().product();
1322 let data = vec![value; len];
1323 let view = HostTensorView { data: &data, shape };
1324 self.upload(&view)
1325 }
1326
1327 fn fill_like(
1329 &self,
1330 prototype: &GpuTensorHandle,
1331 value: f64,
1332 ) -> anyhow::Result<GpuTensorHandle> {
1333 if value == 0.0 {
1334 return self.zeros_like(prototype);
1335 }
1336 if let Ok(base) = self.zeros_like(prototype) {
1337 match self.scalar_add(&base, value) {
1338 Ok(out) => {
1339 let _ = self.free(&base);
1340 return Ok(out);
1341 }
1342 Err(_) => {
1343 let _ = self.free(&base);
1344 }
1345 }
1346 }
1347 self.fill(&prototype.shape, value)
1348 }
1349
1350 fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1352 self.ones(&prototype.shape)
1353 }
1354
1355 fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1357 Err(anyhow::anyhow!("eye not supported by provider"))
1358 }
1359
1360 fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1362 self.eye(&prototype.shape)
1363 }
1364
1365 fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1367 Err(anyhow::anyhow!("meshgrid not supported by provider"))
1368 }
1369
1370 fn diag_from_vector(
1372 &self,
1373 _vector: &GpuTensorHandle,
1374 _offset: isize,
1375 ) -> anyhow::Result<GpuTensorHandle> {
1376 Err(anyhow::anyhow!(
1377 "diag_from_vector not supported by provider"
1378 ))
1379 }
1380
1381 fn diag_extract(
1383 &self,
1384 _matrix: &GpuTensorHandle,
1385 _offset: isize,
1386 ) -> anyhow::Result<GpuTensorHandle> {
1387 Err(anyhow::anyhow!("diag_extract not supported by provider"))
1388 }
1389
1390 fn tril<'a>(
1392 &'a self,
1393 _matrix: &'a GpuTensorHandle,
1394 _offset: isize,
1395 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1396 Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1397 }
1398
1399 fn triu<'a>(
1401 &'a self,
1402 _matrix: &'a GpuTensorHandle,
1403 _offset: isize,
1404 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1405 Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1406 }
1407
1408 fn polyval(
1410 &self,
1411 _coefficients: &GpuTensorHandle,
1412 _points: &GpuTensorHandle,
1413 _options: &ProviderPolyvalOptions,
1414 ) -> anyhow::Result<GpuTensorHandle> {
1415 Err(anyhow::anyhow!("polyval not supported by provider"))
1416 }
1417
1418 fn polyfit<'a>(
1420 &'a self,
1421 _x: &'a GpuTensorHandle,
1422 _y: &'a GpuTensorHandle,
1423 _degree: usize,
1424 _weights: Option<&'a GpuTensorHandle>,
1425 ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1426 Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1427 }
1428
1429 fn polyder_single<'a>(
1431 &'a self,
1432 _polynomial: &'a GpuTensorHandle,
1433 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1434 Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1435 }
1436
1437 fn polyder_product<'a>(
1439 &'a self,
1440 _p: &'a GpuTensorHandle,
1441 _q: &'a GpuTensorHandle,
1442 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1443 Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1444 }
1445
1446 fn polyder_quotient<'a>(
1448 &'a self,
1449 _u: &'a GpuTensorHandle,
1450 _v: &'a GpuTensorHandle,
1451 ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1452 Box::pin(async move {
1453 Err(anyhow::anyhow!(
1454 "polyder_quotient not supported by provider"
1455 ))
1456 })
1457 }
1458
1459 fn polyint(
1461 &self,
1462 _polynomial: &GpuTensorHandle,
1463 _constant: f64,
1464 ) -> anyhow::Result<GpuTensorHandle> {
1465 Err(anyhow::anyhow!("polyint not supported by provider"))
1466 }
1467
1468 fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1470 Err(anyhow::anyhow!("random_uniform not supported by provider"))
1471 }
1472
1473 fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1475 self.random_uniform(&prototype.shape)
1476 }
1477
1478 fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1480 Err(anyhow::anyhow!("random_normal not supported by provider"))
1481 }
1482
1483 fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1485 self.random_normal(&prototype.shape)
1486 }
1487
1488 fn random_exponential(&self, _mu: f64, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1490 Err(anyhow::anyhow!(
1491 "random_exponential not supported by provider"
1492 ))
1493 }
1494
1495 fn random_normrnd(
1497 &self,
1498 _mu: f64,
1499 _sigma: f64,
1500 _shape: &[usize],
1501 ) -> anyhow::Result<GpuTensorHandle> {
1502 Err(anyhow::anyhow!("random_normrnd not supported by provider"))
1503 }
1504
1505 fn random_unifrnd(
1507 &self,
1508 _a: f64,
1509 _b: f64,
1510 _shape: &[usize],
1511 ) -> anyhow::Result<GpuTensorHandle> {
1512 Err(anyhow::anyhow!("random_unifrnd not supported by provider"))
1513 }
1514
1515 fn stochastic_evolution(
1516 &self,
1517 _state: &GpuTensorHandle,
1518 _drift: f64,
1519 _scale: f64,
1520 _steps: u32,
1521 ) -> anyhow::Result<GpuTensorHandle> {
1522 Err(anyhow::anyhow!(
1523 "stochastic_evolution not supported by provider"
1524 ))
1525 }
1526
1527 fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1529 Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1530 }
1531
1532 fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1534 Err(anyhow::anyhow!("fspecial not supported by provider"))
1535 }
1536
1537 fn peaks(&self, _n: usize) -> anyhow::Result<GpuTensorHandle> {
1540 Err(anyhow::anyhow!("peaks not supported by provider"))
1541 }
1542
1543 fn peaks_xy(
1546 &self,
1547 _x: &GpuTensorHandle,
1548 _y: &GpuTensorHandle,
1549 ) -> anyhow::Result<GpuTensorHandle> {
1550 Err(anyhow::anyhow!("peaks_xy not supported by provider"))
1551 }
1552
1553 fn hann_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1554 Err(anyhow::anyhow!("hann_window not supported by provider"))
1555 }
1556
1557 fn hamming_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1558 Err(anyhow::anyhow!("hamming_window not supported by provider"))
1559 }
1560
1561 fn blackman_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1562 Err(anyhow::anyhow!("blackman_window not supported by provider"))
1563 }
1564
1565 fn imfilter<'a>(
1567 &'a self,
1568 _image: &'a GpuTensorHandle,
1569 _kernel: &'a GpuTensorHandle,
1570 _options: &'a ImfilterOptions,
1571 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1572 unsupported_future("imfilter not supported by provider")
1573 }
1574
1575 fn random_integer_range(
1577 &self,
1578 _lower: i64,
1579 _upper: i64,
1580 _shape: &[usize],
1581 ) -> anyhow::Result<GpuTensorHandle> {
1582 Err(anyhow::anyhow!(
1583 "random_integer_range not supported by provider"
1584 ))
1585 }
1586
1587 fn random_integer_like(
1589 &self,
1590 prototype: &GpuTensorHandle,
1591 lower: i64,
1592 upper: i64,
1593 ) -> anyhow::Result<GpuTensorHandle> {
1594 self.random_integer_range(lower, upper, &prototype.shape)
1595 }
1596
1597 fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1599 Err(anyhow!("random_permutation not supported by provider"))
1600 }
1601
1602 fn random_permutation_like(
1604 &self,
1605 _prototype: &GpuTensorHandle,
1606 n: usize,
1607 k: usize,
1608 ) -> anyhow::Result<GpuTensorHandle> {
1609 self.random_permutation(n, k)
1610 }
1611
1612 fn covariance<'a>(
1614 &'a self,
1615 _matrix: &'a GpuTensorHandle,
1616 _second: Option<&'a GpuTensorHandle>,
1617 _weights: Option<&'a GpuTensorHandle>,
1618 _options: &'a CovarianceOptions,
1619 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1620 unsupported_future("covariance not supported by provider")
1621 }
1622
1623 fn corrcoef<'a>(
1625 &'a self,
1626 _matrix: &'a GpuTensorHandle,
1627 _options: &'a CorrcoefOptions,
1628 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1629 unsupported_future("corrcoef not supported by provider")
1630 }
1631
1632 fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1634 Err(anyhow::anyhow!("linspace not supported by provider"))
1635 }
1636 fn elem_add<'a>(
1637 &'a self,
1638 _a: &'a GpuTensorHandle,
1639 _b: &'a GpuTensorHandle,
1640 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1641 unsupported_future("elem_add not supported by provider")
1642 }
1643 fn elem_mul<'a>(
1644 &'a self,
1645 _a: &'a GpuTensorHandle,
1646 _b: &'a GpuTensorHandle,
1647 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1648 unsupported_future("elem_mul not supported by provider")
1649 }
1650 fn elem_max<'a>(
1651 &'a self,
1652 _a: &'a GpuTensorHandle,
1653 _b: &'a GpuTensorHandle,
1654 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1655 unsupported_future("elem_max not supported by provider")
1656 }
1657 fn elem_min<'a>(
1658 &'a self,
1659 _a: &'a GpuTensorHandle,
1660 _b: &'a GpuTensorHandle,
1661 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1662 unsupported_future("elem_min not supported by provider")
1663 }
1664 fn elem_sub<'a>(
1665 &'a self,
1666 _a: &'a GpuTensorHandle,
1667 _b: &'a GpuTensorHandle,
1668 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1669 unsupported_future("elem_sub not supported by provider")
1670 }
1671 fn elem_div<'a>(
1672 &'a self,
1673 _a: &'a GpuTensorHandle,
1674 _b: &'a GpuTensorHandle,
1675 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1676 unsupported_future("elem_div not supported by provider")
1677 }
1678 fn elem_pow<'a>(
1679 &'a self,
1680 _a: &'a GpuTensorHandle,
1681 _b: &'a GpuTensorHandle,
1682 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1683 unsupported_future("elem_pow not supported by provider")
1684 }
1685
1686 fn complex_from_real<'a>(
1689 &'a self,
1690 _real: &'a GpuTensorHandle,
1691 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1692 unsupported_future("complex_from_real not supported by provider")
1693 }
1694
1695 fn complex_from_real_imag<'a>(
1700 &'a self,
1701 _real: &'a GpuTensorHandle,
1702 _imag: &'a GpuTensorHandle,
1703 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1704 unsupported_future("complex_from_real_imag not supported by provider")
1705 }
1706
1707 fn modulate_constellation<'a>(
1710 &'a self,
1711 _request: ProviderModulationRequest<'a>,
1712 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1713 unsupported_future("modulate_constellation not supported by provider")
1714 }
1715
1716 fn modulate_bits_constellation<'a>(
1719 &'a self,
1720 _request: ProviderBitModulationRequest<'a>,
1721 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1722 unsupported_future("modulate_bits_constellation not supported by provider")
1723 }
1724
1725 fn elem_hypot<'a>(
1726 &'a self,
1727 _a: &'a GpuTensorHandle,
1728 _b: &'a GpuTensorHandle,
1729 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1730 unsupported_future("elem_hypot not supported by provider")
1731 }
1732 fn elem_ge<'a>(
1733 &'a self,
1734 _a: &'a GpuTensorHandle,
1735 _b: &'a GpuTensorHandle,
1736 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1737 unsupported_future("elem_ge not supported by provider")
1738 }
1739 fn elem_le<'a>(
1740 &'a self,
1741 _a: &'a GpuTensorHandle,
1742 _b: &'a GpuTensorHandle,
1743 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1744 unsupported_future("elem_le not supported by provider")
1745 }
1746 fn elem_lt<'a>(
1747 &'a self,
1748 _a: &'a GpuTensorHandle,
1749 _b: &'a GpuTensorHandle,
1750 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1751 unsupported_future("elem_lt not supported by provider")
1752 }
1753 fn elem_gt<'a>(
1754 &'a self,
1755 _a: &'a GpuTensorHandle,
1756 _b: &'a GpuTensorHandle,
1757 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1758 unsupported_future("elem_gt not supported by provider")
1759 }
1760 fn elem_eq<'a>(
1761 &'a self,
1762 _a: &'a GpuTensorHandle,
1763 _b: &'a GpuTensorHandle,
1764 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1765 unsupported_future("elem_eq not supported by provider")
1766 }
1767 fn elem_ne<'a>(
1768 &'a self,
1769 _a: &'a GpuTensorHandle,
1770 _b: &'a GpuTensorHandle,
1771 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1772 unsupported_future("elem_ne not supported by provider")
1773 }
1774 fn logical_and(
1775 &self,
1776 _a: &GpuTensorHandle,
1777 _b: &GpuTensorHandle,
1778 ) -> anyhow::Result<GpuTensorHandle> {
1779 Err(anyhow::anyhow!("logical_and not supported by provider"))
1780 }
1781 fn logical_or(
1782 &self,
1783 _a: &GpuTensorHandle,
1784 _b: &GpuTensorHandle,
1785 ) -> anyhow::Result<GpuTensorHandle> {
1786 Err(anyhow::anyhow!("logical_or not supported by provider"))
1787 }
1788 fn logical_xor(
1789 &self,
1790 _a: &GpuTensorHandle,
1791 _b: &GpuTensorHandle,
1792 ) -> anyhow::Result<GpuTensorHandle> {
1793 Err(anyhow::anyhow!("logical_xor not supported by provider"))
1794 }
1795 fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1796 Err(anyhow::anyhow!("logical_not not supported by provider"))
1797 }
1798 fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1799 Ok(handle_is_logical(a))
1800 }
1801 fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1802 Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1803 }
1804 fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1805 Err(anyhow::anyhow!(
1806 "logical_isfinite not supported by provider"
1807 ))
1808 }
1809 fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1810 Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1811 }
1812 fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1813 Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1814 }
1815 fn elem_atan2<'a>(
1816 &'a self,
1817 _y: &'a GpuTensorHandle,
1818 _x: &'a GpuTensorHandle,
1819 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1820 unsupported_future("elem_atan2 not supported by provider")
1821 }
1822 fn unary_sin<'a>(
1824 &'a self,
1825 _a: &'a GpuTensorHandle,
1826 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1827 unsupported_future("unary_sin not supported by provider")
1828 }
1829 fn unary_sinc<'a>(
1830 &'a self,
1831 _a: &'a GpuTensorHandle,
1832 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1833 unsupported_future("unary_sinc not supported by provider")
1834 }
1835 fn unary_gamma<'a>(
1836 &'a self,
1837 _a: &'a GpuTensorHandle,
1838 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1839 unsupported_future("unary_gamma not supported by provider")
1840 }
1841 fn unary_factorial<'a>(
1842 &'a self,
1843 _a: &'a GpuTensorHandle,
1844 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1845 unsupported_future("unary_factorial not supported by provider")
1846 }
1847 fn unary_asinh<'a>(
1848 &'a self,
1849 _a: &'a GpuTensorHandle,
1850 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1851 unsupported_future("unary_asinh not supported by provider")
1852 }
1853 fn unary_sinh<'a>(
1854 &'a self,
1855 _a: &'a GpuTensorHandle,
1856 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1857 unsupported_future("unary_sinh not supported by provider")
1858 }
1859 fn unary_cosh<'a>(
1860 &'a self,
1861 _a: &'a GpuTensorHandle,
1862 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1863 unsupported_future("unary_cosh not supported by provider")
1864 }
1865 fn unary_asin<'a>(
1866 &'a self,
1867 _a: &'a GpuTensorHandle,
1868 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1869 unsupported_future("unary_asin not supported by provider")
1870 }
1871 fn unary_acos<'a>(
1872 &'a self,
1873 _a: &'a GpuTensorHandle,
1874 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1875 unsupported_future("unary_acos not supported by provider")
1876 }
1877 fn unary_acosh<'a>(
1878 &'a self,
1879 _a: &'a GpuTensorHandle,
1880 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1881 unsupported_future("unary_acosh not supported by provider")
1882 }
1883 fn unary_tan<'a>(
1884 &'a self,
1885 _a: &'a GpuTensorHandle,
1886 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1887 unsupported_future("unary_tan not supported by provider")
1888 }
1889 fn unary_tanh<'a>(
1890 &'a self,
1891 _a: &'a GpuTensorHandle,
1892 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1893 unsupported_future("unary_tanh not supported by provider")
1894 }
1895 fn unary_atan<'a>(
1896 &'a self,
1897 _a: &'a GpuTensorHandle,
1898 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1899 unsupported_future("unary_atan not supported by provider")
1900 }
1901 fn unary_atanh<'a>(
1902 &'a self,
1903 _a: &'a GpuTensorHandle,
1904 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1905 unsupported_future("unary_atanh not supported by provider")
1906 }
1907 fn unary_ceil<'a>(
1908 &'a self,
1909 _a: &'a GpuTensorHandle,
1910 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1911 unsupported_future("unary_ceil not supported by provider")
1912 }
1913 fn unary_floor<'a>(
1914 &'a self,
1915 _a: &'a GpuTensorHandle,
1916 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1917 unsupported_future("unary_floor not supported by provider")
1918 }
1919 fn unary_round<'a>(
1920 &'a self,
1921 _a: &'a GpuTensorHandle,
1922 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1923 unsupported_future("unary_round not supported by provider")
1924 }
1925 fn unary_fix<'a>(
1926 &'a self,
1927 _a: &'a GpuTensorHandle,
1928 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1929 unsupported_future("unary_fix not supported by provider")
1930 }
1931 fn unary_cos<'a>(
1932 &'a self,
1933 _a: &'a GpuTensorHandle,
1934 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1935 unsupported_future("unary_cos not supported by provider")
1936 }
1937 fn unary_angle<'a>(
1938 &'a self,
1939 _a: &'a GpuTensorHandle,
1940 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1941 unsupported_future("unary_angle not supported by provider")
1942 }
1943 fn unary_imag<'a>(
1944 &'a self,
1945 _a: &'a GpuTensorHandle,
1946 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1947 unsupported_future("unary_imag not supported by provider")
1948 }
1949 fn unary_real<'a>(
1950 &'a self,
1951 _a: &'a GpuTensorHandle,
1952 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1953 unsupported_future("unary_real not supported by provider")
1954 }
1955 fn unary_conj<'a>(
1956 &'a self,
1957 _a: &'a GpuTensorHandle,
1958 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1959 unsupported_future("unary_conj not supported by provider")
1960 }
1961 fn unary_abs<'a>(
1962 &'a self,
1963 _a: &'a GpuTensorHandle,
1964 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1965 unsupported_future("unary_abs not supported by provider")
1966 }
1967 fn unary_sign<'a>(
1968 &'a self,
1969 _a: &'a GpuTensorHandle,
1970 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1971 unsupported_future("unary_sign not supported by provider")
1972 }
1973 fn unary_heaviside<'a>(
1974 &'a self,
1975 _a: &'a GpuTensorHandle,
1976 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1977 unsupported_future("unary_heaviside not supported by provider")
1978 }
1979 fn unary_exp<'a>(
1980 &'a self,
1981 _a: &'a GpuTensorHandle,
1982 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1983 unsupported_future("unary_exp not supported by provider")
1984 }
1985 fn unary_expm1<'a>(
1986 &'a self,
1987 _a: &'a GpuTensorHandle,
1988 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1989 unsupported_future("unary_expm1 not supported by provider")
1990 }
1991 fn unary_log<'a>(
1992 &'a self,
1993 _a: &'a GpuTensorHandle,
1994 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1995 unsupported_future("unary_log not supported by provider")
1996 }
1997 fn unary_log2<'a>(
1998 &'a self,
1999 _a: &'a GpuTensorHandle,
2000 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2001 unsupported_future("unary_log2 not supported by provider")
2002 }
2003 fn unary_log10<'a>(
2004 &'a self,
2005 _a: &'a GpuTensorHandle,
2006 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2007 unsupported_future("unary_log10 not supported by provider")
2008 }
2009 fn unary_log1p<'a>(
2010 &'a self,
2011 _a: &'a GpuTensorHandle,
2012 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2013 unsupported_future("unary_log1p not supported by provider")
2014 }
2015 fn unary_sqrt<'a>(
2016 &'a self,
2017 _a: &'a GpuTensorHandle,
2018 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2019 unsupported_future("unary_sqrt not supported by provider")
2020 }
2021 fn unary_double<'a>(
2022 &'a self,
2023 _a: &'a GpuTensorHandle,
2024 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2025 unsupported_future("unary_double not supported by provider")
2026 }
2027 fn unary_single<'a>(
2028 &'a self,
2029 _a: &'a GpuTensorHandle,
2030 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2031 unsupported_future("unary_single not supported by provider")
2032 }
2033 fn unary_pow2<'a>(
2034 &'a self,
2035 _a: &'a GpuTensorHandle,
2036 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2037 unsupported_future("unary_pow2 not supported by provider")
2038 }
2039 fn unary_nextpow2<'a>(
2040 &'a self,
2041 _a: &'a GpuTensorHandle,
2042 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2043 unsupported_future("unary_nextpow2 not supported by provider")
2044 }
2045 fn pow2_scale(
2046 &self,
2047 _mantissa: &GpuTensorHandle,
2048 _exponent: &GpuTensorHandle,
2049 ) -> anyhow::Result<GpuTensorHandle> {
2050 Err(anyhow::anyhow!("pow2_scale not supported by provider"))
2051 }
2052 fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2054 Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
2055 }
2056 fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2057 Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
2058 }
2059 fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2061 Err(anyhow::anyhow!("scalar_add not supported by provider"))
2062 }
2063 fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2064 Err(anyhow::anyhow!("scalar_sub not supported by provider"))
2065 }
2066 fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2067 Err(anyhow::anyhow!("scalar_mul not supported by provider"))
2068 }
2069 fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2070 Err(anyhow::anyhow!("scalar_max not supported by provider"))
2071 }
2072 fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2073 Err(anyhow::anyhow!("scalar_min not supported by provider"))
2074 }
2075 fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
2076 Err(anyhow::anyhow!("scalar_div not supported by provider"))
2077 }
2078 fn sort_dim<'a>(
2079 &'a self,
2080 _a: &'a GpuTensorHandle,
2081 _dim: usize,
2082 _order: SortOrder,
2083 _comparison: SortComparison,
2084 ) -> AccelProviderFuture<'a, SortResult> {
2085 unsupported_future("sort_dim not supported by provider")
2086 }
2087 fn sort_rows<'a>(
2088 &'a self,
2089 _a: &'a GpuTensorHandle,
2090 _columns: &'a [SortRowsColumnSpec],
2091 _comparison: SortComparison,
2092 ) -> AccelProviderFuture<'a, SortResult> {
2093 unsupported_future("sort_rows not supported by provider")
2094 }
2095 fn matmul<'a>(
2096 &'a self,
2097 _a: &'a GpuTensorHandle,
2098 _b: &'a GpuTensorHandle,
2099 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2100 unsupported_future("matmul not supported by provider")
2101 }
2102
2103 fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2104 Err(anyhow::anyhow!("syrk not supported by provider"))
2105 }
2106 fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
2107 Err(anyhow::anyhow!("pagefun not supported by provider"))
2108 }
2109
2110 fn matmul_epilogue<'a>(
2115 &'a self,
2116 a: &'a GpuTensorHandle,
2117 b: &'a GpuTensorHandle,
2118 epilogue: &'a MatmulEpilogue,
2119 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2120 Box::pin(async move {
2121 if epilogue.is_noop() {
2122 return self.matmul(a, b).await;
2123 }
2124 Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
2125 })
2126 }
2127 fn image_normalize<'a>(
2128 &'a self,
2129 _input: &'a GpuTensorHandle,
2130 _desc: &'a ImageNormalizeDescriptor,
2131 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2132 unsupported_future("image_normalize fusion not supported by provider")
2133 }
2134 fn matmul_power_step<'a>(
2135 &'a self,
2136 _lhs: &'a GpuTensorHandle,
2137 _rhs: &'a GpuTensorHandle,
2138 _epilogue: &'a PowerStepEpilogue,
2139 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2140 unsupported_future("matmul_power_step normalization not supported by provider")
2141 }
2142 fn linsolve<'a>(
2143 &'a self,
2144 _lhs: &'a GpuTensorHandle,
2145 _rhs: &'a GpuTensorHandle,
2146 _options: &'a ProviderLinsolveOptions,
2147 ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
2148 unsupported_future("linsolve not supported by provider")
2149 }
2150 fn inv<'a>(
2151 &'a self,
2152 _matrix: &'a GpuTensorHandle,
2153 _options: ProviderInvOptions,
2154 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2155 unsupported_future("inv not supported by provider")
2156 }
2157 fn pinv<'a>(
2158 &'a self,
2159 _matrix: &'a GpuTensorHandle,
2160 _options: ProviderPinvOptions,
2161 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2162 unsupported_future("pinv not supported by provider")
2163 }
2164 fn cond<'a>(
2165 &'a self,
2166 _matrix: &'a GpuTensorHandle,
2167 _norm: ProviderCondNorm,
2168 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2169 Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
2170 }
2171 fn norm<'a>(
2172 &'a self,
2173 _tensor: &'a GpuTensorHandle,
2174 _order: ProviderNormOrder,
2175 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2176 Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
2177 }
2178 fn rank<'a>(
2179 &'a self,
2180 _matrix: &'a GpuTensorHandle,
2181 _tolerance: Option<f64>,
2182 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2183 Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
2184 }
2185 fn rcond<'a>(
2186 &'a self,
2187 _matrix: &'a GpuTensorHandle,
2188 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2189 Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
2190 }
2191 fn mldivide<'a>(
2192 &'a self,
2193 _lhs: &'a GpuTensorHandle,
2194 _rhs: &'a GpuTensorHandle,
2195 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2196 Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
2197 }
2198 fn mrdivide<'a>(
2199 &'a self,
2200 _lhs: &'a GpuTensorHandle,
2201 _rhs: &'a GpuTensorHandle,
2202 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2203 Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
2204 }
2205 fn eig<'a>(
2206 &'a self,
2207 _a: &'a GpuTensorHandle,
2208 _compute_left: bool,
2209 ) -> AccelProviderFuture<'a, ProviderEigResult> {
2210 Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
2211 }
2212 fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
2213 Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
2214 }
2215
2216 fn chol<'a>(
2217 &'a self,
2218 _a: &'a GpuTensorHandle,
2219 _lower: bool,
2220 ) -> AccelProviderFuture<'a, ProviderCholResult> {
2221 Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
2222 }
2223 fn qr<'a>(
2224 &'a self,
2225 _a: &'a GpuTensorHandle,
2226 _options: ProviderQrOptions,
2227 ) -> AccelProviderFuture<'a, ProviderQrResult> {
2228 Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
2229 }
2230 fn take_matmul_sources(
2231 &self,
2232 _product: &GpuTensorHandle,
2233 ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
2234 None
2235 }
2236 fn qr_power_iter<'a>(
2237 &'a self,
2238 product: &'a GpuTensorHandle,
2239 _product_lhs: Option<&'a GpuTensorHandle>,
2240 q_handle: &'a GpuTensorHandle,
2241 options: &'a ProviderQrOptions,
2242 ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
2243 let _ = (product, q_handle, options);
2244 Box::pin(async move { Ok(None) })
2245 }
2246 fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2247 Err(anyhow::anyhow!("transpose not supported by provider"))
2248 }
2249 fn conv1d(
2250 &self,
2251 _signal: &GpuTensorHandle,
2252 _kernel: &GpuTensorHandle,
2253 _options: ProviderConv1dOptions,
2254 ) -> anyhow::Result<GpuTensorHandle> {
2255 Err(anyhow::anyhow!("conv1d not supported by provider"))
2256 }
2257 fn conv2d(
2258 &self,
2259 _signal: &GpuTensorHandle,
2260 _kernel: &GpuTensorHandle,
2261 _mode: ProviderConvMode,
2262 ) -> anyhow::Result<GpuTensorHandle> {
2263 Err(anyhow::anyhow!("conv2d not supported by provider"))
2264 }
2265 fn iir_filter<'a>(
2266 &'a self,
2267 _b: &'a GpuTensorHandle,
2268 _a: &'a GpuTensorHandle,
2269 _x: &'a GpuTensorHandle,
2270 _options: ProviderIirFilterOptions,
2271 ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
2272 Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
2273 }
2274 fn uniform_spectral_estimate<'a>(
2275 &'a self,
2276 _request: &'a ProviderSpectralRequest<'a>,
2277 ) -> AccelProviderFuture<'a, ProviderSpectralResult> {
2278 unsupported_future("uniform_spectral_estimate not supported by provider")
2279 }
2280 fn signal_envelope<'a>(
2281 &'a self,
2282 _request: &'a ProviderEnvelopeRequest<'a>,
2283 ) -> AccelProviderFuture<'a, ProviderEnvelopeResult> {
2284 unsupported_future("signal_envelope not supported by provider")
2285 }
2286 fn signal_hilbert<'a>(
2287 &'a self,
2288 _request: &'a ProviderHilbertRequest<'a>,
2289 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2290 unsupported_future("signal_hilbert not supported by provider")
2291 }
2292 fn permute(
2294 &self,
2295 _handle: &GpuTensorHandle,
2296 _order: &[usize],
2297 ) -> anyhow::Result<GpuTensorHandle> {
2298 Err(anyhow::anyhow!("permute not supported by provider"))
2299 }
2300 fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
2301 Err(anyhow::anyhow!("flip not supported by provider"))
2302 }
2303 fn circshift(
2304 &self,
2305 _handle: &GpuTensorHandle,
2306 _shifts: &[isize],
2307 ) -> anyhow::Result<GpuTensorHandle> {
2308 Err(anyhow::anyhow!("circshift not supported by provider"))
2309 }
2310 fn diff_dim(
2311 &self,
2312 _handle: &GpuTensorHandle,
2313 _order: usize,
2314 _dim: usize,
2315 ) -> anyhow::Result<GpuTensorHandle> {
2316 Err(anyhow::anyhow!("diff_dim not supported by provider"))
2317 }
2318 fn gradient_dim(
2319 &self,
2320 _handle: &GpuTensorHandle,
2321 _dim: usize,
2322 _spacing: f64,
2323 ) -> anyhow::Result<GpuTensorHandle> {
2324 Err(anyhow::anyhow!("gradient_dim not supported by provider"))
2325 }
2326 fn fft_dim<'a>(
2328 &'a self,
2329 _handle: &'a GpuTensorHandle,
2330 _len: Option<usize>,
2331 _dim: usize,
2332 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2333 unsupported_future("fft_dim not supported by provider")
2334 }
2335 fn ifft_dim<'a>(
2336 &'a self,
2337 _handle: &'a GpuTensorHandle,
2338 _len: Option<usize>,
2339 _dim: usize,
2340 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2341 unsupported_future("ifft_dim not supported by provider")
2342 }
2343 fn fft_extract_real<'a>(
2344 &'a self,
2345 _handle: &'a GpuTensorHandle,
2346 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2347 unsupported_future("fft_extract_real not supported by provider")
2348 }
2349 fn unique<'a>(
2350 &'a self,
2351 _handle: &'a GpuTensorHandle,
2352 _options: &'a UniqueOptions,
2353 ) -> AccelProviderFuture<'a, UniqueResult> {
2354 Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
2355 }
2356 fn union<'a>(
2357 &'a self,
2358 _a: &'a GpuTensorHandle,
2359 _b: &'a GpuTensorHandle,
2360 _options: &'a UnionOptions,
2361 ) -> AccelProviderFuture<'a, UnionResult> {
2362 Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
2363 }
2364 fn setdiff<'a>(
2365 &'a self,
2366 _a: &'a GpuTensorHandle,
2367 _b: &'a GpuTensorHandle,
2368 _options: &'a SetdiffOptions,
2369 ) -> AccelProviderFuture<'a, SetdiffResult> {
2370 Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
2371 }
2372 fn ismember<'a>(
2373 &'a self,
2374 _a: &'a GpuTensorHandle,
2375 _b: &'a GpuTensorHandle,
2376 _options: &'a IsMemberOptions,
2377 ) -> AccelProviderFuture<'a, IsMemberResult> {
2378 Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
2379 }
2380 fn reshape(
2381 &self,
2382 handle: &GpuTensorHandle,
2383 new_shape: &[usize],
2384 ) -> anyhow::Result<GpuTensorHandle> {
2385 let mut updated = handle.clone();
2386 updated.shape = new_shape.to_vec();
2387 Ok(updated)
2388 }
2389 fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
2391 Err(anyhow::anyhow!("cat not supported by provider"))
2392 }
2393 fn repmat(
2394 &self,
2395 _handle: &GpuTensorHandle,
2396 _reps: &[usize],
2397 ) -> anyhow::Result<GpuTensorHandle> {
2398 Err(anyhow::anyhow!("repmat not supported by provider"))
2399 }
2400 fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2402 Err(anyhow::anyhow!("kron not supported by provider"))
2403 }
2404 fn cross(
2406 &self,
2407 _lhs: &GpuTensorHandle,
2408 _rhs: &GpuTensorHandle,
2409 _dim: Option<usize>,
2410 ) -> anyhow::Result<GpuTensorHandle> {
2411 Err(anyhow::anyhow!("cross not supported by provider"))
2412 }
2413 fn reduce_sum<'a>(
2414 &'a self,
2415 _a: &'a GpuTensorHandle,
2416 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2417 unsupported_future("reduce_sum not supported by provider")
2418 }
2419 fn reduce_sum_dim<'a>(
2420 &'a self,
2421 _a: &'a GpuTensorHandle,
2422 _dim: usize,
2423 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2424 unsupported_future("reduce_sum_dim not supported by provider")
2425 }
2426 fn dot<'a>(
2427 &'a self,
2428 _lhs: &'a GpuTensorHandle,
2429 _rhs: &'a GpuTensorHandle,
2430 _dim: Option<usize>,
2431 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2432 unsupported_future("dot not supported by provider")
2433 }
2434 fn reduce_nnz<'a>(
2435 &'a self,
2436 _a: &'a GpuTensorHandle,
2437 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2438 unsupported_future("reduce_nnz not supported by provider")
2439 }
2440 fn reduce_nnz_dim<'a>(
2441 &'a self,
2442 _a: &'a GpuTensorHandle,
2443 _dim: usize,
2444 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2445 unsupported_future("reduce_nnz_dim not supported by provider")
2446 }
2447 fn reduce_prod<'a>(
2448 &'a self,
2449 _a: &'a GpuTensorHandle,
2450 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2451 unsupported_future("reduce_prod not supported by provider")
2452 }
2453 fn reduce_prod_dim<'a>(
2454 &'a self,
2455 _a: &'a GpuTensorHandle,
2456 _dim: usize,
2457 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2458 unsupported_future("reduce_prod_dim not supported by provider")
2459 }
2460 fn reduce_mean<'a>(
2461 &'a self,
2462 _a: &'a GpuTensorHandle,
2463 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2464 unsupported_future("reduce_mean not supported by provider")
2465 }
2466 fn reduce_mean_nd<'a>(
2468 &'a self,
2469 _a: &'a GpuTensorHandle,
2470 _dims_zero_based: &'a [usize],
2471 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2472 unsupported_future("reduce_mean_nd not supported by provider")
2473 }
2474 fn reduce_moments_nd<'a>(
2477 &'a self,
2478 _a: &'a GpuTensorHandle,
2479 _dims_zero_based: &'a [usize],
2480 ) -> AccelProviderFuture<'a, ProviderMoments2> {
2481 unsupported_future("reduce_moments_nd not supported by provider")
2482 }
2483 fn reduce_mean_dim<'a>(
2484 &'a self,
2485 _a: &'a GpuTensorHandle,
2486 _dim: usize,
2487 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2488 unsupported_future("reduce_mean_dim not supported by provider")
2489 }
2490 fn reduce_std<'a>(
2491 &'a self,
2492 _a: &'a GpuTensorHandle,
2493 _normalization: ProviderStdNormalization,
2494 _nan_mode: ProviderNanMode,
2495 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2496 unsupported_future("reduce_std not supported by provider")
2497 }
2498 fn reduce_std_dim<'a>(
2499 &'a self,
2500 _a: &'a GpuTensorHandle,
2501 _dim: usize,
2502 _normalization: ProviderStdNormalization,
2503 _nan_mode: ProviderNanMode,
2504 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2505 unsupported_future("reduce_std_dim not supported by provider")
2506 }
2507 fn reduce_any<'a>(
2508 &'a self,
2509 _a: &'a GpuTensorHandle,
2510 _omit_nan: bool,
2511 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2512 unsupported_future("reduce_any not supported by provider")
2513 }
2514 fn reduce_any_dim<'a>(
2515 &'a self,
2516 _a: &'a GpuTensorHandle,
2517 _dim: usize,
2518 _omit_nan: bool,
2519 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2520 unsupported_future("reduce_any_dim not supported by provider")
2521 }
2522 fn reduce_all<'a>(
2523 &'a self,
2524 _a: &'a GpuTensorHandle,
2525 _omit_nan: bool,
2526 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2527 unsupported_future("reduce_all not supported by provider")
2528 }
2529 fn reduce_all_dim<'a>(
2530 &'a self,
2531 _a: &'a GpuTensorHandle,
2532 _dim: usize,
2533 _omit_nan: bool,
2534 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2535 unsupported_future("reduce_all_dim not supported by provider")
2536 }
2537 fn reduce_median<'a>(
2538 &'a self,
2539 _a: &'a GpuTensorHandle,
2540 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2541 unsupported_future("reduce_median not supported by provider")
2542 }
2543 fn reduce_median_dim<'a>(
2544 &'a self,
2545 _a: &'a GpuTensorHandle,
2546 _dim: usize,
2547 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2548 unsupported_future("reduce_median_dim not supported by provider")
2549 }
2550 fn reduce_min<'a>(
2551 &'a self,
2552 _a: &'a GpuTensorHandle,
2553 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2554 unsupported_future("reduce_min not supported by provider")
2555 }
2556 fn reduce_min_dim<'a>(
2557 &'a self,
2558 _a: &'a GpuTensorHandle,
2559 _dim: usize,
2560 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2561 unsupported_future("reduce_min_dim not supported by provider")
2562 }
2563 fn reduce_max<'a>(
2564 &'a self,
2565 _a: &'a GpuTensorHandle,
2566 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2567 unsupported_future("reduce_max not supported by provider")
2568 }
2569 fn reduce_max_dim<'a>(
2570 &'a self,
2571 _a: &'a GpuTensorHandle,
2572 _dim: usize,
2573 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2574 unsupported_future("reduce_max_dim not supported by provider")
2575 }
2576 fn cumsum_scan(
2577 &self,
2578 _input: &GpuTensorHandle,
2579 _dim: usize,
2580 _direction: ProviderScanDirection,
2581 _nan_mode: ProviderNanMode,
2582 ) -> anyhow::Result<GpuTensorHandle> {
2583 Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2584 }
2585 fn cumprod_scan(
2586 &self,
2587 _input: &GpuTensorHandle,
2588 _dim: usize,
2589 _direction: ProviderScanDirection,
2590 _nan_mode: ProviderNanMode,
2591 ) -> anyhow::Result<GpuTensorHandle> {
2592 Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2593 }
2594 fn cummin_scan(
2595 &self,
2596 _input: &GpuTensorHandle,
2597 _dim: usize,
2598 _direction: ProviderScanDirection,
2599 _nan_mode: ProviderNanMode,
2600 ) -> anyhow::Result<ProviderCumminResult> {
2601 Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2602 }
2603 fn cummax_scan(
2604 &self,
2605 _input: &GpuTensorHandle,
2606 _dim: usize,
2607 _direction: ProviderScanDirection,
2608 _nan_mode: ProviderNanMode,
2609 ) -> anyhow::Result<ProviderCummaxResult> {
2610 Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2611 }
2612
2613 fn find(
2614 &self,
2615 _a: &GpuTensorHandle,
2616 _limit: Option<usize>,
2617 _direction: FindDirection,
2618 ) -> anyhow::Result<ProviderFindResult> {
2619 Err(anyhow::anyhow!("find not supported by provider"))
2620 }
2621
2622 fn fused_elementwise(
2623 &self,
2624 _shader: &str,
2625 _inputs: &[GpuTensorHandle],
2626 _output_shape: &[usize],
2627 _len: usize,
2628 ) -> anyhow::Result<GpuTensorHandle> {
2629 Err(anyhow::anyhow!(
2630 "fused_elementwise not supported by provider"
2631 ))
2632 }
2633
2634 fn fused_elementwise_multi(
2643 &self,
2644 _shader: &str,
2645 _inputs: &[GpuTensorHandle],
2646 _output_shape: &[usize],
2647 _len: usize,
2648 _num_outputs: usize,
2649 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2650 Err(anyhow::anyhow!(
2651 "fused_elementwise_multi not supported by provider"
2652 ))
2653 }
2654
2655 fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2657 Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2658 }
2659
2660 fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2662 Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2663 }
2664
2665 #[allow(clippy::too_many_arguments)]
2672 fn fused_reduction(
2673 &self,
2674 _shader: &str,
2675 _inputs: &[GpuTensorHandle],
2676 _output_shape: &[usize],
2677 _reduce_len: usize,
2678 _num_slices: usize,
2679 _workgroup_size: u32,
2680 _flavor: ReductionFlavor,
2681 ) -> anyhow::Result<GpuTensorHandle> {
2682 Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2683 }
2684
2685 fn warmup(&self) {}
2687
2688 fn fused_cache_counters(&self) -> (u64, u64) {
2690 (0, 0)
2691 }
2692
2693 fn last_warmup_millis(&self) -> Option<u64> {
2695 None
2696 }
2697
2698 fn telemetry_snapshot(&self) -> ProviderTelemetry {
2700 let (hits, misses) = self.fused_cache_counters();
2701 ProviderTelemetry {
2702 fused_elementwise: ProviderDispatchStats::default(),
2703 fused_reduction: ProviderDispatchStats::default(),
2704 matmul: ProviderDispatchStats::default(),
2705 linsolve: ProviderDispatchStats::default(),
2706 mldivide: ProviderDispatchStats::default(),
2707 mrdivide: ProviderDispatchStats::default(),
2708 upload_bytes: 0,
2709 download_bytes: 0,
2710 solve_fallbacks: Vec::new(),
2711 fusion_cache_hits: hits,
2712 fusion_cache_misses: misses,
2713 bind_group_cache_hits: 0,
2714 bind_group_cache_misses: 0,
2715 bind_group_cache_by_layout: None,
2716 kernel_launches: Vec::new(),
2717 }
2718 }
2719
2720 fn reset_telemetry(&self) {}
2722
2723 fn default_reduction_workgroup_size(&self) -> u32 {
2725 256
2726 }
2727
2728 fn two_pass_threshold(&self) -> usize {
2730 1024
2731 }
2732
2733 fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2735 ReductionTwoPassMode::Auto
2736 }
2737
2738 fn scatter_column(
2741 &self,
2742 _matrix: &GpuTensorHandle,
2743 _col_index: usize,
2744 _values: &GpuTensorHandle,
2745 ) -> anyhow::Result<GpuTensorHandle> {
2746 Err(anyhow::anyhow!("scatter_column not supported by provider"))
2747 }
2748
2749 fn scatter_row(
2752 &self,
2753 _matrix: &GpuTensorHandle,
2754 _row_index: usize,
2755 _values: &GpuTensorHandle,
2756 ) -> anyhow::Result<GpuTensorHandle> {
2757 Err(anyhow::anyhow!("scatter_row not supported by provider"))
2758 }
2759
2760 fn sub2ind(
2761 &self,
2762 _dims: &[usize],
2763 _strides: &[usize],
2764 _inputs: &[&GpuTensorHandle],
2765 _scalar_mask: &[bool],
2766 _len: usize,
2767 _output_shape: &[usize],
2768 ) -> anyhow::Result<GpuTensorHandle> {
2769 Err(anyhow::anyhow!("sub2ind not supported by provider"))
2770 }
2771
2772 fn supports_ind2sub(&self) -> bool {
2774 false
2775 }
2776
2777 fn ind2sub(
2779 &self,
2780 _dims: &[usize],
2781 _strides: &[usize],
2782 _indices: &GpuTensorHandle,
2783 _total: usize,
2784 _len: usize,
2785 _output_shape: &[usize],
2786 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2787 Err(anyhow::anyhow!("ind2sub not supported by provider"))
2788 }
2789
2790 fn issymmetric(
2792 &self,
2793 _matrix: &GpuTensorHandle,
2794 _kind: ProviderSymmetryKind,
2795 _tolerance: f64,
2796 ) -> anyhow::Result<bool> {
2797 Err(anyhow::anyhow!(
2798 "issymmetric predicate not supported by provider"
2799 ))
2800 }
2801
2802 fn ishermitian<'a>(
2804 &'a self,
2805 _matrix: &'a GpuTensorHandle,
2806 _kind: ProviderHermitianKind,
2807 _tolerance: f64,
2808 ) -> AccelProviderFuture<'a, bool> {
2809 Box::pin(async move {
2810 Err(anyhow::anyhow!(
2811 "ishermitian predicate not supported by provider"
2812 ))
2813 })
2814 }
2815
2816 fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2818 Err(anyhow::anyhow!("bandwidth not supported by provider"))
2819 }
2820
2821 fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2826 Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2827 }
2828}
2829
2830static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2831 Lazy::new(|| RwLock::new(None));
2832static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2833 Lazy::new(|| RwLock::new(HashMap::new()));
2834static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2835
2836#[cfg(not(target_arch = "wasm32"))]
2837thread_local! {
2838 static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2839}
2840
2841#[cfg(target_arch = "wasm32")]
2842static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2843 Lazy::new(|| Mutex::new(None));
2844
2845#[cfg(not(target_arch = "wasm32"))]
2846fn replace_thread_provider(
2847 provider: Option<&'static dyn AccelProvider>,
2848) -> Option<&'static dyn AccelProvider> {
2849 THREAD_PROVIDER.with(|cell| {
2850 let prev = cell.get();
2851 cell.set(provider);
2852 prev
2853 })
2854}
2855
2856#[cfg(target_arch = "wasm32")]
2857fn replace_thread_provider(
2858 provider: Option<&'static dyn AccelProvider>,
2859) -> Option<&'static dyn AccelProvider> {
2860 let mut slot = WASM_THREAD_PROVIDER
2861 .lock()
2862 .expect("wasm provider mutex poisoned");
2863 let prev = *slot;
2864 *slot = provider;
2865 prev
2866}
2867
2868#[cfg(not(target_arch = "wasm32"))]
2869fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2870 THREAD_PROVIDER.with(|cell| cell.get())
2871}
2872
2873#[cfg(target_arch = "wasm32")]
2874fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2875 WASM_THREAD_PROVIDER
2876 .lock()
2877 .expect("wasm provider mutex poisoned")
2878 .as_ref()
2879 .copied()
2880}
2881
2882pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2890 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2891 *guard = Some(p);
2892 }
2893 register_provider_for_device(p.device_id(), p);
2894}
2895
2896unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2897 if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2898 guard.insert(device_id, provider);
2899 }
2900}
2901
2902pub fn provider() -> Option<&'static dyn AccelProvider> {
2903 if let Some(p) = current_thread_provider() {
2904 return Some(p);
2905 }
2906 GLOBAL_PROVIDER
2907 .read()
2908 .ok()
2909 .and_then(|guard| guard.as_ref().copied())
2910}
2911
2912pub fn clear_provider() {
2914 replace_thread_provider(None);
2915 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2916 *guard = None;
2917 }
2918 if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2919 map.clear();
2920 }
2921}
2922
2923pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2924 if let Some(registered) = PROVIDER_REGISTRY
2925 .read()
2926 .ok()
2927 .and_then(|guard| guard.get(&device_id).copied())
2928 {
2929 return Some(registered);
2930 }
2931 if let Some(thread_provider) = current_thread_provider() {
2932 if thread_provider.device_id() == device_id {
2933 return Some(thread_provider);
2934 }
2935 }
2936 GLOBAL_PROVIDER
2939 .read()
2940 .ok()
2941 .and_then(|guard| guard.as_ref().copied())
2942}
2943
2944pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2945 provider_for_device(handle.device_id)
2946}
2947
2948pub fn spawn_handle_concurrency_for(handle: &GpuTensorHandle) -> Option<SpawnHandleConcurrency> {
2949 provider_for_handle(handle).map(AccelProvider::spawn_handle_concurrency)
2950}
2951
2952pub fn next_device_id() -> u32 {
2953 DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2954}
2955
2956pub struct ThreadProviderGuard {
2957 prev: Option<&'static dyn AccelProvider>,
2958}
2959
2960impl ThreadProviderGuard {
2961 pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2962 let prev = replace_thread_provider(provider);
2963 ThreadProviderGuard { prev }
2964 }
2965}
2966
2967impl Drop for ThreadProviderGuard {
2968 fn drop(&mut self) {
2969 let prev = self.prev.take();
2970 replace_thread_provider(prev);
2971 }
2972}
2973
2974pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2975 replace_thread_provider(provider);
2976}
2977
2978pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2980 if let Some(p) = provider() {
2981 if let Ok(h) = p.elem_add(a, b).await {
2982 return Some(h);
2983 }
2984 }
2985 None
2986}
2987
2988pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2990 if let Some(p) = provider() {
2991 if let Ok(h) = p.elem_hypot(a, b).await {
2992 return Some(h);
2993 }
2994 }
2995 None
2996}
2997
2998pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
3000 if let Some(p) = provider() {
3001 if let Ok(h) = p.elem_max(a, b).await {
3002 return Some(h);
3003 }
3004 }
3005 None
3006}
3007
3008pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
3010 if let Some(p) = provider() {
3011 if let Ok(h) = p.elem_min(a, b).await {
3012 return Some(h);
3013 }
3014 }
3015 None
3016}
3017
3018pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
3020 if let Some(p) = provider() {
3021 if let Ok(h) = p.elem_atan2(y, x).await {
3022 return Some(h);
3023 }
3024 }
3025 None
3026}
3027
3028#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3030pub struct HostTensorOwned {
3031 pub data: Vec<f64>,
3032 pub shape: Vec<usize>,
3033 pub storage: GpuTensorStorage,
3034}
3035
3036#[derive(Debug)]
3037pub struct HostTensorView<'a> {
3038 pub data: &'a [f64],
3039 pub shape: &'a [usize],
3040}
3041
3042#[derive(Debug)]
3044pub struct MeshgridAxisView<'a> {
3045 pub data: &'a [f64],
3046}
3047
3048#[derive(Debug, Clone)]
3050pub struct ProviderMeshgridResult {
3051 pub outputs: Vec<GpuTensorHandle>,
3052}
3053
3054#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
3060pub enum ScaleOp {
3061 Multiply,
3062 Divide,
3063}
3064
3065#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3066pub struct MatmulEpilogue {
3067 pub alpha: f64,
3069 pub beta: f64,
3071 pub row_scale: Option<GpuTensorHandle>,
3073 pub col_scale: Option<GpuTensorHandle>,
3075 pub row_op: ScaleOp,
3077 pub col_op: ScaleOp,
3079 #[serde(default)]
3081 pub clamp_min: Option<f64>,
3082 #[serde(default)]
3084 pub clamp_max: Option<f64>,
3085 #[serde(default)]
3087 pub pow_exponent: Option<f64>,
3088 #[serde(default)]
3090 pub diag_output: Option<GpuTensorHandle>,
3091}
3092
3093impl MatmulEpilogue {
3094 pub fn noop() -> Self {
3095 Self {
3096 alpha: 1.0,
3097 beta: 0.0,
3098 row_scale: None,
3099 col_scale: None,
3100 row_op: ScaleOp::Multiply,
3101 col_op: ScaleOp::Multiply,
3102 clamp_min: None,
3103 clamp_max: None,
3104 pow_exponent: None,
3105 diag_output: None,
3106 }
3107 }
3108 pub fn is_noop(&self) -> bool {
3109 self.alpha == 1.0
3110 && self.beta == 0.0
3111 && self.row_scale.is_none()
3112 && self.col_scale.is_none()
3113 && self.clamp_min.is_none()
3114 && self.clamp_max.is_none()
3115 && self.pow_exponent.is_none()
3116 && self.diag_output.is_none()
3117 }
3118}
3119
3120#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
3121pub struct PowerStepEpilogue {
3122 pub epsilon: f64,
3123}
3124
3125impl Default for PowerStepEpilogue {
3126 fn default() -> Self {
3127 Self { epsilon: 0.0 }
3128 }
3129}
3130
3131#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
3132pub struct ImageNormalizeDescriptor {
3133 pub batch: usize,
3134 pub height: usize,
3135 pub width: usize,
3136 pub epsilon: f64,
3137 #[serde(default)]
3138 pub gain: Option<f64>,
3139 #[serde(default)]
3140 pub bias: Option<f64>,
3141 #[serde(default)]
3142 pub gamma: Option<f64>,
3143 #[serde(default = "default_image_normalize_clamp_zero")]
3144 pub clamp_zero: bool,
3145}
3146
3147fn default_image_normalize_clamp_zero() -> bool {
3148 true
3149}
3150
3151#[cfg(test)]
3152mod tests {
3153 use super::*;
3154
3155 struct TestProvider {
3156 device_id: u32,
3157 name: &'static str,
3158 spawn_concurrency: SpawnHandleConcurrency,
3159 }
3160
3161 impl AccelProvider for TestProvider {
3162 fn upload(&self, _host: &HostTensorView) -> anyhow::Result<GpuTensorHandle> {
3163 Err(anyhow!("test provider upload should not be called"))
3164 }
3165
3166 fn download<'a>(&'a self, _h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a> {
3167 unsupported_future("test provider download should not be called")
3168 }
3169
3170 fn free(&self, _h: &GpuTensorHandle) -> anyhow::Result<()> {
3171 Err(anyhow!("test provider free should not be called"))
3172 }
3173
3174 fn device_info(&self) -> String {
3175 self.name.to_string()
3176 }
3177
3178 fn device_id(&self) -> u32 {
3179 self.device_id
3180 }
3181
3182 fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
3183 self.spawn_concurrency
3184 }
3185 }
3186
3187 static PROVIDER_TEST_LOCK: Lazy<std::sync::Mutex<()>> = Lazy::new(|| std::sync::Mutex::new(()));
3188 static PROVIDER_A: TestProvider = TestProvider {
3189 device_id: 101,
3190 name: "provider-a",
3191 spawn_concurrency: SpawnHandleConcurrency::ImmutableShare,
3192 };
3193 static PROVIDER_B: TestProvider = TestProvider {
3194 device_id: 202,
3195 name: "provider-b",
3196 spawn_concurrency: SpawnHandleConcurrency::Reject,
3197 };
3198 static PROVIDER_C: TestProvider = TestProvider {
3199 device_id: 303,
3200 name: "provider-c",
3201 spawn_concurrency: SpawnHandleConcurrency::CopyOnWrite,
3202 };
3203
3204 fn register_test_providers() {
3205 clear_provider();
3206 unsafe {
3207 register_provider(&PROVIDER_A);
3208 register_provider(&PROVIDER_B);
3209 }
3210 }
3211
3212 fn test_handle(device_id: u32) -> GpuTensorHandle {
3213 GpuTensorHandle {
3214 shape: vec![1],
3215 device_id,
3216 buffer_id: 42,
3217 }
3218 }
3219
3220 fn spectral_request<'a>(
3221 input: &'a GpuTensorHandle,
3222 frame_mode: ProviderSpectralFrameMode,
3223 ) -> ProviderSpectralRequest<'a> {
3224 static WINDOW: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
3225 ProviderSpectralRequest {
3226 input,
3227 input_len: 16,
3228 input_complex: false,
3229 window: &WINDOW,
3230 nfft: 8,
3231 frame_count: 3,
3232 frame_mode,
3233 range: ProviderSpectralRange::Onesided,
3234 denominator: 1.0,
3235 }
3236 }
3237
3238 #[test]
3239 fn provider_envelope_shape_guard_rejects_equal_len_layout_spoofing() {
3240 assert!(provider_envelope_input_shape_matches(&[2, 3], 2, 3));
3241 assert!(provider_envelope_input_shape_matches(&[6, 1], 6, 1));
3242 assert!(provider_envelope_input_shape_matches(&[1, 6], 6, 1));
3243 assert!(provider_envelope_input_shape_matches(&[6], 6, 1));
3244
3245 assert!(!provider_envelope_input_shape_matches(&[3, 2], 2, 3));
3246 assert!(!provider_envelope_input_shape_matches(&[6], 2, 3));
3247 assert!(!provider_envelope_input_shape_matches(&[2, 1, 3], 2, 3));
3248 }
3249
3250 #[test]
3251 fn provider_for_device_prefers_registered_device_over_thread_provider() {
3252 let _lock = PROVIDER_TEST_LOCK
3253 .lock()
3254 .expect("provider test lock poisoned");
3255 register_test_providers();
3256 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
3257
3258 let provider = provider_for_device(PROVIDER_A.device_id()).expect("provider for device");
3259
3260 assert_eq!(provider.device_info(), PROVIDER_A.name);
3261 clear_provider();
3262 }
3263
3264 #[test]
3265 fn provider_for_handle_uses_handle_device_owner() {
3266 let _lock = PROVIDER_TEST_LOCK
3267 .lock()
3268 .expect("provider test lock poisoned");
3269 register_test_providers();
3270 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
3271
3272 let provider =
3273 provider_for_handle(&test_handle(PROVIDER_A.device_id())).expect("provider for handle");
3274
3275 assert_eq!(provider.device_info(), PROVIDER_A.name);
3276 clear_provider();
3277 }
3278
3279 #[test]
3280 fn spawn_handle_concurrency_for_uses_registered_owner() {
3281 let _lock = PROVIDER_TEST_LOCK
3282 .lock()
3283 .expect("provider test lock poisoned");
3284 register_test_providers();
3285 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
3286
3287 let concurrency = spawn_handle_concurrency_for(&test_handle(PROVIDER_A.device_id()))
3288 .expect("spawn concurrency");
3289
3290 assert_eq!(concurrency, PROVIDER_A.spawn_concurrency);
3291 clear_provider();
3292 }
3293
3294 #[test]
3295 fn provider_keeps_thread_local_active_provider_semantics() {
3296 let _lock = PROVIDER_TEST_LOCK
3297 .lock()
3298 .expect("provider test lock poisoned");
3299 register_test_providers();
3300 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_A));
3301
3302 let active = provider().expect("active provider");
3303
3304 assert_eq!(active.device_info(), PROVIDER_A.name);
3305 clear_provider();
3306 }
3307
3308 #[test]
3309 fn unregistered_thread_provider_only_matches_own_device_before_global_fallback() {
3310 let _lock = PROVIDER_TEST_LOCK
3311 .lock()
3312 .expect("provider test lock poisoned");
3313 clear_provider();
3314 unsafe {
3315 register_provider(&PROVIDER_A);
3316 }
3317 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_C));
3318
3319 let own_device = provider_for_device(PROVIDER_C.device_id()).expect("own provider");
3320 let fallback = provider_for_device(404).expect("global fallback provider");
3321
3322 assert_eq!(own_device.device_info(), PROVIDER_C.name);
3323 assert_eq!(fallback.device_info(), PROVIDER_A.name);
3324 clear_provider();
3325 }
3326
3327 #[test]
3328 fn uniform_spectral_request_validates_sliding_input_coverage() {
3329 let input = test_handle(PROVIDER_A.device_id());
3330 let mut request = spectral_request(&input, ProviderSpectralFrameMode::Sliding { hop: 6 });
3331 assert!(validate_uniform_spectral_request(&request).is_ok());
3332
3333 request.input_len = 15;
3334 assert!(validate_uniform_spectral_request(&request).is_err());
3335 }
3336
3337 #[test]
3338 fn uniform_spectral_request_rejects_sliding_coverage_overflow() {
3339 let input = test_handle(PROVIDER_A.device_id());
3340 let mut request = spectral_request(&input, ProviderSpectralFrameMode::Sliding { hop: 2 });
3341 request.frame_count = usize::MAX;
3342
3343 assert!(validate_uniform_spectral_request(&request).is_err());
3344 }
3345
3346 #[test]
3347 fn uniform_spectral_request_validates_folded_input_coverage() {
3348 let input = test_handle(PROVIDER_A.device_id());
3349 let mut request = spectral_request(
3350 &input,
3351 ProviderSpectralFrameMode::FoldedColumns { input_rows: 5 },
3352 );
3353 assert!(validate_uniform_spectral_request(&request).is_ok());
3354
3355 request.frame_mode = ProviderSpectralFrameMode::FoldedColumns { input_rows: 0 };
3356 assert!(validate_uniform_spectral_request(&request).is_err());
3357
3358 request.frame_mode = ProviderSpectralFrameMode::FoldedColumns { input_rows: 6 };
3359 assert!(validate_uniform_spectral_request(&request).is_err());
3360 }
3361
3362 #[test]
3363 fn uniform_spectral_request_rejects_folded_coverage_overflow() {
3364 let input = test_handle(PROVIDER_A.device_id());
3365 let request = spectral_request(
3366 &input,
3367 ProviderSpectralFrameMode::FoldedColumns {
3368 input_rows: usize::MAX,
3369 },
3370 );
3371
3372 assert!(validate_uniform_spectral_request(&request).is_err());
3373 }
3374
3375 #[test]
3376 fn image_normalize_descriptor_omitted_clamp_zero_defaults_true() {
3377 let payload = r#"{
3378 "batch": 2,
3379 "height": 4,
3380 "width": 5,
3381 "epsilon": 0.000001
3382 }"#;
3383
3384 let desc: ImageNormalizeDescriptor =
3385 serde_json::from_str(payload).expect("deserialize descriptor");
3386
3387 assert!(
3388 desc.clamp_zero,
3389 "legacy serialized descriptors should default to clamped image normalize"
3390 );
3391 }
3392
3393 #[test]
3394 fn image_normalize_descriptor_explicit_false_preserves_unclamped() {
3395 let payload = r#"{
3396 "batch": 2,
3397 "height": 4,
3398 "width": 5,
3399 "epsilon": 0.000001,
3400 "clamp_zero": false
3401 }"#;
3402
3403 let desc: ImageNormalizeDescriptor =
3404 serde_json::from_str(payload).expect("deserialize descriptor");
3405
3406 assert!(
3407 !desc.clamp_zero,
3408 "explicit clamp_zero=false should preserve unclamped semantics"
3409 );
3410 }
3411}