1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyMarkFn = fn(&GpuTensorHandle);
17type ResidencyClearFn = fn(&GpuTensorHandle);
18type SequenceThresholdFn = fn() -> Option<usize>;
19type WorkgroupSizeHintFn = fn() -> Option<u32>;
20
21static RESIDENCY_MARK: OnceCell<ResidencyMarkFn> = OnceCell::new();
22static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
23static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
24static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
25
26static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
27static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
28 Lazy::new(|| RwLock::new(HashMap::new()));
29static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
30 Lazy::new(|| RwLock::new(HashMap::new()));
31
32static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
33 Lazy::new(|| RwLock::new(HashMap::new()));
34static HANDLE_STORAGES: Lazy<RwLock<HashMap<u64, GpuTensorStorage>>> =
35 Lazy::new(|| RwLock::new(HashMap::new()));
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct TransposeInfo {
39 pub base_rows: usize,
40 pub base_cols: usize,
41}
42
43pub fn register_residency_mark(handler: ResidencyMarkFn) {
46 let _ = RESIDENCY_MARK.set(handler);
47}
48
49pub fn mark_residency(handle: &GpuTensorHandle) {
52 if let Some(handler) = RESIDENCY_MARK.get() {
53 handler(handle);
54 }
55}
56
57pub fn register_residency_clear(handler: ResidencyClearFn) {
61 let _ = RESIDENCY_CLEAR.set(handler);
62}
63
64pub fn clear_residency(handle: &GpuTensorHandle) {
67 if let Some(handler) = RESIDENCY_CLEAR.get() {
68 handler(handle);
69 }
70}
71
72pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
76 let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
77}
78
79pub fn sequence_threshold_hint() -> Option<usize> {
81 SEQUENCE_THRESHOLD_PROVIDER
82 .get()
83 .and_then(|provider| provider())
84}
85
86pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
90 let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
91}
92
93pub fn workgroup_size_hint() -> Option<u32> {
95 WORKGROUP_SIZE_HINT_PROVIDER
96 .get()
97 .and_then(|provider| provider())
98}
99
100pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
103 provider().and_then(|p| p.export_context(kind))
104}
105
106#[cfg(feature = "wgpu")]
110pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
111 provider().and_then(|p| p.export_wgpu_buffer(handle))
112}
113
114pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
117 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
118 guard.insert(handle.buffer_id, precision);
119 }
120}
121
122pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
124 HANDLE_PRECISIONS
125 .read()
126 .ok()
127 .and_then(|guard| guard.get(&handle.buffer_id).copied())
128}
129
130pub fn clear_handle_precision(handle: &GpuTensorHandle) {
132 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
133 guard.remove(&handle.buffer_id);
134 }
135}
136
137pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
140 if let Ok(mut guard) = LOGICAL_HANDLES.write() {
141 if logical {
142 guard.insert(handle.buffer_id);
143 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
144 *hits.entry(handle.buffer_id).or_insert(0) += 1;
145 }
146 } else {
147 guard.remove(&handle.buffer_id);
148 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
149 hits.remove(&handle.buffer_id);
150 }
151 }
152 }
153}
154
155pub fn clear_handle_logical(handle: &GpuTensorHandle) {
157 set_handle_logical(handle, false);
158}
159
160pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
162 LOGICAL_HANDLES
163 .read()
164 .map(|guard| guard.contains(&handle.buffer_id))
165 .unwrap_or(false)
166}
167
168pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
169 LOGICAL_HANDLE_HITS
170 .read()
171 .ok()
172 .and_then(|guard| guard.get(&buffer_id).copied())
173}
174
175pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
176 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
177 guard.insert(
178 handle.buffer_id,
179 TransposeInfo {
180 base_rows,
181 base_cols,
182 },
183 );
184 }
185}
186
187pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
188 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
189 guard.remove(&handle.buffer_id);
190 }
191}
192
193pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
194 TRANSPOSED_HANDLES
195 .read()
196 .ok()
197 .and_then(|guard| guard.get(&handle.buffer_id).copied())
198}
199
200pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
201 handle_transpose_info(handle).is_some()
202}
203
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum GpuTensorStorage {
206 Real,
207 ComplexInterleaved,
208}
209
210impl Default for GpuTensorStorage {
211 fn default() -> Self {
212 Self::Real
213 }
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct GpuTensorHandle {
218 pub shape: Vec<usize>,
219 pub device_id: u32,
220 pub buffer_id: u64,
221}
222
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct ApiDeviceInfo {
225 pub device_id: u32,
226 pub name: String,
227 pub vendor: String,
228 pub memory_bytes: Option<u64>,
229 pub backend: Option<String>,
230}
231
232#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
233pub struct ReduceDimResult {
234 pub values: GpuTensorHandle,
235 pub indices: GpuTensorHandle,
236}
237
238#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
239pub struct ProviderCumminResult {
240 pub values: GpuTensorHandle,
241 pub indices: GpuTensorHandle,
242}
243
244pub type ProviderCummaxResult = ProviderCumminResult;
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum AccelContextKind {
253 Plotting,
254}
255
256#[derive(Clone)]
258pub enum AccelContextHandle {
259 #[cfg(feature = "wgpu")]
260 Wgpu(WgpuContextHandle),
261}
262
263impl AccelContextHandle {
264 #[cfg(feature = "wgpu")]
266 pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
267 match self {
268 AccelContextHandle::Wgpu(ctx) => Some(ctx),
269 }
270 }
271}
272
273#[cfg(feature = "wgpu")]
275#[derive(Clone)]
276pub struct WgpuContextHandle {
277 pub instance: Arc<wgpu::Instance>,
278 pub device: Arc<wgpu::Device>,
279 pub queue: Arc<wgpu::Queue>,
280 pub adapter: Arc<wgpu::Adapter>,
281 pub adapter_info: wgpu::AdapterInfo,
282 pub limits: wgpu::Limits,
283 pub features: wgpu::Features,
284}
285
286#[cfg(feature = "wgpu")]
288#[derive(Clone)]
289pub struct WgpuBufferRef {
290 pub buffer: Arc<wgpu::Buffer>,
291 pub len: usize,
292 pub shape: Vec<usize>,
293 pub element_size: usize,
294 pub precision: ProviderPrecision,
295}
296
297pub fn set_handle_storage(handle: &GpuTensorHandle, storage: GpuTensorStorage) {
298 if let Ok(mut guard) = HANDLE_STORAGES.write() {
299 guard.insert(handle.buffer_id, storage);
300 }
301}
302
303pub fn handle_storage(handle: &GpuTensorHandle) -> GpuTensorStorage {
304 HANDLE_STORAGES
305 .read()
306 .ok()
307 .and_then(|guard| guard.get(&handle.buffer_id).cloned())
308 .unwrap_or(GpuTensorStorage::Real)
309}
310
311pub fn clear_handle_storage(handle: &GpuTensorHandle) {
312 if let Ok(mut guard) = HANDLE_STORAGES.write() {
313 guard.remove(&handle.buffer_id);
314 }
315}
316
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
318pub enum PagefunOp {
319 Mtimes,
320}
321
322#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
323pub struct PagefunRequest {
324 pub op: PagefunOp,
325 pub inputs: Vec<GpuTensorHandle>,
326 pub output_shape: Vec<usize>,
327 pub page_dims: Vec<usize>,
328 pub input_page_dims: Vec<Vec<usize>>,
329}
330
331#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
332pub enum FindDirection {
333 First,
334 Last,
335}
336
337#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
338pub struct ProviderFindResult {
339 pub linear: GpuTensorHandle,
340 pub rows: GpuTensorHandle,
341 pub cols: GpuTensorHandle,
342 pub values: Option<GpuTensorHandle>,
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
346pub struct ProviderBandwidth {
347 pub lower: u32,
348 pub upper: u32,
349}
350
351#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
352pub enum ProviderSymmetryKind {
353 Symmetric,
354 Skew,
355}
356
357#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
358pub enum ProviderHermitianKind {
359 Hermitian,
360 Skew,
361}
362
363#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
364pub struct ProviderLuResult {
365 pub combined: GpuTensorHandle,
366 pub lower: GpuTensorHandle,
367 pub upper: GpuTensorHandle,
368 pub perm_matrix: GpuTensorHandle,
369 pub perm_vector: GpuTensorHandle,
370}
371
372#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
373pub struct ProviderCholResult {
374 pub factor: GpuTensorHandle,
375 pub info: u32,
377}
378
379#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
380pub struct ProviderQrResult {
381 pub q: GpuTensorHandle,
382 pub r: GpuTensorHandle,
383 pub perm_matrix: GpuTensorHandle,
384 pub perm_vector: GpuTensorHandle,
385}
386
387#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
388pub struct ProviderQrPowerIterResult {
389 pub q: GpuTensorHandle,
390 pub r: GpuTensorHandle,
391 pub perm_matrix: GpuTensorHandle,
392 pub perm_vector: GpuTensorHandle,
393}
394
395#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
396pub struct ProviderLinsolveOptions {
397 pub lower: bool,
398 pub upper: bool,
399 pub rectangular: bool,
400 pub transposed: bool,
401 pub conjugate: bool,
402 pub symmetric: bool,
403 pub posdef: bool,
404 pub need_rcond: bool,
405 pub rcond: Option<f64>,
406}
407
408#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
409pub struct ProviderLinsolveResult {
410 pub solution: GpuTensorHandle,
411 pub reciprocal_condition: f64,
412}
413
414#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
415pub struct ProviderPinvOptions {
416 pub tolerance: Option<f64>,
417}
418
419#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
420pub struct ProviderPolyvalMu {
421 pub mean: f64,
422 pub scale: f64,
423}
424
425#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
426pub struct ProviderPolyvalOptions {
427 pub mu: Option<ProviderPolyvalMu>,
428}
429
430#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
431pub struct ProviderInvOptions {}
432
433#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
434pub struct ProviderPolyfitResult {
435 pub coefficients: Vec<f64>,
436 pub r_matrix: Vec<f64>,
437 pub normr: f64,
438 pub df: f64,
439 pub mu: [f64; 2],
440}
441
442#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
444pub struct ProviderPolyderQuotient {
445 pub numerator: GpuTensorHandle,
446 pub denominator: GpuTensorHandle,
447}
448
449#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
451pub enum ProviderCondNorm {
452 Two,
453 One,
454 Inf,
455 Fro,
456}
457
458#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
460pub enum ProviderNormOrder {
461 Two,
462 One,
463 Inf,
464 NegInf,
465 Zero,
466 Fro,
467 Nuc,
468 P(f64),
469}
470
471#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
472pub struct ProviderEigResult {
473 pub eigenvalues: GpuTensorHandle,
474 pub diagonal: GpuTensorHandle,
475 pub right: GpuTensorHandle,
476 pub left: Option<GpuTensorHandle>,
477}
478
479#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
480pub enum ProviderQrPivot {
481 Matrix,
482 Vector,
483}
484
485#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
486pub struct ProviderQrOptions {
487 pub economy: bool,
488 pub pivot: ProviderQrPivot,
489}
490
491impl Default for ProviderQrOptions {
492 fn default() -> Self {
493 Self {
494 economy: false,
495 pivot: ProviderQrPivot::Matrix,
496 }
497 }
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
501pub enum ProviderPrecision {
502 F32,
503 F64,
504}
505
506#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
507pub enum ReductionTwoPassMode {
508 Auto,
509 ForceOn,
510 ForceOff,
511}
512
513impl ReductionTwoPassMode {
514 pub fn as_str(self) -> &'static str {
515 match self {
516 ReductionTwoPassMode::Auto => "auto",
517 ReductionTwoPassMode::ForceOn => "force_on",
518 ReductionTwoPassMode::ForceOff => "force_off",
519 }
520 }
521}
522
523#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
524pub enum ReductionFlavor {
525 Sum,
526 Mean,
527 CustomScale(f64),
528}
529
530impl ReductionFlavor {
531 pub fn is_mean(self) -> bool {
532 matches!(self, ReductionFlavor::Mean)
533 }
534
535 pub fn scale(self, reduce_len: usize) -> f64 {
536 match self {
537 ReductionFlavor::Sum => 1.0,
538 ReductionFlavor::Mean => {
539 if reduce_len == 0 {
540 1.0
541 } else {
542 1.0 / reduce_len as f64
543 }
544 }
545 ReductionFlavor::CustomScale(scale) => scale,
546 }
547 }
548}
549
550#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
552pub enum CorrcoefNormalization {
553 Unbiased,
554 Biased,
555}
556
557#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
559pub enum CorrcoefRows {
560 All,
561 Complete,
562 Pairwise,
563}
564
565#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
567pub struct CorrcoefOptions {
568 pub normalization: CorrcoefNormalization,
569 pub rows: CorrcoefRows,
570}
571
572impl Default for CorrcoefOptions {
573 fn default() -> Self {
574 Self {
575 normalization: CorrcoefNormalization::Unbiased,
576 rows: CorrcoefRows::All,
577 }
578 }
579}
580
581#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
583pub enum CovNormalization {
584 Unbiased,
585 Biased,
586}
587
588#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
590pub enum CovRows {
591 All,
592 OmitRows,
593 PartialRows,
594}
595
596#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
598pub struct CovarianceOptions {
599 pub normalization: CovNormalization,
600 pub rows: CovRows,
601 pub has_weight_vector: bool,
602}
603
604impl Default for CovarianceOptions {
605 fn default() -> Self {
606 Self {
607 normalization: CovNormalization::Unbiased,
608 rows: CovRows::All,
609 has_weight_vector: false,
610 }
611 }
612}
613
614#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
616pub enum ProviderStdNormalization {
617 Sample,
618 Population,
619}
620
621#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
623pub enum ProviderNanMode {
624 Include,
625 Omit,
626}
627
628#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
630pub enum ProviderScanDirection {
631 Forward,
632 Reverse,
633}
634
635#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
637pub enum SortOrder {
638 Ascend,
639 Descend,
640}
641
642#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
644pub enum SortComparison {
645 Auto,
646 Real,
647 Abs,
648}
649
650#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
652pub struct SortResult {
653 pub values: HostTensorOwned,
654 pub indices: HostTensorOwned,
655}
656
657#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
658pub struct SortRowsColumnSpec {
659 pub index: usize,
660 pub order: SortOrder,
661}
662
663#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
665pub enum UniqueOrder {
666 Sorted,
667 Stable,
668}
669
670#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
672pub enum UniqueOccurrence {
673 First,
674 Last,
675}
676
677#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
679pub struct UniqueOptions {
680 pub rows: bool,
681 pub order: UniqueOrder,
682 pub occurrence: UniqueOccurrence,
683}
684
685#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
687pub struct UniqueResult {
688 pub values: HostTensorOwned,
689 pub ia: HostTensorOwned,
690 pub ic: HostTensorOwned,
691}
692
693#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
695pub enum UnionOrder {
696 Sorted,
697 Stable,
698}
699
700#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
702pub struct UnionOptions {
703 pub rows: bool,
704 pub order: UnionOrder,
705}
706
707#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
709pub struct UnionResult {
710 pub values: HostTensorOwned,
711 pub ia: HostTensorOwned,
712 pub ib: HostTensorOwned,
713}
714
715#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
717pub enum FspecialFilter {
718 Average {
719 rows: u32,
720 cols: u32,
721 },
722 Disk {
723 radius: f64,
724 size: u32,
725 },
726 Gaussian {
727 rows: u32,
728 cols: u32,
729 sigma: f64,
730 },
731 Laplacian {
732 alpha: f64,
733 },
734 Log {
735 rows: u32,
736 cols: u32,
737 sigma: f64,
738 },
739 Motion {
740 length: u32,
741 kernel_size: u32,
742 angle_degrees: f64,
743 oversample: u32,
744 },
745 Prewitt,
746 Sobel,
747 Unsharp {
748 alpha: f64,
749 },
750}
751
752#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
754pub struct FspecialRequest {
755 pub filter: FspecialFilter,
756}
757
758#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
760pub enum ImfilterPadding {
761 Constant,
762 Replicate,
763 Symmetric,
764 Circular,
765}
766
767#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
769pub enum ImfilterShape {
770 Same,
771 Full,
772 Valid,
773}
774
775#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
777pub enum ImfilterMode {
778 Correlation,
779 Convolution,
780}
781
782#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
784pub struct ImfilterOptions {
785 pub padding: ImfilterPadding,
786 pub constant_value: f64,
787 pub shape: ImfilterShape,
788 pub mode: ImfilterMode,
789}
790
791impl Default for ImfilterOptions {
792 fn default() -> Self {
793 Self {
794 padding: ImfilterPadding::Constant,
795 constant_value: 0.0,
796 shape: ImfilterShape::Same,
797 mode: ImfilterMode::Correlation,
798 }
799 }
800}
801
802#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
804pub enum SetdiffOrder {
805 Sorted,
806 Stable,
807}
808
809#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
811pub struct SetdiffOptions {
812 pub rows: bool,
813 pub order: SetdiffOrder,
814}
815
816#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
818pub struct SetdiffResult {
819 pub values: HostTensorOwned,
820 pub ia: HostTensorOwned,
821}
822
823#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
825pub struct IsMemberOptions {
826 pub rows: bool,
827}
828
829#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
831pub struct HostLogicalOwned {
832 pub data: Vec<u8>,
833 pub shape: Vec<usize>,
834}
835
836#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
838pub struct IsMemberResult {
839 pub mask: HostLogicalOwned,
840 pub loc: HostTensorOwned,
841}
842
843#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
844pub enum ProviderConvMode {
845 Full,
846 Same,
847 Valid,
848}
849
850#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
851pub enum ProviderConvOrientation {
852 Row,
853 Column,
854}
855
856#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
857pub struct ProviderConv1dOptions {
858 pub mode: ProviderConvMode,
859 pub orientation: ProviderConvOrientation,
860}
861
862#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
863pub struct ProviderIirFilterOptions {
864 pub dim: usize,
866 pub zi: Option<GpuTensorHandle>,
868}
869
870#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
871pub struct ProviderIirFilterResult {
872 pub output: GpuTensorHandle,
874 pub final_state: Option<GpuTensorHandle>,
876}
877
878#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
879pub struct ProviderMoments2 {
880 pub mean: GpuTensorHandle,
881 pub ex2: GpuTensorHandle,
882}
883
884#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
885pub struct ProviderDispatchStats {
886 pub count: u64,
888 pub total_wall_time_ns: u64,
890}
891
892#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
893pub struct ProviderFallbackStat {
894 pub reason: String,
895 pub count: u64,
896}
897
898#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
899pub struct ProviderTelemetry {
900 pub fused_elementwise: ProviderDispatchStats,
901 pub fused_reduction: ProviderDispatchStats,
902 pub matmul: ProviderDispatchStats,
903 pub linsolve: ProviderDispatchStats,
904 pub mldivide: ProviderDispatchStats,
905 pub mrdivide: ProviderDispatchStats,
906 pub upload_bytes: u64,
907 pub download_bytes: u64,
908 pub solve_fallbacks: Vec<ProviderFallbackStat>,
909 pub fusion_cache_hits: u64,
910 pub fusion_cache_misses: u64,
911 pub bind_group_cache_hits: u64,
912 pub bind_group_cache_misses: u64,
913 pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
915 pub kernel_launches: Vec<KernelLaunchTelemetry>,
917}
918
919#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
920pub struct BindGroupLayoutTelemetry {
921 pub tag: String,
922 pub hits: u64,
923 pub misses: u64,
924}
925
926#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
927pub struct KernelAttrTelemetry {
928 pub key: String,
929 pub value: u64,
930}
931
932#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
933pub struct KernelLaunchTelemetry {
934 pub kernel: String,
935 pub precision: Option<String>,
936 pub shape: Vec<KernelAttrTelemetry>,
937 pub tuning: Vec<KernelAttrTelemetry>,
938}
939
940pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
941pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
942
943fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
944 Box::pin(async move { Err(anyhow::anyhow!(message)) })
945}
946
947pub trait AccelProvider: Send + Sync {
949 fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
950 fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
951 fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
952 fn device_info(&self) -> String;
953 fn device_id(&self) -> u32 {
954 0
955 }
956
957 fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
960 None
961 }
962
963 #[cfg(feature = "wgpu")]
965 fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
966 let _ = _handle;
967 None
968 }
969
970 fn gather_linear(
973 &self,
974 _source: &GpuTensorHandle,
975 _indices: &[u32],
976 _output_shape: &[usize],
977 ) -> anyhow::Result<GpuTensorHandle> {
978 Err(anyhow::anyhow!("gather_linear not supported by provider"))
979 }
980
981 fn scatter_linear(
985 &self,
986 _target: &GpuTensorHandle,
987 _indices: &[u32],
988 _values: &GpuTensorHandle,
989 ) -> anyhow::Result<()> {
990 Err(anyhow::anyhow!("scatter_linear not supported by provider"))
991 }
992
993 fn device_info_struct(&self) -> ApiDeviceInfo {
995 ApiDeviceInfo {
996 device_id: 0,
997 name: self.device_info(),
998 vendor: String::new(),
999 memory_bytes: None,
1000 backend: None,
1001 }
1002 }
1003
1004 fn precision(&self) -> ProviderPrecision {
1005 ProviderPrecision::F64
1006 }
1007
1008 fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
1010 Err(anyhow::anyhow!("read_scalar not supported by provider"))
1011 }
1012
1013 fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1015 Err(anyhow::anyhow!("zeros not supported by provider"))
1016 }
1017
1018 fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1020 Err(anyhow::anyhow!("ones not supported by provider"))
1021 }
1022
1023 fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1025 self.zeros(&prototype.shape)
1026 }
1027
1028 fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
1030 if value == 0.0 {
1031 return self.zeros(shape);
1032 }
1033 if let Ok(base) = self.zeros(shape) {
1034 match self.scalar_add(&base, value) {
1035 Ok(out) => {
1036 let _ = self.free(&base);
1037 return Ok(out);
1038 }
1039 Err(_) => {
1040 let _ = self.free(&base);
1041 }
1042 }
1043 }
1044 let len: usize = shape.iter().copied().product();
1045 let data = vec![value; len];
1046 let view = HostTensorView { data: &data, shape };
1047 self.upload(&view)
1048 }
1049
1050 fn fill_like(
1052 &self,
1053 prototype: &GpuTensorHandle,
1054 value: f64,
1055 ) -> anyhow::Result<GpuTensorHandle> {
1056 if value == 0.0 {
1057 return self.zeros_like(prototype);
1058 }
1059 if let Ok(base) = self.zeros_like(prototype) {
1060 match self.scalar_add(&base, value) {
1061 Ok(out) => {
1062 let _ = self.free(&base);
1063 return Ok(out);
1064 }
1065 Err(_) => {
1066 let _ = self.free(&base);
1067 }
1068 }
1069 }
1070 self.fill(&prototype.shape, value)
1071 }
1072
1073 fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1075 self.ones(&prototype.shape)
1076 }
1077
1078 fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1080 Err(anyhow::anyhow!("eye not supported by provider"))
1081 }
1082
1083 fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1085 self.eye(&prototype.shape)
1086 }
1087
1088 fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1090 Err(anyhow::anyhow!("meshgrid not supported by provider"))
1091 }
1092
1093 fn diag_from_vector(
1095 &self,
1096 _vector: &GpuTensorHandle,
1097 _offset: isize,
1098 ) -> anyhow::Result<GpuTensorHandle> {
1099 Err(anyhow::anyhow!(
1100 "diag_from_vector not supported by provider"
1101 ))
1102 }
1103
1104 fn diag_extract(
1106 &self,
1107 _matrix: &GpuTensorHandle,
1108 _offset: isize,
1109 ) -> anyhow::Result<GpuTensorHandle> {
1110 Err(anyhow::anyhow!("diag_extract not supported by provider"))
1111 }
1112
1113 fn tril<'a>(
1115 &'a self,
1116 _matrix: &'a GpuTensorHandle,
1117 _offset: isize,
1118 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1119 Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1120 }
1121
1122 fn triu<'a>(
1124 &'a self,
1125 _matrix: &'a GpuTensorHandle,
1126 _offset: isize,
1127 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1128 Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1129 }
1130
1131 fn polyval(
1133 &self,
1134 _coefficients: &GpuTensorHandle,
1135 _points: &GpuTensorHandle,
1136 _options: &ProviderPolyvalOptions,
1137 ) -> anyhow::Result<GpuTensorHandle> {
1138 Err(anyhow::anyhow!("polyval not supported by provider"))
1139 }
1140
1141 fn polyfit<'a>(
1143 &'a self,
1144 _x: &'a GpuTensorHandle,
1145 _y: &'a GpuTensorHandle,
1146 _degree: usize,
1147 _weights: Option<&'a GpuTensorHandle>,
1148 ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1149 Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1150 }
1151
1152 fn polyder_single<'a>(
1154 &'a self,
1155 _polynomial: &'a GpuTensorHandle,
1156 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1157 Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1158 }
1159
1160 fn polyder_product<'a>(
1162 &'a self,
1163 _p: &'a GpuTensorHandle,
1164 _q: &'a GpuTensorHandle,
1165 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1166 Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1167 }
1168
1169 fn polyder_quotient<'a>(
1171 &'a self,
1172 _u: &'a GpuTensorHandle,
1173 _v: &'a GpuTensorHandle,
1174 ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1175 Box::pin(async move {
1176 Err(anyhow::anyhow!(
1177 "polyder_quotient not supported by provider"
1178 ))
1179 })
1180 }
1181
1182 fn polyint(
1184 &self,
1185 _polynomial: &GpuTensorHandle,
1186 _constant: f64,
1187 ) -> anyhow::Result<GpuTensorHandle> {
1188 Err(anyhow::anyhow!("polyint not supported by provider"))
1189 }
1190
1191 fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1193 Err(anyhow::anyhow!("random_uniform not supported by provider"))
1194 }
1195
1196 fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1198 self.random_uniform(&prototype.shape)
1199 }
1200
1201 fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1203 Err(anyhow::anyhow!("random_normal not supported by provider"))
1204 }
1205
1206 fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1208 self.random_normal(&prototype.shape)
1209 }
1210
1211 fn stochastic_evolution(
1212 &self,
1213 _state: &GpuTensorHandle,
1214 _drift: f64,
1215 _scale: f64,
1216 _steps: u32,
1217 ) -> anyhow::Result<GpuTensorHandle> {
1218 Err(anyhow::anyhow!(
1219 "stochastic_evolution not supported by provider"
1220 ))
1221 }
1222
1223 fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1225 Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1226 }
1227
1228 fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1230 Err(anyhow::anyhow!("fspecial not supported by provider"))
1231 }
1232
1233 fn peaks(&self, _n: usize) -> anyhow::Result<GpuTensorHandle> {
1236 Err(anyhow::anyhow!("peaks not supported by provider"))
1237 }
1238
1239 fn peaks_xy(
1242 &self,
1243 _x: &GpuTensorHandle,
1244 _y: &GpuTensorHandle,
1245 ) -> anyhow::Result<GpuTensorHandle> {
1246 Err(anyhow::anyhow!("peaks_xy not supported by provider"))
1247 }
1248
1249 fn hann_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1250 Err(anyhow::anyhow!("hann_window not supported by provider"))
1251 }
1252
1253 fn hamming_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1254 Err(anyhow::anyhow!("hamming_window not supported by provider"))
1255 }
1256
1257 fn blackman_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1258 Err(anyhow::anyhow!("blackman_window not supported by provider"))
1259 }
1260
1261 fn imfilter<'a>(
1263 &'a self,
1264 _image: &'a GpuTensorHandle,
1265 _kernel: &'a GpuTensorHandle,
1266 _options: &'a ImfilterOptions,
1267 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1268 unsupported_future("imfilter not supported by provider")
1269 }
1270
1271 fn random_integer_range(
1273 &self,
1274 _lower: i64,
1275 _upper: i64,
1276 _shape: &[usize],
1277 ) -> anyhow::Result<GpuTensorHandle> {
1278 Err(anyhow::anyhow!(
1279 "random_integer_range not supported by provider"
1280 ))
1281 }
1282
1283 fn random_integer_like(
1285 &self,
1286 prototype: &GpuTensorHandle,
1287 lower: i64,
1288 upper: i64,
1289 ) -> anyhow::Result<GpuTensorHandle> {
1290 self.random_integer_range(lower, upper, &prototype.shape)
1291 }
1292
1293 fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1295 Err(anyhow!("random_permutation not supported by provider"))
1296 }
1297
1298 fn random_permutation_like(
1300 &self,
1301 _prototype: &GpuTensorHandle,
1302 n: usize,
1303 k: usize,
1304 ) -> anyhow::Result<GpuTensorHandle> {
1305 self.random_permutation(n, k)
1306 }
1307
1308 fn covariance<'a>(
1310 &'a self,
1311 _matrix: &'a GpuTensorHandle,
1312 _second: Option<&'a GpuTensorHandle>,
1313 _weights: Option<&'a GpuTensorHandle>,
1314 _options: &'a CovarianceOptions,
1315 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1316 unsupported_future("covariance not supported by provider")
1317 }
1318
1319 fn corrcoef<'a>(
1321 &'a self,
1322 _matrix: &'a GpuTensorHandle,
1323 _options: &'a CorrcoefOptions,
1324 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1325 unsupported_future("corrcoef not supported by provider")
1326 }
1327
1328 fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1330 Err(anyhow::anyhow!("linspace not supported by provider"))
1331 }
1332 fn elem_add<'a>(
1333 &'a self,
1334 _a: &'a GpuTensorHandle,
1335 _b: &'a GpuTensorHandle,
1336 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1337 unsupported_future("elem_add not supported by provider")
1338 }
1339 fn elem_mul<'a>(
1340 &'a self,
1341 _a: &'a GpuTensorHandle,
1342 _b: &'a GpuTensorHandle,
1343 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1344 unsupported_future("elem_mul not supported by provider")
1345 }
1346 fn elem_max<'a>(
1347 &'a self,
1348 _a: &'a GpuTensorHandle,
1349 _b: &'a GpuTensorHandle,
1350 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1351 unsupported_future("elem_max not supported by provider")
1352 }
1353 fn elem_min<'a>(
1354 &'a self,
1355 _a: &'a GpuTensorHandle,
1356 _b: &'a GpuTensorHandle,
1357 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1358 unsupported_future("elem_min not supported by provider")
1359 }
1360 fn elem_sub<'a>(
1361 &'a self,
1362 _a: &'a GpuTensorHandle,
1363 _b: &'a GpuTensorHandle,
1364 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1365 unsupported_future("elem_sub not supported by provider")
1366 }
1367 fn elem_div<'a>(
1368 &'a self,
1369 _a: &'a GpuTensorHandle,
1370 _b: &'a GpuTensorHandle,
1371 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1372 unsupported_future("elem_div not supported by provider")
1373 }
1374 fn elem_pow<'a>(
1375 &'a self,
1376 _a: &'a GpuTensorHandle,
1377 _b: &'a GpuTensorHandle,
1378 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1379 unsupported_future("elem_pow not supported by provider")
1380 }
1381
1382 fn elem_hypot<'a>(
1383 &'a self,
1384 _a: &'a GpuTensorHandle,
1385 _b: &'a GpuTensorHandle,
1386 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1387 unsupported_future("elem_hypot not supported by provider")
1388 }
1389 fn elem_ge<'a>(
1390 &'a self,
1391 _a: &'a GpuTensorHandle,
1392 _b: &'a GpuTensorHandle,
1393 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1394 unsupported_future("elem_ge not supported by provider")
1395 }
1396 fn elem_le<'a>(
1397 &'a self,
1398 _a: &'a GpuTensorHandle,
1399 _b: &'a GpuTensorHandle,
1400 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1401 unsupported_future("elem_le not supported by provider")
1402 }
1403 fn elem_lt<'a>(
1404 &'a self,
1405 _a: &'a GpuTensorHandle,
1406 _b: &'a GpuTensorHandle,
1407 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1408 unsupported_future("elem_lt not supported by provider")
1409 }
1410 fn elem_gt<'a>(
1411 &'a self,
1412 _a: &'a GpuTensorHandle,
1413 _b: &'a GpuTensorHandle,
1414 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1415 unsupported_future("elem_gt not supported by provider")
1416 }
1417 fn elem_eq<'a>(
1418 &'a self,
1419 _a: &'a GpuTensorHandle,
1420 _b: &'a GpuTensorHandle,
1421 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1422 unsupported_future("elem_eq not supported by provider")
1423 }
1424 fn elem_ne<'a>(
1425 &'a self,
1426 _a: &'a GpuTensorHandle,
1427 _b: &'a GpuTensorHandle,
1428 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1429 unsupported_future("elem_ne not supported by provider")
1430 }
1431 fn logical_and(
1432 &self,
1433 _a: &GpuTensorHandle,
1434 _b: &GpuTensorHandle,
1435 ) -> anyhow::Result<GpuTensorHandle> {
1436 Err(anyhow::anyhow!("logical_and not supported by provider"))
1437 }
1438 fn logical_or(
1439 &self,
1440 _a: &GpuTensorHandle,
1441 _b: &GpuTensorHandle,
1442 ) -> anyhow::Result<GpuTensorHandle> {
1443 Err(anyhow::anyhow!("logical_or not supported by provider"))
1444 }
1445 fn logical_xor(
1446 &self,
1447 _a: &GpuTensorHandle,
1448 _b: &GpuTensorHandle,
1449 ) -> anyhow::Result<GpuTensorHandle> {
1450 Err(anyhow::anyhow!("logical_xor not supported by provider"))
1451 }
1452 fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1453 Err(anyhow::anyhow!("logical_not not supported by provider"))
1454 }
1455 fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1456 Ok(handle_is_logical(a))
1457 }
1458 fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1459 Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1460 }
1461 fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1462 Err(anyhow::anyhow!(
1463 "logical_isfinite not supported by provider"
1464 ))
1465 }
1466 fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1467 Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1468 }
1469 fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1470 Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1471 }
1472 fn elem_atan2<'a>(
1473 &'a self,
1474 _y: &'a GpuTensorHandle,
1475 _x: &'a GpuTensorHandle,
1476 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1477 unsupported_future("elem_atan2 not supported by provider")
1478 }
1479 fn unary_sin<'a>(
1481 &'a self,
1482 _a: &'a GpuTensorHandle,
1483 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1484 unsupported_future("unary_sin not supported by provider")
1485 }
1486 fn unary_gamma<'a>(
1487 &'a self,
1488 _a: &'a GpuTensorHandle,
1489 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1490 unsupported_future("unary_gamma not supported by provider")
1491 }
1492 fn unary_factorial<'a>(
1493 &'a self,
1494 _a: &'a GpuTensorHandle,
1495 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1496 unsupported_future("unary_factorial not supported by provider")
1497 }
1498 fn unary_asinh<'a>(
1499 &'a self,
1500 _a: &'a GpuTensorHandle,
1501 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1502 unsupported_future("unary_asinh not supported by provider")
1503 }
1504 fn unary_sinh<'a>(
1505 &'a self,
1506 _a: &'a GpuTensorHandle,
1507 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1508 unsupported_future("unary_sinh not supported by provider")
1509 }
1510 fn unary_cosh<'a>(
1511 &'a self,
1512 _a: &'a GpuTensorHandle,
1513 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1514 unsupported_future("unary_cosh not supported by provider")
1515 }
1516 fn unary_asin<'a>(
1517 &'a self,
1518 _a: &'a GpuTensorHandle,
1519 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1520 unsupported_future("unary_asin not supported by provider")
1521 }
1522 fn unary_acos<'a>(
1523 &'a self,
1524 _a: &'a GpuTensorHandle,
1525 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1526 unsupported_future("unary_acos not supported by provider")
1527 }
1528 fn unary_acosh<'a>(
1529 &'a self,
1530 _a: &'a GpuTensorHandle,
1531 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1532 unsupported_future("unary_acosh not supported by provider")
1533 }
1534 fn unary_tan<'a>(
1535 &'a self,
1536 _a: &'a GpuTensorHandle,
1537 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1538 unsupported_future("unary_tan not supported by provider")
1539 }
1540 fn unary_tanh<'a>(
1541 &'a self,
1542 _a: &'a GpuTensorHandle,
1543 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1544 unsupported_future("unary_tanh not supported by provider")
1545 }
1546 fn unary_atan<'a>(
1547 &'a self,
1548 _a: &'a GpuTensorHandle,
1549 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1550 unsupported_future("unary_atan not supported by provider")
1551 }
1552 fn unary_atanh<'a>(
1553 &'a self,
1554 _a: &'a GpuTensorHandle,
1555 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1556 unsupported_future("unary_atanh not supported by provider")
1557 }
1558 fn unary_ceil<'a>(
1559 &'a self,
1560 _a: &'a GpuTensorHandle,
1561 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1562 unsupported_future("unary_ceil not supported by provider")
1563 }
1564 fn unary_floor<'a>(
1565 &'a self,
1566 _a: &'a GpuTensorHandle,
1567 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1568 unsupported_future("unary_floor not supported by provider")
1569 }
1570 fn unary_round<'a>(
1571 &'a self,
1572 _a: &'a GpuTensorHandle,
1573 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1574 unsupported_future("unary_round not supported by provider")
1575 }
1576 fn unary_fix<'a>(
1577 &'a self,
1578 _a: &'a GpuTensorHandle,
1579 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1580 unsupported_future("unary_fix not supported by provider")
1581 }
1582 fn unary_cos<'a>(
1583 &'a self,
1584 _a: &'a GpuTensorHandle,
1585 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1586 unsupported_future("unary_cos not supported by provider")
1587 }
1588 fn unary_angle<'a>(
1589 &'a self,
1590 _a: &'a GpuTensorHandle,
1591 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1592 unsupported_future("unary_angle not supported by provider")
1593 }
1594 fn unary_imag<'a>(
1595 &'a self,
1596 _a: &'a GpuTensorHandle,
1597 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1598 unsupported_future("unary_imag not supported by provider")
1599 }
1600 fn unary_real<'a>(
1601 &'a self,
1602 _a: &'a GpuTensorHandle,
1603 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1604 unsupported_future("unary_real not supported by provider")
1605 }
1606 fn unary_conj<'a>(
1607 &'a self,
1608 _a: &'a GpuTensorHandle,
1609 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1610 unsupported_future("unary_conj not supported by provider")
1611 }
1612 fn unary_abs<'a>(
1613 &'a self,
1614 _a: &'a GpuTensorHandle,
1615 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1616 unsupported_future("unary_abs not supported by provider")
1617 }
1618 fn unary_sign<'a>(
1619 &'a self,
1620 _a: &'a GpuTensorHandle,
1621 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1622 unsupported_future("unary_sign not supported by provider")
1623 }
1624 fn unary_exp<'a>(
1625 &'a self,
1626 _a: &'a GpuTensorHandle,
1627 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1628 unsupported_future("unary_exp not supported by provider")
1629 }
1630 fn unary_expm1<'a>(
1631 &'a self,
1632 _a: &'a GpuTensorHandle,
1633 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1634 unsupported_future("unary_expm1 not supported by provider")
1635 }
1636 fn unary_log<'a>(
1637 &'a self,
1638 _a: &'a GpuTensorHandle,
1639 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1640 unsupported_future("unary_log not supported by provider")
1641 }
1642 fn unary_log2<'a>(
1643 &'a self,
1644 _a: &'a GpuTensorHandle,
1645 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1646 unsupported_future("unary_log2 not supported by provider")
1647 }
1648 fn unary_log10<'a>(
1649 &'a self,
1650 _a: &'a GpuTensorHandle,
1651 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1652 unsupported_future("unary_log10 not supported by provider")
1653 }
1654 fn unary_log1p<'a>(
1655 &'a self,
1656 _a: &'a GpuTensorHandle,
1657 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1658 unsupported_future("unary_log1p not supported by provider")
1659 }
1660 fn unary_sqrt<'a>(
1661 &'a self,
1662 _a: &'a GpuTensorHandle,
1663 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1664 unsupported_future("unary_sqrt not supported by provider")
1665 }
1666 fn unary_double<'a>(
1667 &'a self,
1668 _a: &'a GpuTensorHandle,
1669 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1670 unsupported_future("unary_double not supported by provider")
1671 }
1672 fn unary_single<'a>(
1673 &'a self,
1674 _a: &'a GpuTensorHandle,
1675 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1676 unsupported_future("unary_single not supported by provider")
1677 }
1678 fn unary_pow2<'a>(
1679 &'a self,
1680 _a: &'a GpuTensorHandle,
1681 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1682 unsupported_future("unary_pow2 not supported by provider")
1683 }
1684 fn unary_nextpow2<'a>(
1685 &'a self,
1686 _a: &'a GpuTensorHandle,
1687 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1688 unsupported_future("unary_nextpow2 not supported by provider")
1689 }
1690 fn pow2_scale(
1691 &self,
1692 _mantissa: &GpuTensorHandle,
1693 _exponent: &GpuTensorHandle,
1694 ) -> anyhow::Result<GpuTensorHandle> {
1695 Err(anyhow::anyhow!("pow2_scale not supported by provider"))
1696 }
1697 fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1699 Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
1700 }
1701 fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1702 Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
1703 }
1704 fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1706 Err(anyhow::anyhow!("scalar_add not supported by provider"))
1707 }
1708 fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1709 Err(anyhow::anyhow!("scalar_sub not supported by provider"))
1710 }
1711 fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1712 Err(anyhow::anyhow!("scalar_mul not supported by provider"))
1713 }
1714 fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1715 Err(anyhow::anyhow!("scalar_max not supported by provider"))
1716 }
1717 fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1718 Err(anyhow::anyhow!("scalar_min not supported by provider"))
1719 }
1720 fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1721 Err(anyhow::anyhow!("scalar_div not supported by provider"))
1722 }
1723 fn sort_dim<'a>(
1724 &'a self,
1725 _a: &'a GpuTensorHandle,
1726 _dim: usize,
1727 _order: SortOrder,
1728 _comparison: SortComparison,
1729 ) -> AccelProviderFuture<'a, SortResult> {
1730 unsupported_future("sort_dim not supported by provider")
1731 }
1732 fn sort_rows<'a>(
1733 &'a self,
1734 _a: &'a GpuTensorHandle,
1735 _columns: &'a [SortRowsColumnSpec],
1736 _comparison: SortComparison,
1737 ) -> AccelProviderFuture<'a, SortResult> {
1738 unsupported_future("sort_rows not supported by provider")
1739 }
1740 fn matmul<'a>(
1741 &'a self,
1742 _a: &'a GpuTensorHandle,
1743 _b: &'a GpuTensorHandle,
1744 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1745 unsupported_future("matmul not supported by provider")
1746 }
1747
1748 fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1749 Err(anyhow::anyhow!("syrk not supported by provider"))
1750 }
1751 fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
1752 Err(anyhow::anyhow!("pagefun not supported by provider"))
1753 }
1754
1755 fn matmul_epilogue<'a>(
1760 &'a self,
1761 a: &'a GpuTensorHandle,
1762 b: &'a GpuTensorHandle,
1763 epilogue: &'a MatmulEpilogue,
1764 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1765 Box::pin(async move {
1766 if epilogue.is_noop() {
1767 return self.matmul(a, b).await;
1768 }
1769 Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
1770 })
1771 }
1772 fn image_normalize<'a>(
1773 &'a self,
1774 _input: &'a GpuTensorHandle,
1775 _desc: &'a ImageNormalizeDescriptor,
1776 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1777 unsupported_future("image_normalize fusion not supported by provider")
1778 }
1779 fn matmul_power_step<'a>(
1780 &'a self,
1781 _lhs: &'a GpuTensorHandle,
1782 _rhs: &'a GpuTensorHandle,
1783 _epilogue: &'a PowerStepEpilogue,
1784 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1785 unsupported_future("matmul_power_step normalization not supported by provider")
1786 }
1787 fn linsolve<'a>(
1788 &'a self,
1789 _lhs: &'a GpuTensorHandle,
1790 _rhs: &'a GpuTensorHandle,
1791 _options: &'a ProviderLinsolveOptions,
1792 ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
1793 unsupported_future("linsolve not supported by provider")
1794 }
1795 fn inv<'a>(
1796 &'a self,
1797 _matrix: &'a GpuTensorHandle,
1798 _options: ProviderInvOptions,
1799 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1800 unsupported_future("inv not supported by provider")
1801 }
1802 fn pinv<'a>(
1803 &'a self,
1804 _matrix: &'a GpuTensorHandle,
1805 _options: ProviderPinvOptions,
1806 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1807 unsupported_future("pinv not supported by provider")
1808 }
1809 fn cond<'a>(
1810 &'a self,
1811 _matrix: &'a GpuTensorHandle,
1812 _norm: ProviderCondNorm,
1813 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1814 Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
1815 }
1816 fn norm<'a>(
1817 &'a self,
1818 _tensor: &'a GpuTensorHandle,
1819 _order: ProviderNormOrder,
1820 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1821 Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
1822 }
1823 fn rank<'a>(
1824 &'a self,
1825 _matrix: &'a GpuTensorHandle,
1826 _tolerance: Option<f64>,
1827 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1828 Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
1829 }
1830 fn rcond<'a>(
1831 &'a self,
1832 _matrix: &'a GpuTensorHandle,
1833 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1834 Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
1835 }
1836 fn mldivide<'a>(
1837 &'a self,
1838 _lhs: &'a GpuTensorHandle,
1839 _rhs: &'a GpuTensorHandle,
1840 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1841 Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
1842 }
1843 fn mrdivide<'a>(
1844 &'a self,
1845 _lhs: &'a GpuTensorHandle,
1846 _rhs: &'a GpuTensorHandle,
1847 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1848 Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
1849 }
1850 fn eig<'a>(
1851 &'a self,
1852 _a: &'a GpuTensorHandle,
1853 _compute_left: bool,
1854 ) -> AccelProviderFuture<'a, ProviderEigResult> {
1855 Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
1856 }
1857 fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
1858 Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
1859 }
1860
1861 fn chol<'a>(
1862 &'a self,
1863 _a: &'a GpuTensorHandle,
1864 _lower: bool,
1865 ) -> AccelProviderFuture<'a, ProviderCholResult> {
1866 Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
1867 }
1868 fn qr<'a>(
1869 &'a self,
1870 _a: &'a GpuTensorHandle,
1871 _options: ProviderQrOptions,
1872 ) -> AccelProviderFuture<'a, ProviderQrResult> {
1873 Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
1874 }
1875 fn take_matmul_sources(
1876 &self,
1877 _product: &GpuTensorHandle,
1878 ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
1879 None
1880 }
1881 fn qr_power_iter<'a>(
1882 &'a self,
1883 product: &'a GpuTensorHandle,
1884 _product_lhs: Option<&'a GpuTensorHandle>,
1885 q_handle: &'a GpuTensorHandle,
1886 options: &'a ProviderQrOptions,
1887 ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
1888 let _ = (product, q_handle, options);
1889 Box::pin(async move { Ok(None) })
1890 }
1891 fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1892 Err(anyhow::anyhow!("transpose not supported by provider"))
1893 }
1894 fn conv1d(
1895 &self,
1896 _signal: &GpuTensorHandle,
1897 _kernel: &GpuTensorHandle,
1898 _options: ProviderConv1dOptions,
1899 ) -> anyhow::Result<GpuTensorHandle> {
1900 Err(anyhow::anyhow!("conv1d not supported by provider"))
1901 }
1902 fn conv2d(
1903 &self,
1904 _signal: &GpuTensorHandle,
1905 _kernel: &GpuTensorHandle,
1906 _mode: ProviderConvMode,
1907 ) -> anyhow::Result<GpuTensorHandle> {
1908 Err(anyhow::anyhow!("conv2d not supported by provider"))
1909 }
1910 fn iir_filter<'a>(
1911 &'a self,
1912 _b: &'a GpuTensorHandle,
1913 _a: &'a GpuTensorHandle,
1914 _x: &'a GpuTensorHandle,
1915 _options: ProviderIirFilterOptions,
1916 ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
1917 Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
1918 }
1919 fn permute(
1921 &self,
1922 _handle: &GpuTensorHandle,
1923 _order: &[usize],
1924 ) -> anyhow::Result<GpuTensorHandle> {
1925 Err(anyhow::anyhow!("permute not supported by provider"))
1926 }
1927 fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1928 Err(anyhow::anyhow!("flip not supported by provider"))
1929 }
1930 fn circshift(
1931 &self,
1932 _handle: &GpuTensorHandle,
1933 _shifts: &[isize],
1934 ) -> anyhow::Result<GpuTensorHandle> {
1935 Err(anyhow::anyhow!("circshift not supported by provider"))
1936 }
1937 fn diff_dim(
1938 &self,
1939 _handle: &GpuTensorHandle,
1940 _order: usize,
1941 _dim: usize,
1942 ) -> anyhow::Result<GpuTensorHandle> {
1943 Err(anyhow::anyhow!("diff_dim not supported by provider"))
1944 }
1945 fn gradient_dim(
1946 &self,
1947 _handle: &GpuTensorHandle,
1948 _dim: usize,
1949 _spacing: f64,
1950 ) -> anyhow::Result<GpuTensorHandle> {
1951 Err(anyhow::anyhow!("gradient_dim not supported by provider"))
1952 }
1953 fn fft_dim<'a>(
1955 &'a self,
1956 _handle: &'a GpuTensorHandle,
1957 _len: Option<usize>,
1958 _dim: usize,
1959 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1960 unsupported_future("fft_dim not supported by provider")
1961 }
1962 fn ifft_dim<'a>(
1963 &'a self,
1964 _handle: &'a GpuTensorHandle,
1965 _len: Option<usize>,
1966 _dim: usize,
1967 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1968 unsupported_future("ifft_dim not supported by provider")
1969 }
1970 fn fft_extract_real<'a>(
1971 &'a self,
1972 _handle: &'a GpuTensorHandle,
1973 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1974 unsupported_future("fft_extract_real not supported by provider")
1975 }
1976 fn unique<'a>(
1977 &'a self,
1978 _handle: &'a GpuTensorHandle,
1979 _options: &'a UniqueOptions,
1980 ) -> AccelProviderFuture<'a, UniqueResult> {
1981 Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
1982 }
1983 fn union<'a>(
1984 &'a self,
1985 _a: &'a GpuTensorHandle,
1986 _b: &'a GpuTensorHandle,
1987 _options: &'a UnionOptions,
1988 ) -> AccelProviderFuture<'a, UnionResult> {
1989 Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
1990 }
1991 fn setdiff<'a>(
1992 &'a self,
1993 _a: &'a GpuTensorHandle,
1994 _b: &'a GpuTensorHandle,
1995 _options: &'a SetdiffOptions,
1996 ) -> AccelProviderFuture<'a, SetdiffResult> {
1997 Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
1998 }
1999 fn ismember<'a>(
2000 &'a self,
2001 _a: &'a GpuTensorHandle,
2002 _b: &'a GpuTensorHandle,
2003 _options: &'a IsMemberOptions,
2004 ) -> AccelProviderFuture<'a, IsMemberResult> {
2005 Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
2006 }
2007 fn reshape(
2008 &self,
2009 handle: &GpuTensorHandle,
2010 new_shape: &[usize],
2011 ) -> anyhow::Result<GpuTensorHandle> {
2012 let mut updated = handle.clone();
2013 updated.shape = new_shape.to_vec();
2014 Ok(updated)
2015 }
2016 fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
2018 Err(anyhow::anyhow!("cat not supported by provider"))
2019 }
2020 fn repmat(
2021 &self,
2022 _handle: &GpuTensorHandle,
2023 _reps: &[usize],
2024 ) -> anyhow::Result<GpuTensorHandle> {
2025 Err(anyhow::anyhow!("repmat not supported by provider"))
2026 }
2027 fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2029 Err(anyhow::anyhow!("kron not supported by provider"))
2030 }
2031 fn cross(
2033 &self,
2034 _lhs: &GpuTensorHandle,
2035 _rhs: &GpuTensorHandle,
2036 _dim: Option<usize>,
2037 ) -> anyhow::Result<GpuTensorHandle> {
2038 Err(anyhow::anyhow!("cross not supported by provider"))
2039 }
2040 fn reduce_sum<'a>(
2041 &'a self,
2042 _a: &'a GpuTensorHandle,
2043 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2044 unsupported_future("reduce_sum not supported by provider")
2045 }
2046 fn reduce_sum_dim<'a>(
2047 &'a self,
2048 _a: &'a GpuTensorHandle,
2049 _dim: usize,
2050 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2051 unsupported_future("reduce_sum_dim not supported by provider")
2052 }
2053 fn dot<'a>(
2054 &'a self,
2055 _lhs: &'a GpuTensorHandle,
2056 _rhs: &'a GpuTensorHandle,
2057 _dim: Option<usize>,
2058 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2059 unsupported_future("dot not supported by provider")
2060 }
2061 fn reduce_nnz<'a>(
2062 &'a self,
2063 _a: &'a GpuTensorHandle,
2064 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2065 unsupported_future("reduce_nnz not supported by provider")
2066 }
2067 fn reduce_nnz_dim<'a>(
2068 &'a self,
2069 _a: &'a GpuTensorHandle,
2070 _dim: usize,
2071 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2072 unsupported_future("reduce_nnz_dim not supported by provider")
2073 }
2074 fn reduce_prod<'a>(
2075 &'a self,
2076 _a: &'a GpuTensorHandle,
2077 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2078 unsupported_future("reduce_prod not supported by provider")
2079 }
2080 fn reduce_prod_dim<'a>(
2081 &'a self,
2082 _a: &'a GpuTensorHandle,
2083 _dim: usize,
2084 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2085 unsupported_future("reduce_prod_dim not supported by provider")
2086 }
2087 fn reduce_mean<'a>(
2088 &'a self,
2089 _a: &'a GpuTensorHandle,
2090 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2091 unsupported_future("reduce_mean not supported by provider")
2092 }
2093 fn reduce_mean_nd<'a>(
2095 &'a self,
2096 _a: &'a GpuTensorHandle,
2097 _dims_zero_based: &'a [usize],
2098 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2099 unsupported_future("reduce_mean_nd not supported by provider")
2100 }
2101 fn reduce_moments_nd<'a>(
2104 &'a self,
2105 _a: &'a GpuTensorHandle,
2106 _dims_zero_based: &'a [usize],
2107 ) -> AccelProviderFuture<'a, ProviderMoments2> {
2108 unsupported_future("reduce_moments_nd not supported by provider")
2109 }
2110 fn reduce_mean_dim<'a>(
2111 &'a self,
2112 _a: &'a GpuTensorHandle,
2113 _dim: usize,
2114 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2115 unsupported_future("reduce_mean_dim not supported by provider")
2116 }
2117 fn reduce_std<'a>(
2118 &'a self,
2119 _a: &'a GpuTensorHandle,
2120 _normalization: ProviderStdNormalization,
2121 _nan_mode: ProviderNanMode,
2122 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2123 unsupported_future("reduce_std not supported by provider")
2124 }
2125 fn reduce_std_dim<'a>(
2126 &'a self,
2127 _a: &'a GpuTensorHandle,
2128 _dim: usize,
2129 _normalization: ProviderStdNormalization,
2130 _nan_mode: ProviderNanMode,
2131 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2132 unsupported_future("reduce_std_dim not supported by provider")
2133 }
2134 fn reduce_any<'a>(
2135 &'a self,
2136 _a: &'a GpuTensorHandle,
2137 _omit_nan: bool,
2138 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2139 unsupported_future("reduce_any not supported by provider")
2140 }
2141 fn reduce_any_dim<'a>(
2142 &'a self,
2143 _a: &'a GpuTensorHandle,
2144 _dim: usize,
2145 _omit_nan: bool,
2146 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2147 unsupported_future("reduce_any_dim not supported by provider")
2148 }
2149 fn reduce_all<'a>(
2150 &'a self,
2151 _a: &'a GpuTensorHandle,
2152 _omit_nan: bool,
2153 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2154 unsupported_future("reduce_all not supported by provider")
2155 }
2156 fn reduce_all_dim<'a>(
2157 &'a self,
2158 _a: &'a GpuTensorHandle,
2159 _dim: usize,
2160 _omit_nan: bool,
2161 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2162 unsupported_future("reduce_all_dim not supported by provider")
2163 }
2164 fn reduce_median<'a>(
2165 &'a self,
2166 _a: &'a GpuTensorHandle,
2167 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2168 unsupported_future("reduce_median not supported by provider")
2169 }
2170 fn reduce_median_dim<'a>(
2171 &'a self,
2172 _a: &'a GpuTensorHandle,
2173 _dim: usize,
2174 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2175 unsupported_future("reduce_median_dim not supported by provider")
2176 }
2177 fn reduce_min<'a>(
2178 &'a self,
2179 _a: &'a GpuTensorHandle,
2180 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2181 unsupported_future("reduce_min not supported by provider")
2182 }
2183 fn reduce_min_dim<'a>(
2184 &'a self,
2185 _a: &'a GpuTensorHandle,
2186 _dim: usize,
2187 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2188 unsupported_future("reduce_min_dim not supported by provider")
2189 }
2190 fn reduce_max<'a>(
2191 &'a self,
2192 _a: &'a GpuTensorHandle,
2193 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2194 unsupported_future("reduce_max not supported by provider")
2195 }
2196 fn reduce_max_dim<'a>(
2197 &'a self,
2198 _a: &'a GpuTensorHandle,
2199 _dim: usize,
2200 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2201 unsupported_future("reduce_max_dim not supported by provider")
2202 }
2203 fn cumsum_scan(
2204 &self,
2205 _input: &GpuTensorHandle,
2206 _dim: usize,
2207 _direction: ProviderScanDirection,
2208 _nan_mode: ProviderNanMode,
2209 ) -> anyhow::Result<GpuTensorHandle> {
2210 Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2211 }
2212 fn cumprod_scan(
2213 &self,
2214 _input: &GpuTensorHandle,
2215 _dim: usize,
2216 _direction: ProviderScanDirection,
2217 _nan_mode: ProviderNanMode,
2218 ) -> anyhow::Result<GpuTensorHandle> {
2219 Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2220 }
2221 fn cummin_scan(
2222 &self,
2223 _input: &GpuTensorHandle,
2224 _dim: usize,
2225 _direction: ProviderScanDirection,
2226 _nan_mode: ProviderNanMode,
2227 ) -> anyhow::Result<ProviderCumminResult> {
2228 Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2229 }
2230 fn cummax_scan(
2231 &self,
2232 _input: &GpuTensorHandle,
2233 _dim: usize,
2234 _direction: ProviderScanDirection,
2235 _nan_mode: ProviderNanMode,
2236 ) -> anyhow::Result<ProviderCummaxResult> {
2237 Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2238 }
2239
2240 fn find(
2241 &self,
2242 _a: &GpuTensorHandle,
2243 _limit: Option<usize>,
2244 _direction: FindDirection,
2245 ) -> anyhow::Result<ProviderFindResult> {
2246 Err(anyhow::anyhow!("find not supported by provider"))
2247 }
2248
2249 fn fused_elementwise(
2250 &self,
2251 _shader: &str,
2252 _inputs: &[GpuTensorHandle],
2253 _output_shape: &[usize],
2254 _len: usize,
2255 ) -> anyhow::Result<GpuTensorHandle> {
2256 Err(anyhow::anyhow!(
2257 "fused_elementwise not supported by provider"
2258 ))
2259 }
2260
2261 fn fused_elementwise_multi(
2270 &self,
2271 _shader: &str,
2272 _inputs: &[GpuTensorHandle],
2273 _output_shape: &[usize],
2274 _len: usize,
2275 _num_outputs: usize,
2276 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2277 Err(anyhow::anyhow!(
2278 "fused_elementwise_multi not supported by provider"
2279 ))
2280 }
2281
2282 fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2284 Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2285 }
2286
2287 fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2289 Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2290 }
2291
2292 #[allow(clippy::too_many_arguments)]
2299 fn fused_reduction(
2300 &self,
2301 _shader: &str,
2302 _inputs: &[GpuTensorHandle],
2303 _output_shape: &[usize],
2304 _reduce_len: usize,
2305 _num_slices: usize,
2306 _workgroup_size: u32,
2307 _flavor: ReductionFlavor,
2308 ) -> anyhow::Result<GpuTensorHandle> {
2309 Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2310 }
2311
2312 fn warmup(&self) {}
2314
2315 fn fused_cache_counters(&self) -> (u64, u64) {
2317 (0, 0)
2318 }
2319
2320 fn last_warmup_millis(&self) -> Option<u64> {
2322 None
2323 }
2324
2325 fn telemetry_snapshot(&self) -> ProviderTelemetry {
2327 let (hits, misses) = self.fused_cache_counters();
2328 ProviderTelemetry {
2329 fused_elementwise: ProviderDispatchStats::default(),
2330 fused_reduction: ProviderDispatchStats::default(),
2331 matmul: ProviderDispatchStats::default(),
2332 linsolve: ProviderDispatchStats::default(),
2333 mldivide: ProviderDispatchStats::default(),
2334 mrdivide: ProviderDispatchStats::default(),
2335 upload_bytes: 0,
2336 download_bytes: 0,
2337 solve_fallbacks: Vec::new(),
2338 fusion_cache_hits: hits,
2339 fusion_cache_misses: misses,
2340 bind_group_cache_hits: 0,
2341 bind_group_cache_misses: 0,
2342 bind_group_cache_by_layout: None,
2343 kernel_launches: Vec::new(),
2344 }
2345 }
2346
2347 fn reset_telemetry(&self) {}
2349
2350 fn default_reduction_workgroup_size(&self) -> u32 {
2352 256
2353 }
2354
2355 fn two_pass_threshold(&self) -> usize {
2357 1024
2358 }
2359
2360 fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2362 ReductionTwoPassMode::Auto
2363 }
2364
2365 fn scatter_column(
2368 &self,
2369 _matrix: &GpuTensorHandle,
2370 _col_index: usize,
2371 _values: &GpuTensorHandle,
2372 ) -> anyhow::Result<GpuTensorHandle> {
2373 Err(anyhow::anyhow!("scatter_column not supported by provider"))
2374 }
2375
2376 fn scatter_row(
2379 &self,
2380 _matrix: &GpuTensorHandle,
2381 _row_index: usize,
2382 _values: &GpuTensorHandle,
2383 ) -> anyhow::Result<GpuTensorHandle> {
2384 Err(anyhow::anyhow!("scatter_row not supported by provider"))
2385 }
2386
2387 fn sub2ind(
2388 &self,
2389 _dims: &[usize],
2390 _strides: &[usize],
2391 _inputs: &[&GpuTensorHandle],
2392 _scalar_mask: &[bool],
2393 _len: usize,
2394 _output_shape: &[usize],
2395 ) -> anyhow::Result<GpuTensorHandle> {
2396 Err(anyhow::anyhow!("sub2ind not supported by provider"))
2397 }
2398
2399 fn supports_ind2sub(&self) -> bool {
2401 false
2402 }
2403
2404 fn ind2sub(
2406 &self,
2407 _dims: &[usize],
2408 _strides: &[usize],
2409 _indices: &GpuTensorHandle,
2410 _total: usize,
2411 _len: usize,
2412 _output_shape: &[usize],
2413 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2414 Err(anyhow::anyhow!("ind2sub not supported by provider"))
2415 }
2416
2417 fn issymmetric(
2419 &self,
2420 _matrix: &GpuTensorHandle,
2421 _kind: ProviderSymmetryKind,
2422 _tolerance: f64,
2423 ) -> anyhow::Result<bool> {
2424 Err(anyhow::anyhow!(
2425 "issymmetric predicate not supported by provider"
2426 ))
2427 }
2428
2429 fn ishermitian<'a>(
2431 &'a self,
2432 _matrix: &'a GpuTensorHandle,
2433 _kind: ProviderHermitianKind,
2434 _tolerance: f64,
2435 ) -> AccelProviderFuture<'a, bool> {
2436 Box::pin(async move {
2437 Err(anyhow::anyhow!(
2438 "ishermitian predicate not supported by provider"
2439 ))
2440 })
2441 }
2442
2443 fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2445 Err(anyhow::anyhow!("bandwidth not supported by provider"))
2446 }
2447
2448 fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2453 Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2454 }
2455}
2456
2457static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2458 Lazy::new(|| RwLock::new(None));
2459static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2460 Lazy::new(|| RwLock::new(HashMap::new()));
2461static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2462
2463#[cfg(not(target_arch = "wasm32"))]
2464thread_local! {
2465 static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2466}
2467
2468#[cfg(target_arch = "wasm32")]
2469static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2470 Lazy::new(|| Mutex::new(None));
2471
2472#[cfg(not(target_arch = "wasm32"))]
2473fn replace_thread_provider(
2474 provider: Option<&'static dyn AccelProvider>,
2475) -> Option<&'static dyn AccelProvider> {
2476 THREAD_PROVIDER.with(|cell| {
2477 let prev = cell.get();
2478 cell.set(provider);
2479 prev
2480 })
2481}
2482
2483#[cfg(target_arch = "wasm32")]
2484fn replace_thread_provider(
2485 provider: Option<&'static dyn AccelProvider>,
2486) -> Option<&'static dyn AccelProvider> {
2487 let mut slot = WASM_THREAD_PROVIDER
2488 .lock()
2489 .expect("wasm provider mutex poisoned");
2490 let prev = *slot;
2491 *slot = provider;
2492 prev
2493}
2494
2495#[cfg(not(target_arch = "wasm32"))]
2496fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2497 THREAD_PROVIDER.with(|cell| cell.get())
2498}
2499
2500#[cfg(target_arch = "wasm32")]
2501fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2502 WASM_THREAD_PROVIDER
2503 .lock()
2504 .expect("wasm provider mutex poisoned")
2505 .as_ref()
2506 .copied()
2507}
2508
2509pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2517 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2518 *guard = Some(p);
2519 }
2520 register_provider_for_device(p.device_id(), p);
2521}
2522
2523unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2524 if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2525 guard.insert(device_id, provider);
2526 }
2527}
2528
2529pub fn provider() -> Option<&'static dyn AccelProvider> {
2530 if let Some(p) = current_thread_provider() {
2531 return Some(p);
2532 }
2533 GLOBAL_PROVIDER
2534 .read()
2535 .ok()
2536 .and_then(|guard| guard.as_ref().copied())
2537}
2538
2539pub fn clear_provider() {
2541 replace_thread_provider(None);
2542 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2543 *guard = None;
2544 }
2545 if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2546 map.clear();
2547 }
2548}
2549
2550pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2551 PROVIDER_REGISTRY
2552 .read()
2553 .ok()
2554 .and_then(|guard| guard.get(&device_id).copied())
2555 .or_else(|| provider())
2556}
2557
2558pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2559 provider_for_device(handle.device_id)
2560}
2561
2562pub fn next_device_id() -> u32 {
2563 DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2564}
2565
2566pub struct ThreadProviderGuard {
2567 prev: Option<&'static dyn AccelProvider>,
2568}
2569
2570impl ThreadProviderGuard {
2571 pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2572 let prev = replace_thread_provider(provider);
2573 ThreadProviderGuard { prev }
2574 }
2575}
2576
2577impl Drop for ThreadProviderGuard {
2578 fn drop(&mut self) {
2579 let prev = self.prev.take();
2580 replace_thread_provider(prev);
2581 }
2582}
2583
2584pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2585 replace_thread_provider(provider);
2586}
2587
2588pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2590 if let Some(p) = provider() {
2591 if let Ok(h) = p.elem_add(a, b).await {
2592 return Some(h);
2593 }
2594 }
2595 None
2596}
2597
2598pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2600 if let Some(p) = provider() {
2601 if let Ok(h) = p.elem_hypot(a, b).await {
2602 return Some(h);
2603 }
2604 }
2605 None
2606}
2607
2608pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2610 if let Some(p) = provider() {
2611 if let Ok(h) = p.elem_max(a, b).await {
2612 return Some(h);
2613 }
2614 }
2615 None
2616}
2617
2618pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2620 if let Some(p) = provider() {
2621 if let Ok(h) = p.elem_min(a, b).await {
2622 return Some(h);
2623 }
2624 }
2625 None
2626}
2627
2628pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2630 if let Some(p) = provider() {
2631 if let Ok(h) = p.elem_atan2(y, x).await {
2632 return Some(h);
2633 }
2634 }
2635 None
2636}
2637
2638#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2640pub struct HostTensorOwned {
2641 pub data: Vec<f64>,
2642 pub shape: Vec<usize>,
2643 pub storage: GpuTensorStorage,
2644}
2645
2646#[derive(Debug)]
2647pub struct HostTensorView<'a> {
2648 pub data: &'a [f64],
2649 pub shape: &'a [usize],
2650}
2651
2652#[derive(Debug)]
2654pub struct MeshgridAxisView<'a> {
2655 pub data: &'a [f64],
2656}
2657
2658#[derive(Debug, Clone)]
2660pub struct ProviderMeshgridResult {
2661 pub outputs: Vec<GpuTensorHandle>,
2662}
2663
2664#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
2670pub enum ScaleOp {
2671 Multiply,
2672 Divide,
2673}
2674
2675#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2676pub struct MatmulEpilogue {
2677 pub alpha: f64,
2679 pub beta: f64,
2681 pub row_scale: Option<GpuTensorHandle>,
2683 pub col_scale: Option<GpuTensorHandle>,
2685 pub row_op: ScaleOp,
2687 pub col_op: ScaleOp,
2689 #[serde(default)]
2691 pub clamp_min: Option<f64>,
2692 #[serde(default)]
2694 pub clamp_max: Option<f64>,
2695 #[serde(default)]
2697 pub pow_exponent: Option<f64>,
2698 #[serde(default)]
2700 pub diag_output: Option<GpuTensorHandle>,
2701}
2702
2703impl MatmulEpilogue {
2704 pub fn noop() -> Self {
2705 Self {
2706 alpha: 1.0,
2707 beta: 0.0,
2708 row_scale: None,
2709 col_scale: None,
2710 row_op: ScaleOp::Multiply,
2711 col_op: ScaleOp::Multiply,
2712 clamp_min: None,
2713 clamp_max: None,
2714 pow_exponent: None,
2715 diag_output: None,
2716 }
2717 }
2718 pub fn is_noop(&self) -> bool {
2719 self.alpha == 1.0
2720 && self.beta == 0.0
2721 && self.row_scale.is_none()
2722 && self.col_scale.is_none()
2723 && self.clamp_min.is_none()
2724 && self.clamp_max.is_none()
2725 && self.pow_exponent.is_none()
2726 && self.diag_output.is_none()
2727 }
2728}
2729
2730#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
2731pub struct PowerStepEpilogue {
2732 pub epsilon: f64,
2733}
2734
2735impl Default for PowerStepEpilogue {
2736 fn default() -> Self {
2737 Self { epsilon: 0.0 }
2738 }
2739}
2740
2741#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2742pub struct ImageNormalizeDescriptor {
2743 pub batch: usize,
2744 pub height: usize,
2745 pub width: usize,
2746 pub epsilon: f64,
2747 #[serde(default)]
2748 pub gain: Option<f64>,
2749 #[serde(default)]
2750 pub bias: Option<f64>,
2751 #[serde(default)]
2752 pub gamma: Option<f64>,
2753}