1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyMarkFn = fn(&GpuTensorHandle);
17type ResidencyClearFn = fn(&GpuTensorHandle);
18type SequenceThresholdFn = fn() -> Option<usize>;
19type WorkgroupSizeHintFn = fn() -> Option<u32>;
20
21static RESIDENCY_MARK: OnceCell<ResidencyMarkFn> = OnceCell::new();
22static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
23static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
24static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
25
26static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
27static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
28 Lazy::new(|| RwLock::new(HashMap::new()));
29static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
30 Lazy::new(|| RwLock::new(HashMap::new()));
31
32static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
33 Lazy::new(|| RwLock::new(HashMap::new()));
34static HANDLE_STORAGES: Lazy<RwLock<HashMap<u64, GpuTensorStorage>>> =
35 Lazy::new(|| RwLock::new(HashMap::new()));
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct TransposeInfo {
39 pub base_rows: usize,
40 pub base_cols: usize,
41}
42
43pub fn register_residency_mark(handler: ResidencyMarkFn) {
46 let _ = RESIDENCY_MARK.set(handler);
47}
48
49pub fn mark_residency(handle: &GpuTensorHandle) {
52 if let Some(handler) = RESIDENCY_MARK.get() {
53 handler(handle);
54 }
55}
56
57pub fn register_residency_clear(handler: ResidencyClearFn) {
61 let _ = RESIDENCY_CLEAR.set(handler);
62}
63
64pub fn clear_residency(handle: &GpuTensorHandle) {
67 if let Some(handler) = RESIDENCY_CLEAR.get() {
68 handler(handle);
69 }
70}
71
72pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
76 let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
77}
78
79pub fn sequence_threshold_hint() -> Option<usize> {
81 SEQUENCE_THRESHOLD_PROVIDER
82 .get()
83 .and_then(|provider| provider())
84}
85
86pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
90 let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
91}
92
93pub fn workgroup_size_hint() -> Option<u32> {
95 WORKGROUP_SIZE_HINT_PROVIDER
96 .get()
97 .and_then(|provider| provider())
98}
99
100pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
103 provider().and_then(|p| p.export_context(kind))
104}
105
106#[cfg(feature = "wgpu")]
110pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
111 provider().and_then(|p| p.export_wgpu_buffer(handle))
112}
113
114pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
117 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
118 guard.insert(handle.buffer_id, precision);
119 }
120}
121
122pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
124 HANDLE_PRECISIONS
125 .read()
126 .ok()
127 .and_then(|guard| guard.get(&handle.buffer_id).copied())
128}
129
130pub fn clear_handle_precision(handle: &GpuTensorHandle) {
132 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
133 guard.remove(&handle.buffer_id);
134 }
135}
136
137pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
140 if let Ok(mut guard) = LOGICAL_HANDLES.write() {
141 if logical {
142 guard.insert(handle.buffer_id);
143 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
144 *hits.entry(handle.buffer_id).or_insert(0) += 1;
145 }
146 } else {
147 guard.remove(&handle.buffer_id);
148 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
149 hits.remove(&handle.buffer_id);
150 }
151 }
152 }
153}
154
155pub fn clear_handle_logical(handle: &GpuTensorHandle) {
157 set_handle_logical(handle, false);
158}
159
160pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
162 LOGICAL_HANDLES
163 .read()
164 .map(|guard| guard.contains(&handle.buffer_id))
165 .unwrap_or(false)
166}
167
168pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
169 LOGICAL_HANDLE_HITS
170 .read()
171 .ok()
172 .and_then(|guard| guard.get(&buffer_id).copied())
173}
174
175pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
176 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
177 guard.insert(
178 handle.buffer_id,
179 TransposeInfo {
180 base_rows,
181 base_cols,
182 },
183 );
184 }
185}
186
187pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
188 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
189 guard.remove(&handle.buffer_id);
190 }
191}
192
193pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
194 TRANSPOSED_HANDLES
195 .read()
196 .ok()
197 .and_then(|guard| guard.get(&handle.buffer_id).copied())
198}
199
200pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
201 handle_transpose_info(handle).is_some()
202}
203
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum GpuTensorStorage {
206 Real,
207 ComplexInterleaved,
208}
209
210impl Default for GpuTensorStorage {
211 fn default() -> Self {
212 Self::Real
213 }
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct GpuTensorHandle {
218 pub shape: Vec<usize>,
219 pub device_id: u32,
220 pub buffer_id: u64,
221}
222
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct ApiDeviceInfo {
225 pub device_id: u32,
226 pub name: String,
227 pub vendor: String,
228 pub memory_bytes: Option<u64>,
229 pub backend: Option<String>,
230}
231
232#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
233pub struct ReduceDimResult {
234 pub values: GpuTensorHandle,
235 pub indices: GpuTensorHandle,
236}
237
238#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
239pub struct ProviderCumminResult {
240 pub values: GpuTensorHandle,
241 pub indices: GpuTensorHandle,
242}
243
244pub type ProviderCummaxResult = ProviderCumminResult;
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum AccelContextKind {
253 Plotting,
254}
255
256#[derive(Clone)]
258pub enum AccelContextHandle {
259 #[cfg(feature = "wgpu")]
260 Wgpu(WgpuContextHandle),
261}
262
263impl AccelContextHandle {
264 #[cfg(feature = "wgpu")]
266 pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
267 match self {
268 AccelContextHandle::Wgpu(ctx) => Some(ctx),
269 }
270 }
271}
272
273#[cfg(feature = "wgpu")]
275#[derive(Clone)]
276pub struct WgpuContextHandle {
277 pub instance: Arc<wgpu::Instance>,
278 pub device: Arc<wgpu::Device>,
279 pub queue: Arc<wgpu::Queue>,
280 pub adapter: Arc<wgpu::Adapter>,
281 pub adapter_info: wgpu::AdapterInfo,
282 pub limits: wgpu::Limits,
283 pub features: wgpu::Features,
284}
285
286#[cfg(feature = "wgpu")]
288#[derive(Clone)]
289pub struct WgpuBufferRef {
290 pub buffer: Arc<wgpu::Buffer>,
291 pub len: usize,
292 pub shape: Vec<usize>,
293 pub element_size: usize,
294 pub precision: ProviderPrecision,
295}
296
297pub fn set_handle_storage(handle: &GpuTensorHandle, storage: GpuTensorStorage) {
298 if let Ok(mut guard) = HANDLE_STORAGES.write() {
299 guard.insert(handle.buffer_id, storage);
300 }
301}
302
303pub fn handle_storage(handle: &GpuTensorHandle) -> GpuTensorStorage {
304 HANDLE_STORAGES
305 .read()
306 .ok()
307 .and_then(|guard| guard.get(&handle.buffer_id).cloned())
308 .unwrap_or(GpuTensorStorage::Real)
309}
310
311pub fn clear_handle_storage(handle: &GpuTensorHandle) {
312 if let Ok(mut guard) = HANDLE_STORAGES.write() {
313 guard.remove(&handle.buffer_id);
314 }
315}
316
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
318pub enum PagefunOp {
319 Mtimes,
320}
321
322#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
323pub struct PagefunRequest {
324 pub op: PagefunOp,
325 pub inputs: Vec<GpuTensorHandle>,
326 pub output_shape: Vec<usize>,
327 pub page_dims: Vec<usize>,
328 pub input_page_dims: Vec<Vec<usize>>,
329}
330
331#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
332pub enum FindDirection {
333 First,
334 Last,
335}
336
337#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
338pub struct ProviderFindResult {
339 pub linear: GpuTensorHandle,
340 pub rows: GpuTensorHandle,
341 pub cols: GpuTensorHandle,
342 pub values: Option<GpuTensorHandle>,
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
346pub struct ProviderBandwidth {
347 pub lower: u32,
348 pub upper: u32,
349}
350
351#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
352pub enum ProviderSymmetryKind {
353 Symmetric,
354 Skew,
355}
356
357#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
358pub enum ProviderHermitianKind {
359 Hermitian,
360 Skew,
361}
362
363#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
364pub struct ProviderLuResult {
365 pub combined: GpuTensorHandle,
366 pub lower: GpuTensorHandle,
367 pub upper: GpuTensorHandle,
368 pub perm_matrix: GpuTensorHandle,
369 pub perm_vector: GpuTensorHandle,
370}
371
372#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
373pub struct ProviderCholResult {
374 pub factor: GpuTensorHandle,
375 pub info: u32,
377}
378
379#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
380pub struct ProviderQrResult {
381 pub q: GpuTensorHandle,
382 pub r: GpuTensorHandle,
383 pub perm_matrix: GpuTensorHandle,
384 pub perm_vector: GpuTensorHandle,
385}
386
387#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
388pub struct ProviderQrPowerIterResult {
389 pub q: GpuTensorHandle,
390 pub r: GpuTensorHandle,
391 pub perm_matrix: GpuTensorHandle,
392 pub perm_vector: GpuTensorHandle,
393}
394
395#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
396pub struct ProviderLinsolveOptions {
397 pub lower: bool,
398 pub upper: bool,
399 pub rectangular: bool,
400 pub transposed: bool,
401 pub conjugate: bool,
402 pub symmetric: bool,
403 pub posdef: bool,
404 pub need_rcond: bool,
405 pub rcond: Option<f64>,
406}
407
408#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
409pub struct ProviderLinsolveResult {
410 pub solution: GpuTensorHandle,
411 pub reciprocal_condition: f64,
412}
413
414#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
415pub struct ProviderPinvOptions {
416 pub tolerance: Option<f64>,
417}
418
419#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
420pub struct ProviderPolyvalMu {
421 pub mean: f64,
422 pub scale: f64,
423}
424
425#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
426pub struct ProviderPolyvalOptions {
427 pub mu: Option<ProviderPolyvalMu>,
428}
429
430#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
431pub struct ProviderInvOptions {}
432
433#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
434pub struct ProviderPolyfitResult {
435 pub coefficients: Vec<f64>,
436 pub r_matrix: Vec<f64>,
437 pub normr: f64,
438 pub df: f64,
439 pub mu: [f64; 2],
440}
441
442#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
444pub struct ProviderPolyderQuotient {
445 pub numerator: GpuTensorHandle,
446 pub denominator: GpuTensorHandle,
447}
448
449#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
451pub enum ProviderCondNorm {
452 Two,
453 One,
454 Inf,
455 Fro,
456}
457
458#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
460pub enum ProviderNormOrder {
461 Two,
462 One,
463 Inf,
464 NegInf,
465 Zero,
466 Fro,
467 Nuc,
468 P(f64),
469}
470
471#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
472pub struct ProviderEigResult {
473 pub eigenvalues: GpuTensorHandle,
474 pub diagonal: GpuTensorHandle,
475 pub right: GpuTensorHandle,
476 pub left: Option<GpuTensorHandle>,
477}
478
479#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
480pub enum ProviderQrPivot {
481 Matrix,
482 Vector,
483}
484
485#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
486pub struct ProviderQrOptions {
487 pub economy: bool,
488 pub pivot: ProviderQrPivot,
489}
490
491impl Default for ProviderQrOptions {
492 fn default() -> Self {
493 Self {
494 economy: false,
495 pivot: ProviderQrPivot::Matrix,
496 }
497 }
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
501pub enum ProviderPrecision {
502 F32,
503 F64,
504}
505
506#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
511pub enum SpawnHandleConcurrency {
512 ImmutableShare,
514 CopyOnWrite,
516 SynchronizedMutation,
518 Reject,
520}
521
522impl SpawnHandleConcurrency {
523 pub fn as_str(self) -> &'static str {
524 match self {
525 SpawnHandleConcurrency::ImmutableShare => "immutable_share",
526 SpawnHandleConcurrency::CopyOnWrite => "copy_on_write",
527 SpawnHandleConcurrency::SynchronizedMutation => "synchronized_mutation",
528 SpawnHandleConcurrency::Reject => "reject",
529 }
530 }
531}
532
533#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
534pub enum ReductionTwoPassMode {
535 Auto,
536 ForceOn,
537 ForceOff,
538}
539
540impl ReductionTwoPassMode {
541 pub fn as_str(self) -> &'static str {
542 match self {
543 ReductionTwoPassMode::Auto => "auto",
544 ReductionTwoPassMode::ForceOn => "force_on",
545 ReductionTwoPassMode::ForceOff => "force_off",
546 }
547 }
548}
549
550#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
551pub enum ReductionFlavor {
552 Sum,
553 Mean,
554 CustomScale(f64),
555}
556
557impl ReductionFlavor {
558 pub fn is_mean(self) -> bool {
559 matches!(self, ReductionFlavor::Mean)
560 }
561
562 pub fn scale(self, reduce_len: usize) -> f64 {
563 match self {
564 ReductionFlavor::Sum => 1.0,
565 ReductionFlavor::Mean => {
566 if reduce_len == 0 {
567 1.0
568 } else {
569 1.0 / reduce_len as f64
570 }
571 }
572 ReductionFlavor::CustomScale(scale) => scale,
573 }
574 }
575}
576
577#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
579pub enum CorrcoefNormalization {
580 Unbiased,
581 Biased,
582}
583
584#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
586pub enum CorrcoefRows {
587 All,
588 Complete,
589 Pairwise,
590}
591
592#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
594pub struct CorrcoefOptions {
595 pub normalization: CorrcoefNormalization,
596 pub rows: CorrcoefRows,
597}
598
599impl Default for CorrcoefOptions {
600 fn default() -> Self {
601 Self {
602 normalization: CorrcoefNormalization::Unbiased,
603 rows: CorrcoefRows::All,
604 }
605 }
606}
607
608#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
610pub enum CovNormalization {
611 Unbiased,
612 Biased,
613}
614
615#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
617pub enum CovRows {
618 All,
619 OmitRows,
620 PartialRows,
621}
622
623#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
625pub struct CovarianceOptions {
626 pub normalization: CovNormalization,
627 pub rows: CovRows,
628 pub has_weight_vector: bool,
629}
630
631impl Default for CovarianceOptions {
632 fn default() -> Self {
633 Self {
634 normalization: CovNormalization::Unbiased,
635 rows: CovRows::All,
636 has_weight_vector: false,
637 }
638 }
639}
640
641#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
643pub enum ProviderStdNormalization {
644 Sample,
645 Population,
646}
647
648#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
650pub enum ProviderNanMode {
651 Include,
652 Omit,
653}
654
655#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
657pub enum ProviderScanDirection {
658 Forward,
659 Reverse,
660}
661
662#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
664pub enum SortOrder {
665 Ascend,
666 Descend,
667}
668
669#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
671pub enum SortComparison {
672 Auto,
673 Real,
674 Abs,
675}
676
677#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
679pub struct SortResult {
680 pub values: HostTensorOwned,
681 pub indices: HostTensorOwned,
682}
683
684#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
685pub struct SortRowsColumnSpec {
686 pub index: usize,
687 pub order: SortOrder,
688}
689
690#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
692pub enum UniqueOrder {
693 Sorted,
694 Stable,
695}
696
697#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
699pub enum UniqueOccurrence {
700 First,
701 Last,
702}
703
704#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
706pub struct UniqueOptions {
707 pub rows: bool,
708 pub order: UniqueOrder,
709 pub occurrence: UniqueOccurrence,
710}
711
712#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
714pub struct UniqueResult {
715 pub values: HostTensorOwned,
716 pub ia: HostTensorOwned,
717 pub ic: HostTensorOwned,
718}
719
720#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
722pub enum UnionOrder {
723 Sorted,
724 Stable,
725}
726
727#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
729pub struct UnionOptions {
730 pub rows: bool,
731 pub order: UnionOrder,
732}
733
734#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
736pub struct UnionResult {
737 pub values: HostTensorOwned,
738 pub ia: HostTensorOwned,
739 pub ib: HostTensorOwned,
740}
741
742#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
744pub enum FspecialFilter {
745 Average {
746 rows: u32,
747 cols: u32,
748 },
749 Disk {
750 radius: f64,
751 size: u32,
752 },
753 Gaussian {
754 rows: u32,
755 cols: u32,
756 sigma: f64,
757 },
758 Laplacian {
759 alpha: f64,
760 },
761 Log {
762 rows: u32,
763 cols: u32,
764 sigma: f64,
765 },
766 Motion {
767 length: u32,
768 kernel_size: u32,
769 angle_degrees: f64,
770 oversample: u32,
771 },
772 Prewitt,
773 Sobel,
774 Unsharp {
775 alpha: f64,
776 },
777}
778
779#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
781pub struct FspecialRequest {
782 pub filter: FspecialFilter,
783}
784
785#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
787pub enum ImfilterPadding {
788 Constant,
789 Replicate,
790 Symmetric,
791 Circular,
792}
793
794#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
796pub enum ImfilterShape {
797 Same,
798 Full,
799 Valid,
800}
801
802#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
804pub enum ImfilterMode {
805 Correlation,
806 Convolution,
807}
808
809#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
811pub struct ImfilterOptions {
812 pub padding: ImfilterPadding,
813 pub constant_value: f64,
814 pub shape: ImfilterShape,
815 pub mode: ImfilterMode,
816}
817
818impl Default for ImfilterOptions {
819 fn default() -> Self {
820 Self {
821 padding: ImfilterPadding::Constant,
822 constant_value: 0.0,
823 shape: ImfilterShape::Same,
824 mode: ImfilterMode::Correlation,
825 }
826 }
827}
828
829#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
831pub enum SetdiffOrder {
832 Sorted,
833 Stable,
834}
835
836#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
838pub struct SetdiffOptions {
839 pub rows: bool,
840 pub order: SetdiffOrder,
841}
842
843#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
845pub struct SetdiffResult {
846 pub values: HostTensorOwned,
847 pub ia: HostTensorOwned,
848}
849
850#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
852pub struct IsMemberOptions {
853 pub rows: bool,
854}
855
856#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
858pub struct HostLogicalOwned {
859 pub data: Vec<u8>,
860 pub shape: Vec<usize>,
861}
862
863#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
865pub struct IsMemberResult {
866 pub mask: HostLogicalOwned,
867 pub loc: HostTensorOwned,
868}
869
870#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
871pub enum ProviderConvMode {
872 Full,
873 Same,
874 Valid,
875}
876
877#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
878pub enum ProviderConvOrientation {
879 Row,
880 Column,
881}
882
883#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
884pub struct ProviderConv1dOptions {
885 pub mode: ProviderConvMode,
886 pub orientation: ProviderConvOrientation,
887}
888
889#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
890pub struct ProviderIirFilterOptions {
891 pub dim: usize,
893 pub zi: Option<GpuTensorHandle>,
895}
896
897#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
898pub struct ProviderIirFilterResult {
899 pub output: GpuTensorHandle,
901 pub final_state: Option<GpuTensorHandle>,
903}
904
905#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
906pub struct ProviderMoments2 {
907 pub mean: GpuTensorHandle,
908 pub ex2: GpuTensorHandle,
909}
910
911#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
912pub struct ProviderDispatchStats {
913 pub count: u64,
915 pub total_wall_time_ns: u64,
917}
918
919#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
920pub struct ProviderFallbackStat {
921 pub reason: String,
922 pub count: u64,
923}
924
925#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
926pub struct ProviderTelemetry {
927 pub fused_elementwise: ProviderDispatchStats,
928 pub fused_reduction: ProviderDispatchStats,
929 pub matmul: ProviderDispatchStats,
930 pub linsolve: ProviderDispatchStats,
931 pub mldivide: ProviderDispatchStats,
932 pub mrdivide: ProviderDispatchStats,
933 pub upload_bytes: u64,
934 pub download_bytes: u64,
935 pub solve_fallbacks: Vec<ProviderFallbackStat>,
936 pub fusion_cache_hits: u64,
937 pub fusion_cache_misses: u64,
938 pub bind_group_cache_hits: u64,
939 pub bind_group_cache_misses: u64,
940 pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
942 pub kernel_launches: Vec<KernelLaunchTelemetry>,
944}
945
946#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
947pub struct BindGroupLayoutTelemetry {
948 pub tag: String,
949 pub hits: u64,
950 pub misses: u64,
951}
952
953#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
954pub struct KernelAttrTelemetry {
955 pub key: String,
956 pub value: u64,
957}
958
959#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
960pub struct KernelLaunchTelemetry {
961 pub kernel: String,
962 pub precision: Option<String>,
963 pub shape: Vec<KernelAttrTelemetry>,
964 pub tuning: Vec<KernelAttrTelemetry>,
965}
966
967pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
968pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
969
970fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
971 Box::pin(async move { Err(anyhow::anyhow!(message)) })
972}
973
974pub trait AccelProvider: Send + Sync {
976 fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
977 fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
978 fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
979 fn device_info(&self) -> String;
980 fn device_id(&self) -> u32 {
981 0
982 }
983
984 fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
990 SpawnHandleConcurrency::Reject
991 }
992
993 fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
996 None
997 }
998
999 #[cfg(feature = "wgpu")]
1001 fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
1002 let _ = _handle;
1003 None
1004 }
1005
1006 fn gather_linear(
1009 &self,
1010 _source: &GpuTensorHandle,
1011 _indices: &[u32],
1012 _output_shape: &[usize],
1013 ) -> anyhow::Result<GpuTensorHandle> {
1014 Err(anyhow::anyhow!("gather_linear not supported by provider"))
1015 }
1016
1017 fn scatter_linear(
1021 &self,
1022 _target: &GpuTensorHandle,
1023 _indices: &[u32],
1024 _values: &GpuTensorHandle,
1025 ) -> anyhow::Result<()> {
1026 Err(anyhow::anyhow!("scatter_linear not supported by provider"))
1027 }
1028
1029 fn device_info_struct(&self) -> ApiDeviceInfo {
1031 ApiDeviceInfo {
1032 device_id: 0,
1033 name: self.device_info(),
1034 vendor: String::new(),
1035 memory_bytes: None,
1036 backend: None,
1037 }
1038 }
1039
1040 fn precision(&self) -> ProviderPrecision {
1041 ProviderPrecision::F64
1042 }
1043
1044 fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
1046 Err(anyhow::anyhow!("read_scalar not supported by provider"))
1047 }
1048
1049 fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1051 Err(anyhow::anyhow!("zeros not supported by provider"))
1052 }
1053
1054 fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1056 Err(anyhow::anyhow!("ones not supported by provider"))
1057 }
1058
1059 fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1061 self.zeros(&prototype.shape)
1062 }
1063
1064 fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
1066 if value == 0.0 {
1067 return self.zeros(shape);
1068 }
1069 if let Ok(base) = self.zeros(shape) {
1070 match self.scalar_add(&base, value) {
1071 Ok(out) => {
1072 let _ = self.free(&base);
1073 return Ok(out);
1074 }
1075 Err(_) => {
1076 let _ = self.free(&base);
1077 }
1078 }
1079 }
1080 let len: usize = shape.iter().copied().product();
1081 let data = vec![value; len];
1082 let view = HostTensorView { data: &data, shape };
1083 self.upload(&view)
1084 }
1085
1086 fn fill_like(
1088 &self,
1089 prototype: &GpuTensorHandle,
1090 value: f64,
1091 ) -> anyhow::Result<GpuTensorHandle> {
1092 if value == 0.0 {
1093 return self.zeros_like(prototype);
1094 }
1095 if let Ok(base) = self.zeros_like(prototype) {
1096 match self.scalar_add(&base, value) {
1097 Ok(out) => {
1098 let _ = self.free(&base);
1099 return Ok(out);
1100 }
1101 Err(_) => {
1102 let _ = self.free(&base);
1103 }
1104 }
1105 }
1106 self.fill(&prototype.shape, value)
1107 }
1108
1109 fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1111 self.ones(&prototype.shape)
1112 }
1113
1114 fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1116 Err(anyhow::anyhow!("eye not supported by provider"))
1117 }
1118
1119 fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1121 self.eye(&prototype.shape)
1122 }
1123
1124 fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1126 Err(anyhow::anyhow!("meshgrid not supported by provider"))
1127 }
1128
1129 fn diag_from_vector(
1131 &self,
1132 _vector: &GpuTensorHandle,
1133 _offset: isize,
1134 ) -> anyhow::Result<GpuTensorHandle> {
1135 Err(anyhow::anyhow!(
1136 "diag_from_vector not supported by provider"
1137 ))
1138 }
1139
1140 fn diag_extract(
1142 &self,
1143 _matrix: &GpuTensorHandle,
1144 _offset: isize,
1145 ) -> anyhow::Result<GpuTensorHandle> {
1146 Err(anyhow::anyhow!("diag_extract not supported by provider"))
1147 }
1148
1149 fn tril<'a>(
1151 &'a self,
1152 _matrix: &'a GpuTensorHandle,
1153 _offset: isize,
1154 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1155 Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1156 }
1157
1158 fn triu<'a>(
1160 &'a self,
1161 _matrix: &'a GpuTensorHandle,
1162 _offset: isize,
1163 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1164 Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1165 }
1166
1167 fn polyval(
1169 &self,
1170 _coefficients: &GpuTensorHandle,
1171 _points: &GpuTensorHandle,
1172 _options: &ProviderPolyvalOptions,
1173 ) -> anyhow::Result<GpuTensorHandle> {
1174 Err(anyhow::anyhow!("polyval not supported by provider"))
1175 }
1176
1177 fn polyfit<'a>(
1179 &'a self,
1180 _x: &'a GpuTensorHandle,
1181 _y: &'a GpuTensorHandle,
1182 _degree: usize,
1183 _weights: Option<&'a GpuTensorHandle>,
1184 ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1185 Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1186 }
1187
1188 fn polyder_single<'a>(
1190 &'a self,
1191 _polynomial: &'a GpuTensorHandle,
1192 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1193 Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1194 }
1195
1196 fn polyder_product<'a>(
1198 &'a self,
1199 _p: &'a GpuTensorHandle,
1200 _q: &'a GpuTensorHandle,
1201 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1202 Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1203 }
1204
1205 fn polyder_quotient<'a>(
1207 &'a self,
1208 _u: &'a GpuTensorHandle,
1209 _v: &'a GpuTensorHandle,
1210 ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1211 Box::pin(async move {
1212 Err(anyhow::anyhow!(
1213 "polyder_quotient not supported by provider"
1214 ))
1215 })
1216 }
1217
1218 fn polyint(
1220 &self,
1221 _polynomial: &GpuTensorHandle,
1222 _constant: f64,
1223 ) -> anyhow::Result<GpuTensorHandle> {
1224 Err(anyhow::anyhow!("polyint not supported by provider"))
1225 }
1226
1227 fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1229 Err(anyhow::anyhow!("random_uniform not supported by provider"))
1230 }
1231
1232 fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1234 self.random_uniform(&prototype.shape)
1235 }
1236
1237 fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1239 Err(anyhow::anyhow!("random_normal not supported by provider"))
1240 }
1241
1242 fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1244 self.random_normal(&prototype.shape)
1245 }
1246
1247 fn random_exponential(&self, _mu: f64, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1249 Err(anyhow::anyhow!(
1250 "random_exponential not supported by provider"
1251 ))
1252 }
1253
1254 fn random_normrnd(
1256 &self,
1257 _mu: f64,
1258 _sigma: f64,
1259 _shape: &[usize],
1260 ) -> anyhow::Result<GpuTensorHandle> {
1261 Err(anyhow::anyhow!("random_normrnd not supported by provider"))
1262 }
1263
1264 fn random_unifrnd(
1266 &self,
1267 _a: f64,
1268 _b: f64,
1269 _shape: &[usize],
1270 ) -> anyhow::Result<GpuTensorHandle> {
1271 Err(anyhow::anyhow!("random_unifrnd not supported by provider"))
1272 }
1273
1274 fn stochastic_evolution(
1275 &self,
1276 _state: &GpuTensorHandle,
1277 _drift: f64,
1278 _scale: f64,
1279 _steps: u32,
1280 ) -> anyhow::Result<GpuTensorHandle> {
1281 Err(anyhow::anyhow!(
1282 "stochastic_evolution not supported by provider"
1283 ))
1284 }
1285
1286 fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1288 Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1289 }
1290
1291 fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1293 Err(anyhow::anyhow!("fspecial not supported by provider"))
1294 }
1295
1296 fn peaks(&self, _n: usize) -> anyhow::Result<GpuTensorHandle> {
1299 Err(anyhow::anyhow!("peaks not supported by provider"))
1300 }
1301
1302 fn peaks_xy(
1305 &self,
1306 _x: &GpuTensorHandle,
1307 _y: &GpuTensorHandle,
1308 ) -> anyhow::Result<GpuTensorHandle> {
1309 Err(anyhow::anyhow!("peaks_xy not supported by provider"))
1310 }
1311
1312 fn hann_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1313 Err(anyhow::anyhow!("hann_window not supported by provider"))
1314 }
1315
1316 fn hamming_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1317 Err(anyhow::anyhow!("hamming_window not supported by provider"))
1318 }
1319
1320 fn blackman_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1321 Err(anyhow::anyhow!("blackman_window not supported by provider"))
1322 }
1323
1324 fn imfilter<'a>(
1326 &'a self,
1327 _image: &'a GpuTensorHandle,
1328 _kernel: &'a GpuTensorHandle,
1329 _options: &'a ImfilterOptions,
1330 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1331 unsupported_future("imfilter not supported by provider")
1332 }
1333
1334 fn random_integer_range(
1336 &self,
1337 _lower: i64,
1338 _upper: i64,
1339 _shape: &[usize],
1340 ) -> anyhow::Result<GpuTensorHandle> {
1341 Err(anyhow::anyhow!(
1342 "random_integer_range not supported by provider"
1343 ))
1344 }
1345
1346 fn random_integer_like(
1348 &self,
1349 prototype: &GpuTensorHandle,
1350 lower: i64,
1351 upper: i64,
1352 ) -> anyhow::Result<GpuTensorHandle> {
1353 self.random_integer_range(lower, upper, &prototype.shape)
1354 }
1355
1356 fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1358 Err(anyhow!("random_permutation not supported by provider"))
1359 }
1360
1361 fn random_permutation_like(
1363 &self,
1364 _prototype: &GpuTensorHandle,
1365 n: usize,
1366 k: usize,
1367 ) -> anyhow::Result<GpuTensorHandle> {
1368 self.random_permutation(n, k)
1369 }
1370
1371 fn covariance<'a>(
1373 &'a self,
1374 _matrix: &'a GpuTensorHandle,
1375 _second: Option<&'a GpuTensorHandle>,
1376 _weights: Option<&'a GpuTensorHandle>,
1377 _options: &'a CovarianceOptions,
1378 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1379 unsupported_future("covariance not supported by provider")
1380 }
1381
1382 fn corrcoef<'a>(
1384 &'a self,
1385 _matrix: &'a GpuTensorHandle,
1386 _options: &'a CorrcoefOptions,
1387 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1388 unsupported_future("corrcoef not supported by provider")
1389 }
1390
1391 fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1393 Err(anyhow::anyhow!("linspace not supported by provider"))
1394 }
1395 fn elem_add<'a>(
1396 &'a self,
1397 _a: &'a GpuTensorHandle,
1398 _b: &'a GpuTensorHandle,
1399 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1400 unsupported_future("elem_add not supported by provider")
1401 }
1402 fn elem_mul<'a>(
1403 &'a self,
1404 _a: &'a GpuTensorHandle,
1405 _b: &'a GpuTensorHandle,
1406 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1407 unsupported_future("elem_mul not supported by provider")
1408 }
1409 fn elem_max<'a>(
1410 &'a self,
1411 _a: &'a GpuTensorHandle,
1412 _b: &'a GpuTensorHandle,
1413 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1414 unsupported_future("elem_max not supported by provider")
1415 }
1416 fn elem_min<'a>(
1417 &'a self,
1418 _a: &'a GpuTensorHandle,
1419 _b: &'a GpuTensorHandle,
1420 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1421 unsupported_future("elem_min not supported by provider")
1422 }
1423 fn elem_sub<'a>(
1424 &'a self,
1425 _a: &'a GpuTensorHandle,
1426 _b: &'a GpuTensorHandle,
1427 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1428 unsupported_future("elem_sub not supported by provider")
1429 }
1430 fn elem_div<'a>(
1431 &'a self,
1432 _a: &'a GpuTensorHandle,
1433 _b: &'a GpuTensorHandle,
1434 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1435 unsupported_future("elem_div not supported by provider")
1436 }
1437 fn elem_pow<'a>(
1438 &'a self,
1439 _a: &'a GpuTensorHandle,
1440 _b: &'a GpuTensorHandle,
1441 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1442 unsupported_future("elem_pow not supported by provider")
1443 }
1444
1445 fn elem_hypot<'a>(
1446 &'a self,
1447 _a: &'a GpuTensorHandle,
1448 _b: &'a GpuTensorHandle,
1449 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1450 unsupported_future("elem_hypot not supported by provider")
1451 }
1452 fn elem_ge<'a>(
1453 &'a self,
1454 _a: &'a GpuTensorHandle,
1455 _b: &'a GpuTensorHandle,
1456 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1457 unsupported_future("elem_ge not supported by provider")
1458 }
1459 fn elem_le<'a>(
1460 &'a self,
1461 _a: &'a GpuTensorHandle,
1462 _b: &'a GpuTensorHandle,
1463 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1464 unsupported_future("elem_le not supported by provider")
1465 }
1466 fn elem_lt<'a>(
1467 &'a self,
1468 _a: &'a GpuTensorHandle,
1469 _b: &'a GpuTensorHandle,
1470 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1471 unsupported_future("elem_lt not supported by provider")
1472 }
1473 fn elem_gt<'a>(
1474 &'a self,
1475 _a: &'a GpuTensorHandle,
1476 _b: &'a GpuTensorHandle,
1477 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1478 unsupported_future("elem_gt not supported by provider")
1479 }
1480 fn elem_eq<'a>(
1481 &'a self,
1482 _a: &'a GpuTensorHandle,
1483 _b: &'a GpuTensorHandle,
1484 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1485 unsupported_future("elem_eq not supported by provider")
1486 }
1487 fn elem_ne<'a>(
1488 &'a self,
1489 _a: &'a GpuTensorHandle,
1490 _b: &'a GpuTensorHandle,
1491 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1492 unsupported_future("elem_ne not supported by provider")
1493 }
1494 fn logical_and(
1495 &self,
1496 _a: &GpuTensorHandle,
1497 _b: &GpuTensorHandle,
1498 ) -> anyhow::Result<GpuTensorHandle> {
1499 Err(anyhow::anyhow!("logical_and not supported by provider"))
1500 }
1501 fn logical_or(
1502 &self,
1503 _a: &GpuTensorHandle,
1504 _b: &GpuTensorHandle,
1505 ) -> anyhow::Result<GpuTensorHandle> {
1506 Err(anyhow::anyhow!("logical_or not supported by provider"))
1507 }
1508 fn logical_xor(
1509 &self,
1510 _a: &GpuTensorHandle,
1511 _b: &GpuTensorHandle,
1512 ) -> anyhow::Result<GpuTensorHandle> {
1513 Err(anyhow::anyhow!("logical_xor not supported by provider"))
1514 }
1515 fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1516 Err(anyhow::anyhow!("logical_not not supported by provider"))
1517 }
1518 fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1519 Ok(handle_is_logical(a))
1520 }
1521 fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1522 Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1523 }
1524 fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1525 Err(anyhow::anyhow!(
1526 "logical_isfinite not supported by provider"
1527 ))
1528 }
1529 fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1530 Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1531 }
1532 fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1533 Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1534 }
1535 fn elem_atan2<'a>(
1536 &'a self,
1537 _y: &'a GpuTensorHandle,
1538 _x: &'a GpuTensorHandle,
1539 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1540 unsupported_future("elem_atan2 not supported by provider")
1541 }
1542 fn unary_sin<'a>(
1544 &'a self,
1545 _a: &'a GpuTensorHandle,
1546 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1547 unsupported_future("unary_sin not supported by provider")
1548 }
1549 fn unary_sinc<'a>(
1550 &'a self,
1551 _a: &'a GpuTensorHandle,
1552 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1553 unsupported_future("unary_sinc not supported by provider")
1554 }
1555 fn unary_gamma<'a>(
1556 &'a self,
1557 _a: &'a GpuTensorHandle,
1558 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1559 unsupported_future("unary_gamma not supported by provider")
1560 }
1561 fn unary_factorial<'a>(
1562 &'a self,
1563 _a: &'a GpuTensorHandle,
1564 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1565 unsupported_future("unary_factorial not supported by provider")
1566 }
1567 fn unary_asinh<'a>(
1568 &'a self,
1569 _a: &'a GpuTensorHandle,
1570 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1571 unsupported_future("unary_asinh not supported by provider")
1572 }
1573 fn unary_sinh<'a>(
1574 &'a self,
1575 _a: &'a GpuTensorHandle,
1576 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1577 unsupported_future("unary_sinh not supported by provider")
1578 }
1579 fn unary_cosh<'a>(
1580 &'a self,
1581 _a: &'a GpuTensorHandle,
1582 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1583 unsupported_future("unary_cosh not supported by provider")
1584 }
1585 fn unary_asin<'a>(
1586 &'a self,
1587 _a: &'a GpuTensorHandle,
1588 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1589 unsupported_future("unary_asin not supported by provider")
1590 }
1591 fn unary_acos<'a>(
1592 &'a self,
1593 _a: &'a GpuTensorHandle,
1594 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1595 unsupported_future("unary_acos not supported by provider")
1596 }
1597 fn unary_acosh<'a>(
1598 &'a self,
1599 _a: &'a GpuTensorHandle,
1600 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1601 unsupported_future("unary_acosh not supported by provider")
1602 }
1603 fn unary_tan<'a>(
1604 &'a self,
1605 _a: &'a GpuTensorHandle,
1606 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1607 unsupported_future("unary_tan not supported by provider")
1608 }
1609 fn unary_tanh<'a>(
1610 &'a self,
1611 _a: &'a GpuTensorHandle,
1612 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1613 unsupported_future("unary_tanh not supported by provider")
1614 }
1615 fn unary_atan<'a>(
1616 &'a self,
1617 _a: &'a GpuTensorHandle,
1618 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1619 unsupported_future("unary_atan not supported by provider")
1620 }
1621 fn unary_atanh<'a>(
1622 &'a self,
1623 _a: &'a GpuTensorHandle,
1624 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1625 unsupported_future("unary_atanh not supported by provider")
1626 }
1627 fn unary_ceil<'a>(
1628 &'a self,
1629 _a: &'a GpuTensorHandle,
1630 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1631 unsupported_future("unary_ceil not supported by provider")
1632 }
1633 fn unary_floor<'a>(
1634 &'a self,
1635 _a: &'a GpuTensorHandle,
1636 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1637 unsupported_future("unary_floor not supported by provider")
1638 }
1639 fn unary_round<'a>(
1640 &'a self,
1641 _a: &'a GpuTensorHandle,
1642 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1643 unsupported_future("unary_round not supported by provider")
1644 }
1645 fn unary_fix<'a>(
1646 &'a self,
1647 _a: &'a GpuTensorHandle,
1648 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1649 unsupported_future("unary_fix not supported by provider")
1650 }
1651 fn unary_cos<'a>(
1652 &'a self,
1653 _a: &'a GpuTensorHandle,
1654 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1655 unsupported_future("unary_cos not supported by provider")
1656 }
1657 fn unary_angle<'a>(
1658 &'a self,
1659 _a: &'a GpuTensorHandle,
1660 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1661 unsupported_future("unary_angle not supported by provider")
1662 }
1663 fn unary_imag<'a>(
1664 &'a self,
1665 _a: &'a GpuTensorHandle,
1666 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1667 unsupported_future("unary_imag not supported by provider")
1668 }
1669 fn unary_real<'a>(
1670 &'a self,
1671 _a: &'a GpuTensorHandle,
1672 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1673 unsupported_future("unary_real not supported by provider")
1674 }
1675 fn unary_conj<'a>(
1676 &'a self,
1677 _a: &'a GpuTensorHandle,
1678 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1679 unsupported_future("unary_conj not supported by provider")
1680 }
1681 fn unary_abs<'a>(
1682 &'a self,
1683 _a: &'a GpuTensorHandle,
1684 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1685 unsupported_future("unary_abs not supported by provider")
1686 }
1687 fn unary_sign<'a>(
1688 &'a self,
1689 _a: &'a GpuTensorHandle,
1690 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1691 unsupported_future("unary_sign not supported by provider")
1692 }
1693 fn unary_exp<'a>(
1694 &'a self,
1695 _a: &'a GpuTensorHandle,
1696 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1697 unsupported_future("unary_exp not supported by provider")
1698 }
1699 fn unary_expm1<'a>(
1700 &'a self,
1701 _a: &'a GpuTensorHandle,
1702 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1703 unsupported_future("unary_expm1 not supported by provider")
1704 }
1705 fn unary_log<'a>(
1706 &'a self,
1707 _a: &'a GpuTensorHandle,
1708 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1709 unsupported_future("unary_log not supported by provider")
1710 }
1711 fn unary_log2<'a>(
1712 &'a self,
1713 _a: &'a GpuTensorHandle,
1714 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1715 unsupported_future("unary_log2 not supported by provider")
1716 }
1717 fn unary_log10<'a>(
1718 &'a self,
1719 _a: &'a GpuTensorHandle,
1720 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1721 unsupported_future("unary_log10 not supported by provider")
1722 }
1723 fn unary_log1p<'a>(
1724 &'a self,
1725 _a: &'a GpuTensorHandle,
1726 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1727 unsupported_future("unary_log1p not supported by provider")
1728 }
1729 fn unary_sqrt<'a>(
1730 &'a self,
1731 _a: &'a GpuTensorHandle,
1732 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1733 unsupported_future("unary_sqrt not supported by provider")
1734 }
1735 fn unary_double<'a>(
1736 &'a self,
1737 _a: &'a GpuTensorHandle,
1738 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1739 unsupported_future("unary_double not supported by provider")
1740 }
1741 fn unary_single<'a>(
1742 &'a self,
1743 _a: &'a GpuTensorHandle,
1744 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1745 unsupported_future("unary_single not supported by provider")
1746 }
1747 fn unary_pow2<'a>(
1748 &'a self,
1749 _a: &'a GpuTensorHandle,
1750 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1751 unsupported_future("unary_pow2 not supported by provider")
1752 }
1753 fn unary_nextpow2<'a>(
1754 &'a self,
1755 _a: &'a GpuTensorHandle,
1756 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1757 unsupported_future("unary_nextpow2 not supported by provider")
1758 }
1759 fn pow2_scale(
1760 &self,
1761 _mantissa: &GpuTensorHandle,
1762 _exponent: &GpuTensorHandle,
1763 ) -> anyhow::Result<GpuTensorHandle> {
1764 Err(anyhow::anyhow!("pow2_scale not supported by provider"))
1765 }
1766 fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1768 Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
1769 }
1770 fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1771 Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
1772 }
1773 fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1775 Err(anyhow::anyhow!("scalar_add not supported by provider"))
1776 }
1777 fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1778 Err(anyhow::anyhow!("scalar_sub not supported by provider"))
1779 }
1780 fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1781 Err(anyhow::anyhow!("scalar_mul not supported by provider"))
1782 }
1783 fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1784 Err(anyhow::anyhow!("scalar_max not supported by provider"))
1785 }
1786 fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1787 Err(anyhow::anyhow!("scalar_min not supported by provider"))
1788 }
1789 fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1790 Err(anyhow::anyhow!("scalar_div not supported by provider"))
1791 }
1792 fn sort_dim<'a>(
1793 &'a self,
1794 _a: &'a GpuTensorHandle,
1795 _dim: usize,
1796 _order: SortOrder,
1797 _comparison: SortComparison,
1798 ) -> AccelProviderFuture<'a, SortResult> {
1799 unsupported_future("sort_dim not supported by provider")
1800 }
1801 fn sort_rows<'a>(
1802 &'a self,
1803 _a: &'a GpuTensorHandle,
1804 _columns: &'a [SortRowsColumnSpec],
1805 _comparison: SortComparison,
1806 ) -> AccelProviderFuture<'a, SortResult> {
1807 unsupported_future("sort_rows not supported by provider")
1808 }
1809 fn matmul<'a>(
1810 &'a self,
1811 _a: &'a GpuTensorHandle,
1812 _b: &'a GpuTensorHandle,
1813 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1814 unsupported_future("matmul not supported by provider")
1815 }
1816
1817 fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1818 Err(anyhow::anyhow!("syrk not supported by provider"))
1819 }
1820 fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
1821 Err(anyhow::anyhow!("pagefun not supported by provider"))
1822 }
1823
1824 fn matmul_epilogue<'a>(
1829 &'a self,
1830 a: &'a GpuTensorHandle,
1831 b: &'a GpuTensorHandle,
1832 epilogue: &'a MatmulEpilogue,
1833 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1834 Box::pin(async move {
1835 if epilogue.is_noop() {
1836 return self.matmul(a, b).await;
1837 }
1838 Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
1839 })
1840 }
1841 fn image_normalize<'a>(
1842 &'a self,
1843 _input: &'a GpuTensorHandle,
1844 _desc: &'a ImageNormalizeDescriptor,
1845 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1846 unsupported_future("image_normalize fusion not supported by provider")
1847 }
1848 fn matmul_power_step<'a>(
1849 &'a self,
1850 _lhs: &'a GpuTensorHandle,
1851 _rhs: &'a GpuTensorHandle,
1852 _epilogue: &'a PowerStepEpilogue,
1853 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1854 unsupported_future("matmul_power_step normalization not supported by provider")
1855 }
1856 fn linsolve<'a>(
1857 &'a self,
1858 _lhs: &'a GpuTensorHandle,
1859 _rhs: &'a GpuTensorHandle,
1860 _options: &'a ProviderLinsolveOptions,
1861 ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
1862 unsupported_future("linsolve not supported by provider")
1863 }
1864 fn inv<'a>(
1865 &'a self,
1866 _matrix: &'a GpuTensorHandle,
1867 _options: ProviderInvOptions,
1868 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1869 unsupported_future("inv not supported by provider")
1870 }
1871 fn pinv<'a>(
1872 &'a self,
1873 _matrix: &'a GpuTensorHandle,
1874 _options: ProviderPinvOptions,
1875 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1876 unsupported_future("pinv not supported by provider")
1877 }
1878 fn cond<'a>(
1879 &'a self,
1880 _matrix: &'a GpuTensorHandle,
1881 _norm: ProviderCondNorm,
1882 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1883 Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
1884 }
1885 fn norm<'a>(
1886 &'a self,
1887 _tensor: &'a GpuTensorHandle,
1888 _order: ProviderNormOrder,
1889 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1890 Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
1891 }
1892 fn rank<'a>(
1893 &'a self,
1894 _matrix: &'a GpuTensorHandle,
1895 _tolerance: Option<f64>,
1896 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1897 Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
1898 }
1899 fn rcond<'a>(
1900 &'a self,
1901 _matrix: &'a GpuTensorHandle,
1902 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1903 Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
1904 }
1905 fn mldivide<'a>(
1906 &'a self,
1907 _lhs: &'a GpuTensorHandle,
1908 _rhs: &'a GpuTensorHandle,
1909 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1910 Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
1911 }
1912 fn mrdivide<'a>(
1913 &'a self,
1914 _lhs: &'a GpuTensorHandle,
1915 _rhs: &'a GpuTensorHandle,
1916 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1917 Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
1918 }
1919 fn eig<'a>(
1920 &'a self,
1921 _a: &'a GpuTensorHandle,
1922 _compute_left: bool,
1923 ) -> AccelProviderFuture<'a, ProviderEigResult> {
1924 Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
1925 }
1926 fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
1927 Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
1928 }
1929
1930 fn chol<'a>(
1931 &'a self,
1932 _a: &'a GpuTensorHandle,
1933 _lower: bool,
1934 ) -> AccelProviderFuture<'a, ProviderCholResult> {
1935 Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
1936 }
1937 fn qr<'a>(
1938 &'a self,
1939 _a: &'a GpuTensorHandle,
1940 _options: ProviderQrOptions,
1941 ) -> AccelProviderFuture<'a, ProviderQrResult> {
1942 Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
1943 }
1944 fn take_matmul_sources(
1945 &self,
1946 _product: &GpuTensorHandle,
1947 ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
1948 None
1949 }
1950 fn qr_power_iter<'a>(
1951 &'a self,
1952 product: &'a GpuTensorHandle,
1953 _product_lhs: Option<&'a GpuTensorHandle>,
1954 q_handle: &'a GpuTensorHandle,
1955 options: &'a ProviderQrOptions,
1956 ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
1957 let _ = (product, q_handle, options);
1958 Box::pin(async move { Ok(None) })
1959 }
1960 fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1961 Err(anyhow::anyhow!("transpose not supported by provider"))
1962 }
1963 fn conv1d(
1964 &self,
1965 _signal: &GpuTensorHandle,
1966 _kernel: &GpuTensorHandle,
1967 _options: ProviderConv1dOptions,
1968 ) -> anyhow::Result<GpuTensorHandle> {
1969 Err(anyhow::anyhow!("conv1d not supported by provider"))
1970 }
1971 fn conv2d(
1972 &self,
1973 _signal: &GpuTensorHandle,
1974 _kernel: &GpuTensorHandle,
1975 _mode: ProviderConvMode,
1976 ) -> anyhow::Result<GpuTensorHandle> {
1977 Err(anyhow::anyhow!("conv2d not supported by provider"))
1978 }
1979 fn iir_filter<'a>(
1980 &'a self,
1981 _b: &'a GpuTensorHandle,
1982 _a: &'a GpuTensorHandle,
1983 _x: &'a GpuTensorHandle,
1984 _options: ProviderIirFilterOptions,
1985 ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
1986 Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
1987 }
1988 fn permute(
1990 &self,
1991 _handle: &GpuTensorHandle,
1992 _order: &[usize],
1993 ) -> anyhow::Result<GpuTensorHandle> {
1994 Err(anyhow::anyhow!("permute not supported by provider"))
1995 }
1996 fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1997 Err(anyhow::anyhow!("flip not supported by provider"))
1998 }
1999 fn circshift(
2000 &self,
2001 _handle: &GpuTensorHandle,
2002 _shifts: &[isize],
2003 ) -> anyhow::Result<GpuTensorHandle> {
2004 Err(anyhow::anyhow!("circshift not supported by provider"))
2005 }
2006 fn diff_dim(
2007 &self,
2008 _handle: &GpuTensorHandle,
2009 _order: usize,
2010 _dim: usize,
2011 ) -> anyhow::Result<GpuTensorHandle> {
2012 Err(anyhow::anyhow!("diff_dim not supported by provider"))
2013 }
2014 fn gradient_dim(
2015 &self,
2016 _handle: &GpuTensorHandle,
2017 _dim: usize,
2018 _spacing: f64,
2019 ) -> anyhow::Result<GpuTensorHandle> {
2020 Err(anyhow::anyhow!("gradient_dim not supported by provider"))
2021 }
2022 fn fft_dim<'a>(
2024 &'a self,
2025 _handle: &'a GpuTensorHandle,
2026 _len: Option<usize>,
2027 _dim: usize,
2028 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2029 unsupported_future("fft_dim not supported by provider")
2030 }
2031 fn ifft_dim<'a>(
2032 &'a self,
2033 _handle: &'a GpuTensorHandle,
2034 _len: Option<usize>,
2035 _dim: usize,
2036 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2037 unsupported_future("ifft_dim not supported by provider")
2038 }
2039 fn fft_extract_real<'a>(
2040 &'a self,
2041 _handle: &'a GpuTensorHandle,
2042 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2043 unsupported_future("fft_extract_real not supported by provider")
2044 }
2045 fn unique<'a>(
2046 &'a self,
2047 _handle: &'a GpuTensorHandle,
2048 _options: &'a UniqueOptions,
2049 ) -> AccelProviderFuture<'a, UniqueResult> {
2050 Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
2051 }
2052 fn union<'a>(
2053 &'a self,
2054 _a: &'a GpuTensorHandle,
2055 _b: &'a GpuTensorHandle,
2056 _options: &'a UnionOptions,
2057 ) -> AccelProviderFuture<'a, UnionResult> {
2058 Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
2059 }
2060 fn setdiff<'a>(
2061 &'a self,
2062 _a: &'a GpuTensorHandle,
2063 _b: &'a GpuTensorHandle,
2064 _options: &'a SetdiffOptions,
2065 ) -> AccelProviderFuture<'a, SetdiffResult> {
2066 Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
2067 }
2068 fn ismember<'a>(
2069 &'a self,
2070 _a: &'a GpuTensorHandle,
2071 _b: &'a GpuTensorHandle,
2072 _options: &'a IsMemberOptions,
2073 ) -> AccelProviderFuture<'a, IsMemberResult> {
2074 Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
2075 }
2076 fn reshape(
2077 &self,
2078 handle: &GpuTensorHandle,
2079 new_shape: &[usize],
2080 ) -> anyhow::Result<GpuTensorHandle> {
2081 let mut updated = handle.clone();
2082 updated.shape = new_shape.to_vec();
2083 Ok(updated)
2084 }
2085 fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
2087 Err(anyhow::anyhow!("cat not supported by provider"))
2088 }
2089 fn repmat(
2090 &self,
2091 _handle: &GpuTensorHandle,
2092 _reps: &[usize],
2093 ) -> anyhow::Result<GpuTensorHandle> {
2094 Err(anyhow::anyhow!("repmat not supported by provider"))
2095 }
2096 fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2098 Err(anyhow::anyhow!("kron not supported by provider"))
2099 }
2100 fn cross(
2102 &self,
2103 _lhs: &GpuTensorHandle,
2104 _rhs: &GpuTensorHandle,
2105 _dim: Option<usize>,
2106 ) -> anyhow::Result<GpuTensorHandle> {
2107 Err(anyhow::anyhow!("cross not supported by provider"))
2108 }
2109 fn reduce_sum<'a>(
2110 &'a self,
2111 _a: &'a GpuTensorHandle,
2112 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2113 unsupported_future("reduce_sum not supported by provider")
2114 }
2115 fn reduce_sum_dim<'a>(
2116 &'a self,
2117 _a: &'a GpuTensorHandle,
2118 _dim: usize,
2119 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2120 unsupported_future("reduce_sum_dim not supported by provider")
2121 }
2122 fn dot<'a>(
2123 &'a self,
2124 _lhs: &'a GpuTensorHandle,
2125 _rhs: &'a GpuTensorHandle,
2126 _dim: Option<usize>,
2127 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2128 unsupported_future("dot not supported by provider")
2129 }
2130 fn reduce_nnz<'a>(
2131 &'a self,
2132 _a: &'a GpuTensorHandle,
2133 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2134 unsupported_future("reduce_nnz not supported by provider")
2135 }
2136 fn reduce_nnz_dim<'a>(
2137 &'a self,
2138 _a: &'a GpuTensorHandle,
2139 _dim: usize,
2140 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2141 unsupported_future("reduce_nnz_dim not supported by provider")
2142 }
2143 fn reduce_prod<'a>(
2144 &'a self,
2145 _a: &'a GpuTensorHandle,
2146 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2147 unsupported_future("reduce_prod not supported by provider")
2148 }
2149 fn reduce_prod_dim<'a>(
2150 &'a self,
2151 _a: &'a GpuTensorHandle,
2152 _dim: usize,
2153 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2154 unsupported_future("reduce_prod_dim not supported by provider")
2155 }
2156 fn reduce_mean<'a>(
2157 &'a self,
2158 _a: &'a GpuTensorHandle,
2159 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2160 unsupported_future("reduce_mean not supported by provider")
2161 }
2162 fn reduce_mean_nd<'a>(
2164 &'a self,
2165 _a: &'a GpuTensorHandle,
2166 _dims_zero_based: &'a [usize],
2167 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2168 unsupported_future("reduce_mean_nd not supported by provider")
2169 }
2170 fn reduce_moments_nd<'a>(
2173 &'a self,
2174 _a: &'a GpuTensorHandle,
2175 _dims_zero_based: &'a [usize],
2176 ) -> AccelProviderFuture<'a, ProviderMoments2> {
2177 unsupported_future("reduce_moments_nd not supported by provider")
2178 }
2179 fn reduce_mean_dim<'a>(
2180 &'a self,
2181 _a: &'a GpuTensorHandle,
2182 _dim: usize,
2183 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2184 unsupported_future("reduce_mean_dim not supported by provider")
2185 }
2186 fn reduce_std<'a>(
2187 &'a self,
2188 _a: &'a GpuTensorHandle,
2189 _normalization: ProviderStdNormalization,
2190 _nan_mode: ProviderNanMode,
2191 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2192 unsupported_future("reduce_std not supported by provider")
2193 }
2194 fn reduce_std_dim<'a>(
2195 &'a self,
2196 _a: &'a GpuTensorHandle,
2197 _dim: usize,
2198 _normalization: ProviderStdNormalization,
2199 _nan_mode: ProviderNanMode,
2200 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2201 unsupported_future("reduce_std_dim not supported by provider")
2202 }
2203 fn reduce_any<'a>(
2204 &'a self,
2205 _a: &'a GpuTensorHandle,
2206 _omit_nan: bool,
2207 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2208 unsupported_future("reduce_any not supported by provider")
2209 }
2210 fn reduce_any_dim<'a>(
2211 &'a self,
2212 _a: &'a GpuTensorHandle,
2213 _dim: usize,
2214 _omit_nan: bool,
2215 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2216 unsupported_future("reduce_any_dim not supported by provider")
2217 }
2218 fn reduce_all<'a>(
2219 &'a self,
2220 _a: &'a GpuTensorHandle,
2221 _omit_nan: bool,
2222 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2223 unsupported_future("reduce_all not supported by provider")
2224 }
2225 fn reduce_all_dim<'a>(
2226 &'a self,
2227 _a: &'a GpuTensorHandle,
2228 _dim: usize,
2229 _omit_nan: bool,
2230 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2231 unsupported_future("reduce_all_dim not supported by provider")
2232 }
2233 fn reduce_median<'a>(
2234 &'a self,
2235 _a: &'a GpuTensorHandle,
2236 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2237 unsupported_future("reduce_median not supported by provider")
2238 }
2239 fn reduce_median_dim<'a>(
2240 &'a self,
2241 _a: &'a GpuTensorHandle,
2242 _dim: usize,
2243 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2244 unsupported_future("reduce_median_dim not supported by provider")
2245 }
2246 fn reduce_min<'a>(
2247 &'a self,
2248 _a: &'a GpuTensorHandle,
2249 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2250 unsupported_future("reduce_min not supported by provider")
2251 }
2252 fn reduce_min_dim<'a>(
2253 &'a self,
2254 _a: &'a GpuTensorHandle,
2255 _dim: usize,
2256 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2257 unsupported_future("reduce_min_dim not supported by provider")
2258 }
2259 fn reduce_max<'a>(
2260 &'a self,
2261 _a: &'a GpuTensorHandle,
2262 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2263 unsupported_future("reduce_max not supported by provider")
2264 }
2265 fn reduce_max_dim<'a>(
2266 &'a self,
2267 _a: &'a GpuTensorHandle,
2268 _dim: usize,
2269 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2270 unsupported_future("reduce_max_dim not supported by provider")
2271 }
2272 fn cumsum_scan(
2273 &self,
2274 _input: &GpuTensorHandle,
2275 _dim: usize,
2276 _direction: ProviderScanDirection,
2277 _nan_mode: ProviderNanMode,
2278 ) -> anyhow::Result<GpuTensorHandle> {
2279 Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2280 }
2281 fn cumprod_scan(
2282 &self,
2283 _input: &GpuTensorHandle,
2284 _dim: usize,
2285 _direction: ProviderScanDirection,
2286 _nan_mode: ProviderNanMode,
2287 ) -> anyhow::Result<GpuTensorHandle> {
2288 Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2289 }
2290 fn cummin_scan(
2291 &self,
2292 _input: &GpuTensorHandle,
2293 _dim: usize,
2294 _direction: ProviderScanDirection,
2295 _nan_mode: ProviderNanMode,
2296 ) -> anyhow::Result<ProviderCumminResult> {
2297 Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2298 }
2299 fn cummax_scan(
2300 &self,
2301 _input: &GpuTensorHandle,
2302 _dim: usize,
2303 _direction: ProviderScanDirection,
2304 _nan_mode: ProviderNanMode,
2305 ) -> anyhow::Result<ProviderCummaxResult> {
2306 Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2307 }
2308
2309 fn find(
2310 &self,
2311 _a: &GpuTensorHandle,
2312 _limit: Option<usize>,
2313 _direction: FindDirection,
2314 ) -> anyhow::Result<ProviderFindResult> {
2315 Err(anyhow::anyhow!("find not supported by provider"))
2316 }
2317
2318 fn fused_elementwise(
2319 &self,
2320 _shader: &str,
2321 _inputs: &[GpuTensorHandle],
2322 _output_shape: &[usize],
2323 _len: usize,
2324 ) -> anyhow::Result<GpuTensorHandle> {
2325 Err(anyhow::anyhow!(
2326 "fused_elementwise not supported by provider"
2327 ))
2328 }
2329
2330 fn fused_elementwise_multi(
2339 &self,
2340 _shader: &str,
2341 _inputs: &[GpuTensorHandle],
2342 _output_shape: &[usize],
2343 _len: usize,
2344 _num_outputs: usize,
2345 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2346 Err(anyhow::anyhow!(
2347 "fused_elementwise_multi not supported by provider"
2348 ))
2349 }
2350
2351 fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2353 Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2354 }
2355
2356 fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2358 Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2359 }
2360
2361 #[allow(clippy::too_many_arguments)]
2368 fn fused_reduction(
2369 &self,
2370 _shader: &str,
2371 _inputs: &[GpuTensorHandle],
2372 _output_shape: &[usize],
2373 _reduce_len: usize,
2374 _num_slices: usize,
2375 _workgroup_size: u32,
2376 _flavor: ReductionFlavor,
2377 ) -> anyhow::Result<GpuTensorHandle> {
2378 Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2379 }
2380
2381 fn warmup(&self) {}
2383
2384 fn fused_cache_counters(&self) -> (u64, u64) {
2386 (0, 0)
2387 }
2388
2389 fn last_warmup_millis(&self) -> Option<u64> {
2391 None
2392 }
2393
2394 fn telemetry_snapshot(&self) -> ProviderTelemetry {
2396 let (hits, misses) = self.fused_cache_counters();
2397 ProviderTelemetry {
2398 fused_elementwise: ProviderDispatchStats::default(),
2399 fused_reduction: ProviderDispatchStats::default(),
2400 matmul: ProviderDispatchStats::default(),
2401 linsolve: ProviderDispatchStats::default(),
2402 mldivide: ProviderDispatchStats::default(),
2403 mrdivide: ProviderDispatchStats::default(),
2404 upload_bytes: 0,
2405 download_bytes: 0,
2406 solve_fallbacks: Vec::new(),
2407 fusion_cache_hits: hits,
2408 fusion_cache_misses: misses,
2409 bind_group_cache_hits: 0,
2410 bind_group_cache_misses: 0,
2411 bind_group_cache_by_layout: None,
2412 kernel_launches: Vec::new(),
2413 }
2414 }
2415
2416 fn reset_telemetry(&self) {}
2418
2419 fn default_reduction_workgroup_size(&self) -> u32 {
2421 256
2422 }
2423
2424 fn two_pass_threshold(&self) -> usize {
2426 1024
2427 }
2428
2429 fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2431 ReductionTwoPassMode::Auto
2432 }
2433
2434 fn scatter_column(
2437 &self,
2438 _matrix: &GpuTensorHandle,
2439 _col_index: usize,
2440 _values: &GpuTensorHandle,
2441 ) -> anyhow::Result<GpuTensorHandle> {
2442 Err(anyhow::anyhow!("scatter_column not supported by provider"))
2443 }
2444
2445 fn scatter_row(
2448 &self,
2449 _matrix: &GpuTensorHandle,
2450 _row_index: usize,
2451 _values: &GpuTensorHandle,
2452 ) -> anyhow::Result<GpuTensorHandle> {
2453 Err(anyhow::anyhow!("scatter_row not supported by provider"))
2454 }
2455
2456 fn sub2ind(
2457 &self,
2458 _dims: &[usize],
2459 _strides: &[usize],
2460 _inputs: &[&GpuTensorHandle],
2461 _scalar_mask: &[bool],
2462 _len: usize,
2463 _output_shape: &[usize],
2464 ) -> anyhow::Result<GpuTensorHandle> {
2465 Err(anyhow::anyhow!("sub2ind not supported by provider"))
2466 }
2467
2468 fn supports_ind2sub(&self) -> bool {
2470 false
2471 }
2472
2473 fn ind2sub(
2475 &self,
2476 _dims: &[usize],
2477 _strides: &[usize],
2478 _indices: &GpuTensorHandle,
2479 _total: usize,
2480 _len: usize,
2481 _output_shape: &[usize],
2482 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2483 Err(anyhow::anyhow!("ind2sub not supported by provider"))
2484 }
2485
2486 fn issymmetric(
2488 &self,
2489 _matrix: &GpuTensorHandle,
2490 _kind: ProviderSymmetryKind,
2491 _tolerance: f64,
2492 ) -> anyhow::Result<bool> {
2493 Err(anyhow::anyhow!(
2494 "issymmetric predicate not supported by provider"
2495 ))
2496 }
2497
2498 fn ishermitian<'a>(
2500 &'a self,
2501 _matrix: &'a GpuTensorHandle,
2502 _kind: ProviderHermitianKind,
2503 _tolerance: f64,
2504 ) -> AccelProviderFuture<'a, bool> {
2505 Box::pin(async move {
2506 Err(anyhow::anyhow!(
2507 "ishermitian predicate not supported by provider"
2508 ))
2509 })
2510 }
2511
2512 fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2514 Err(anyhow::anyhow!("bandwidth not supported by provider"))
2515 }
2516
2517 fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2522 Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2523 }
2524}
2525
2526static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2527 Lazy::new(|| RwLock::new(None));
2528static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2529 Lazy::new(|| RwLock::new(HashMap::new()));
2530static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2531
2532#[cfg(not(target_arch = "wasm32"))]
2533thread_local! {
2534 static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2535}
2536
2537#[cfg(target_arch = "wasm32")]
2538static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2539 Lazy::new(|| Mutex::new(None));
2540
2541#[cfg(not(target_arch = "wasm32"))]
2542fn replace_thread_provider(
2543 provider: Option<&'static dyn AccelProvider>,
2544) -> Option<&'static dyn AccelProvider> {
2545 THREAD_PROVIDER.with(|cell| {
2546 let prev = cell.get();
2547 cell.set(provider);
2548 prev
2549 })
2550}
2551
2552#[cfg(target_arch = "wasm32")]
2553fn replace_thread_provider(
2554 provider: Option<&'static dyn AccelProvider>,
2555) -> Option<&'static dyn AccelProvider> {
2556 let mut slot = WASM_THREAD_PROVIDER
2557 .lock()
2558 .expect("wasm provider mutex poisoned");
2559 let prev = *slot;
2560 *slot = provider;
2561 prev
2562}
2563
2564#[cfg(not(target_arch = "wasm32"))]
2565fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2566 THREAD_PROVIDER.with(|cell| cell.get())
2567}
2568
2569#[cfg(target_arch = "wasm32")]
2570fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2571 WASM_THREAD_PROVIDER
2572 .lock()
2573 .expect("wasm provider mutex poisoned")
2574 .as_ref()
2575 .copied()
2576}
2577
2578pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2586 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2587 *guard = Some(p);
2588 }
2589 register_provider_for_device(p.device_id(), p);
2590}
2591
2592unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2593 if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2594 guard.insert(device_id, provider);
2595 }
2596}
2597
2598pub fn provider() -> Option<&'static dyn AccelProvider> {
2599 if let Some(p) = current_thread_provider() {
2600 return Some(p);
2601 }
2602 GLOBAL_PROVIDER
2603 .read()
2604 .ok()
2605 .and_then(|guard| guard.as_ref().copied())
2606}
2607
2608pub fn clear_provider() {
2610 replace_thread_provider(None);
2611 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2612 *guard = None;
2613 }
2614 if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2615 map.clear();
2616 }
2617}
2618
2619pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2620 if let Some(registered) = PROVIDER_REGISTRY
2621 .read()
2622 .ok()
2623 .and_then(|guard| guard.get(&device_id).copied())
2624 {
2625 return Some(registered);
2626 }
2627 if let Some(thread_provider) = current_thread_provider() {
2628 if thread_provider.device_id() == device_id {
2629 return Some(thread_provider);
2630 }
2631 }
2632 GLOBAL_PROVIDER
2635 .read()
2636 .ok()
2637 .and_then(|guard| guard.as_ref().copied())
2638}
2639
2640pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2641 provider_for_device(handle.device_id)
2642}
2643
2644pub fn spawn_handle_concurrency_for(handle: &GpuTensorHandle) -> Option<SpawnHandleConcurrency> {
2645 provider_for_handle(handle).map(AccelProvider::spawn_handle_concurrency)
2646}
2647
2648pub fn next_device_id() -> u32 {
2649 DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2650}
2651
2652pub struct ThreadProviderGuard {
2653 prev: Option<&'static dyn AccelProvider>,
2654}
2655
2656impl ThreadProviderGuard {
2657 pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2658 let prev = replace_thread_provider(provider);
2659 ThreadProviderGuard { prev }
2660 }
2661}
2662
2663impl Drop for ThreadProviderGuard {
2664 fn drop(&mut self) {
2665 let prev = self.prev.take();
2666 replace_thread_provider(prev);
2667 }
2668}
2669
2670pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2671 replace_thread_provider(provider);
2672}
2673
2674pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2676 if let Some(p) = provider() {
2677 if let Ok(h) = p.elem_add(a, b).await {
2678 return Some(h);
2679 }
2680 }
2681 None
2682}
2683
2684pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2686 if let Some(p) = provider() {
2687 if let Ok(h) = p.elem_hypot(a, b).await {
2688 return Some(h);
2689 }
2690 }
2691 None
2692}
2693
2694pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2696 if let Some(p) = provider() {
2697 if let Ok(h) = p.elem_max(a, b).await {
2698 return Some(h);
2699 }
2700 }
2701 None
2702}
2703
2704pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2706 if let Some(p) = provider() {
2707 if let Ok(h) = p.elem_min(a, b).await {
2708 return Some(h);
2709 }
2710 }
2711 None
2712}
2713
2714pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2716 if let Some(p) = provider() {
2717 if let Ok(h) = p.elem_atan2(y, x).await {
2718 return Some(h);
2719 }
2720 }
2721 None
2722}
2723
2724#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2726pub struct HostTensorOwned {
2727 pub data: Vec<f64>,
2728 pub shape: Vec<usize>,
2729 pub storage: GpuTensorStorage,
2730}
2731
2732#[derive(Debug)]
2733pub struct HostTensorView<'a> {
2734 pub data: &'a [f64],
2735 pub shape: &'a [usize],
2736}
2737
2738#[derive(Debug)]
2740pub struct MeshgridAxisView<'a> {
2741 pub data: &'a [f64],
2742}
2743
2744#[derive(Debug, Clone)]
2746pub struct ProviderMeshgridResult {
2747 pub outputs: Vec<GpuTensorHandle>,
2748}
2749
2750#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
2756pub enum ScaleOp {
2757 Multiply,
2758 Divide,
2759}
2760
2761#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2762pub struct MatmulEpilogue {
2763 pub alpha: f64,
2765 pub beta: f64,
2767 pub row_scale: Option<GpuTensorHandle>,
2769 pub col_scale: Option<GpuTensorHandle>,
2771 pub row_op: ScaleOp,
2773 pub col_op: ScaleOp,
2775 #[serde(default)]
2777 pub clamp_min: Option<f64>,
2778 #[serde(default)]
2780 pub clamp_max: Option<f64>,
2781 #[serde(default)]
2783 pub pow_exponent: Option<f64>,
2784 #[serde(default)]
2786 pub diag_output: Option<GpuTensorHandle>,
2787}
2788
2789impl MatmulEpilogue {
2790 pub fn noop() -> Self {
2791 Self {
2792 alpha: 1.0,
2793 beta: 0.0,
2794 row_scale: None,
2795 col_scale: None,
2796 row_op: ScaleOp::Multiply,
2797 col_op: ScaleOp::Multiply,
2798 clamp_min: None,
2799 clamp_max: None,
2800 pow_exponent: None,
2801 diag_output: None,
2802 }
2803 }
2804 pub fn is_noop(&self) -> bool {
2805 self.alpha == 1.0
2806 && self.beta == 0.0
2807 && self.row_scale.is_none()
2808 && self.col_scale.is_none()
2809 && self.clamp_min.is_none()
2810 && self.clamp_max.is_none()
2811 && self.pow_exponent.is_none()
2812 && self.diag_output.is_none()
2813 }
2814}
2815
2816#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
2817pub struct PowerStepEpilogue {
2818 pub epsilon: f64,
2819}
2820
2821impl Default for PowerStepEpilogue {
2822 fn default() -> Self {
2823 Self { epsilon: 0.0 }
2824 }
2825}
2826
2827#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2828pub struct ImageNormalizeDescriptor {
2829 pub batch: usize,
2830 pub height: usize,
2831 pub width: usize,
2832 pub epsilon: f64,
2833 #[serde(default)]
2834 pub gain: Option<f64>,
2835 #[serde(default)]
2836 pub bias: Option<f64>,
2837 #[serde(default)]
2838 pub gamma: Option<f64>,
2839}
2840
2841#[cfg(test)]
2842mod tests {
2843 use super::*;
2844
2845 struct TestProvider {
2846 device_id: u32,
2847 name: &'static str,
2848 spawn_concurrency: SpawnHandleConcurrency,
2849 }
2850
2851 impl AccelProvider for TestProvider {
2852 fn upload(&self, _host: &HostTensorView) -> anyhow::Result<GpuTensorHandle> {
2853 Err(anyhow!("test provider upload should not be called"))
2854 }
2855
2856 fn download<'a>(&'a self, _h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a> {
2857 unsupported_future("test provider download should not be called")
2858 }
2859
2860 fn free(&self, _h: &GpuTensorHandle) -> anyhow::Result<()> {
2861 Err(anyhow!("test provider free should not be called"))
2862 }
2863
2864 fn device_info(&self) -> String {
2865 self.name.to_string()
2866 }
2867
2868 fn device_id(&self) -> u32 {
2869 self.device_id
2870 }
2871
2872 fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
2873 self.spawn_concurrency
2874 }
2875 }
2876
2877 static PROVIDER_TEST_LOCK: Lazy<std::sync::Mutex<()>> = Lazy::new(|| std::sync::Mutex::new(()));
2878 static PROVIDER_A: TestProvider = TestProvider {
2879 device_id: 101,
2880 name: "provider-a",
2881 spawn_concurrency: SpawnHandleConcurrency::ImmutableShare,
2882 };
2883 static PROVIDER_B: TestProvider = TestProvider {
2884 device_id: 202,
2885 name: "provider-b",
2886 spawn_concurrency: SpawnHandleConcurrency::Reject,
2887 };
2888 static PROVIDER_C: TestProvider = TestProvider {
2889 device_id: 303,
2890 name: "provider-c",
2891 spawn_concurrency: SpawnHandleConcurrency::CopyOnWrite,
2892 };
2893
2894 fn register_test_providers() {
2895 clear_provider();
2896 unsafe {
2897 register_provider(&PROVIDER_A);
2898 register_provider(&PROVIDER_B);
2899 }
2900 }
2901
2902 fn test_handle(device_id: u32) -> GpuTensorHandle {
2903 GpuTensorHandle {
2904 shape: vec![1],
2905 device_id,
2906 buffer_id: 42,
2907 }
2908 }
2909
2910 #[test]
2911 fn provider_for_device_prefers_registered_device_over_thread_provider() {
2912 let _lock = PROVIDER_TEST_LOCK
2913 .lock()
2914 .expect("provider test lock poisoned");
2915 register_test_providers();
2916 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
2917
2918 let provider = provider_for_device(PROVIDER_A.device_id()).expect("provider for device");
2919
2920 assert_eq!(provider.device_info(), PROVIDER_A.name);
2921 clear_provider();
2922 }
2923
2924 #[test]
2925 fn provider_for_handle_uses_handle_device_owner() {
2926 let _lock = PROVIDER_TEST_LOCK
2927 .lock()
2928 .expect("provider test lock poisoned");
2929 register_test_providers();
2930 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
2931
2932 let provider =
2933 provider_for_handle(&test_handle(PROVIDER_A.device_id())).expect("provider for handle");
2934
2935 assert_eq!(provider.device_info(), PROVIDER_A.name);
2936 clear_provider();
2937 }
2938
2939 #[test]
2940 fn spawn_handle_concurrency_for_uses_registered_owner() {
2941 let _lock = PROVIDER_TEST_LOCK
2942 .lock()
2943 .expect("provider test lock poisoned");
2944 register_test_providers();
2945 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
2946
2947 let concurrency = spawn_handle_concurrency_for(&test_handle(PROVIDER_A.device_id()))
2948 .expect("spawn concurrency");
2949
2950 assert_eq!(concurrency, PROVIDER_A.spawn_concurrency);
2951 clear_provider();
2952 }
2953
2954 #[test]
2955 fn provider_keeps_thread_local_active_provider_semantics() {
2956 let _lock = PROVIDER_TEST_LOCK
2957 .lock()
2958 .expect("provider test lock poisoned");
2959 register_test_providers();
2960 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_A));
2961
2962 let active = provider().expect("active provider");
2963
2964 assert_eq!(active.device_info(), PROVIDER_A.name);
2965 clear_provider();
2966 }
2967
2968 #[test]
2969 fn unregistered_thread_provider_only_matches_own_device_before_global_fallback() {
2970 let _lock = PROVIDER_TEST_LOCK
2971 .lock()
2972 .expect("provider test lock poisoned");
2973 clear_provider();
2974 unsafe {
2975 register_provider(&PROVIDER_A);
2976 }
2977 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_C));
2978
2979 let own_device = provider_for_device(PROVIDER_C.device_id()).expect("own provider");
2980 let fallback = provider_for_device(404).expect("global fallback provider");
2981
2982 assert_eq!(own_device.device_info(), PROVIDER_C.name);
2983 assert_eq!(fallback.device_info(), PROVIDER_A.name);
2984 clear_provider();
2985 }
2986}