1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyClearFn = fn(&GpuTensorHandle);
17type SequenceThresholdFn = fn() -> Option<usize>;
18type WorkgroupSizeHintFn = fn() -> Option<u32>;
19
20static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
21static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
22static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
23
24static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
25static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
26 Lazy::new(|| RwLock::new(HashMap::new()));
27static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
28 Lazy::new(|| RwLock::new(HashMap::new()));
29
30static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
31 Lazy::new(|| RwLock::new(HashMap::new()));
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct TransposeInfo {
35 pub base_rows: usize,
36 pub base_cols: usize,
37}
38
39pub fn register_residency_clear(handler: ResidencyClearFn) {
43 let _ = RESIDENCY_CLEAR.set(handler);
44}
45
46pub fn clear_residency(handle: &GpuTensorHandle) {
49 if let Some(handler) = RESIDENCY_CLEAR.get() {
50 handler(handle);
51 }
52}
53
54pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
58 let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
59}
60
61pub fn sequence_threshold_hint() -> Option<usize> {
63 SEQUENCE_THRESHOLD_PROVIDER
64 .get()
65 .and_then(|provider| provider())
66}
67
68pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
72 let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
73}
74
75pub fn workgroup_size_hint() -> Option<u32> {
77 WORKGROUP_SIZE_HINT_PROVIDER
78 .get()
79 .and_then(|provider| provider())
80}
81
82pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
85 provider().and_then(|p| p.export_context(kind))
86}
87
88#[cfg(feature = "wgpu")]
92pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
93 provider().and_then(|p| p.export_wgpu_buffer(handle))
94}
95
96pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
99 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
100 guard.insert(handle.buffer_id, precision);
101 }
102}
103
104pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
106 HANDLE_PRECISIONS
107 .read()
108 .ok()
109 .and_then(|guard| guard.get(&handle.buffer_id).copied())
110}
111
112pub fn clear_handle_precision(handle: &GpuTensorHandle) {
114 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
115 guard.remove(&handle.buffer_id);
116 }
117}
118
119pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
122 if let Ok(mut guard) = LOGICAL_HANDLES.write() {
123 if logical {
124 guard.insert(handle.buffer_id);
125 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
126 *hits.entry(handle.buffer_id).or_insert(0) += 1;
127 }
128 } else {
129 guard.remove(&handle.buffer_id);
130 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
131 hits.remove(&handle.buffer_id);
132 }
133 }
134 }
135}
136
137pub fn clear_handle_logical(handle: &GpuTensorHandle) {
139 set_handle_logical(handle, false);
140}
141
142pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
144 LOGICAL_HANDLES
145 .read()
146 .map(|guard| guard.contains(&handle.buffer_id))
147 .unwrap_or(false)
148}
149
150pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
151 LOGICAL_HANDLE_HITS
152 .read()
153 .ok()
154 .and_then(|guard| guard.get(&buffer_id).copied())
155}
156
157pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
158 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
159 guard.insert(
160 handle.buffer_id,
161 TransposeInfo {
162 base_rows,
163 base_cols,
164 },
165 );
166 }
167}
168
169pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
170 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
171 guard.remove(&handle.buffer_id);
172 }
173}
174
175pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
176 TRANSPOSED_HANDLES
177 .read()
178 .ok()
179 .and_then(|guard| guard.get(&handle.buffer_id).copied())
180}
181
182pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
183 handle_transpose_info(handle).is_some()
184}
185
186#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
187pub struct GpuTensorHandle {
188 pub shape: Vec<usize>,
189 pub device_id: u32,
190 pub buffer_id: u64,
191}
192
193#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
194pub struct ApiDeviceInfo {
195 pub device_id: u32,
196 pub name: String,
197 pub vendor: String,
198 pub memory_bytes: Option<u64>,
199 pub backend: Option<String>,
200}
201
202#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
203pub struct ReduceDimResult {
204 pub values: GpuTensorHandle,
205 pub indices: GpuTensorHandle,
206}
207
208#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
209pub struct ProviderCumminResult {
210 pub values: GpuTensorHandle,
211 pub indices: GpuTensorHandle,
212}
213
214pub type ProviderCummaxResult = ProviderCumminResult;
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq)]
222pub enum AccelContextKind {
223 Plotting,
224}
225
226#[derive(Clone)]
228pub enum AccelContextHandle {
229 #[cfg(feature = "wgpu")]
230 Wgpu(WgpuContextHandle),
231}
232
233impl AccelContextHandle {
234 #[cfg(feature = "wgpu")]
236 pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
237 match self {
238 AccelContextHandle::Wgpu(ctx) => Some(ctx),
239 }
240 }
241}
242
243#[cfg(feature = "wgpu")]
245#[derive(Clone)]
246pub struct WgpuContextHandle {
247 pub instance: Arc<wgpu::Instance>,
248 pub device: Arc<wgpu::Device>,
249 pub queue: Arc<wgpu::Queue>,
250 pub adapter: Arc<wgpu::Adapter>,
251 pub adapter_info: wgpu::AdapterInfo,
252 pub limits: wgpu::Limits,
253 pub features: wgpu::Features,
254}
255
256#[cfg(feature = "wgpu")]
258#[derive(Clone)]
259pub struct WgpuBufferRef {
260 pub buffer: Arc<wgpu::Buffer>,
261 pub len: usize,
262 pub shape: Vec<usize>,
263 pub element_size: usize,
264 pub precision: ProviderPrecision,
265}
266
267#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
268pub enum PagefunOp {
269 Mtimes,
270}
271
272#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
273pub struct PagefunRequest {
274 pub op: PagefunOp,
275 pub inputs: Vec<GpuTensorHandle>,
276 pub output_shape: Vec<usize>,
277 pub page_dims: Vec<usize>,
278 pub input_page_dims: Vec<Vec<usize>>,
279}
280
281#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
282pub enum FindDirection {
283 First,
284 Last,
285}
286
287#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
288pub struct ProviderFindResult {
289 pub linear: GpuTensorHandle,
290 pub rows: GpuTensorHandle,
291 pub cols: GpuTensorHandle,
292 pub values: Option<GpuTensorHandle>,
293}
294
295#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
296pub struct ProviderBandwidth {
297 pub lower: u32,
298 pub upper: u32,
299}
300
301#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
302pub enum ProviderSymmetryKind {
303 Symmetric,
304 Skew,
305}
306
307#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
308pub enum ProviderHermitianKind {
309 Hermitian,
310 Skew,
311}
312
313#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
314pub struct ProviderLuResult {
315 pub combined: GpuTensorHandle,
316 pub lower: GpuTensorHandle,
317 pub upper: GpuTensorHandle,
318 pub perm_matrix: GpuTensorHandle,
319 pub perm_vector: GpuTensorHandle,
320}
321
322#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
323pub struct ProviderCholResult {
324 pub factor: GpuTensorHandle,
325 pub info: u32,
327}
328
329#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
330pub struct ProviderQrResult {
331 pub q: GpuTensorHandle,
332 pub r: GpuTensorHandle,
333 pub perm_matrix: GpuTensorHandle,
334 pub perm_vector: GpuTensorHandle,
335}
336
337#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
338pub struct ProviderQrPowerIterResult {
339 pub q: GpuTensorHandle,
340 pub r: GpuTensorHandle,
341 pub perm_matrix: GpuTensorHandle,
342 pub perm_vector: GpuTensorHandle,
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
346pub struct ProviderLinsolveOptions {
347 pub lower: bool,
348 pub upper: bool,
349 pub rectangular: bool,
350 pub transposed: bool,
351 pub conjugate: bool,
352 pub symmetric: bool,
353 pub posdef: bool,
354 pub rcond: Option<f64>,
355}
356
357#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
358pub struct ProviderLinsolveResult {
359 pub solution: GpuTensorHandle,
360 pub reciprocal_condition: f64,
361}
362
363#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
364pub struct ProviderPinvOptions {
365 pub tolerance: Option<f64>,
366}
367
368#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
369pub struct ProviderPolyvalMu {
370 pub mean: f64,
371 pub scale: f64,
372}
373
374#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
375pub struct ProviderPolyvalOptions {
376 pub mu: Option<ProviderPolyvalMu>,
377}
378
379#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
380pub struct ProviderInvOptions {}
381
382#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
383pub struct ProviderPolyfitResult {
384 pub coefficients: Vec<f64>,
385 pub r_matrix: Vec<f64>,
386 pub normr: f64,
387 pub df: f64,
388 pub mu: [f64; 2],
389}
390
391#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
393pub struct ProviderPolyderQuotient {
394 pub numerator: GpuTensorHandle,
395 pub denominator: GpuTensorHandle,
396}
397
398#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
400pub enum ProviderCondNorm {
401 Two,
402 One,
403 Inf,
404 Fro,
405}
406
407#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
409pub enum ProviderNormOrder {
410 Two,
411 One,
412 Inf,
413 NegInf,
414 Zero,
415 Fro,
416 Nuc,
417 P(f64),
418}
419
420#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
421pub struct ProviderEigResult {
422 pub eigenvalues: GpuTensorHandle,
423 pub diagonal: GpuTensorHandle,
424 pub right: GpuTensorHandle,
425 pub left: Option<GpuTensorHandle>,
426}
427
428#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
429pub enum ProviderQrPivot {
430 Matrix,
431 Vector,
432}
433
434#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
435pub struct ProviderQrOptions {
436 pub economy: bool,
437 pub pivot: ProviderQrPivot,
438}
439
440impl Default for ProviderQrOptions {
441 fn default() -> Self {
442 Self {
443 economy: false,
444 pivot: ProviderQrPivot::Matrix,
445 }
446 }
447}
448
449#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
450pub enum ProviderPrecision {
451 F32,
452 F64,
453}
454
455#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
456pub enum ReductionTwoPassMode {
457 Auto,
458 ForceOn,
459 ForceOff,
460}
461
462impl ReductionTwoPassMode {
463 pub fn as_str(self) -> &'static str {
464 match self {
465 ReductionTwoPassMode::Auto => "auto",
466 ReductionTwoPassMode::ForceOn => "force_on",
467 ReductionTwoPassMode::ForceOff => "force_off",
468 }
469 }
470}
471
472#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
473pub enum ReductionFlavor {
474 Sum,
475 Mean,
476 CustomScale(f64),
477}
478
479impl ReductionFlavor {
480 pub fn is_mean(self) -> bool {
481 matches!(self, ReductionFlavor::Mean)
482 }
483
484 pub fn scale(self, reduce_len: usize) -> f64 {
485 match self {
486 ReductionFlavor::Sum => 1.0,
487 ReductionFlavor::Mean => {
488 if reduce_len == 0 {
489 1.0
490 } else {
491 1.0 / reduce_len as f64
492 }
493 }
494 ReductionFlavor::CustomScale(scale) => scale,
495 }
496 }
497}
498
499#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
501pub enum CorrcoefNormalization {
502 Unbiased,
503 Biased,
504}
505
506#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
508pub enum CorrcoefRows {
509 All,
510 Complete,
511 Pairwise,
512}
513
514#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
516pub struct CorrcoefOptions {
517 pub normalization: CorrcoefNormalization,
518 pub rows: CorrcoefRows,
519}
520
521impl Default for CorrcoefOptions {
522 fn default() -> Self {
523 Self {
524 normalization: CorrcoefNormalization::Unbiased,
525 rows: CorrcoefRows::All,
526 }
527 }
528}
529
530#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
532pub enum CovNormalization {
533 Unbiased,
534 Biased,
535}
536
537#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
539pub enum CovRows {
540 All,
541 OmitRows,
542 PartialRows,
543}
544
545#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
547pub struct CovarianceOptions {
548 pub normalization: CovNormalization,
549 pub rows: CovRows,
550 pub has_weight_vector: bool,
551}
552
553impl Default for CovarianceOptions {
554 fn default() -> Self {
555 Self {
556 normalization: CovNormalization::Unbiased,
557 rows: CovRows::All,
558 has_weight_vector: false,
559 }
560 }
561}
562
563#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
565pub enum ProviderStdNormalization {
566 Sample,
567 Population,
568}
569
570#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
572pub enum ProviderNanMode {
573 Include,
574 Omit,
575}
576
577#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
579pub enum ProviderScanDirection {
580 Forward,
581 Reverse,
582}
583
584#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
586pub enum SortOrder {
587 Ascend,
588 Descend,
589}
590
591#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
593pub enum SortComparison {
594 Auto,
595 Real,
596 Abs,
597}
598
599#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
601pub struct SortResult {
602 pub values: HostTensorOwned,
603 pub indices: HostTensorOwned,
604}
605
606#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
607pub struct SortRowsColumnSpec {
608 pub index: usize,
609 pub order: SortOrder,
610}
611
612#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
614pub enum UniqueOrder {
615 Sorted,
616 Stable,
617}
618
619#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
621pub enum UniqueOccurrence {
622 First,
623 Last,
624}
625
626#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
628pub struct UniqueOptions {
629 pub rows: bool,
630 pub order: UniqueOrder,
631 pub occurrence: UniqueOccurrence,
632}
633
634#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
636pub struct UniqueResult {
637 pub values: HostTensorOwned,
638 pub ia: HostTensorOwned,
639 pub ic: HostTensorOwned,
640}
641
642#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
644pub enum UnionOrder {
645 Sorted,
646 Stable,
647}
648
649#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
651pub struct UnionOptions {
652 pub rows: bool,
653 pub order: UnionOrder,
654}
655
656#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
658pub struct UnionResult {
659 pub values: HostTensorOwned,
660 pub ia: HostTensorOwned,
661 pub ib: HostTensorOwned,
662}
663
664#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
666pub enum FspecialFilter {
667 Average {
668 rows: u32,
669 cols: u32,
670 },
671 Disk {
672 radius: f64,
673 size: u32,
674 },
675 Gaussian {
676 rows: u32,
677 cols: u32,
678 sigma: f64,
679 },
680 Laplacian {
681 alpha: f64,
682 },
683 Log {
684 rows: u32,
685 cols: u32,
686 sigma: f64,
687 },
688 Motion {
689 length: u32,
690 kernel_size: u32,
691 angle_degrees: f64,
692 oversample: u32,
693 },
694 Prewitt,
695 Sobel,
696 Unsharp {
697 alpha: f64,
698 },
699}
700
701#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
703pub struct FspecialRequest {
704 pub filter: FspecialFilter,
705}
706
707#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
709pub enum ImfilterPadding {
710 Constant,
711 Replicate,
712 Symmetric,
713 Circular,
714}
715
716#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
718pub enum ImfilterShape {
719 Same,
720 Full,
721 Valid,
722}
723
724#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
726pub enum ImfilterMode {
727 Correlation,
728 Convolution,
729}
730
731#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
733pub struct ImfilterOptions {
734 pub padding: ImfilterPadding,
735 pub constant_value: f64,
736 pub shape: ImfilterShape,
737 pub mode: ImfilterMode,
738}
739
740impl Default for ImfilterOptions {
741 fn default() -> Self {
742 Self {
743 padding: ImfilterPadding::Constant,
744 constant_value: 0.0,
745 shape: ImfilterShape::Same,
746 mode: ImfilterMode::Correlation,
747 }
748 }
749}
750
751#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
753pub enum SetdiffOrder {
754 Sorted,
755 Stable,
756}
757
758#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
760pub struct SetdiffOptions {
761 pub rows: bool,
762 pub order: SetdiffOrder,
763}
764
765#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
767pub struct SetdiffResult {
768 pub values: HostTensorOwned,
769 pub ia: HostTensorOwned,
770}
771
772#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
774pub struct IsMemberOptions {
775 pub rows: bool,
776}
777
778#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
780pub struct HostLogicalOwned {
781 pub data: Vec<u8>,
782 pub shape: Vec<usize>,
783}
784
785#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
787pub struct IsMemberResult {
788 pub mask: HostLogicalOwned,
789 pub loc: HostTensorOwned,
790}
791
792#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
793pub enum ProviderConvMode {
794 Full,
795 Same,
796 Valid,
797}
798
799#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
800pub enum ProviderConvOrientation {
801 Row,
802 Column,
803}
804
805#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
806pub struct ProviderConv1dOptions {
807 pub mode: ProviderConvMode,
808 pub orientation: ProviderConvOrientation,
809}
810
811#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
812pub struct ProviderIirFilterOptions {
813 pub dim: usize,
815 pub zi: Option<GpuTensorHandle>,
817}
818
819#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
820pub struct ProviderIirFilterResult {
821 pub output: GpuTensorHandle,
823 pub final_state: Option<GpuTensorHandle>,
825}
826
827#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
828pub struct ProviderMoments2 {
829 pub mean: GpuTensorHandle,
830 pub ex2: GpuTensorHandle,
831}
832
833#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
834pub struct ProviderDispatchStats {
835 pub count: u64,
837 pub total_wall_time_ns: u64,
839}
840
841#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
842pub struct ProviderTelemetry {
843 pub fused_elementwise: ProviderDispatchStats,
844 pub fused_reduction: ProviderDispatchStats,
845 pub matmul: ProviderDispatchStats,
846 pub upload_bytes: u64,
847 pub download_bytes: u64,
848 pub fusion_cache_hits: u64,
849 pub fusion_cache_misses: u64,
850 pub bind_group_cache_hits: u64,
851 pub bind_group_cache_misses: u64,
852 pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
854 pub kernel_launches: Vec<KernelLaunchTelemetry>,
856}
857
858#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
859pub struct BindGroupLayoutTelemetry {
860 pub tag: String,
861 pub hits: u64,
862 pub misses: u64,
863}
864
865#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
866pub struct KernelAttrTelemetry {
867 pub key: String,
868 pub value: u64,
869}
870
871#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
872pub struct KernelLaunchTelemetry {
873 pub kernel: String,
874 pub precision: Option<String>,
875 pub shape: Vec<KernelAttrTelemetry>,
876 pub tuning: Vec<KernelAttrTelemetry>,
877}
878
879pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
880pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
881
882fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
883 Box::pin(async move { Err(anyhow::anyhow!(message)) })
884}
885
886pub trait AccelProvider: Send + Sync {
888 fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
889 fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
890 fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
891 fn device_info(&self) -> String;
892 fn device_id(&self) -> u32 {
893 0
894 }
895
896 fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
899 None
900 }
901
902 #[cfg(feature = "wgpu")]
904 fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
905 let _ = _handle;
906 None
907 }
908
909 fn gather_linear(
912 &self,
913 _source: &GpuTensorHandle,
914 _indices: &[u32],
915 _output_shape: &[usize],
916 ) -> anyhow::Result<GpuTensorHandle> {
917 Err(anyhow::anyhow!("gather_linear not supported by provider"))
918 }
919
920 fn scatter_linear(
924 &self,
925 _target: &GpuTensorHandle,
926 _indices: &[u32],
927 _values: &GpuTensorHandle,
928 ) -> anyhow::Result<()> {
929 Err(anyhow::anyhow!("scatter_linear not supported by provider"))
930 }
931
932 fn device_info_struct(&self) -> ApiDeviceInfo {
934 ApiDeviceInfo {
935 device_id: 0,
936 name: self.device_info(),
937 vendor: String::new(),
938 memory_bytes: None,
939 backend: None,
940 }
941 }
942
943 fn precision(&self) -> ProviderPrecision {
944 ProviderPrecision::F64
945 }
946
947 fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
949 Err(anyhow::anyhow!("read_scalar not supported by provider"))
950 }
951
952 fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
954 Err(anyhow::anyhow!("zeros not supported by provider"))
955 }
956
957 fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
959 Err(anyhow::anyhow!("ones not supported by provider"))
960 }
961
962 fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
964 self.zeros(&prototype.shape)
965 }
966
967 fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
969 if value == 0.0 {
970 return self.zeros(shape);
971 }
972 if let Ok(base) = self.zeros(shape) {
973 match self.scalar_add(&base, value) {
974 Ok(out) => {
975 let _ = self.free(&base);
976 return Ok(out);
977 }
978 Err(_) => {
979 let _ = self.free(&base);
980 }
981 }
982 }
983 let len: usize = shape.iter().copied().product();
984 let data = vec![value; len];
985 let view = HostTensorView { data: &data, shape };
986 self.upload(&view)
987 }
988
989 fn fill_like(
991 &self,
992 prototype: &GpuTensorHandle,
993 value: f64,
994 ) -> anyhow::Result<GpuTensorHandle> {
995 if value == 0.0 {
996 return self.zeros_like(prototype);
997 }
998 if let Ok(base) = self.zeros_like(prototype) {
999 match self.scalar_add(&base, value) {
1000 Ok(out) => {
1001 let _ = self.free(&base);
1002 return Ok(out);
1003 }
1004 Err(_) => {
1005 let _ = self.free(&base);
1006 }
1007 }
1008 }
1009 self.fill(&prototype.shape, value)
1010 }
1011
1012 fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1014 self.ones(&prototype.shape)
1015 }
1016
1017 fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1019 Err(anyhow::anyhow!("eye not supported by provider"))
1020 }
1021
1022 fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1024 self.eye(&prototype.shape)
1025 }
1026
1027 fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1029 Err(anyhow::anyhow!("meshgrid not supported by provider"))
1030 }
1031
1032 fn diag_from_vector(
1034 &self,
1035 _vector: &GpuTensorHandle,
1036 _offset: isize,
1037 ) -> anyhow::Result<GpuTensorHandle> {
1038 Err(anyhow::anyhow!(
1039 "diag_from_vector not supported by provider"
1040 ))
1041 }
1042
1043 fn diag_extract(
1045 &self,
1046 _matrix: &GpuTensorHandle,
1047 _offset: isize,
1048 ) -> anyhow::Result<GpuTensorHandle> {
1049 Err(anyhow::anyhow!("diag_extract not supported by provider"))
1050 }
1051
1052 fn tril<'a>(
1054 &'a self,
1055 _matrix: &'a GpuTensorHandle,
1056 _offset: isize,
1057 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1058 Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1059 }
1060
1061 fn triu<'a>(
1063 &'a self,
1064 _matrix: &'a GpuTensorHandle,
1065 _offset: isize,
1066 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1067 Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1068 }
1069
1070 fn polyval(
1072 &self,
1073 _coefficients: &GpuTensorHandle,
1074 _points: &GpuTensorHandle,
1075 _options: &ProviderPolyvalOptions,
1076 ) -> anyhow::Result<GpuTensorHandle> {
1077 Err(anyhow::anyhow!("polyval not supported by provider"))
1078 }
1079
1080 fn polyfit<'a>(
1082 &'a self,
1083 _x: &'a GpuTensorHandle,
1084 _y: &'a GpuTensorHandle,
1085 _degree: usize,
1086 _weights: Option<&'a GpuTensorHandle>,
1087 ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1088 Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1089 }
1090
1091 fn polyder_single<'a>(
1093 &'a self,
1094 _polynomial: &'a GpuTensorHandle,
1095 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1096 Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1097 }
1098
1099 fn polyder_product<'a>(
1101 &'a self,
1102 _p: &'a GpuTensorHandle,
1103 _q: &'a GpuTensorHandle,
1104 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1105 Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1106 }
1107
1108 fn polyder_quotient<'a>(
1110 &'a self,
1111 _u: &'a GpuTensorHandle,
1112 _v: &'a GpuTensorHandle,
1113 ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1114 Box::pin(async move {
1115 Err(anyhow::anyhow!(
1116 "polyder_quotient not supported by provider"
1117 ))
1118 })
1119 }
1120
1121 fn polyint(
1123 &self,
1124 _polynomial: &GpuTensorHandle,
1125 _constant: f64,
1126 ) -> anyhow::Result<GpuTensorHandle> {
1127 Err(anyhow::anyhow!("polyint not supported by provider"))
1128 }
1129
1130 fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1132 Err(anyhow::anyhow!("random_uniform not supported by provider"))
1133 }
1134
1135 fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1137 self.random_uniform(&prototype.shape)
1138 }
1139
1140 fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1142 Err(anyhow::anyhow!("random_normal not supported by provider"))
1143 }
1144
1145 fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1147 self.random_normal(&prototype.shape)
1148 }
1149
1150 fn stochastic_evolution(
1151 &self,
1152 _state: &GpuTensorHandle,
1153 _drift: f64,
1154 _scale: f64,
1155 _steps: u32,
1156 ) -> anyhow::Result<GpuTensorHandle> {
1157 Err(anyhow::anyhow!(
1158 "stochastic_evolution not supported by provider"
1159 ))
1160 }
1161
1162 fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1164 Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1165 }
1166
1167 fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1169 Err(anyhow::anyhow!("fspecial not supported by provider"))
1170 }
1171
1172 fn imfilter<'a>(
1174 &'a self,
1175 _image: &'a GpuTensorHandle,
1176 _kernel: &'a GpuTensorHandle,
1177 _options: &'a ImfilterOptions,
1178 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1179 unsupported_future("imfilter not supported by provider")
1180 }
1181
1182 fn random_integer_range(
1184 &self,
1185 _lower: i64,
1186 _upper: i64,
1187 _shape: &[usize],
1188 ) -> anyhow::Result<GpuTensorHandle> {
1189 Err(anyhow::anyhow!(
1190 "random_integer_range not supported by provider"
1191 ))
1192 }
1193
1194 fn random_integer_like(
1196 &self,
1197 prototype: &GpuTensorHandle,
1198 lower: i64,
1199 upper: i64,
1200 ) -> anyhow::Result<GpuTensorHandle> {
1201 self.random_integer_range(lower, upper, &prototype.shape)
1202 }
1203
1204 fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1206 Err(anyhow!("random_permutation not supported by provider"))
1207 }
1208
1209 fn random_permutation_like(
1211 &self,
1212 _prototype: &GpuTensorHandle,
1213 n: usize,
1214 k: usize,
1215 ) -> anyhow::Result<GpuTensorHandle> {
1216 self.random_permutation(n, k)
1217 }
1218
1219 fn covariance<'a>(
1221 &'a self,
1222 _matrix: &'a GpuTensorHandle,
1223 _second: Option<&'a GpuTensorHandle>,
1224 _weights: Option<&'a GpuTensorHandle>,
1225 _options: &'a CovarianceOptions,
1226 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1227 unsupported_future("covariance not supported by provider")
1228 }
1229
1230 fn corrcoef<'a>(
1232 &'a self,
1233 _matrix: &'a GpuTensorHandle,
1234 _options: &'a CorrcoefOptions,
1235 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1236 unsupported_future("corrcoef not supported by provider")
1237 }
1238
1239 fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1241 Err(anyhow::anyhow!("linspace not supported by provider"))
1242 }
1243 fn elem_add<'a>(
1244 &'a self,
1245 _a: &'a GpuTensorHandle,
1246 _b: &'a GpuTensorHandle,
1247 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1248 unsupported_future("elem_add not supported by provider")
1249 }
1250 fn elem_mul<'a>(
1251 &'a self,
1252 _a: &'a GpuTensorHandle,
1253 _b: &'a GpuTensorHandle,
1254 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1255 unsupported_future("elem_mul not supported by provider")
1256 }
1257 fn elem_max<'a>(
1258 &'a self,
1259 _a: &'a GpuTensorHandle,
1260 _b: &'a GpuTensorHandle,
1261 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1262 unsupported_future("elem_max not supported by provider")
1263 }
1264 fn elem_min<'a>(
1265 &'a self,
1266 _a: &'a GpuTensorHandle,
1267 _b: &'a GpuTensorHandle,
1268 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1269 unsupported_future("elem_min not supported by provider")
1270 }
1271 fn elem_sub<'a>(
1272 &'a self,
1273 _a: &'a GpuTensorHandle,
1274 _b: &'a GpuTensorHandle,
1275 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1276 unsupported_future("elem_sub not supported by provider")
1277 }
1278 fn elem_div<'a>(
1279 &'a self,
1280 _a: &'a GpuTensorHandle,
1281 _b: &'a GpuTensorHandle,
1282 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1283 unsupported_future("elem_div not supported by provider")
1284 }
1285 fn elem_pow<'a>(
1286 &'a self,
1287 _a: &'a GpuTensorHandle,
1288 _b: &'a GpuTensorHandle,
1289 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1290 unsupported_future("elem_pow not supported by provider")
1291 }
1292
1293 fn elem_hypot<'a>(
1294 &'a self,
1295 _a: &'a GpuTensorHandle,
1296 _b: &'a GpuTensorHandle,
1297 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1298 unsupported_future("elem_hypot not supported by provider")
1299 }
1300 fn elem_ge<'a>(
1301 &'a self,
1302 _a: &'a GpuTensorHandle,
1303 _b: &'a GpuTensorHandle,
1304 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1305 unsupported_future("elem_ge not supported by provider")
1306 }
1307 fn elem_le<'a>(
1308 &'a self,
1309 _a: &'a GpuTensorHandle,
1310 _b: &'a GpuTensorHandle,
1311 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1312 unsupported_future("elem_le not supported by provider")
1313 }
1314 fn elem_lt<'a>(
1315 &'a self,
1316 _a: &'a GpuTensorHandle,
1317 _b: &'a GpuTensorHandle,
1318 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1319 unsupported_future("elem_lt not supported by provider")
1320 }
1321 fn elem_gt<'a>(
1322 &'a self,
1323 _a: &'a GpuTensorHandle,
1324 _b: &'a GpuTensorHandle,
1325 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1326 unsupported_future("elem_gt not supported by provider")
1327 }
1328 fn elem_eq<'a>(
1329 &'a self,
1330 _a: &'a GpuTensorHandle,
1331 _b: &'a GpuTensorHandle,
1332 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1333 unsupported_future("elem_eq not supported by provider")
1334 }
1335 fn elem_ne<'a>(
1336 &'a self,
1337 _a: &'a GpuTensorHandle,
1338 _b: &'a GpuTensorHandle,
1339 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1340 unsupported_future("elem_ne not supported by provider")
1341 }
1342 fn logical_and(
1343 &self,
1344 _a: &GpuTensorHandle,
1345 _b: &GpuTensorHandle,
1346 ) -> anyhow::Result<GpuTensorHandle> {
1347 Err(anyhow::anyhow!("logical_and not supported by provider"))
1348 }
1349 fn logical_or(
1350 &self,
1351 _a: &GpuTensorHandle,
1352 _b: &GpuTensorHandle,
1353 ) -> anyhow::Result<GpuTensorHandle> {
1354 Err(anyhow::anyhow!("logical_or not supported by provider"))
1355 }
1356 fn logical_xor(
1357 &self,
1358 _a: &GpuTensorHandle,
1359 _b: &GpuTensorHandle,
1360 ) -> anyhow::Result<GpuTensorHandle> {
1361 Err(anyhow::anyhow!("logical_xor not supported by provider"))
1362 }
1363 fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1364 Err(anyhow::anyhow!("logical_not not supported by provider"))
1365 }
1366 fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1367 Ok(handle_is_logical(a))
1368 }
1369 fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1370 Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1371 }
1372 fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1373 Err(anyhow::anyhow!(
1374 "logical_isfinite not supported by provider"
1375 ))
1376 }
1377 fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1378 Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1379 }
1380 fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1381 Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1382 }
1383 fn elem_atan2<'a>(
1384 &'a self,
1385 _y: &'a GpuTensorHandle,
1386 _x: &'a GpuTensorHandle,
1387 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1388 unsupported_future("elem_atan2 not supported by provider")
1389 }
1390 fn unary_sin<'a>(
1392 &'a self,
1393 _a: &'a GpuTensorHandle,
1394 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1395 unsupported_future("unary_sin not supported by provider")
1396 }
1397 fn unary_gamma<'a>(
1398 &'a self,
1399 _a: &'a GpuTensorHandle,
1400 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1401 unsupported_future("unary_gamma not supported by provider")
1402 }
1403 fn unary_factorial<'a>(
1404 &'a self,
1405 _a: &'a GpuTensorHandle,
1406 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1407 unsupported_future("unary_factorial not supported by provider")
1408 }
1409 fn unary_asinh<'a>(
1410 &'a self,
1411 _a: &'a GpuTensorHandle,
1412 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1413 unsupported_future("unary_asinh not supported by provider")
1414 }
1415 fn unary_sinh<'a>(
1416 &'a self,
1417 _a: &'a GpuTensorHandle,
1418 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1419 unsupported_future("unary_sinh not supported by provider")
1420 }
1421 fn unary_cosh<'a>(
1422 &'a self,
1423 _a: &'a GpuTensorHandle,
1424 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1425 unsupported_future("unary_cosh not supported by provider")
1426 }
1427 fn unary_asin<'a>(
1428 &'a self,
1429 _a: &'a GpuTensorHandle,
1430 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1431 unsupported_future("unary_asin not supported by provider")
1432 }
1433 fn unary_acos<'a>(
1434 &'a self,
1435 _a: &'a GpuTensorHandle,
1436 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1437 unsupported_future("unary_acos not supported by provider")
1438 }
1439 fn unary_acosh<'a>(
1440 &'a self,
1441 _a: &'a GpuTensorHandle,
1442 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1443 unsupported_future("unary_acosh not supported by provider")
1444 }
1445 fn unary_tan<'a>(
1446 &'a self,
1447 _a: &'a GpuTensorHandle,
1448 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1449 unsupported_future("unary_tan not supported by provider")
1450 }
1451 fn unary_tanh<'a>(
1452 &'a self,
1453 _a: &'a GpuTensorHandle,
1454 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1455 unsupported_future("unary_tanh not supported by provider")
1456 }
1457 fn unary_atan<'a>(
1458 &'a self,
1459 _a: &'a GpuTensorHandle,
1460 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1461 unsupported_future("unary_atan not supported by provider")
1462 }
1463 fn unary_atanh<'a>(
1464 &'a self,
1465 _a: &'a GpuTensorHandle,
1466 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1467 unsupported_future("unary_atanh not supported by provider")
1468 }
1469 fn unary_ceil<'a>(
1470 &'a self,
1471 _a: &'a GpuTensorHandle,
1472 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1473 unsupported_future("unary_ceil not supported by provider")
1474 }
1475 fn unary_floor<'a>(
1476 &'a self,
1477 _a: &'a GpuTensorHandle,
1478 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1479 unsupported_future("unary_floor not supported by provider")
1480 }
1481 fn unary_round<'a>(
1482 &'a self,
1483 _a: &'a GpuTensorHandle,
1484 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1485 unsupported_future("unary_round not supported by provider")
1486 }
1487 fn unary_fix<'a>(
1488 &'a self,
1489 _a: &'a GpuTensorHandle,
1490 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1491 unsupported_future("unary_fix not supported by provider")
1492 }
1493 fn unary_cos<'a>(
1494 &'a self,
1495 _a: &'a GpuTensorHandle,
1496 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1497 unsupported_future("unary_cos not supported by provider")
1498 }
1499 fn unary_angle<'a>(
1500 &'a self,
1501 _a: &'a GpuTensorHandle,
1502 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1503 unsupported_future("unary_angle not supported by provider")
1504 }
1505 fn unary_imag<'a>(
1506 &'a self,
1507 _a: &'a GpuTensorHandle,
1508 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1509 unsupported_future("unary_imag not supported by provider")
1510 }
1511 fn unary_real<'a>(
1512 &'a self,
1513 _a: &'a GpuTensorHandle,
1514 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1515 unsupported_future("unary_real not supported by provider")
1516 }
1517 fn unary_conj<'a>(
1518 &'a self,
1519 _a: &'a GpuTensorHandle,
1520 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1521 unsupported_future("unary_conj not supported by provider")
1522 }
1523 fn unary_abs<'a>(
1524 &'a self,
1525 _a: &'a GpuTensorHandle,
1526 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1527 unsupported_future("unary_abs not supported by provider")
1528 }
1529 fn unary_sign<'a>(
1530 &'a self,
1531 _a: &'a GpuTensorHandle,
1532 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1533 unsupported_future("unary_sign not supported by provider")
1534 }
1535 fn unary_exp<'a>(
1536 &'a self,
1537 _a: &'a GpuTensorHandle,
1538 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1539 unsupported_future("unary_exp not supported by provider")
1540 }
1541 fn unary_expm1<'a>(
1542 &'a self,
1543 _a: &'a GpuTensorHandle,
1544 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1545 unsupported_future("unary_expm1 not supported by provider")
1546 }
1547 fn unary_log<'a>(
1548 &'a self,
1549 _a: &'a GpuTensorHandle,
1550 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1551 unsupported_future("unary_log not supported by provider")
1552 }
1553 fn unary_log2<'a>(
1554 &'a self,
1555 _a: &'a GpuTensorHandle,
1556 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1557 unsupported_future("unary_log2 not supported by provider")
1558 }
1559 fn unary_log10<'a>(
1560 &'a self,
1561 _a: &'a GpuTensorHandle,
1562 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1563 unsupported_future("unary_log10 not supported by provider")
1564 }
1565 fn unary_log1p<'a>(
1566 &'a self,
1567 _a: &'a GpuTensorHandle,
1568 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1569 unsupported_future("unary_log1p not supported by provider")
1570 }
1571 fn unary_sqrt<'a>(
1572 &'a self,
1573 _a: &'a GpuTensorHandle,
1574 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1575 unsupported_future("unary_sqrt not supported by provider")
1576 }
1577 fn unary_double<'a>(
1578 &'a self,
1579 _a: &'a GpuTensorHandle,
1580 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1581 unsupported_future("unary_double not supported by provider")
1582 }
1583 fn unary_single<'a>(
1584 &'a self,
1585 _a: &'a GpuTensorHandle,
1586 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1587 unsupported_future("unary_single not supported by provider")
1588 }
1589 fn unary_pow2<'a>(
1590 &'a self,
1591 _a: &'a GpuTensorHandle,
1592 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1593 unsupported_future("unary_pow2 not supported by provider")
1594 }
1595 fn pow2_scale(
1596 &self,
1597 _mantissa: &GpuTensorHandle,
1598 _exponent: &GpuTensorHandle,
1599 ) -> anyhow::Result<GpuTensorHandle> {
1600 Err(anyhow::anyhow!("pow2_scale not supported by provider"))
1601 }
1602 fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1604 Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
1605 }
1606 fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1607 Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
1608 }
1609 fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1611 Err(anyhow::anyhow!("scalar_add not supported by provider"))
1612 }
1613 fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1614 Err(anyhow::anyhow!("scalar_sub not supported by provider"))
1615 }
1616 fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1617 Err(anyhow::anyhow!("scalar_mul not supported by provider"))
1618 }
1619 fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1620 Err(anyhow::anyhow!("scalar_max not supported by provider"))
1621 }
1622 fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1623 Err(anyhow::anyhow!("scalar_min not supported by provider"))
1624 }
1625 fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1626 Err(anyhow::anyhow!("scalar_div not supported by provider"))
1627 }
1628 fn sort_dim<'a>(
1629 &'a self,
1630 _a: &'a GpuTensorHandle,
1631 _dim: usize,
1632 _order: SortOrder,
1633 _comparison: SortComparison,
1634 ) -> AccelProviderFuture<'a, SortResult> {
1635 unsupported_future("sort_dim not supported by provider")
1636 }
1637 fn sort_rows<'a>(
1638 &'a self,
1639 _a: &'a GpuTensorHandle,
1640 _columns: &'a [SortRowsColumnSpec],
1641 _comparison: SortComparison,
1642 ) -> AccelProviderFuture<'a, SortResult> {
1643 unsupported_future("sort_rows not supported by provider")
1644 }
1645 fn matmul<'a>(
1646 &'a self,
1647 _a: &'a GpuTensorHandle,
1648 _b: &'a GpuTensorHandle,
1649 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1650 unsupported_future("matmul not supported by provider")
1651 }
1652
1653 fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1654 Err(anyhow::anyhow!("syrk not supported by provider"))
1655 }
1656 fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
1657 Err(anyhow::anyhow!("pagefun not supported by provider"))
1658 }
1659
1660 fn matmul_epilogue<'a>(
1665 &'a self,
1666 a: &'a GpuTensorHandle,
1667 b: &'a GpuTensorHandle,
1668 epilogue: &'a MatmulEpilogue,
1669 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1670 Box::pin(async move {
1671 if epilogue.is_noop() {
1672 return self.matmul(a, b).await;
1673 }
1674 Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
1675 })
1676 }
1677 fn image_normalize<'a>(
1678 &'a self,
1679 _input: &'a GpuTensorHandle,
1680 _desc: &'a ImageNormalizeDescriptor,
1681 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1682 unsupported_future("image_normalize fusion not supported by provider")
1683 }
1684 fn matmul_power_step<'a>(
1685 &'a self,
1686 _lhs: &'a GpuTensorHandle,
1687 _rhs: &'a GpuTensorHandle,
1688 _epilogue: &'a PowerStepEpilogue,
1689 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1690 unsupported_future("matmul_power_step normalization not supported by provider")
1691 }
1692 fn linsolve<'a>(
1693 &'a self,
1694 _lhs: &'a GpuTensorHandle,
1695 _rhs: &'a GpuTensorHandle,
1696 _options: &'a ProviderLinsolveOptions,
1697 ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
1698 unsupported_future("linsolve not supported by provider")
1699 }
1700 fn inv<'a>(
1701 &'a self,
1702 _matrix: &'a GpuTensorHandle,
1703 _options: ProviderInvOptions,
1704 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1705 unsupported_future("inv not supported by provider")
1706 }
1707 fn pinv<'a>(
1708 &'a self,
1709 _matrix: &'a GpuTensorHandle,
1710 _options: ProviderPinvOptions,
1711 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1712 unsupported_future("pinv not supported by provider")
1713 }
1714 fn cond<'a>(
1715 &'a self,
1716 _matrix: &'a GpuTensorHandle,
1717 _norm: ProviderCondNorm,
1718 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1719 Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
1720 }
1721 fn norm<'a>(
1722 &'a self,
1723 _tensor: &'a GpuTensorHandle,
1724 _order: ProviderNormOrder,
1725 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1726 Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
1727 }
1728 fn rank<'a>(
1729 &'a self,
1730 _matrix: &'a GpuTensorHandle,
1731 _tolerance: Option<f64>,
1732 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1733 Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
1734 }
1735 fn rcond<'a>(
1736 &'a self,
1737 _matrix: &'a GpuTensorHandle,
1738 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1739 Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
1740 }
1741 fn mldivide<'a>(
1742 &'a self,
1743 _lhs: &'a GpuTensorHandle,
1744 _rhs: &'a GpuTensorHandle,
1745 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1746 Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
1747 }
1748 fn mrdivide<'a>(
1749 &'a self,
1750 _lhs: &'a GpuTensorHandle,
1751 _rhs: &'a GpuTensorHandle,
1752 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1753 Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
1754 }
1755 fn eig<'a>(
1756 &'a self,
1757 _a: &'a GpuTensorHandle,
1758 _compute_left: bool,
1759 ) -> AccelProviderFuture<'a, ProviderEigResult> {
1760 Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
1761 }
1762 fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
1763 Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
1764 }
1765
1766 fn chol<'a>(
1767 &'a self,
1768 _a: &'a GpuTensorHandle,
1769 _lower: bool,
1770 ) -> AccelProviderFuture<'a, ProviderCholResult> {
1771 Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
1772 }
1773 fn qr<'a>(
1774 &'a self,
1775 _a: &'a GpuTensorHandle,
1776 _options: ProviderQrOptions,
1777 ) -> AccelProviderFuture<'a, ProviderQrResult> {
1778 Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
1779 }
1780 fn take_matmul_sources(
1781 &self,
1782 _product: &GpuTensorHandle,
1783 ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
1784 None
1785 }
1786 fn qr_power_iter<'a>(
1787 &'a self,
1788 product: &'a GpuTensorHandle,
1789 _product_lhs: Option<&'a GpuTensorHandle>,
1790 q_handle: &'a GpuTensorHandle,
1791 options: &'a ProviderQrOptions,
1792 ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
1793 let _ = (product, q_handle, options);
1794 Box::pin(async move { Ok(None) })
1795 }
1796 fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1797 Err(anyhow::anyhow!("transpose not supported by provider"))
1798 }
1799 fn conv1d(
1800 &self,
1801 _signal: &GpuTensorHandle,
1802 _kernel: &GpuTensorHandle,
1803 _options: ProviderConv1dOptions,
1804 ) -> anyhow::Result<GpuTensorHandle> {
1805 Err(anyhow::anyhow!("conv1d not supported by provider"))
1806 }
1807 fn conv2d(
1808 &self,
1809 _signal: &GpuTensorHandle,
1810 _kernel: &GpuTensorHandle,
1811 _mode: ProviderConvMode,
1812 ) -> anyhow::Result<GpuTensorHandle> {
1813 Err(anyhow::anyhow!("conv2d not supported by provider"))
1814 }
1815 fn iir_filter<'a>(
1816 &'a self,
1817 _b: &'a GpuTensorHandle,
1818 _a: &'a GpuTensorHandle,
1819 _x: &'a GpuTensorHandle,
1820 _options: ProviderIirFilterOptions,
1821 ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
1822 Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
1823 }
1824 fn permute(
1826 &self,
1827 _handle: &GpuTensorHandle,
1828 _order: &[usize],
1829 ) -> anyhow::Result<GpuTensorHandle> {
1830 Err(anyhow::anyhow!("permute not supported by provider"))
1831 }
1832 fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1833 Err(anyhow::anyhow!("flip not supported by provider"))
1834 }
1835 fn circshift(
1836 &self,
1837 _handle: &GpuTensorHandle,
1838 _shifts: &[isize],
1839 ) -> anyhow::Result<GpuTensorHandle> {
1840 Err(anyhow::anyhow!("circshift not supported by provider"))
1841 }
1842 fn diff_dim(
1843 &self,
1844 _handle: &GpuTensorHandle,
1845 _order: usize,
1846 _dim: usize,
1847 ) -> anyhow::Result<GpuTensorHandle> {
1848 Err(anyhow::anyhow!("diff_dim not supported by provider"))
1849 }
1850 fn fft_dim<'a>(
1852 &'a self,
1853 _handle: &'a GpuTensorHandle,
1854 _len: Option<usize>,
1855 _dim: usize,
1856 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1857 unsupported_future("fft_dim not supported by provider")
1858 }
1859 fn ifft_dim<'a>(
1860 &'a self,
1861 _handle: &'a GpuTensorHandle,
1862 _len: Option<usize>,
1863 _dim: usize,
1864 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1865 unsupported_future("ifft_dim not supported by provider")
1866 }
1867 fn unique<'a>(
1868 &'a self,
1869 _handle: &'a GpuTensorHandle,
1870 _options: &'a UniqueOptions,
1871 ) -> AccelProviderFuture<'a, UniqueResult> {
1872 Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
1873 }
1874 fn union<'a>(
1875 &'a self,
1876 _a: &'a GpuTensorHandle,
1877 _b: &'a GpuTensorHandle,
1878 _options: &'a UnionOptions,
1879 ) -> AccelProviderFuture<'a, UnionResult> {
1880 Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
1881 }
1882 fn setdiff<'a>(
1883 &'a self,
1884 _a: &'a GpuTensorHandle,
1885 _b: &'a GpuTensorHandle,
1886 _options: &'a SetdiffOptions,
1887 ) -> AccelProviderFuture<'a, SetdiffResult> {
1888 Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
1889 }
1890 fn ismember<'a>(
1891 &'a self,
1892 _a: &'a GpuTensorHandle,
1893 _b: &'a GpuTensorHandle,
1894 _options: &'a IsMemberOptions,
1895 ) -> AccelProviderFuture<'a, IsMemberResult> {
1896 Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
1897 }
1898 fn reshape(
1899 &self,
1900 handle: &GpuTensorHandle,
1901 new_shape: &[usize],
1902 ) -> anyhow::Result<GpuTensorHandle> {
1903 let mut updated = handle.clone();
1904 updated.shape = new_shape.to_vec();
1905 Ok(updated)
1906 }
1907 fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
1909 Err(anyhow::anyhow!("cat not supported by provider"))
1910 }
1911 fn repmat(
1912 &self,
1913 _handle: &GpuTensorHandle,
1914 _reps: &[usize],
1915 ) -> anyhow::Result<GpuTensorHandle> {
1916 Err(anyhow::anyhow!("repmat not supported by provider"))
1917 }
1918 fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1920 Err(anyhow::anyhow!("kron not supported by provider"))
1921 }
1922 fn reduce_sum<'a>(
1923 &'a self,
1924 _a: &'a GpuTensorHandle,
1925 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1926 unsupported_future("reduce_sum not supported by provider")
1927 }
1928 fn reduce_sum_dim<'a>(
1929 &'a self,
1930 _a: &'a GpuTensorHandle,
1931 _dim: usize,
1932 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1933 unsupported_future("reduce_sum_dim not supported by provider")
1934 }
1935 fn dot<'a>(
1936 &'a self,
1937 _lhs: &'a GpuTensorHandle,
1938 _rhs: &'a GpuTensorHandle,
1939 _dim: Option<usize>,
1940 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1941 unsupported_future("dot not supported by provider")
1942 }
1943 fn reduce_nnz<'a>(
1944 &'a self,
1945 _a: &'a GpuTensorHandle,
1946 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1947 unsupported_future("reduce_nnz not supported by provider")
1948 }
1949 fn reduce_nnz_dim<'a>(
1950 &'a self,
1951 _a: &'a GpuTensorHandle,
1952 _dim: usize,
1953 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1954 unsupported_future("reduce_nnz_dim not supported by provider")
1955 }
1956 fn reduce_prod<'a>(
1957 &'a self,
1958 _a: &'a GpuTensorHandle,
1959 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1960 unsupported_future("reduce_prod not supported by provider")
1961 }
1962 fn reduce_prod_dim<'a>(
1963 &'a self,
1964 _a: &'a GpuTensorHandle,
1965 _dim: usize,
1966 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1967 unsupported_future("reduce_prod_dim not supported by provider")
1968 }
1969 fn reduce_mean<'a>(
1970 &'a self,
1971 _a: &'a GpuTensorHandle,
1972 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1973 unsupported_future("reduce_mean not supported by provider")
1974 }
1975 fn reduce_mean_nd<'a>(
1977 &'a self,
1978 _a: &'a GpuTensorHandle,
1979 _dims_zero_based: &'a [usize],
1980 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1981 unsupported_future("reduce_mean_nd not supported by provider")
1982 }
1983 fn reduce_moments_nd<'a>(
1986 &'a self,
1987 _a: &'a GpuTensorHandle,
1988 _dims_zero_based: &'a [usize],
1989 ) -> AccelProviderFuture<'a, ProviderMoments2> {
1990 unsupported_future("reduce_moments_nd not supported by provider")
1991 }
1992 fn reduce_mean_dim<'a>(
1993 &'a self,
1994 _a: &'a GpuTensorHandle,
1995 _dim: usize,
1996 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1997 unsupported_future("reduce_mean_dim not supported by provider")
1998 }
1999 fn reduce_std<'a>(
2000 &'a self,
2001 _a: &'a GpuTensorHandle,
2002 _normalization: ProviderStdNormalization,
2003 _nan_mode: ProviderNanMode,
2004 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2005 unsupported_future("reduce_std not supported by provider")
2006 }
2007 fn reduce_std_dim<'a>(
2008 &'a self,
2009 _a: &'a GpuTensorHandle,
2010 _dim: usize,
2011 _normalization: ProviderStdNormalization,
2012 _nan_mode: ProviderNanMode,
2013 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2014 unsupported_future("reduce_std_dim not supported by provider")
2015 }
2016 fn reduce_any<'a>(
2017 &'a self,
2018 _a: &'a GpuTensorHandle,
2019 _omit_nan: bool,
2020 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2021 unsupported_future("reduce_any not supported by provider")
2022 }
2023 fn reduce_any_dim<'a>(
2024 &'a self,
2025 _a: &'a GpuTensorHandle,
2026 _dim: usize,
2027 _omit_nan: bool,
2028 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2029 unsupported_future("reduce_any_dim not supported by provider")
2030 }
2031 fn reduce_all<'a>(
2032 &'a self,
2033 _a: &'a GpuTensorHandle,
2034 _omit_nan: bool,
2035 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2036 unsupported_future("reduce_all not supported by provider")
2037 }
2038 fn reduce_all_dim<'a>(
2039 &'a self,
2040 _a: &'a GpuTensorHandle,
2041 _dim: usize,
2042 _omit_nan: bool,
2043 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2044 unsupported_future("reduce_all_dim not supported by provider")
2045 }
2046 fn reduce_median<'a>(
2047 &'a self,
2048 _a: &'a GpuTensorHandle,
2049 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2050 unsupported_future("reduce_median not supported by provider")
2051 }
2052 fn reduce_median_dim<'a>(
2053 &'a self,
2054 _a: &'a GpuTensorHandle,
2055 _dim: usize,
2056 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2057 unsupported_future("reduce_median_dim not supported by provider")
2058 }
2059 fn reduce_min<'a>(
2060 &'a self,
2061 _a: &'a GpuTensorHandle,
2062 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2063 unsupported_future("reduce_min not supported by provider")
2064 }
2065 fn reduce_min_dim<'a>(
2066 &'a self,
2067 _a: &'a GpuTensorHandle,
2068 _dim: usize,
2069 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2070 unsupported_future("reduce_min_dim not supported by provider")
2071 }
2072 fn reduce_max<'a>(
2073 &'a self,
2074 _a: &'a GpuTensorHandle,
2075 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2076 unsupported_future("reduce_max not supported by provider")
2077 }
2078 fn reduce_max_dim<'a>(
2079 &'a self,
2080 _a: &'a GpuTensorHandle,
2081 _dim: usize,
2082 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2083 unsupported_future("reduce_max_dim not supported by provider")
2084 }
2085 fn cumsum_scan(
2086 &self,
2087 _input: &GpuTensorHandle,
2088 _dim: usize,
2089 _direction: ProviderScanDirection,
2090 _nan_mode: ProviderNanMode,
2091 ) -> anyhow::Result<GpuTensorHandle> {
2092 Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2093 }
2094 fn cumprod_scan(
2095 &self,
2096 _input: &GpuTensorHandle,
2097 _dim: usize,
2098 _direction: ProviderScanDirection,
2099 _nan_mode: ProviderNanMode,
2100 ) -> anyhow::Result<GpuTensorHandle> {
2101 Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2102 }
2103 fn cummin_scan(
2104 &self,
2105 _input: &GpuTensorHandle,
2106 _dim: usize,
2107 _direction: ProviderScanDirection,
2108 _nan_mode: ProviderNanMode,
2109 ) -> anyhow::Result<ProviderCumminResult> {
2110 Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2111 }
2112 fn cummax_scan(
2113 &self,
2114 _input: &GpuTensorHandle,
2115 _dim: usize,
2116 _direction: ProviderScanDirection,
2117 _nan_mode: ProviderNanMode,
2118 ) -> anyhow::Result<ProviderCummaxResult> {
2119 Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2120 }
2121
2122 fn find(
2123 &self,
2124 _a: &GpuTensorHandle,
2125 _limit: Option<usize>,
2126 _direction: FindDirection,
2127 ) -> anyhow::Result<ProviderFindResult> {
2128 Err(anyhow::anyhow!("find not supported by provider"))
2129 }
2130
2131 fn fused_elementwise(
2132 &self,
2133 _shader: &str,
2134 _inputs: &[GpuTensorHandle],
2135 _output_shape: &[usize],
2136 _len: usize,
2137 ) -> anyhow::Result<GpuTensorHandle> {
2138 Err(anyhow::anyhow!(
2139 "fused_elementwise not supported by provider"
2140 ))
2141 }
2142
2143 fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2145 Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2146 }
2147
2148 fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2150 Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2151 }
2152
2153 #[allow(clippy::too_many_arguments)]
2160 fn fused_reduction(
2161 &self,
2162 _shader: &str,
2163 _inputs: &[GpuTensorHandle],
2164 _output_shape: &[usize],
2165 _reduce_len: usize,
2166 _num_slices: usize,
2167 _workgroup_size: u32,
2168 _flavor: ReductionFlavor,
2169 ) -> anyhow::Result<GpuTensorHandle> {
2170 Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2171 }
2172
2173 fn warmup(&self) {}
2175
2176 fn fused_cache_counters(&self) -> (u64, u64) {
2178 (0, 0)
2179 }
2180
2181 fn last_warmup_millis(&self) -> Option<u64> {
2183 None
2184 }
2185
2186 fn telemetry_snapshot(&self) -> ProviderTelemetry {
2188 let (hits, misses) = self.fused_cache_counters();
2189 ProviderTelemetry {
2190 fused_elementwise: ProviderDispatchStats::default(),
2191 fused_reduction: ProviderDispatchStats::default(),
2192 matmul: ProviderDispatchStats::default(),
2193 upload_bytes: 0,
2194 download_bytes: 0,
2195 fusion_cache_hits: hits,
2196 fusion_cache_misses: misses,
2197 bind_group_cache_hits: 0,
2198 bind_group_cache_misses: 0,
2199 bind_group_cache_by_layout: None,
2200 kernel_launches: Vec::new(),
2201 }
2202 }
2203
2204 fn reset_telemetry(&self) {}
2206
2207 fn default_reduction_workgroup_size(&self) -> u32 {
2209 256
2210 }
2211
2212 fn two_pass_threshold(&self) -> usize {
2214 1024
2215 }
2216
2217 fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2219 ReductionTwoPassMode::Auto
2220 }
2221
2222 fn scatter_column(
2225 &self,
2226 _matrix: &GpuTensorHandle,
2227 _col_index: usize,
2228 _values: &GpuTensorHandle,
2229 ) -> anyhow::Result<GpuTensorHandle> {
2230 Err(anyhow::anyhow!("scatter_column not supported by provider"))
2231 }
2232
2233 fn scatter_row(
2236 &self,
2237 _matrix: &GpuTensorHandle,
2238 _row_index: usize,
2239 _values: &GpuTensorHandle,
2240 ) -> anyhow::Result<GpuTensorHandle> {
2241 Err(anyhow::anyhow!("scatter_row not supported by provider"))
2242 }
2243
2244 fn sub2ind(
2245 &self,
2246 _dims: &[usize],
2247 _strides: &[usize],
2248 _inputs: &[&GpuTensorHandle],
2249 _scalar_mask: &[bool],
2250 _len: usize,
2251 _output_shape: &[usize],
2252 ) -> anyhow::Result<GpuTensorHandle> {
2253 Err(anyhow::anyhow!("sub2ind not supported by provider"))
2254 }
2255
2256 fn supports_ind2sub(&self) -> bool {
2258 false
2259 }
2260
2261 fn ind2sub(
2263 &self,
2264 _dims: &[usize],
2265 _strides: &[usize],
2266 _indices: &GpuTensorHandle,
2267 _total: usize,
2268 _len: usize,
2269 _output_shape: &[usize],
2270 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2271 Err(anyhow::anyhow!("ind2sub not supported by provider"))
2272 }
2273
2274 fn issymmetric(
2276 &self,
2277 _matrix: &GpuTensorHandle,
2278 _kind: ProviderSymmetryKind,
2279 _tolerance: f64,
2280 ) -> anyhow::Result<bool> {
2281 Err(anyhow::anyhow!(
2282 "issymmetric predicate not supported by provider"
2283 ))
2284 }
2285
2286 fn ishermitian<'a>(
2288 &'a self,
2289 _matrix: &'a GpuTensorHandle,
2290 _kind: ProviderHermitianKind,
2291 _tolerance: f64,
2292 ) -> AccelProviderFuture<'a, bool> {
2293 Box::pin(async move {
2294 Err(anyhow::anyhow!(
2295 "ishermitian predicate not supported by provider"
2296 ))
2297 })
2298 }
2299
2300 fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2302 Err(anyhow::anyhow!("bandwidth not supported by provider"))
2303 }
2304
2305 fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2310 Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2311 }
2312}
2313
2314static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2315 Lazy::new(|| RwLock::new(None));
2316static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2317 Lazy::new(|| RwLock::new(HashMap::new()));
2318static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2319
2320#[cfg(not(target_arch = "wasm32"))]
2321thread_local! {
2322 static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2323}
2324
2325#[cfg(target_arch = "wasm32")]
2326static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2327 Lazy::new(|| Mutex::new(None));
2328
2329#[cfg(not(target_arch = "wasm32"))]
2330fn replace_thread_provider(
2331 provider: Option<&'static dyn AccelProvider>,
2332) -> Option<&'static dyn AccelProvider> {
2333 THREAD_PROVIDER.with(|cell| {
2334 let prev = cell.get();
2335 cell.set(provider);
2336 prev
2337 })
2338}
2339
2340#[cfg(target_arch = "wasm32")]
2341fn replace_thread_provider(
2342 provider: Option<&'static dyn AccelProvider>,
2343) -> Option<&'static dyn AccelProvider> {
2344 let mut slot = WASM_THREAD_PROVIDER
2345 .lock()
2346 .expect("wasm provider mutex poisoned");
2347 let prev = *slot;
2348 *slot = provider;
2349 prev
2350}
2351
2352#[cfg(not(target_arch = "wasm32"))]
2353fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2354 THREAD_PROVIDER.with(|cell| cell.get())
2355}
2356
2357#[cfg(target_arch = "wasm32")]
2358fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2359 WASM_THREAD_PROVIDER
2360 .lock()
2361 .expect("wasm provider mutex poisoned")
2362 .as_ref()
2363 .copied()
2364}
2365
2366pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2374 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2375 *guard = Some(p);
2376 }
2377 register_provider_for_device(p.device_id(), p);
2378}
2379
2380unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2381 if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2382 guard.insert(device_id, provider);
2383 }
2384}
2385
2386pub fn provider() -> Option<&'static dyn AccelProvider> {
2387 if let Some(p) = current_thread_provider() {
2388 return Some(p);
2389 }
2390 GLOBAL_PROVIDER
2391 .read()
2392 .ok()
2393 .and_then(|guard| guard.as_ref().copied())
2394}
2395
2396pub fn clear_provider() {
2398 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2399 *guard = None;
2400 }
2401 if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2402 map.clear();
2403 }
2404}
2405
2406pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2407 PROVIDER_REGISTRY
2408 .read()
2409 .ok()
2410 .and_then(|guard| guard.get(&device_id).copied())
2411 .or_else(|| provider())
2412}
2413
2414pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2415 provider_for_device(handle.device_id)
2416}
2417
2418pub fn next_device_id() -> u32 {
2419 DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2420}
2421
2422pub struct ThreadProviderGuard {
2423 prev: Option<&'static dyn AccelProvider>,
2424}
2425
2426impl ThreadProviderGuard {
2427 pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2428 let prev = replace_thread_provider(provider);
2429 ThreadProviderGuard { prev }
2430 }
2431}
2432
2433impl Drop for ThreadProviderGuard {
2434 fn drop(&mut self) {
2435 let prev = self.prev.take();
2436 replace_thread_provider(prev);
2437 }
2438}
2439
2440pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2441 replace_thread_provider(provider);
2442}
2443
2444pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2446 if let Some(p) = provider() {
2447 if let Ok(h) = p.elem_add(a, b).await {
2448 return Some(h);
2449 }
2450 }
2451 None
2452}
2453
2454pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2456 if let Some(p) = provider() {
2457 if let Ok(h) = p.elem_hypot(a, b).await {
2458 return Some(h);
2459 }
2460 }
2461 None
2462}
2463
2464pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2466 if let Some(p) = provider() {
2467 if let Ok(h) = p.elem_max(a, b).await {
2468 return Some(h);
2469 }
2470 }
2471 None
2472}
2473
2474pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2476 if let Some(p) = provider() {
2477 if let Ok(h) = p.elem_min(a, b).await {
2478 return Some(h);
2479 }
2480 }
2481 None
2482}
2483
2484pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2486 if let Some(p) = provider() {
2487 if let Ok(h) = p.elem_atan2(y, x).await {
2488 return Some(h);
2489 }
2490 }
2491 None
2492}
2493
2494#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2496pub struct HostTensorOwned {
2497 pub data: Vec<f64>,
2498 pub shape: Vec<usize>,
2499}
2500
2501#[derive(Debug)]
2502pub struct HostTensorView<'a> {
2503 pub data: &'a [f64],
2504 pub shape: &'a [usize],
2505}
2506
2507#[derive(Debug)]
2509pub struct MeshgridAxisView<'a> {
2510 pub data: &'a [f64],
2511}
2512
2513#[derive(Debug, Clone)]
2515pub struct ProviderMeshgridResult {
2516 pub outputs: Vec<GpuTensorHandle>,
2517}
2518
2519#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
2525pub enum ScaleOp {
2526 Multiply,
2527 Divide,
2528}
2529
2530#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2531pub struct MatmulEpilogue {
2532 pub alpha: f64,
2534 pub beta: f64,
2536 pub row_scale: Option<GpuTensorHandle>,
2538 pub col_scale: Option<GpuTensorHandle>,
2540 pub row_op: ScaleOp,
2542 pub col_op: ScaleOp,
2544 #[serde(default)]
2546 pub clamp_min: Option<f64>,
2547 #[serde(default)]
2549 pub clamp_max: Option<f64>,
2550 #[serde(default)]
2552 pub pow_exponent: Option<f64>,
2553 #[serde(default)]
2555 pub diag_output: Option<GpuTensorHandle>,
2556}
2557
2558impl MatmulEpilogue {
2559 pub fn noop() -> Self {
2560 Self {
2561 alpha: 1.0,
2562 beta: 0.0,
2563 row_scale: None,
2564 col_scale: None,
2565 row_op: ScaleOp::Multiply,
2566 col_op: ScaleOp::Multiply,
2567 clamp_min: None,
2568 clamp_max: None,
2569 pow_exponent: None,
2570 diag_output: None,
2571 }
2572 }
2573 pub fn is_noop(&self) -> bool {
2574 self.alpha == 1.0
2575 && self.beta == 0.0
2576 && self.row_scale.is_none()
2577 && self.col_scale.is_none()
2578 && self.clamp_min.is_none()
2579 && self.clamp_max.is_none()
2580 && self.pow_exponent.is_none()
2581 && self.diag_output.is_none()
2582 }
2583}
2584
2585#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
2586pub struct PowerStepEpilogue {
2587 pub epsilon: f64,
2588}
2589
2590impl Default for PowerStepEpilogue {
2591 fn default() -> Self {
2592 Self { epsilon: 0.0 }
2593 }
2594}
2595
2596#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2597pub struct ImageNormalizeDescriptor {
2598 pub batch: usize,
2599 pub height: usize,
2600 pub width: usize,
2601 pub epsilon: f64,
2602 #[serde(default)]
2603 pub gain: Option<f64>,
2604 #[serde(default)]
2605 pub bias: Option<f64>,
2606 #[serde(default)]
2607 pub gamma: Option<f64>,
2608}