1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyMarkFn = fn(&GpuTensorHandle);
17type ResidencyClearFn = fn(&GpuTensorHandle);
18type SequenceThresholdFn = fn() -> Option<usize>;
19type WorkgroupSizeHintFn = fn() -> Option<u32>;
20
21static RESIDENCY_MARK: OnceCell<ResidencyMarkFn> = OnceCell::new();
22static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
23static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
24static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
25
26static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
27static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
28 Lazy::new(|| RwLock::new(HashMap::new()));
29static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
30 Lazy::new(|| RwLock::new(HashMap::new()));
31
32static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
33 Lazy::new(|| RwLock::new(HashMap::new()));
34static HANDLE_STORAGES: Lazy<RwLock<HashMap<u64, GpuTensorStorage>>> =
35 Lazy::new(|| RwLock::new(HashMap::new()));
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct TransposeInfo {
39 pub base_rows: usize,
40 pub base_cols: usize,
41}
42
43pub fn register_residency_mark(handler: ResidencyMarkFn) {
46 let _ = RESIDENCY_MARK.set(handler);
47}
48
49pub fn mark_residency(handle: &GpuTensorHandle) {
52 if let Some(handler) = RESIDENCY_MARK.get() {
53 handler(handle);
54 }
55}
56
57pub fn register_residency_clear(handler: ResidencyClearFn) {
61 let _ = RESIDENCY_CLEAR.set(handler);
62}
63
64pub fn clear_residency(handle: &GpuTensorHandle) {
67 if let Some(handler) = RESIDENCY_CLEAR.get() {
68 handler(handle);
69 }
70}
71
72pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
76 let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
77}
78
79pub fn sequence_threshold_hint() -> Option<usize> {
81 SEQUENCE_THRESHOLD_PROVIDER
82 .get()
83 .and_then(|provider| provider())
84}
85
86pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
90 let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
91}
92
93pub fn workgroup_size_hint() -> Option<u32> {
95 WORKGROUP_SIZE_HINT_PROVIDER
96 .get()
97 .and_then(|provider| provider())
98}
99
100pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
103 provider().and_then(|p| p.export_context(kind))
104}
105
106#[cfg(feature = "wgpu")]
110pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
111 provider().and_then(|p| p.export_wgpu_buffer(handle))
112}
113
114pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
117 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
118 guard.insert(handle.buffer_id, precision);
119 }
120}
121
122pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
124 HANDLE_PRECISIONS
125 .read()
126 .ok()
127 .and_then(|guard| guard.get(&handle.buffer_id).copied())
128}
129
130pub fn clear_handle_precision(handle: &GpuTensorHandle) {
132 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
133 guard.remove(&handle.buffer_id);
134 }
135}
136
137pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
140 if let Ok(mut guard) = LOGICAL_HANDLES.write() {
141 if logical {
142 guard.insert(handle.buffer_id);
143 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
144 *hits.entry(handle.buffer_id).or_insert(0) += 1;
145 }
146 } else {
147 guard.remove(&handle.buffer_id);
148 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
149 hits.remove(&handle.buffer_id);
150 }
151 }
152 }
153}
154
155pub fn clear_handle_logical(handle: &GpuTensorHandle) {
157 set_handle_logical(handle, false);
158}
159
160pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
162 LOGICAL_HANDLES
163 .read()
164 .map(|guard| guard.contains(&handle.buffer_id))
165 .unwrap_or(false)
166}
167
168pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
169 LOGICAL_HANDLE_HITS
170 .read()
171 .ok()
172 .and_then(|guard| guard.get(&buffer_id).copied())
173}
174
175pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
176 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
177 guard.insert(
178 handle.buffer_id,
179 TransposeInfo {
180 base_rows,
181 base_cols,
182 },
183 );
184 }
185}
186
187pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
188 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
189 guard.remove(&handle.buffer_id);
190 }
191}
192
193pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
194 TRANSPOSED_HANDLES
195 .read()
196 .ok()
197 .and_then(|guard| guard.get(&handle.buffer_id).copied())
198}
199
200pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
201 handle_transpose_info(handle).is_some()
202}
203
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum GpuTensorStorage {
206 Real,
207 ComplexInterleaved,
208}
209
210impl Default for GpuTensorStorage {
211 fn default() -> Self {
212 Self::Real
213 }
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct GpuTensorHandle {
218 pub shape: Vec<usize>,
219 pub device_id: u32,
220 pub buffer_id: u64,
221}
222
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct ApiDeviceInfo {
225 pub device_id: u32,
226 pub name: String,
227 pub vendor: String,
228 pub memory_bytes: Option<u64>,
229 pub backend: Option<String>,
230}
231
232#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
233pub struct ReduceDimResult {
234 pub values: GpuTensorHandle,
235 pub indices: GpuTensorHandle,
236}
237
238#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
239pub struct ProviderCumminResult {
240 pub values: GpuTensorHandle,
241 pub indices: GpuTensorHandle,
242}
243
244pub type ProviderCummaxResult = ProviderCumminResult;
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum AccelContextKind {
253 Plotting,
254}
255
256#[derive(Clone)]
258pub enum AccelContextHandle {
259 #[cfg(feature = "wgpu")]
260 Wgpu(WgpuContextHandle),
261}
262
263impl AccelContextHandle {
264 #[cfg(feature = "wgpu")]
266 pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
267 match self {
268 AccelContextHandle::Wgpu(ctx) => Some(ctx),
269 }
270 }
271}
272
273#[cfg(feature = "wgpu")]
275#[derive(Clone)]
276pub struct WgpuContextHandle {
277 pub instance: Arc<wgpu::Instance>,
278 pub device: Arc<wgpu::Device>,
279 pub queue: Arc<wgpu::Queue>,
280 pub adapter: Arc<wgpu::Adapter>,
281 pub adapter_info: wgpu::AdapterInfo,
282 pub limits: wgpu::Limits,
283 pub features: wgpu::Features,
284}
285
286#[cfg(feature = "wgpu")]
288#[derive(Clone)]
289pub struct WgpuBufferRef {
290 pub buffer: Arc<wgpu::Buffer>,
291 pub len: usize,
292 pub shape: Vec<usize>,
293 pub element_size: usize,
294 pub precision: ProviderPrecision,
295}
296
297pub fn set_handle_storage(handle: &GpuTensorHandle, storage: GpuTensorStorage) {
298 if let Ok(mut guard) = HANDLE_STORAGES.write() {
299 guard.insert(handle.buffer_id, storage);
300 }
301}
302
303pub fn handle_storage(handle: &GpuTensorHandle) -> GpuTensorStorage {
304 HANDLE_STORAGES
305 .read()
306 .ok()
307 .and_then(|guard| guard.get(&handle.buffer_id).cloned())
308 .unwrap_or(GpuTensorStorage::Real)
309}
310
311pub fn clear_handle_storage(handle: &GpuTensorHandle) {
312 if let Ok(mut guard) = HANDLE_STORAGES.write() {
313 guard.remove(&handle.buffer_id);
314 }
315}
316
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
318pub enum PagefunOp {
319 Mtimes,
320}
321
322#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
323pub struct PagefunRequest {
324 pub op: PagefunOp,
325 pub inputs: Vec<GpuTensorHandle>,
326 pub output_shape: Vec<usize>,
327 pub page_dims: Vec<usize>,
328 pub input_page_dims: Vec<Vec<usize>>,
329}
330
331#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
332pub enum FindDirection {
333 First,
334 Last,
335}
336
337#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
338pub struct ProviderFindResult {
339 pub linear: GpuTensorHandle,
340 pub rows: GpuTensorHandle,
341 pub cols: GpuTensorHandle,
342 pub values: Option<GpuTensorHandle>,
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
346pub struct ProviderBandwidth {
347 pub lower: u32,
348 pub upper: u32,
349}
350
351#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
352pub enum ProviderSymmetryKind {
353 Symmetric,
354 Skew,
355}
356
357#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
358pub enum ProviderHermitianKind {
359 Hermitian,
360 Skew,
361}
362
363#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
364pub struct ProviderLuResult {
365 pub combined: GpuTensorHandle,
366 pub lower: GpuTensorHandle,
367 pub upper: GpuTensorHandle,
368 pub perm_matrix: GpuTensorHandle,
369 pub perm_vector: GpuTensorHandle,
370}
371
372#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
373pub struct ProviderCholResult {
374 pub factor: GpuTensorHandle,
375 pub info: u32,
377}
378
379#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
380pub struct ProviderQrResult {
381 pub q: GpuTensorHandle,
382 pub r: GpuTensorHandle,
383 pub perm_matrix: GpuTensorHandle,
384 pub perm_vector: GpuTensorHandle,
385}
386
387#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
388pub struct ProviderQrPowerIterResult {
389 pub q: GpuTensorHandle,
390 pub r: GpuTensorHandle,
391 pub perm_matrix: GpuTensorHandle,
392 pub perm_vector: GpuTensorHandle,
393}
394
395#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
396pub struct ProviderLinsolveOptions {
397 pub lower: bool,
398 pub upper: bool,
399 pub rectangular: bool,
400 pub transposed: bool,
401 pub conjugate: bool,
402 pub symmetric: bool,
403 pub posdef: bool,
404 pub need_rcond: bool,
405 pub rcond: Option<f64>,
406}
407
408#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
409pub struct ProviderLinsolveResult {
410 pub solution: GpuTensorHandle,
411 pub reciprocal_condition: f64,
412}
413
414#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
415pub struct ProviderPinvOptions {
416 pub tolerance: Option<f64>,
417}
418
419#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
420pub struct ProviderPolyvalMu {
421 pub mean: f64,
422 pub scale: f64,
423}
424
425#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
426pub struct ProviderPolyvalOptions {
427 pub mu: Option<ProviderPolyvalMu>,
428}
429
430#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
431pub struct ProviderInvOptions {}
432
433#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
434pub struct ProviderPolyfitResult {
435 pub coefficients: Vec<f64>,
436 pub r_matrix: Vec<f64>,
437 pub normr: f64,
438 pub df: f64,
439 pub mu: [f64; 2],
440}
441
442#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
444pub struct ProviderPolyderQuotient {
445 pub numerator: GpuTensorHandle,
446 pub denominator: GpuTensorHandle,
447}
448
449#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
451pub enum ProviderCondNorm {
452 Two,
453 One,
454 Inf,
455 Fro,
456}
457
458#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
460pub enum ProviderNormOrder {
461 Two,
462 One,
463 Inf,
464 NegInf,
465 Zero,
466 Fro,
467 Nuc,
468 P(f64),
469}
470
471#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
472pub struct ProviderEigResult {
473 pub eigenvalues: GpuTensorHandle,
474 pub diagonal: GpuTensorHandle,
475 pub right: GpuTensorHandle,
476 pub left: Option<GpuTensorHandle>,
477}
478
479#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
480pub enum ProviderQrPivot {
481 Matrix,
482 Vector,
483}
484
485#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
486pub struct ProviderQrOptions {
487 pub economy: bool,
488 pub pivot: ProviderQrPivot,
489}
490
491impl Default for ProviderQrOptions {
492 fn default() -> Self {
493 Self {
494 economy: false,
495 pivot: ProviderQrPivot::Matrix,
496 }
497 }
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
501pub enum ProviderPrecision {
502 F32,
503 F64,
504}
505
506#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
507pub enum ReductionTwoPassMode {
508 Auto,
509 ForceOn,
510 ForceOff,
511}
512
513impl ReductionTwoPassMode {
514 pub fn as_str(self) -> &'static str {
515 match self {
516 ReductionTwoPassMode::Auto => "auto",
517 ReductionTwoPassMode::ForceOn => "force_on",
518 ReductionTwoPassMode::ForceOff => "force_off",
519 }
520 }
521}
522
523#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
524pub enum ReductionFlavor {
525 Sum,
526 Mean,
527 CustomScale(f64),
528}
529
530impl ReductionFlavor {
531 pub fn is_mean(self) -> bool {
532 matches!(self, ReductionFlavor::Mean)
533 }
534
535 pub fn scale(self, reduce_len: usize) -> f64 {
536 match self {
537 ReductionFlavor::Sum => 1.0,
538 ReductionFlavor::Mean => {
539 if reduce_len == 0 {
540 1.0
541 } else {
542 1.0 / reduce_len as f64
543 }
544 }
545 ReductionFlavor::CustomScale(scale) => scale,
546 }
547 }
548}
549
550#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
552pub enum CorrcoefNormalization {
553 Unbiased,
554 Biased,
555}
556
557#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
559pub enum CorrcoefRows {
560 All,
561 Complete,
562 Pairwise,
563}
564
565#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
567pub struct CorrcoefOptions {
568 pub normalization: CorrcoefNormalization,
569 pub rows: CorrcoefRows,
570}
571
572impl Default for CorrcoefOptions {
573 fn default() -> Self {
574 Self {
575 normalization: CorrcoefNormalization::Unbiased,
576 rows: CorrcoefRows::All,
577 }
578 }
579}
580
581#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
583pub enum CovNormalization {
584 Unbiased,
585 Biased,
586}
587
588#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
590pub enum CovRows {
591 All,
592 OmitRows,
593 PartialRows,
594}
595
596#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
598pub struct CovarianceOptions {
599 pub normalization: CovNormalization,
600 pub rows: CovRows,
601 pub has_weight_vector: bool,
602}
603
604impl Default for CovarianceOptions {
605 fn default() -> Self {
606 Self {
607 normalization: CovNormalization::Unbiased,
608 rows: CovRows::All,
609 has_weight_vector: false,
610 }
611 }
612}
613
614#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
616pub enum ProviderStdNormalization {
617 Sample,
618 Population,
619}
620
621#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
623pub enum ProviderNanMode {
624 Include,
625 Omit,
626}
627
628#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
630pub enum ProviderScanDirection {
631 Forward,
632 Reverse,
633}
634
635#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
637pub enum SortOrder {
638 Ascend,
639 Descend,
640}
641
642#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
644pub enum SortComparison {
645 Auto,
646 Real,
647 Abs,
648}
649
650#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
652pub struct SortResult {
653 pub values: HostTensorOwned,
654 pub indices: HostTensorOwned,
655}
656
657#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
658pub struct SortRowsColumnSpec {
659 pub index: usize,
660 pub order: SortOrder,
661}
662
663#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
665pub enum UniqueOrder {
666 Sorted,
667 Stable,
668}
669
670#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
672pub enum UniqueOccurrence {
673 First,
674 Last,
675}
676
677#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
679pub struct UniqueOptions {
680 pub rows: bool,
681 pub order: UniqueOrder,
682 pub occurrence: UniqueOccurrence,
683}
684
685#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
687pub struct UniqueResult {
688 pub values: HostTensorOwned,
689 pub ia: HostTensorOwned,
690 pub ic: HostTensorOwned,
691}
692
693#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
695pub enum UnionOrder {
696 Sorted,
697 Stable,
698}
699
700#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
702pub struct UnionOptions {
703 pub rows: bool,
704 pub order: UnionOrder,
705}
706
707#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
709pub struct UnionResult {
710 pub values: HostTensorOwned,
711 pub ia: HostTensorOwned,
712 pub ib: HostTensorOwned,
713}
714
715#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
717pub enum FspecialFilter {
718 Average {
719 rows: u32,
720 cols: u32,
721 },
722 Disk {
723 radius: f64,
724 size: u32,
725 },
726 Gaussian {
727 rows: u32,
728 cols: u32,
729 sigma: f64,
730 },
731 Laplacian {
732 alpha: f64,
733 },
734 Log {
735 rows: u32,
736 cols: u32,
737 sigma: f64,
738 },
739 Motion {
740 length: u32,
741 kernel_size: u32,
742 angle_degrees: f64,
743 oversample: u32,
744 },
745 Prewitt,
746 Sobel,
747 Unsharp {
748 alpha: f64,
749 },
750}
751
752#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
754pub struct FspecialRequest {
755 pub filter: FspecialFilter,
756}
757
758#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
760pub enum ImfilterPadding {
761 Constant,
762 Replicate,
763 Symmetric,
764 Circular,
765}
766
767#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
769pub enum ImfilterShape {
770 Same,
771 Full,
772 Valid,
773}
774
775#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
777pub enum ImfilterMode {
778 Correlation,
779 Convolution,
780}
781
782#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
784pub struct ImfilterOptions {
785 pub padding: ImfilterPadding,
786 pub constant_value: f64,
787 pub shape: ImfilterShape,
788 pub mode: ImfilterMode,
789}
790
791impl Default for ImfilterOptions {
792 fn default() -> Self {
793 Self {
794 padding: ImfilterPadding::Constant,
795 constant_value: 0.0,
796 shape: ImfilterShape::Same,
797 mode: ImfilterMode::Correlation,
798 }
799 }
800}
801
802#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
804pub enum SetdiffOrder {
805 Sorted,
806 Stable,
807}
808
809#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
811pub struct SetdiffOptions {
812 pub rows: bool,
813 pub order: SetdiffOrder,
814}
815
816#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
818pub struct SetdiffResult {
819 pub values: HostTensorOwned,
820 pub ia: HostTensorOwned,
821}
822
823#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
825pub struct IsMemberOptions {
826 pub rows: bool,
827}
828
829#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
831pub struct HostLogicalOwned {
832 pub data: Vec<u8>,
833 pub shape: Vec<usize>,
834}
835
836#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
838pub struct IsMemberResult {
839 pub mask: HostLogicalOwned,
840 pub loc: HostTensorOwned,
841}
842
843#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
844pub enum ProviderConvMode {
845 Full,
846 Same,
847 Valid,
848}
849
850#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
851pub enum ProviderConvOrientation {
852 Row,
853 Column,
854}
855
856#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
857pub struct ProviderConv1dOptions {
858 pub mode: ProviderConvMode,
859 pub orientation: ProviderConvOrientation,
860}
861
862#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
863pub struct ProviderIirFilterOptions {
864 pub dim: usize,
866 pub zi: Option<GpuTensorHandle>,
868}
869
870#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
871pub struct ProviderIirFilterResult {
872 pub output: GpuTensorHandle,
874 pub final_state: Option<GpuTensorHandle>,
876}
877
878#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
879pub struct ProviderMoments2 {
880 pub mean: GpuTensorHandle,
881 pub ex2: GpuTensorHandle,
882}
883
884#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
885pub struct ProviderDispatchStats {
886 pub count: u64,
888 pub total_wall_time_ns: u64,
890}
891
892#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
893pub struct ProviderFallbackStat {
894 pub reason: String,
895 pub count: u64,
896}
897
898#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
899pub struct ProviderTelemetry {
900 pub fused_elementwise: ProviderDispatchStats,
901 pub fused_reduction: ProviderDispatchStats,
902 pub matmul: ProviderDispatchStats,
903 pub linsolve: ProviderDispatchStats,
904 pub mldivide: ProviderDispatchStats,
905 pub mrdivide: ProviderDispatchStats,
906 pub upload_bytes: u64,
907 pub download_bytes: u64,
908 pub solve_fallbacks: Vec<ProviderFallbackStat>,
909 pub fusion_cache_hits: u64,
910 pub fusion_cache_misses: u64,
911 pub bind_group_cache_hits: u64,
912 pub bind_group_cache_misses: u64,
913 pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
915 pub kernel_launches: Vec<KernelLaunchTelemetry>,
917}
918
919#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
920pub struct BindGroupLayoutTelemetry {
921 pub tag: String,
922 pub hits: u64,
923 pub misses: u64,
924}
925
926#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
927pub struct KernelAttrTelemetry {
928 pub key: String,
929 pub value: u64,
930}
931
932#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
933pub struct KernelLaunchTelemetry {
934 pub kernel: String,
935 pub precision: Option<String>,
936 pub shape: Vec<KernelAttrTelemetry>,
937 pub tuning: Vec<KernelAttrTelemetry>,
938}
939
940pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
941pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
942
943fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
944 Box::pin(async move { Err(anyhow::anyhow!(message)) })
945}
946
947pub trait AccelProvider: Send + Sync {
949 fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
950 fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
951 fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
952 fn device_info(&self) -> String;
953 fn device_id(&self) -> u32 {
954 0
955 }
956
957 fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
960 None
961 }
962
963 #[cfg(feature = "wgpu")]
965 fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
966 let _ = _handle;
967 None
968 }
969
970 fn gather_linear(
973 &self,
974 _source: &GpuTensorHandle,
975 _indices: &[u32],
976 _output_shape: &[usize],
977 ) -> anyhow::Result<GpuTensorHandle> {
978 Err(anyhow::anyhow!("gather_linear not supported by provider"))
979 }
980
981 fn scatter_linear(
985 &self,
986 _target: &GpuTensorHandle,
987 _indices: &[u32],
988 _values: &GpuTensorHandle,
989 ) -> anyhow::Result<()> {
990 Err(anyhow::anyhow!("scatter_linear not supported by provider"))
991 }
992
993 fn device_info_struct(&self) -> ApiDeviceInfo {
995 ApiDeviceInfo {
996 device_id: 0,
997 name: self.device_info(),
998 vendor: String::new(),
999 memory_bytes: None,
1000 backend: None,
1001 }
1002 }
1003
1004 fn precision(&self) -> ProviderPrecision {
1005 ProviderPrecision::F64
1006 }
1007
1008 fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
1010 Err(anyhow::anyhow!("read_scalar not supported by provider"))
1011 }
1012
1013 fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1015 Err(anyhow::anyhow!("zeros not supported by provider"))
1016 }
1017
1018 fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1020 Err(anyhow::anyhow!("ones not supported by provider"))
1021 }
1022
1023 fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1025 self.zeros(&prototype.shape)
1026 }
1027
1028 fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
1030 if value == 0.0 {
1031 return self.zeros(shape);
1032 }
1033 if let Ok(base) = self.zeros(shape) {
1034 match self.scalar_add(&base, value) {
1035 Ok(out) => {
1036 let _ = self.free(&base);
1037 return Ok(out);
1038 }
1039 Err(_) => {
1040 let _ = self.free(&base);
1041 }
1042 }
1043 }
1044 let len: usize = shape.iter().copied().product();
1045 let data = vec![value; len];
1046 let view = HostTensorView { data: &data, shape };
1047 self.upload(&view)
1048 }
1049
1050 fn fill_like(
1052 &self,
1053 prototype: &GpuTensorHandle,
1054 value: f64,
1055 ) -> anyhow::Result<GpuTensorHandle> {
1056 if value == 0.0 {
1057 return self.zeros_like(prototype);
1058 }
1059 if let Ok(base) = self.zeros_like(prototype) {
1060 match self.scalar_add(&base, value) {
1061 Ok(out) => {
1062 let _ = self.free(&base);
1063 return Ok(out);
1064 }
1065 Err(_) => {
1066 let _ = self.free(&base);
1067 }
1068 }
1069 }
1070 self.fill(&prototype.shape, value)
1071 }
1072
1073 fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1075 self.ones(&prototype.shape)
1076 }
1077
1078 fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1080 Err(anyhow::anyhow!("eye not supported by provider"))
1081 }
1082
1083 fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1085 self.eye(&prototype.shape)
1086 }
1087
1088 fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1090 Err(anyhow::anyhow!("meshgrid not supported by provider"))
1091 }
1092
1093 fn diag_from_vector(
1095 &self,
1096 _vector: &GpuTensorHandle,
1097 _offset: isize,
1098 ) -> anyhow::Result<GpuTensorHandle> {
1099 Err(anyhow::anyhow!(
1100 "diag_from_vector not supported by provider"
1101 ))
1102 }
1103
1104 fn diag_extract(
1106 &self,
1107 _matrix: &GpuTensorHandle,
1108 _offset: isize,
1109 ) -> anyhow::Result<GpuTensorHandle> {
1110 Err(anyhow::anyhow!("diag_extract not supported by provider"))
1111 }
1112
1113 fn tril<'a>(
1115 &'a self,
1116 _matrix: &'a GpuTensorHandle,
1117 _offset: isize,
1118 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1119 Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1120 }
1121
1122 fn triu<'a>(
1124 &'a self,
1125 _matrix: &'a GpuTensorHandle,
1126 _offset: isize,
1127 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1128 Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1129 }
1130
1131 fn polyval(
1133 &self,
1134 _coefficients: &GpuTensorHandle,
1135 _points: &GpuTensorHandle,
1136 _options: &ProviderPolyvalOptions,
1137 ) -> anyhow::Result<GpuTensorHandle> {
1138 Err(anyhow::anyhow!("polyval not supported by provider"))
1139 }
1140
1141 fn polyfit<'a>(
1143 &'a self,
1144 _x: &'a GpuTensorHandle,
1145 _y: &'a GpuTensorHandle,
1146 _degree: usize,
1147 _weights: Option<&'a GpuTensorHandle>,
1148 ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1149 Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1150 }
1151
1152 fn polyder_single<'a>(
1154 &'a self,
1155 _polynomial: &'a GpuTensorHandle,
1156 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1157 Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1158 }
1159
1160 fn polyder_product<'a>(
1162 &'a self,
1163 _p: &'a GpuTensorHandle,
1164 _q: &'a GpuTensorHandle,
1165 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1166 Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1167 }
1168
1169 fn polyder_quotient<'a>(
1171 &'a self,
1172 _u: &'a GpuTensorHandle,
1173 _v: &'a GpuTensorHandle,
1174 ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1175 Box::pin(async move {
1176 Err(anyhow::anyhow!(
1177 "polyder_quotient not supported by provider"
1178 ))
1179 })
1180 }
1181
1182 fn polyint(
1184 &self,
1185 _polynomial: &GpuTensorHandle,
1186 _constant: f64,
1187 ) -> anyhow::Result<GpuTensorHandle> {
1188 Err(anyhow::anyhow!("polyint not supported by provider"))
1189 }
1190
1191 fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1193 Err(anyhow::anyhow!("random_uniform not supported by provider"))
1194 }
1195
1196 fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1198 self.random_uniform(&prototype.shape)
1199 }
1200
1201 fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1203 Err(anyhow::anyhow!("random_normal not supported by provider"))
1204 }
1205
1206 fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1208 self.random_normal(&prototype.shape)
1209 }
1210
1211 fn stochastic_evolution(
1212 &self,
1213 _state: &GpuTensorHandle,
1214 _drift: f64,
1215 _scale: f64,
1216 _steps: u32,
1217 ) -> anyhow::Result<GpuTensorHandle> {
1218 Err(anyhow::anyhow!(
1219 "stochastic_evolution not supported by provider"
1220 ))
1221 }
1222
1223 fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1225 Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1226 }
1227
1228 fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1230 Err(anyhow::anyhow!("fspecial not supported by provider"))
1231 }
1232
1233 fn peaks(&self, _n: usize) -> anyhow::Result<GpuTensorHandle> {
1236 Err(anyhow::anyhow!("peaks not supported by provider"))
1237 }
1238
1239 fn peaks_xy(
1242 &self,
1243 _x: &GpuTensorHandle,
1244 _y: &GpuTensorHandle,
1245 ) -> anyhow::Result<GpuTensorHandle> {
1246 Err(anyhow::anyhow!("peaks_xy not supported by provider"))
1247 }
1248
1249 fn hann_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1250 Err(anyhow::anyhow!("hann_window not supported by provider"))
1251 }
1252
1253 fn hamming_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1254 Err(anyhow::anyhow!("hamming_window not supported by provider"))
1255 }
1256
1257 fn blackman_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1258 Err(anyhow::anyhow!("blackman_window not supported by provider"))
1259 }
1260
1261 fn imfilter<'a>(
1263 &'a self,
1264 _image: &'a GpuTensorHandle,
1265 _kernel: &'a GpuTensorHandle,
1266 _options: &'a ImfilterOptions,
1267 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1268 unsupported_future("imfilter not supported by provider")
1269 }
1270
1271 fn random_integer_range(
1273 &self,
1274 _lower: i64,
1275 _upper: i64,
1276 _shape: &[usize],
1277 ) -> anyhow::Result<GpuTensorHandle> {
1278 Err(anyhow::anyhow!(
1279 "random_integer_range not supported by provider"
1280 ))
1281 }
1282
1283 fn random_integer_like(
1285 &self,
1286 prototype: &GpuTensorHandle,
1287 lower: i64,
1288 upper: i64,
1289 ) -> anyhow::Result<GpuTensorHandle> {
1290 self.random_integer_range(lower, upper, &prototype.shape)
1291 }
1292
1293 fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1295 Err(anyhow!("random_permutation not supported by provider"))
1296 }
1297
1298 fn random_permutation_like(
1300 &self,
1301 _prototype: &GpuTensorHandle,
1302 n: usize,
1303 k: usize,
1304 ) -> anyhow::Result<GpuTensorHandle> {
1305 self.random_permutation(n, k)
1306 }
1307
1308 fn covariance<'a>(
1310 &'a self,
1311 _matrix: &'a GpuTensorHandle,
1312 _second: Option<&'a GpuTensorHandle>,
1313 _weights: Option<&'a GpuTensorHandle>,
1314 _options: &'a CovarianceOptions,
1315 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1316 unsupported_future("covariance not supported by provider")
1317 }
1318
1319 fn corrcoef<'a>(
1321 &'a self,
1322 _matrix: &'a GpuTensorHandle,
1323 _options: &'a CorrcoefOptions,
1324 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1325 unsupported_future("corrcoef not supported by provider")
1326 }
1327
1328 fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1330 Err(anyhow::anyhow!("linspace not supported by provider"))
1331 }
1332 fn elem_add<'a>(
1333 &'a self,
1334 _a: &'a GpuTensorHandle,
1335 _b: &'a GpuTensorHandle,
1336 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1337 unsupported_future("elem_add not supported by provider")
1338 }
1339 fn elem_mul<'a>(
1340 &'a self,
1341 _a: &'a GpuTensorHandle,
1342 _b: &'a GpuTensorHandle,
1343 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1344 unsupported_future("elem_mul not supported by provider")
1345 }
1346 fn elem_max<'a>(
1347 &'a self,
1348 _a: &'a GpuTensorHandle,
1349 _b: &'a GpuTensorHandle,
1350 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1351 unsupported_future("elem_max not supported by provider")
1352 }
1353 fn elem_min<'a>(
1354 &'a self,
1355 _a: &'a GpuTensorHandle,
1356 _b: &'a GpuTensorHandle,
1357 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1358 unsupported_future("elem_min not supported by provider")
1359 }
1360 fn elem_sub<'a>(
1361 &'a self,
1362 _a: &'a GpuTensorHandle,
1363 _b: &'a GpuTensorHandle,
1364 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1365 unsupported_future("elem_sub not supported by provider")
1366 }
1367 fn elem_div<'a>(
1368 &'a self,
1369 _a: &'a GpuTensorHandle,
1370 _b: &'a GpuTensorHandle,
1371 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1372 unsupported_future("elem_div not supported by provider")
1373 }
1374 fn elem_pow<'a>(
1375 &'a self,
1376 _a: &'a GpuTensorHandle,
1377 _b: &'a GpuTensorHandle,
1378 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1379 unsupported_future("elem_pow not supported by provider")
1380 }
1381
1382 fn elem_hypot<'a>(
1383 &'a self,
1384 _a: &'a GpuTensorHandle,
1385 _b: &'a GpuTensorHandle,
1386 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1387 unsupported_future("elem_hypot not supported by provider")
1388 }
1389 fn elem_ge<'a>(
1390 &'a self,
1391 _a: &'a GpuTensorHandle,
1392 _b: &'a GpuTensorHandle,
1393 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1394 unsupported_future("elem_ge not supported by provider")
1395 }
1396 fn elem_le<'a>(
1397 &'a self,
1398 _a: &'a GpuTensorHandle,
1399 _b: &'a GpuTensorHandle,
1400 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1401 unsupported_future("elem_le not supported by provider")
1402 }
1403 fn elem_lt<'a>(
1404 &'a self,
1405 _a: &'a GpuTensorHandle,
1406 _b: &'a GpuTensorHandle,
1407 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1408 unsupported_future("elem_lt not supported by provider")
1409 }
1410 fn elem_gt<'a>(
1411 &'a self,
1412 _a: &'a GpuTensorHandle,
1413 _b: &'a GpuTensorHandle,
1414 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1415 unsupported_future("elem_gt not supported by provider")
1416 }
1417 fn elem_eq<'a>(
1418 &'a self,
1419 _a: &'a GpuTensorHandle,
1420 _b: &'a GpuTensorHandle,
1421 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1422 unsupported_future("elem_eq not supported by provider")
1423 }
1424 fn elem_ne<'a>(
1425 &'a self,
1426 _a: &'a GpuTensorHandle,
1427 _b: &'a GpuTensorHandle,
1428 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1429 unsupported_future("elem_ne not supported by provider")
1430 }
1431 fn logical_and(
1432 &self,
1433 _a: &GpuTensorHandle,
1434 _b: &GpuTensorHandle,
1435 ) -> anyhow::Result<GpuTensorHandle> {
1436 Err(anyhow::anyhow!("logical_and not supported by provider"))
1437 }
1438 fn logical_or(
1439 &self,
1440 _a: &GpuTensorHandle,
1441 _b: &GpuTensorHandle,
1442 ) -> anyhow::Result<GpuTensorHandle> {
1443 Err(anyhow::anyhow!("logical_or not supported by provider"))
1444 }
1445 fn logical_xor(
1446 &self,
1447 _a: &GpuTensorHandle,
1448 _b: &GpuTensorHandle,
1449 ) -> anyhow::Result<GpuTensorHandle> {
1450 Err(anyhow::anyhow!("logical_xor not supported by provider"))
1451 }
1452 fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1453 Err(anyhow::anyhow!("logical_not not supported by provider"))
1454 }
1455 fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1456 Ok(handle_is_logical(a))
1457 }
1458 fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1459 Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1460 }
1461 fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1462 Err(anyhow::anyhow!(
1463 "logical_isfinite not supported by provider"
1464 ))
1465 }
1466 fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1467 Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1468 }
1469 fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1470 Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1471 }
1472 fn elem_atan2<'a>(
1473 &'a self,
1474 _y: &'a GpuTensorHandle,
1475 _x: &'a GpuTensorHandle,
1476 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1477 unsupported_future("elem_atan2 not supported by provider")
1478 }
1479 fn unary_sin<'a>(
1481 &'a self,
1482 _a: &'a GpuTensorHandle,
1483 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1484 unsupported_future("unary_sin not supported by provider")
1485 }
1486 fn unary_gamma<'a>(
1487 &'a self,
1488 _a: &'a GpuTensorHandle,
1489 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1490 unsupported_future("unary_gamma not supported by provider")
1491 }
1492 fn unary_factorial<'a>(
1493 &'a self,
1494 _a: &'a GpuTensorHandle,
1495 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1496 unsupported_future("unary_factorial not supported by provider")
1497 }
1498 fn unary_asinh<'a>(
1499 &'a self,
1500 _a: &'a GpuTensorHandle,
1501 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1502 unsupported_future("unary_asinh not supported by provider")
1503 }
1504 fn unary_sinh<'a>(
1505 &'a self,
1506 _a: &'a GpuTensorHandle,
1507 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1508 unsupported_future("unary_sinh not supported by provider")
1509 }
1510 fn unary_cosh<'a>(
1511 &'a self,
1512 _a: &'a GpuTensorHandle,
1513 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1514 unsupported_future("unary_cosh not supported by provider")
1515 }
1516 fn unary_asin<'a>(
1517 &'a self,
1518 _a: &'a GpuTensorHandle,
1519 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1520 unsupported_future("unary_asin not supported by provider")
1521 }
1522 fn unary_acos<'a>(
1523 &'a self,
1524 _a: &'a GpuTensorHandle,
1525 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1526 unsupported_future("unary_acos not supported by provider")
1527 }
1528 fn unary_acosh<'a>(
1529 &'a self,
1530 _a: &'a GpuTensorHandle,
1531 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1532 unsupported_future("unary_acosh not supported by provider")
1533 }
1534 fn unary_tan<'a>(
1535 &'a self,
1536 _a: &'a GpuTensorHandle,
1537 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1538 unsupported_future("unary_tan not supported by provider")
1539 }
1540 fn unary_tanh<'a>(
1541 &'a self,
1542 _a: &'a GpuTensorHandle,
1543 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1544 unsupported_future("unary_tanh not supported by provider")
1545 }
1546 fn unary_atan<'a>(
1547 &'a self,
1548 _a: &'a GpuTensorHandle,
1549 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1550 unsupported_future("unary_atan not supported by provider")
1551 }
1552 fn unary_atanh<'a>(
1553 &'a self,
1554 _a: &'a GpuTensorHandle,
1555 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1556 unsupported_future("unary_atanh not supported by provider")
1557 }
1558 fn unary_ceil<'a>(
1559 &'a self,
1560 _a: &'a GpuTensorHandle,
1561 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1562 unsupported_future("unary_ceil not supported by provider")
1563 }
1564 fn unary_floor<'a>(
1565 &'a self,
1566 _a: &'a GpuTensorHandle,
1567 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1568 unsupported_future("unary_floor not supported by provider")
1569 }
1570 fn unary_round<'a>(
1571 &'a self,
1572 _a: &'a GpuTensorHandle,
1573 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1574 unsupported_future("unary_round not supported by provider")
1575 }
1576 fn unary_fix<'a>(
1577 &'a self,
1578 _a: &'a GpuTensorHandle,
1579 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1580 unsupported_future("unary_fix not supported by provider")
1581 }
1582 fn unary_cos<'a>(
1583 &'a self,
1584 _a: &'a GpuTensorHandle,
1585 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1586 unsupported_future("unary_cos not supported by provider")
1587 }
1588 fn unary_angle<'a>(
1589 &'a self,
1590 _a: &'a GpuTensorHandle,
1591 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1592 unsupported_future("unary_angle not supported by provider")
1593 }
1594 fn unary_imag<'a>(
1595 &'a self,
1596 _a: &'a GpuTensorHandle,
1597 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1598 unsupported_future("unary_imag not supported by provider")
1599 }
1600 fn unary_real<'a>(
1601 &'a self,
1602 _a: &'a GpuTensorHandle,
1603 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1604 unsupported_future("unary_real not supported by provider")
1605 }
1606 fn unary_conj<'a>(
1607 &'a self,
1608 _a: &'a GpuTensorHandle,
1609 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1610 unsupported_future("unary_conj not supported by provider")
1611 }
1612 fn unary_abs<'a>(
1613 &'a self,
1614 _a: &'a GpuTensorHandle,
1615 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1616 unsupported_future("unary_abs not supported by provider")
1617 }
1618 fn unary_sign<'a>(
1619 &'a self,
1620 _a: &'a GpuTensorHandle,
1621 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1622 unsupported_future("unary_sign not supported by provider")
1623 }
1624 fn unary_exp<'a>(
1625 &'a self,
1626 _a: &'a GpuTensorHandle,
1627 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1628 unsupported_future("unary_exp not supported by provider")
1629 }
1630 fn unary_expm1<'a>(
1631 &'a self,
1632 _a: &'a GpuTensorHandle,
1633 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1634 unsupported_future("unary_expm1 not supported by provider")
1635 }
1636 fn unary_log<'a>(
1637 &'a self,
1638 _a: &'a GpuTensorHandle,
1639 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1640 unsupported_future("unary_log not supported by provider")
1641 }
1642 fn unary_log2<'a>(
1643 &'a self,
1644 _a: &'a GpuTensorHandle,
1645 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1646 unsupported_future("unary_log2 not supported by provider")
1647 }
1648 fn unary_log10<'a>(
1649 &'a self,
1650 _a: &'a GpuTensorHandle,
1651 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1652 unsupported_future("unary_log10 not supported by provider")
1653 }
1654 fn unary_log1p<'a>(
1655 &'a self,
1656 _a: &'a GpuTensorHandle,
1657 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1658 unsupported_future("unary_log1p not supported by provider")
1659 }
1660 fn unary_sqrt<'a>(
1661 &'a self,
1662 _a: &'a GpuTensorHandle,
1663 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1664 unsupported_future("unary_sqrt not supported by provider")
1665 }
1666 fn unary_double<'a>(
1667 &'a self,
1668 _a: &'a GpuTensorHandle,
1669 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1670 unsupported_future("unary_double not supported by provider")
1671 }
1672 fn unary_single<'a>(
1673 &'a self,
1674 _a: &'a GpuTensorHandle,
1675 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1676 unsupported_future("unary_single not supported by provider")
1677 }
1678 fn unary_pow2<'a>(
1679 &'a self,
1680 _a: &'a GpuTensorHandle,
1681 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1682 unsupported_future("unary_pow2 not supported by provider")
1683 }
1684 fn unary_nextpow2<'a>(
1685 &'a self,
1686 _a: &'a GpuTensorHandle,
1687 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1688 unsupported_future("unary_nextpow2 not supported by provider")
1689 }
1690 fn pow2_scale(
1691 &self,
1692 _mantissa: &GpuTensorHandle,
1693 _exponent: &GpuTensorHandle,
1694 ) -> anyhow::Result<GpuTensorHandle> {
1695 Err(anyhow::anyhow!("pow2_scale not supported by provider"))
1696 }
1697 fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1699 Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
1700 }
1701 fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1702 Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
1703 }
1704 fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1706 Err(anyhow::anyhow!("scalar_add not supported by provider"))
1707 }
1708 fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1709 Err(anyhow::anyhow!("scalar_sub not supported by provider"))
1710 }
1711 fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1712 Err(anyhow::anyhow!("scalar_mul not supported by provider"))
1713 }
1714 fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1715 Err(anyhow::anyhow!("scalar_max not supported by provider"))
1716 }
1717 fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1718 Err(anyhow::anyhow!("scalar_min not supported by provider"))
1719 }
1720 fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1721 Err(anyhow::anyhow!("scalar_div not supported by provider"))
1722 }
1723 fn sort_dim<'a>(
1724 &'a self,
1725 _a: &'a GpuTensorHandle,
1726 _dim: usize,
1727 _order: SortOrder,
1728 _comparison: SortComparison,
1729 ) -> AccelProviderFuture<'a, SortResult> {
1730 unsupported_future("sort_dim not supported by provider")
1731 }
1732 fn sort_rows<'a>(
1733 &'a self,
1734 _a: &'a GpuTensorHandle,
1735 _columns: &'a [SortRowsColumnSpec],
1736 _comparison: SortComparison,
1737 ) -> AccelProviderFuture<'a, SortResult> {
1738 unsupported_future("sort_rows not supported by provider")
1739 }
1740 fn matmul<'a>(
1741 &'a self,
1742 _a: &'a GpuTensorHandle,
1743 _b: &'a GpuTensorHandle,
1744 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1745 unsupported_future("matmul not supported by provider")
1746 }
1747
1748 fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1749 Err(anyhow::anyhow!("syrk not supported by provider"))
1750 }
1751 fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
1752 Err(anyhow::anyhow!("pagefun not supported by provider"))
1753 }
1754
1755 fn matmul_epilogue<'a>(
1760 &'a self,
1761 a: &'a GpuTensorHandle,
1762 b: &'a GpuTensorHandle,
1763 epilogue: &'a MatmulEpilogue,
1764 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1765 Box::pin(async move {
1766 if epilogue.is_noop() {
1767 return self.matmul(a, b).await;
1768 }
1769 Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
1770 })
1771 }
1772 fn image_normalize<'a>(
1773 &'a self,
1774 _input: &'a GpuTensorHandle,
1775 _desc: &'a ImageNormalizeDescriptor,
1776 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1777 unsupported_future("image_normalize fusion not supported by provider")
1778 }
1779 fn matmul_power_step<'a>(
1780 &'a self,
1781 _lhs: &'a GpuTensorHandle,
1782 _rhs: &'a GpuTensorHandle,
1783 _epilogue: &'a PowerStepEpilogue,
1784 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1785 unsupported_future("matmul_power_step normalization not supported by provider")
1786 }
1787 fn linsolve<'a>(
1788 &'a self,
1789 _lhs: &'a GpuTensorHandle,
1790 _rhs: &'a GpuTensorHandle,
1791 _options: &'a ProviderLinsolveOptions,
1792 ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
1793 unsupported_future("linsolve not supported by provider")
1794 }
1795 fn inv<'a>(
1796 &'a self,
1797 _matrix: &'a GpuTensorHandle,
1798 _options: ProviderInvOptions,
1799 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1800 unsupported_future("inv not supported by provider")
1801 }
1802 fn pinv<'a>(
1803 &'a self,
1804 _matrix: &'a GpuTensorHandle,
1805 _options: ProviderPinvOptions,
1806 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1807 unsupported_future("pinv not supported by provider")
1808 }
1809 fn cond<'a>(
1810 &'a self,
1811 _matrix: &'a GpuTensorHandle,
1812 _norm: ProviderCondNorm,
1813 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1814 Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
1815 }
1816 fn norm<'a>(
1817 &'a self,
1818 _tensor: &'a GpuTensorHandle,
1819 _order: ProviderNormOrder,
1820 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1821 Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
1822 }
1823 fn rank<'a>(
1824 &'a self,
1825 _matrix: &'a GpuTensorHandle,
1826 _tolerance: Option<f64>,
1827 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1828 Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
1829 }
1830 fn rcond<'a>(
1831 &'a self,
1832 _matrix: &'a GpuTensorHandle,
1833 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1834 Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
1835 }
1836 fn mldivide<'a>(
1837 &'a self,
1838 _lhs: &'a GpuTensorHandle,
1839 _rhs: &'a GpuTensorHandle,
1840 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1841 Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
1842 }
1843 fn mrdivide<'a>(
1844 &'a self,
1845 _lhs: &'a GpuTensorHandle,
1846 _rhs: &'a GpuTensorHandle,
1847 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1848 Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
1849 }
1850 fn eig<'a>(
1851 &'a self,
1852 _a: &'a GpuTensorHandle,
1853 _compute_left: bool,
1854 ) -> AccelProviderFuture<'a, ProviderEigResult> {
1855 Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
1856 }
1857 fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
1858 Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
1859 }
1860
1861 fn chol<'a>(
1862 &'a self,
1863 _a: &'a GpuTensorHandle,
1864 _lower: bool,
1865 ) -> AccelProviderFuture<'a, ProviderCholResult> {
1866 Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
1867 }
1868 fn qr<'a>(
1869 &'a self,
1870 _a: &'a GpuTensorHandle,
1871 _options: ProviderQrOptions,
1872 ) -> AccelProviderFuture<'a, ProviderQrResult> {
1873 Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
1874 }
1875 fn take_matmul_sources(
1876 &self,
1877 _product: &GpuTensorHandle,
1878 ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
1879 None
1880 }
1881 fn qr_power_iter<'a>(
1882 &'a self,
1883 product: &'a GpuTensorHandle,
1884 _product_lhs: Option<&'a GpuTensorHandle>,
1885 q_handle: &'a GpuTensorHandle,
1886 options: &'a ProviderQrOptions,
1887 ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
1888 let _ = (product, q_handle, options);
1889 Box::pin(async move { Ok(None) })
1890 }
1891 fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1892 Err(anyhow::anyhow!("transpose not supported by provider"))
1893 }
1894 fn conv1d(
1895 &self,
1896 _signal: &GpuTensorHandle,
1897 _kernel: &GpuTensorHandle,
1898 _options: ProviderConv1dOptions,
1899 ) -> anyhow::Result<GpuTensorHandle> {
1900 Err(anyhow::anyhow!("conv1d not supported by provider"))
1901 }
1902 fn conv2d(
1903 &self,
1904 _signal: &GpuTensorHandle,
1905 _kernel: &GpuTensorHandle,
1906 _mode: ProviderConvMode,
1907 ) -> anyhow::Result<GpuTensorHandle> {
1908 Err(anyhow::anyhow!("conv2d not supported by provider"))
1909 }
1910 fn iir_filter<'a>(
1911 &'a self,
1912 _b: &'a GpuTensorHandle,
1913 _a: &'a GpuTensorHandle,
1914 _x: &'a GpuTensorHandle,
1915 _options: ProviderIirFilterOptions,
1916 ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
1917 Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
1918 }
1919 fn permute(
1921 &self,
1922 _handle: &GpuTensorHandle,
1923 _order: &[usize],
1924 ) -> anyhow::Result<GpuTensorHandle> {
1925 Err(anyhow::anyhow!("permute not supported by provider"))
1926 }
1927 fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1928 Err(anyhow::anyhow!("flip not supported by provider"))
1929 }
1930 fn circshift(
1931 &self,
1932 _handle: &GpuTensorHandle,
1933 _shifts: &[isize],
1934 ) -> anyhow::Result<GpuTensorHandle> {
1935 Err(anyhow::anyhow!("circshift not supported by provider"))
1936 }
1937 fn diff_dim(
1938 &self,
1939 _handle: &GpuTensorHandle,
1940 _order: usize,
1941 _dim: usize,
1942 ) -> anyhow::Result<GpuTensorHandle> {
1943 Err(anyhow::anyhow!("diff_dim not supported by provider"))
1944 }
1945 fn fft_dim<'a>(
1947 &'a self,
1948 _handle: &'a GpuTensorHandle,
1949 _len: Option<usize>,
1950 _dim: usize,
1951 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1952 unsupported_future("fft_dim not supported by provider")
1953 }
1954 fn ifft_dim<'a>(
1955 &'a self,
1956 _handle: &'a GpuTensorHandle,
1957 _len: Option<usize>,
1958 _dim: usize,
1959 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1960 unsupported_future("ifft_dim not supported by provider")
1961 }
1962 fn fft_extract_real<'a>(
1963 &'a self,
1964 _handle: &'a GpuTensorHandle,
1965 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1966 unsupported_future("fft_extract_real not supported by provider")
1967 }
1968 fn unique<'a>(
1969 &'a self,
1970 _handle: &'a GpuTensorHandle,
1971 _options: &'a UniqueOptions,
1972 ) -> AccelProviderFuture<'a, UniqueResult> {
1973 Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
1974 }
1975 fn union<'a>(
1976 &'a self,
1977 _a: &'a GpuTensorHandle,
1978 _b: &'a GpuTensorHandle,
1979 _options: &'a UnionOptions,
1980 ) -> AccelProviderFuture<'a, UnionResult> {
1981 Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
1982 }
1983 fn setdiff<'a>(
1984 &'a self,
1985 _a: &'a GpuTensorHandle,
1986 _b: &'a GpuTensorHandle,
1987 _options: &'a SetdiffOptions,
1988 ) -> AccelProviderFuture<'a, SetdiffResult> {
1989 Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
1990 }
1991 fn ismember<'a>(
1992 &'a self,
1993 _a: &'a GpuTensorHandle,
1994 _b: &'a GpuTensorHandle,
1995 _options: &'a IsMemberOptions,
1996 ) -> AccelProviderFuture<'a, IsMemberResult> {
1997 Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
1998 }
1999 fn reshape(
2000 &self,
2001 handle: &GpuTensorHandle,
2002 new_shape: &[usize],
2003 ) -> anyhow::Result<GpuTensorHandle> {
2004 let mut updated = handle.clone();
2005 updated.shape = new_shape.to_vec();
2006 Ok(updated)
2007 }
2008 fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
2010 Err(anyhow::anyhow!("cat not supported by provider"))
2011 }
2012 fn repmat(
2013 &self,
2014 _handle: &GpuTensorHandle,
2015 _reps: &[usize],
2016 ) -> anyhow::Result<GpuTensorHandle> {
2017 Err(anyhow::anyhow!("repmat not supported by provider"))
2018 }
2019 fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2021 Err(anyhow::anyhow!("kron not supported by provider"))
2022 }
2023 fn reduce_sum<'a>(
2024 &'a self,
2025 _a: &'a GpuTensorHandle,
2026 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2027 unsupported_future("reduce_sum not supported by provider")
2028 }
2029 fn reduce_sum_dim<'a>(
2030 &'a self,
2031 _a: &'a GpuTensorHandle,
2032 _dim: usize,
2033 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2034 unsupported_future("reduce_sum_dim not supported by provider")
2035 }
2036 fn dot<'a>(
2037 &'a self,
2038 _lhs: &'a GpuTensorHandle,
2039 _rhs: &'a GpuTensorHandle,
2040 _dim: Option<usize>,
2041 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2042 unsupported_future("dot not supported by provider")
2043 }
2044 fn reduce_nnz<'a>(
2045 &'a self,
2046 _a: &'a GpuTensorHandle,
2047 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2048 unsupported_future("reduce_nnz not supported by provider")
2049 }
2050 fn reduce_nnz_dim<'a>(
2051 &'a self,
2052 _a: &'a GpuTensorHandle,
2053 _dim: usize,
2054 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2055 unsupported_future("reduce_nnz_dim not supported by provider")
2056 }
2057 fn reduce_prod<'a>(
2058 &'a self,
2059 _a: &'a GpuTensorHandle,
2060 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2061 unsupported_future("reduce_prod not supported by provider")
2062 }
2063 fn reduce_prod_dim<'a>(
2064 &'a self,
2065 _a: &'a GpuTensorHandle,
2066 _dim: usize,
2067 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2068 unsupported_future("reduce_prod_dim not supported by provider")
2069 }
2070 fn reduce_mean<'a>(
2071 &'a self,
2072 _a: &'a GpuTensorHandle,
2073 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2074 unsupported_future("reduce_mean not supported by provider")
2075 }
2076 fn reduce_mean_nd<'a>(
2078 &'a self,
2079 _a: &'a GpuTensorHandle,
2080 _dims_zero_based: &'a [usize],
2081 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2082 unsupported_future("reduce_mean_nd not supported by provider")
2083 }
2084 fn reduce_moments_nd<'a>(
2087 &'a self,
2088 _a: &'a GpuTensorHandle,
2089 _dims_zero_based: &'a [usize],
2090 ) -> AccelProviderFuture<'a, ProviderMoments2> {
2091 unsupported_future("reduce_moments_nd not supported by provider")
2092 }
2093 fn reduce_mean_dim<'a>(
2094 &'a self,
2095 _a: &'a GpuTensorHandle,
2096 _dim: usize,
2097 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2098 unsupported_future("reduce_mean_dim not supported by provider")
2099 }
2100 fn reduce_std<'a>(
2101 &'a self,
2102 _a: &'a GpuTensorHandle,
2103 _normalization: ProviderStdNormalization,
2104 _nan_mode: ProviderNanMode,
2105 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2106 unsupported_future("reduce_std not supported by provider")
2107 }
2108 fn reduce_std_dim<'a>(
2109 &'a self,
2110 _a: &'a GpuTensorHandle,
2111 _dim: usize,
2112 _normalization: ProviderStdNormalization,
2113 _nan_mode: ProviderNanMode,
2114 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2115 unsupported_future("reduce_std_dim not supported by provider")
2116 }
2117 fn reduce_any<'a>(
2118 &'a self,
2119 _a: &'a GpuTensorHandle,
2120 _omit_nan: bool,
2121 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2122 unsupported_future("reduce_any not supported by provider")
2123 }
2124 fn reduce_any_dim<'a>(
2125 &'a self,
2126 _a: &'a GpuTensorHandle,
2127 _dim: usize,
2128 _omit_nan: bool,
2129 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2130 unsupported_future("reduce_any_dim not supported by provider")
2131 }
2132 fn reduce_all<'a>(
2133 &'a self,
2134 _a: &'a GpuTensorHandle,
2135 _omit_nan: bool,
2136 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2137 unsupported_future("reduce_all not supported by provider")
2138 }
2139 fn reduce_all_dim<'a>(
2140 &'a self,
2141 _a: &'a GpuTensorHandle,
2142 _dim: usize,
2143 _omit_nan: bool,
2144 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2145 unsupported_future("reduce_all_dim not supported by provider")
2146 }
2147 fn reduce_median<'a>(
2148 &'a self,
2149 _a: &'a GpuTensorHandle,
2150 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2151 unsupported_future("reduce_median not supported by provider")
2152 }
2153 fn reduce_median_dim<'a>(
2154 &'a self,
2155 _a: &'a GpuTensorHandle,
2156 _dim: usize,
2157 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2158 unsupported_future("reduce_median_dim not supported by provider")
2159 }
2160 fn reduce_min<'a>(
2161 &'a self,
2162 _a: &'a GpuTensorHandle,
2163 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2164 unsupported_future("reduce_min not supported by provider")
2165 }
2166 fn reduce_min_dim<'a>(
2167 &'a self,
2168 _a: &'a GpuTensorHandle,
2169 _dim: usize,
2170 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2171 unsupported_future("reduce_min_dim not supported by provider")
2172 }
2173 fn reduce_max<'a>(
2174 &'a self,
2175 _a: &'a GpuTensorHandle,
2176 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2177 unsupported_future("reduce_max not supported by provider")
2178 }
2179 fn reduce_max_dim<'a>(
2180 &'a self,
2181 _a: &'a GpuTensorHandle,
2182 _dim: usize,
2183 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2184 unsupported_future("reduce_max_dim not supported by provider")
2185 }
2186 fn cumsum_scan(
2187 &self,
2188 _input: &GpuTensorHandle,
2189 _dim: usize,
2190 _direction: ProviderScanDirection,
2191 _nan_mode: ProviderNanMode,
2192 ) -> anyhow::Result<GpuTensorHandle> {
2193 Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2194 }
2195 fn cumprod_scan(
2196 &self,
2197 _input: &GpuTensorHandle,
2198 _dim: usize,
2199 _direction: ProviderScanDirection,
2200 _nan_mode: ProviderNanMode,
2201 ) -> anyhow::Result<GpuTensorHandle> {
2202 Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2203 }
2204 fn cummin_scan(
2205 &self,
2206 _input: &GpuTensorHandle,
2207 _dim: usize,
2208 _direction: ProviderScanDirection,
2209 _nan_mode: ProviderNanMode,
2210 ) -> anyhow::Result<ProviderCumminResult> {
2211 Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2212 }
2213 fn cummax_scan(
2214 &self,
2215 _input: &GpuTensorHandle,
2216 _dim: usize,
2217 _direction: ProviderScanDirection,
2218 _nan_mode: ProviderNanMode,
2219 ) -> anyhow::Result<ProviderCummaxResult> {
2220 Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2221 }
2222
2223 fn find(
2224 &self,
2225 _a: &GpuTensorHandle,
2226 _limit: Option<usize>,
2227 _direction: FindDirection,
2228 ) -> anyhow::Result<ProviderFindResult> {
2229 Err(anyhow::anyhow!("find not supported by provider"))
2230 }
2231
2232 fn fused_elementwise(
2233 &self,
2234 _shader: &str,
2235 _inputs: &[GpuTensorHandle],
2236 _output_shape: &[usize],
2237 _len: usize,
2238 ) -> anyhow::Result<GpuTensorHandle> {
2239 Err(anyhow::anyhow!(
2240 "fused_elementwise not supported by provider"
2241 ))
2242 }
2243
2244 fn fused_elementwise_multi(
2253 &self,
2254 _shader: &str,
2255 _inputs: &[GpuTensorHandle],
2256 _output_shape: &[usize],
2257 _len: usize,
2258 _num_outputs: usize,
2259 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2260 Err(anyhow::anyhow!(
2261 "fused_elementwise_multi not supported by provider"
2262 ))
2263 }
2264
2265 fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2267 Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2268 }
2269
2270 fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2272 Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2273 }
2274
2275 #[allow(clippy::too_many_arguments)]
2282 fn fused_reduction(
2283 &self,
2284 _shader: &str,
2285 _inputs: &[GpuTensorHandle],
2286 _output_shape: &[usize],
2287 _reduce_len: usize,
2288 _num_slices: usize,
2289 _workgroup_size: u32,
2290 _flavor: ReductionFlavor,
2291 ) -> anyhow::Result<GpuTensorHandle> {
2292 Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2293 }
2294
2295 fn warmup(&self) {}
2297
2298 fn fused_cache_counters(&self) -> (u64, u64) {
2300 (0, 0)
2301 }
2302
2303 fn last_warmup_millis(&self) -> Option<u64> {
2305 None
2306 }
2307
2308 fn telemetry_snapshot(&self) -> ProviderTelemetry {
2310 let (hits, misses) = self.fused_cache_counters();
2311 ProviderTelemetry {
2312 fused_elementwise: ProviderDispatchStats::default(),
2313 fused_reduction: ProviderDispatchStats::default(),
2314 matmul: ProviderDispatchStats::default(),
2315 linsolve: ProviderDispatchStats::default(),
2316 mldivide: ProviderDispatchStats::default(),
2317 mrdivide: ProviderDispatchStats::default(),
2318 upload_bytes: 0,
2319 download_bytes: 0,
2320 solve_fallbacks: Vec::new(),
2321 fusion_cache_hits: hits,
2322 fusion_cache_misses: misses,
2323 bind_group_cache_hits: 0,
2324 bind_group_cache_misses: 0,
2325 bind_group_cache_by_layout: None,
2326 kernel_launches: Vec::new(),
2327 }
2328 }
2329
2330 fn reset_telemetry(&self) {}
2332
2333 fn default_reduction_workgroup_size(&self) -> u32 {
2335 256
2336 }
2337
2338 fn two_pass_threshold(&self) -> usize {
2340 1024
2341 }
2342
2343 fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2345 ReductionTwoPassMode::Auto
2346 }
2347
2348 fn scatter_column(
2351 &self,
2352 _matrix: &GpuTensorHandle,
2353 _col_index: usize,
2354 _values: &GpuTensorHandle,
2355 ) -> anyhow::Result<GpuTensorHandle> {
2356 Err(anyhow::anyhow!("scatter_column not supported by provider"))
2357 }
2358
2359 fn scatter_row(
2362 &self,
2363 _matrix: &GpuTensorHandle,
2364 _row_index: usize,
2365 _values: &GpuTensorHandle,
2366 ) -> anyhow::Result<GpuTensorHandle> {
2367 Err(anyhow::anyhow!("scatter_row not supported by provider"))
2368 }
2369
2370 fn sub2ind(
2371 &self,
2372 _dims: &[usize],
2373 _strides: &[usize],
2374 _inputs: &[&GpuTensorHandle],
2375 _scalar_mask: &[bool],
2376 _len: usize,
2377 _output_shape: &[usize],
2378 ) -> anyhow::Result<GpuTensorHandle> {
2379 Err(anyhow::anyhow!("sub2ind not supported by provider"))
2380 }
2381
2382 fn supports_ind2sub(&self) -> bool {
2384 false
2385 }
2386
2387 fn ind2sub(
2389 &self,
2390 _dims: &[usize],
2391 _strides: &[usize],
2392 _indices: &GpuTensorHandle,
2393 _total: usize,
2394 _len: usize,
2395 _output_shape: &[usize],
2396 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2397 Err(anyhow::anyhow!("ind2sub not supported by provider"))
2398 }
2399
2400 fn issymmetric(
2402 &self,
2403 _matrix: &GpuTensorHandle,
2404 _kind: ProviderSymmetryKind,
2405 _tolerance: f64,
2406 ) -> anyhow::Result<bool> {
2407 Err(anyhow::anyhow!(
2408 "issymmetric predicate not supported by provider"
2409 ))
2410 }
2411
2412 fn ishermitian<'a>(
2414 &'a self,
2415 _matrix: &'a GpuTensorHandle,
2416 _kind: ProviderHermitianKind,
2417 _tolerance: f64,
2418 ) -> AccelProviderFuture<'a, bool> {
2419 Box::pin(async move {
2420 Err(anyhow::anyhow!(
2421 "ishermitian predicate not supported by provider"
2422 ))
2423 })
2424 }
2425
2426 fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2428 Err(anyhow::anyhow!("bandwidth not supported by provider"))
2429 }
2430
2431 fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2436 Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2437 }
2438}
2439
2440static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2441 Lazy::new(|| RwLock::new(None));
2442static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2443 Lazy::new(|| RwLock::new(HashMap::new()));
2444static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2445
2446#[cfg(not(target_arch = "wasm32"))]
2447thread_local! {
2448 static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2449}
2450
2451#[cfg(target_arch = "wasm32")]
2452static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2453 Lazy::new(|| Mutex::new(None));
2454
2455#[cfg(not(target_arch = "wasm32"))]
2456fn replace_thread_provider(
2457 provider: Option<&'static dyn AccelProvider>,
2458) -> Option<&'static dyn AccelProvider> {
2459 THREAD_PROVIDER.with(|cell| {
2460 let prev = cell.get();
2461 cell.set(provider);
2462 prev
2463 })
2464}
2465
2466#[cfg(target_arch = "wasm32")]
2467fn replace_thread_provider(
2468 provider: Option<&'static dyn AccelProvider>,
2469) -> Option<&'static dyn AccelProvider> {
2470 let mut slot = WASM_THREAD_PROVIDER
2471 .lock()
2472 .expect("wasm provider mutex poisoned");
2473 let prev = *slot;
2474 *slot = provider;
2475 prev
2476}
2477
2478#[cfg(not(target_arch = "wasm32"))]
2479fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2480 THREAD_PROVIDER.with(|cell| cell.get())
2481}
2482
2483#[cfg(target_arch = "wasm32")]
2484fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2485 WASM_THREAD_PROVIDER
2486 .lock()
2487 .expect("wasm provider mutex poisoned")
2488 .as_ref()
2489 .copied()
2490}
2491
2492pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2500 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2501 *guard = Some(p);
2502 }
2503 register_provider_for_device(p.device_id(), p);
2504}
2505
2506unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2507 if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2508 guard.insert(device_id, provider);
2509 }
2510}
2511
2512pub fn provider() -> Option<&'static dyn AccelProvider> {
2513 if let Some(p) = current_thread_provider() {
2514 return Some(p);
2515 }
2516 GLOBAL_PROVIDER
2517 .read()
2518 .ok()
2519 .and_then(|guard| guard.as_ref().copied())
2520}
2521
2522pub fn clear_provider() {
2524 replace_thread_provider(None);
2525 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2526 *guard = None;
2527 }
2528 if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2529 map.clear();
2530 }
2531}
2532
2533pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2534 PROVIDER_REGISTRY
2535 .read()
2536 .ok()
2537 .and_then(|guard| guard.get(&device_id).copied())
2538 .or_else(|| provider())
2539}
2540
2541pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2542 provider_for_device(handle.device_id)
2543}
2544
2545pub fn next_device_id() -> u32 {
2546 DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2547}
2548
2549pub struct ThreadProviderGuard {
2550 prev: Option<&'static dyn AccelProvider>,
2551}
2552
2553impl ThreadProviderGuard {
2554 pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2555 let prev = replace_thread_provider(provider);
2556 ThreadProviderGuard { prev }
2557 }
2558}
2559
2560impl Drop for ThreadProviderGuard {
2561 fn drop(&mut self) {
2562 let prev = self.prev.take();
2563 replace_thread_provider(prev);
2564 }
2565}
2566
2567pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2568 replace_thread_provider(provider);
2569}
2570
2571pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2573 if let Some(p) = provider() {
2574 if let Ok(h) = p.elem_add(a, b).await {
2575 return Some(h);
2576 }
2577 }
2578 None
2579}
2580
2581pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2583 if let Some(p) = provider() {
2584 if let Ok(h) = p.elem_hypot(a, b).await {
2585 return Some(h);
2586 }
2587 }
2588 None
2589}
2590
2591pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2593 if let Some(p) = provider() {
2594 if let Ok(h) = p.elem_max(a, b).await {
2595 return Some(h);
2596 }
2597 }
2598 None
2599}
2600
2601pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2603 if let Some(p) = provider() {
2604 if let Ok(h) = p.elem_min(a, b).await {
2605 return Some(h);
2606 }
2607 }
2608 None
2609}
2610
2611pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2613 if let Some(p) = provider() {
2614 if let Ok(h) = p.elem_atan2(y, x).await {
2615 return Some(h);
2616 }
2617 }
2618 None
2619}
2620
2621#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2623pub struct HostTensorOwned {
2624 pub data: Vec<f64>,
2625 pub shape: Vec<usize>,
2626 pub storage: GpuTensorStorage,
2627}
2628
2629#[derive(Debug)]
2630pub struct HostTensorView<'a> {
2631 pub data: &'a [f64],
2632 pub shape: &'a [usize],
2633}
2634
2635#[derive(Debug)]
2637pub struct MeshgridAxisView<'a> {
2638 pub data: &'a [f64],
2639}
2640
2641#[derive(Debug, Clone)]
2643pub struct ProviderMeshgridResult {
2644 pub outputs: Vec<GpuTensorHandle>,
2645}
2646
2647#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
2653pub enum ScaleOp {
2654 Multiply,
2655 Divide,
2656}
2657
2658#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2659pub struct MatmulEpilogue {
2660 pub alpha: f64,
2662 pub beta: f64,
2664 pub row_scale: Option<GpuTensorHandle>,
2666 pub col_scale: Option<GpuTensorHandle>,
2668 pub row_op: ScaleOp,
2670 pub col_op: ScaleOp,
2672 #[serde(default)]
2674 pub clamp_min: Option<f64>,
2675 #[serde(default)]
2677 pub clamp_max: Option<f64>,
2678 #[serde(default)]
2680 pub pow_exponent: Option<f64>,
2681 #[serde(default)]
2683 pub diag_output: Option<GpuTensorHandle>,
2684}
2685
2686impl MatmulEpilogue {
2687 pub fn noop() -> Self {
2688 Self {
2689 alpha: 1.0,
2690 beta: 0.0,
2691 row_scale: None,
2692 col_scale: None,
2693 row_op: ScaleOp::Multiply,
2694 col_op: ScaleOp::Multiply,
2695 clamp_min: None,
2696 clamp_max: None,
2697 pow_exponent: None,
2698 diag_output: None,
2699 }
2700 }
2701 pub fn is_noop(&self) -> bool {
2702 self.alpha == 1.0
2703 && self.beta == 0.0
2704 && self.row_scale.is_none()
2705 && self.col_scale.is_none()
2706 && self.clamp_min.is_none()
2707 && self.clamp_max.is_none()
2708 && self.pow_exponent.is_none()
2709 && self.diag_output.is_none()
2710 }
2711}
2712
2713#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
2714pub struct PowerStepEpilogue {
2715 pub epsilon: f64,
2716}
2717
2718impl Default for PowerStepEpilogue {
2719 fn default() -> Self {
2720 Self { epsilon: 0.0 }
2721 }
2722}
2723
2724#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2725pub struct ImageNormalizeDescriptor {
2726 pub batch: usize,
2727 pub height: usize,
2728 pub width: usize,
2729 pub epsilon: f64,
2730 #[serde(default)]
2731 pub gain: Option<f64>,
2732 #[serde(default)]
2733 pub bias: Option<f64>,
2734 #[serde(default)]
2735 pub gamma: Option<f64>,
2736}