1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyMarkFn = fn(&GpuTensorHandle);
17type ResidencyClearFn = fn(&GpuTensorHandle);
18type SequenceThresholdFn = fn() -> Option<usize>;
19type WorkgroupSizeHintFn = fn() -> Option<u32>;
20
21static RESIDENCY_MARK: OnceCell<ResidencyMarkFn> = OnceCell::new();
22static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
23static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
24static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
25
26static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
27static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
28 Lazy::new(|| RwLock::new(HashMap::new()));
29static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
30 Lazy::new(|| RwLock::new(HashMap::new()));
31
32static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
33 Lazy::new(|| RwLock::new(HashMap::new()));
34static HANDLE_STORAGES: Lazy<RwLock<HashMap<u64, GpuTensorStorage>>> =
35 Lazy::new(|| RwLock::new(HashMap::new()));
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct TransposeInfo {
39 pub base_rows: usize,
40 pub base_cols: usize,
41}
42
43pub fn register_residency_mark(handler: ResidencyMarkFn) {
46 let _ = RESIDENCY_MARK.set(handler);
47}
48
49pub fn mark_residency(handle: &GpuTensorHandle) {
52 if let Some(handler) = RESIDENCY_MARK.get() {
53 handler(handle);
54 }
55}
56
57pub fn register_residency_clear(handler: ResidencyClearFn) {
61 let _ = RESIDENCY_CLEAR.set(handler);
62}
63
64pub fn clear_residency(handle: &GpuTensorHandle) {
67 if let Some(handler) = RESIDENCY_CLEAR.get() {
68 handler(handle);
69 }
70}
71
72pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
76 let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
77}
78
79pub fn sequence_threshold_hint() -> Option<usize> {
81 SEQUENCE_THRESHOLD_PROVIDER
82 .get()
83 .and_then(|provider| provider())
84}
85
86pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
90 let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
91}
92
93pub fn workgroup_size_hint() -> Option<u32> {
95 WORKGROUP_SIZE_HINT_PROVIDER
96 .get()
97 .and_then(|provider| provider())
98}
99
100pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
103 provider().and_then(|p| p.export_context(kind))
104}
105
106#[cfg(feature = "wgpu")]
110pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
111 provider().and_then(|p| p.export_wgpu_buffer(handle))
112}
113
114pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
117 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
118 guard.insert(handle.buffer_id, precision);
119 }
120}
121
122pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
124 HANDLE_PRECISIONS
125 .read()
126 .ok()
127 .and_then(|guard| guard.get(&handle.buffer_id).copied())
128}
129
130pub fn clear_handle_precision(handle: &GpuTensorHandle) {
132 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
133 guard.remove(&handle.buffer_id);
134 }
135}
136
137pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
140 if let Ok(mut guard) = LOGICAL_HANDLES.write() {
141 if logical {
142 guard.insert(handle.buffer_id);
143 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
144 *hits.entry(handle.buffer_id).or_insert(0) += 1;
145 }
146 } else {
147 guard.remove(&handle.buffer_id);
148 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
149 hits.remove(&handle.buffer_id);
150 }
151 }
152 }
153}
154
155pub fn clear_handle_logical(handle: &GpuTensorHandle) {
157 set_handle_logical(handle, false);
158}
159
160pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
162 LOGICAL_HANDLES
163 .read()
164 .map(|guard| guard.contains(&handle.buffer_id))
165 .unwrap_or(false)
166}
167
168pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
169 LOGICAL_HANDLE_HITS
170 .read()
171 .ok()
172 .and_then(|guard| guard.get(&buffer_id).copied())
173}
174
175pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
176 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
177 guard.insert(
178 handle.buffer_id,
179 TransposeInfo {
180 base_rows,
181 base_cols,
182 },
183 );
184 }
185}
186
187pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
188 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
189 guard.remove(&handle.buffer_id);
190 }
191}
192
193pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
194 TRANSPOSED_HANDLES
195 .read()
196 .ok()
197 .and_then(|guard| guard.get(&handle.buffer_id).copied())
198}
199
200pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
201 handle_transpose_info(handle).is_some()
202}
203
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum GpuTensorStorage {
206 Real,
207 ComplexInterleaved,
208}
209
210impl Default for GpuTensorStorage {
211 fn default() -> Self {
212 Self::Real
213 }
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct GpuTensorHandle {
218 pub shape: Vec<usize>,
219 pub device_id: u32,
220 pub buffer_id: u64,
221}
222
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct ApiDeviceInfo {
225 pub device_id: u32,
226 pub name: String,
227 pub vendor: String,
228 pub memory_bytes: Option<u64>,
229 pub backend: Option<String>,
230}
231
232#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
233pub struct ReduceDimResult {
234 pub values: GpuTensorHandle,
235 pub indices: GpuTensorHandle,
236}
237
238#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
239pub struct ProviderCumminResult {
240 pub values: GpuTensorHandle,
241 pub indices: GpuTensorHandle,
242}
243
244pub type ProviderCummaxResult = ProviderCumminResult;
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum AccelContextKind {
253 Plotting,
254}
255
256#[derive(Clone)]
258pub enum AccelContextHandle {
259 #[cfg(feature = "wgpu")]
260 Wgpu(WgpuContextHandle),
261}
262
263impl AccelContextHandle {
264 #[cfg(feature = "wgpu")]
266 pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
267 match self {
268 AccelContextHandle::Wgpu(ctx) => Some(ctx),
269 }
270 }
271}
272
273#[cfg(feature = "wgpu")]
275#[derive(Clone)]
276pub struct WgpuContextHandle {
277 pub instance: Arc<wgpu::Instance>,
278 pub device: Arc<wgpu::Device>,
279 pub queue: Arc<wgpu::Queue>,
280 pub adapter: Arc<wgpu::Adapter>,
281 pub adapter_info: wgpu::AdapterInfo,
282 pub limits: wgpu::Limits,
283 pub features: wgpu::Features,
284}
285
286#[cfg(feature = "wgpu")]
288#[derive(Clone)]
289pub struct WgpuBufferRef {
290 pub buffer: Arc<wgpu::Buffer>,
291 pub len: usize,
292 pub shape: Vec<usize>,
293 pub element_size: usize,
294 pub precision: ProviderPrecision,
295}
296
297pub fn set_handle_storage(handle: &GpuTensorHandle, storage: GpuTensorStorage) {
298 if let Ok(mut guard) = HANDLE_STORAGES.write() {
299 guard.insert(handle.buffer_id, storage);
300 }
301}
302
303pub fn handle_storage(handle: &GpuTensorHandle) -> GpuTensorStorage {
304 HANDLE_STORAGES
305 .read()
306 .ok()
307 .and_then(|guard| guard.get(&handle.buffer_id).cloned())
308 .unwrap_or(GpuTensorStorage::Real)
309}
310
311pub fn clear_handle_storage(handle: &GpuTensorHandle) {
312 if let Ok(mut guard) = HANDLE_STORAGES.write() {
313 guard.remove(&handle.buffer_id);
314 }
315}
316
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
318pub enum PagefunOp {
319 Mtimes,
320}
321
322#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
323pub struct PagefunRequest {
324 pub op: PagefunOp,
325 pub inputs: Vec<GpuTensorHandle>,
326 pub output_shape: Vec<usize>,
327 pub page_dims: Vec<usize>,
328 pub input_page_dims: Vec<Vec<usize>>,
329}
330
331#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
332pub enum FindDirection {
333 First,
334 Last,
335}
336
337#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
338pub struct ProviderFindResult {
339 pub linear: GpuTensorHandle,
340 pub rows: GpuTensorHandle,
341 pub cols: GpuTensorHandle,
342 pub values: Option<GpuTensorHandle>,
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
346pub struct ProviderBandwidth {
347 pub lower: u32,
348 pub upper: u32,
349}
350
351#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
352pub enum ProviderSymmetryKind {
353 Symmetric,
354 Skew,
355}
356
357#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
358pub enum ProviderHermitianKind {
359 Hermitian,
360 Skew,
361}
362
363#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
364pub struct ProviderLuResult {
365 pub combined: GpuTensorHandle,
366 pub lower: GpuTensorHandle,
367 pub upper: GpuTensorHandle,
368 pub perm_matrix: GpuTensorHandle,
369 pub perm_vector: GpuTensorHandle,
370}
371
372#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
373pub struct ProviderCholResult {
374 pub factor: GpuTensorHandle,
375 pub info: u32,
377}
378
379#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
380pub struct ProviderQrResult {
381 pub q: GpuTensorHandle,
382 pub r: GpuTensorHandle,
383 pub perm_matrix: GpuTensorHandle,
384 pub perm_vector: GpuTensorHandle,
385}
386
387#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
388pub struct ProviderQrPowerIterResult {
389 pub q: GpuTensorHandle,
390 pub r: GpuTensorHandle,
391 pub perm_matrix: GpuTensorHandle,
392 pub perm_vector: GpuTensorHandle,
393}
394
395#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
396pub struct ProviderLinsolveOptions {
397 pub lower: bool,
398 pub upper: bool,
399 pub rectangular: bool,
400 pub transposed: bool,
401 pub conjugate: bool,
402 pub symmetric: bool,
403 pub posdef: bool,
404 pub need_rcond: bool,
405 pub rcond: Option<f64>,
406}
407
408#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
409pub struct ProviderLinsolveResult {
410 pub solution: GpuTensorHandle,
411 pub reciprocal_condition: f64,
412}
413
414#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
415pub struct ProviderPinvOptions {
416 pub tolerance: Option<f64>,
417}
418
419#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
420pub struct ProviderPolyvalMu {
421 pub mean: f64,
422 pub scale: f64,
423}
424
425#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
426pub struct ProviderPolyvalOptions {
427 pub mu: Option<ProviderPolyvalMu>,
428}
429
430#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
431pub struct ProviderInvOptions {}
432
433#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
434pub struct ProviderPolyfitResult {
435 pub coefficients: Vec<f64>,
436 pub r_matrix: Vec<f64>,
437 pub normr: f64,
438 pub df: f64,
439 pub mu: [f64; 2],
440}
441
442#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
444pub struct ProviderPolyderQuotient {
445 pub numerator: GpuTensorHandle,
446 pub denominator: GpuTensorHandle,
447}
448
449#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
451pub enum ProviderCondNorm {
452 Two,
453 One,
454 Inf,
455 Fro,
456}
457
458#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
460pub enum ProviderNormOrder {
461 Two,
462 One,
463 Inf,
464 NegInf,
465 Zero,
466 Fro,
467 Nuc,
468 P(f64),
469}
470
471#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
472pub struct ProviderEigResult {
473 pub eigenvalues: GpuTensorHandle,
474 pub diagonal: GpuTensorHandle,
475 pub right: GpuTensorHandle,
476 pub left: Option<GpuTensorHandle>,
477}
478
479#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
480pub enum ProviderQrPivot {
481 Matrix,
482 Vector,
483}
484
485#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
486pub struct ProviderQrOptions {
487 pub economy: bool,
488 pub pivot: ProviderQrPivot,
489}
490
491impl Default for ProviderQrOptions {
492 fn default() -> Self {
493 Self {
494 economy: false,
495 pivot: ProviderQrPivot::Matrix,
496 }
497 }
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
501pub enum ProviderPrecision {
502 F32,
503 F64,
504}
505
506#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
511pub enum SpawnHandleConcurrency {
512 ImmutableShare,
514 CopyOnWrite,
516 SynchronizedMutation,
518 Reject,
520}
521
522impl SpawnHandleConcurrency {
523 pub fn as_str(self) -> &'static str {
524 match self {
525 SpawnHandleConcurrency::ImmutableShare => "immutable_share",
526 SpawnHandleConcurrency::CopyOnWrite => "copy_on_write",
527 SpawnHandleConcurrency::SynchronizedMutation => "synchronized_mutation",
528 SpawnHandleConcurrency::Reject => "reject",
529 }
530 }
531}
532
533#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
534pub enum ReductionTwoPassMode {
535 Auto,
536 ForceOn,
537 ForceOff,
538}
539
540impl ReductionTwoPassMode {
541 pub fn as_str(self) -> &'static str {
542 match self {
543 ReductionTwoPassMode::Auto => "auto",
544 ReductionTwoPassMode::ForceOn => "force_on",
545 ReductionTwoPassMode::ForceOff => "force_off",
546 }
547 }
548}
549
550#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
551pub enum ReductionFlavor {
552 Sum,
553 Mean,
554 CustomScale(f64),
555}
556
557impl ReductionFlavor {
558 pub fn is_mean(self) -> bool {
559 matches!(self, ReductionFlavor::Mean)
560 }
561
562 pub fn scale(self, reduce_len: usize) -> f64 {
563 match self {
564 ReductionFlavor::Sum => 1.0,
565 ReductionFlavor::Mean => {
566 if reduce_len == 0 {
567 1.0
568 } else {
569 1.0 / reduce_len as f64
570 }
571 }
572 ReductionFlavor::CustomScale(scale) => scale,
573 }
574 }
575}
576
577#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
579pub enum CorrcoefNormalization {
580 Unbiased,
581 Biased,
582}
583
584#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
586pub enum CorrcoefRows {
587 All,
588 Complete,
589 Pairwise,
590}
591
592#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
594pub struct CorrcoefOptions {
595 pub normalization: CorrcoefNormalization,
596 pub rows: CorrcoefRows,
597}
598
599impl Default for CorrcoefOptions {
600 fn default() -> Self {
601 Self {
602 normalization: CorrcoefNormalization::Unbiased,
603 rows: CorrcoefRows::All,
604 }
605 }
606}
607
608#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
610pub enum CovNormalization {
611 Unbiased,
612 Biased,
613}
614
615#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
617pub enum CovRows {
618 All,
619 OmitRows,
620 PartialRows,
621}
622
623#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
625pub struct CovarianceOptions {
626 pub normalization: CovNormalization,
627 pub rows: CovRows,
628 pub has_weight_vector: bool,
629}
630
631impl Default for CovarianceOptions {
632 fn default() -> Self {
633 Self {
634 normalization: CovNormalization::Unbiased,
635 rows: CovRows::All,
636 has_weight_vector: false,
637 }
638 }
639}
640
641#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
643pub enum ProviderStdNormalization {
644 Sample,
645 Population,
646}
647
648#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
650pub enum ProviderNanMode {
651 Include,
652 Omit,
653}
654
655#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
657pub enum ProviderScanDirection {
658 Forward,
659 Reverse,
660}
661
662#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
664pub enum SortOrder {
665 Ascend,
666 Descend,
667}
668
669#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
671pub enum SortComparison {
672 Auto,
673 Real,
674 Abs,
675}
676
677#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
679pub struct SortResult {
680 pub values: HostTensorOwned,
681 pub indices: HostTensorOwned,
682}
683
684#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
685pub struct SortRowsColumnSpec {
686 pub index: usize,
687 pub order: SortOrder,
688}
689
690#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
692pub enum UniqueOrder {
693 Sorted,
694 Stable,
695}
696
697#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
699pub enum UniqueOccurrence {
700 First,
701 Last,
702}
703
704#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
706pub struct UniqueOptions {
707 pub rows: bool,
708 pub order: UniqueOrder,
709 pub occurrence: UniqueOccurrence,
710}
711
712#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
714pub struct UniqueResult {
715 pub values: HostTensorOwned,
716 pub ia: HostTensorOwned,
717 pub ic: HostTensorOwned,
718}
719
720#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
722pub enum UnionOrder {
723 Sorted,
724 Stable,
725}
726
727#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
729pub struct UnionOptions {
730 pub rows: bool,
731 pub order: UnionOrder,
732}
733
734#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
736pub struct UnionResult {
737 pub values: HostTensorOwned,
738 pub ia: HostTensorOwned,
739 pub ib: HostTensorOwned,
740}
741
742#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
744pub enum FspecialFilter {
745 Average {
746 rows: u32,
747 cols: u32,
748 },
749 Disk {
750 radius: f64,
751 size: u32,
752 },
753 Gaussian {
754 rows: u32,
755 cols: u32,
756 sigma: f64,
757 },
758 Laplacian {
759 alpha: f64,
760 },
761 Log {
762 rows: u32,
763 cols: u32,
764 sigma: f64,
765 },
766 Motion {
767 length: u32,
768 kernel_size: u32,
769 angle_degrees: f64,
770 oversample: u32,
771 },
772 Prewitt,
773 Sobel,
774 Unsharp {
775 alpha: f64,
776 },
777}
778
779#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
781pub struct FspecialRequest {
782 pub filter: FspecialFilter,
783}
784
785#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
787pub enum ImfilterPadding {
788 Constant,
789 Replicate,
790 Symmetric,
791 Circular,
792}
793
794#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
796pub enum ImfilterShape {
797 Same,
798 Full,
799 Valid,
800}
801
802#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
804pub enum ImfilterMode {
805 Correlation,
806 Convolution,
807}
808
809#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
811pub struct ImfilterOptions {
812 pub padding: ImfilterPadding,
813 pub constant_value: f64,
814 pub shape: ImfilterShape,
815 pub mode: ImfilterMode,
816}
817
818impl Default for ImfilterOptions {
819 fn default() -> Self {
820 Self {
821 padding: ImfilterPadding::Constant,
822 constant_value: 0.0,
823 shape: ImfilterShape::Same,
824 mode: ImfilterMode::Correlation,
825 }
826 }
827}
828
829#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
831pub enum SetdiffOrder {
832 Sorted,
833 Stable,
834}
835
836#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
838pub struct SetdiffOptions {
839 pub rows: bool,
840 pub order: SetdiffOrder,
841}
842
843#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
845pub struct SetdiffResult {
846 pub values: HostTensorOwned,
847 pub ia: HostTensorOwned,
848}
849
850#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
852pub struct IsMemberOptions {
853 pub rows: bool,
854}
855
856#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
858pub struct HostLogicalOwned {
859 pub data: Vec<u8>,
860 pub shape: Vec<usize>,
861}
862
863#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
865pub struct IsMemberResult {
866 pub mask: HostLogicalOwned,
867 pub loc: HostTensorOwned,
868}
869
870#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
871pub enum ProviderConvMode {
872 Full,
873 Same,
874 Valid,
875}
876
877#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
878pub enum ProviderConvOrientation {
879 Row,
880 Column,
881}
882
883#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
884pub struct ProviderConv1dOptions {
885 pub mode: ProviderConvMode,
886 pub orientation: ProviderConvOrientation,
887}
888
889#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
890pub struct ProviderIirFilterOptions {
891 pub dim: usize,
893 pub zi: Option<GpuTensorHandle>,
895}
896
897#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
898pub struct ProviderIirFilterResult {
899 pub output: GpuTensorHandle,
901 pub final_state: Option<GpuTensorHandle>,
903}
904
905#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
906pub struct ProviderMoments2 {
907 pub mean: GpuTensorHandle,
908 pub ex2: GpuTensorHandle,
909}
910
911#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
912pub struct ProviderDispatchStats {
913 pub count: u64,
915 pub total_wall_time_ns: u64,
917}
918
919#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
920pub struct ProviderFallbackStat {
921 pub reason: String,
922 pub count: u64,
923}
924
925#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
926pub struct ProviderTelemetry {
927 pub fused_elementwise: ProviderDispatchStats,
928 pub fused_reduction: ProviderDispatchStats,
929 pub matmul: ProviderDispatchStats,
930 pub linsolve: ProviderDispatchStats,
931 pub mldivide: ProviderDispatchStats,
932 pub mrdivide: ProviderDispatchStats,
933 pub upload_bytes: u64,
934 pub download_bytes: u64,
935 pub solve_fallbacks: Vec<ProviderFallbackStat>,
936 pub fusion_cache_hits: u64,
937 pub fusion_cache_misses: u64,
938 pub bind_group_cache_hits: u64,
939 pub bind_group_cache_misses: u64,
940 pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
942 pub kernel_launches: Vec<KernelLaunchTelemetry>,
944}
945
946#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
947pub struct BindGroupLayoutTelemetry {
948 pub tag: String,
949 pub hits: u64,
950 pub misses: u64,
951}
952
953#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
954pub struct KernelAttrTelemetry {
955 pub key: String,
956 pub value: u64,
957}
958
959#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
960pub struct KernelLaunchTelemetry {
961 pub kernel: String,
962 pub precision: Option<String>,
963 pub shape: Vec<KernelAttrTelemetry>,
964 pub tuning: Vec<KernelAttrTelemetry>,
965}
966
967pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
968pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
969
970fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
971 Box::pin(async move { Err(anyhow::anyhow!(message)) })
972}
973
974pub trait AccelProvider: Send + Sync {
976 fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
977 fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
978 fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
979 fn device_info(&self) -> String;
980 fn device_id(&self) -> u32 {
981 0
982 }
983
984 fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
990 SpawnHandleConcurrency::Reject
991 }
992
993 fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
996 None
997 }
998
999 #[cfg(feature = "wgpu")]
1001 fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
1002 let _ = _handle;
1003 None
1004 }
1005
1006 fn gather_linear(
1009 &self,
1010 _source: &GpuTensorHandle,
1011 _indices: &[u32],
1012 _output_shape: &[usize],
1013 ) -> anyhow::Result<GpuTensorHandle> {
1014 Err(anyhow::anyhow!("gather_linear not supported by provider"))
1015 }
1016
1017 fn scatter_linear(
1021 &self,
1022 _target: &GpuTensorHandle,
1023 _indices: &[u32],
1024 _values: &GpuTensorHandle,
1025 ) -> anyhow::Result<()> {
1026 Err(anyhow::anyhow!("scatter_linear not supported by provider"))
1027 }
1028
1029 fn device_info_struct(&self) -> ApiDeviceInfo {
1031 ApiDeviceInfo {
1032 device_id: 0,
1033 name: self.device_info(),
1034 vendor: String::new(),
1035 memory_bytes: None,
1036 backend: None,
1037 }
1038 }
1039
1040 fn precision(&self) -> ProviderPrecision {
1041 ProviderPrecision::F64
1042 }
1043
1044 fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
1046 Err(anyhow::anyhow!("read_scalar not supported by provider"))
1047 }
1048
1049 fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1051 Err(anyhow::anyhow!("zeros not supported by provider"))
1052 }
1053
1054 fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1056 Err(anyhow::anyhow!("ones not supported by provider"))
1057 }
1058
1059 fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1061 self.zeros(&prototype.shape)
1062 }
1063
1064 fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
1066 if value == 0.0 {
1067 return self.zeros(shape);
1068 }
1069 if let Ok(base) = self.zeros(shape) {
1070 match self.scalar_add(&base, value) {
1071 Ok(out) => {
1072 let _ = self.free(&base);
1073 return Ok(out);
1074 }
1075 Err(_) => {
1076 let _ = self.free(&base);
1077 }
1078 }
1079 }
1080 let len: usize = shape.iter().copied().product();
1081 let data = vec![value; len];
1082 let view = HostTensorView { data: &data, shape };
1083 self.upload(&view)
1084 }
1085
1086 fn fill_like(
1088 &self,
1089 prototype: &GpuTensorHandle,
1090 value: f64,
1091 ) -> anyhow::Result<GpuTensorHandle> {
1092 if value == 0.0 {
1093 return self.zeros_like(prototype);
1094 }
1095 if let Ok(base) = self.zeros_like(prototype) {
1096 match self.scalar_add(&base, value) {
1097 Ok(out) => {
1098 let _ = self.free(&base);
1099 return Ok(out);
1100 }
1101 Err(_) => {
1102 let _ = self.free(&base);
1103 }
1104 }
1105 }
1106 self.fill(&prototype.shape, value)
1107 }
1108
1109 fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1111 self.ones(&prototype.shape)
1112 }
1113
1114 fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1116 Err(anyhow::anyhow!("eye not supported by provider"))
1117 }
1118
1119 fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1121 self.eye(&prototype.shape)
1122 }
1123
1124 fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1126 Err(anyhow::anyhow!("meshgrid not supported by provider"))
1127 }
1128
1129 fn diag_from_vector(
1131 &self,
1132 _vector: &GpuTensorHandle,
1133 _offset: isize,
1134 ) -> anyhow::Result<GpuTensorHandle> {
1135 Err(anyhow::anyhow!(
1136 "diag_from_vector not supported by provider"
1137 ))
1138 }
1139
1140 fn diag_extract(
1142 &self,
1143 _matrix: &GpuTensorHandle,
1144 _offset: isize,
1145 ) -> anyhow::Result<GpuTensorHandle> {
1146 Err(anyhow::anyhow!("diag_extract not supported by provider"))
1147 }
1148
1149 fn tril<'a>(
1151 &'a self,
1152 _matrix: &'a GpuTensorHandle,
1153 _offset: isize,
1154 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1155 Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1156 }
1157
1158 fn triu<'a>(
1160 &'a self,
1161 _matrix: &'a GpuTensorHandle,
1162 _offset: isize,
1163 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1164 Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1165 }
1166
1167 fn polyval(
1169 &self,
1170 _coefficients: &GpuTensorHandle,
1171 _points: &GpuTensorHandle,
1172 _options: &ProviderPolyvalOptions,
1173 ) -> anyhow::Result<GpuTensorHandle> {
1174 Err(anyhow::anyhow!("polyval not supported by provider"))
1175 }
1176
1177 fn polyfit<'a>(
1179 &'a self,
1180 _x: &'a GpuTensorHandle,
1181 _y: &'a GpuTensorHandle,
1182 _degree: usize,
1183 _weights: Option<&'a GpuTensorHandle>,
1184 ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1185 Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1186 }
1187
1188 fn polyder_single<'a>(
1190 &'a self,
1191 _polynomial: &'a GpuTensorHandle,
1192 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1193 Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1194 }
1195
1196 fn polyder_product<'a>(
1198 &'a self,
1199 _p: &'a GpuTensorHandle,
1200 _q: &'a GpuTensorHandle,
1201 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1202 Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1203 }
1204
1205 fn polyder_quotient<'a>(
1207 &'a self,
1208 _u: &'a GpuTensorHandle,
1209 _v: &'a GpuTensorHandle,
1210 ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1211 Box::pin(async move {
1212 Err(anyhow::anyhow!(
1213 "polyder_quotient not supported by provider"
1214 ))
1215 })
1216 }
1217
1218 fn polyint(
1220 &self,
1221 _polynomial: &GpuTensorHandle,
1222 _constant: f64,
1223 ) -> anyhow::Result<GpuTensorHandle> {
1224 Err(anyhow::anyhow!("polyint not supported by provider"))
1225 }
1226
1227 fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1229 Err(anyhow::anyhow!("random_uniform not supported by provider"))
1230 }
1231
1232 fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1234 self.random_uniform(&prototype.shape)
1235 }
1236
1237 fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1239 Err(anyhow::anyhow!("random_normal not supported by provider"))
1240 }
1241
1242 fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1244 self.random_normal(&prototype.shape)
1245 }
1246
1247 fn random_exponential(&self, _mu: f64, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1249 Err(anyhow::anyhow!(
1250 "random_exponential not supported by provider"
1251 ))
1252 }
1253
1254 fn random_normrnd(
1256 &self,
1257 _mu: f64,
1258 _sigma: f64,
1259 _shape: &[usize],
1260 ) -> anyhow::Result<GpuTensorHandle> {
1261 Err(anyhow::anyhow!("random_normrnd not supported by provider"))
1262 }
1263
1264 fn random_unifrnd(
1266 &self,
1267 _a: f64,
1268 _b: f64,
1269 _shape: &[usize],
1270 ) -> anyhow::Result<GpuTensorHandle> {
1271 Err(anyhow::anyhow!("random_unifrnd not supported by provider"))
1272 }
1273
1274 fn stochastic_evolution(
1275 &self,
1276 _state: &GpuTensorHandle,
1277 _drift: f64,
1278 _scale: f64,
1279 _steps: u32,
1280 ) -> anyhow::Result<GpuTensorHandle> {
1281 Err(anyhow::anyhow!(
1282 "stochastic_evolution not supported by provider"
1283 ))
1284 }
1285
1286 fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1288 Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1289 }
1290
1291 fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1293 Err(anyhow::anyhow!("fspecial not supported by provider"))
1294 }
1295
1296 fn peaks(&self, _n: usize) -> anyhow::Result<GpuTensorHandle> {
1299 Err(anyhow::anyhow!("peaks not supported by provider"))
1300 }
1301
1302 fn peaks_xy(
1305 &self,
1306 _x: &GpuTensorHandle,
1307 _y: &GpuTensorHandle,
1308 ) -> anyhow::Result<GpuTensorHandle> {
1309 Err(anyhow::anyhow!("peaks_xy not supported by provider"))
1310 }
1311
1312 fn hann_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1313 Err(anyhow::anyhow!("hann_window not supported by provider"))
1314 }
1315
1316 fn hamming_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1317 Err(anyhow::anyhow!("hamming_window not supported by provider"))
1318 }
1319
1320 fn blackman_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1321 Err(anyhow::anyhow!("blackman_window not supported by provider"))
1322 }
1323
1324 fn imfilter<'a>(
1326 &'a self,
1327 _image: &'a GpuTensorHandle,
1328 _kernel: &'a GpuTensorHandle,
1329 _options: &'a ImfilterOptions,
1330 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1331 unsupported_future("imfilter not supported by provider")
1332 }
1333
1334 fn random_integer_range(
1336 &self,
1337 _lower: i64,
1338 _upper: i64,
1339 _shape: &[usize],
1340 ) -> anyhow::Result<GpuTensorHandle> {
1341 Err(anyhow::anyhow!(
1342 "random_integer_range not supported by provider"
1343 ))
1344 }
1345
1346 fn random_integer_like(
1348 &self,
1349 prototype: &GpuTensorHandle,
1350 lower: i64,
1351 upper: i64,
1352 ) -> anyhow::Result<GpuTensorHandle> {
1353 self.random_integer_range(lower, upper, &prototype.shape)
1354 }
1355
1356 fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1358 Err(anyhow!("random_permutation not supported by provider"))
1359 }
1360
1361 fn random_permutation_like(
1363 &self,
1364 _prototype: &GpuTensorHandle,
1365 n: usize,
1366 k: usize,
1367 ) -> anyhow::Result<GpuTensorHandle> {
1368 self.random_permutation(n, k)
1369 }
1370
1371 fn covariance<'a>(
1373 &'a self,
1374 _matrix: &'a GpuTensorHandle,
1375 _second: Option<&'a GpuTensorHandle>,
1376 _weights: Option<&'a GpuTensorHandle>,
1377 _options: &'a CovarianceOptions,
1378 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1379 unsupported_future("covariance not supported by provider")
1380 }
1381
1382 fn corrcoef<'a>(
1384 &'a self,
1385 _matrix: &'a GpuTensorHandle,
1386 _options: &'a CorrcoefOptions,
1387 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1388 unsupported_future("corrcoef not supported by provider")
1389 }
1390
1391 fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1393 Err(anyhow::anyhow!("linspace not supported by provider"))
1394 }
1395 fn elem_add<'a>(
1396 &'a self,
1397 _a: &'a GpuTensorHandle,
1398 _b: &'a GpuTensorHandle,
1399 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1400 unsupported_future("elem_add not supported by provider")
1401 }
1402 fn elem_mul<'a>(
1403 &'a self,
1404 _a: &'a GpuTensorHandle,
1405 _b: &'a GpuTensorHandle,
1406 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1407 unsupported_future("elem_mul not supported by provider")
1408 }
1409 fn elem_max<'a>(
1410 &'a self,
1411 _a: &'a GpuTensorHandle,
1412 _b: &'a GpuTensorHandle,
1413 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1414 unsupported_future("elem_max not supported by provider")
1415 }
1416 fn elem_min<'a>(
1417 &'a self,
1418 _a: &'a GpuTensorHandle,
1419 _b: &'a GpuTensorHandle,
1420 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1421 unsupported_future("elem_min not supported by provider")
1422 }
1423 fn elem_sub<'a>(
1424 &'a self,
1425 _a: &'a GpuTensorHandle,
1426 _b: &'a GpuTensorHandle,
1427 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1428 unsupported_future("elem_sub not supported by provider")
1429 }
1430 fn elem_div<'a>(
1431 &'a self,
1432 _a: &'a GpuTensorHandle,
1433 _b: &'a GpuTensorHandle,
1434 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1435 unsupported_future("elem_div not supported by provider")
1436 }
1437 fn elem_pow<'a>(
1438 &'a self,
1439 _a: &'a GpuTensorHandle,
1440 _b: &'a GpuTensorHandle,
1441 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1442 unsupported_future("elem_pow not supported by provider")
1443 }
1444
1445 fn elem_hypot<'a>(
1446 &'a self,
1447 _a: &'a GpuTensorHandle,
1448 _b: &'a GpuTensorHandle,
1449 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1450 unsupported_future("elem_hypot not supported by provider")
1451 }
1452 fn elem_ge<'a>(
1453 &'a self,
1454 _a: &'a GpuTensorHandle,
1455 _b: &'a GpuTensorHandle,
1456 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1457 unsupported_future("elem_ge not supported by provider")
1458 }
1459 fn elem_le<'a>(
1460 &'a self,
1461 _a: &'a GpuTensorHandle,
1462 _b: &'a GpuTensorHandle,
1463 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1464 unsupported_future("elem_le not supported by provider")
1465 }
1466 fn elem_lt<'a>(
1467 &'a self,
1468 _a: &'a GpuTensorHandle,
1469 _b: &'a GpuTensorHandle,
1470 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1471 unsupported_future("elem_lt not supported by provider")
1472 }
1473 fn elem_gt<'a>(
1474 &'a self,
1475 _a: &'a GpuTensorHandle,
1476 _b: &'a GpuTensorHandle,
1477 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1478 unsupported_future("elem_gt not supported by provider")
1479 }
1480 fn elem_eq<'a>(
1481 &'a self,
1482 _a: &'a GpuTensorHandle,
1483 _b: &'a GpuTensorHandle,
1484 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1485 unsupported_future("elem_eq not supported by provider")
1486 }
1487 fn elem_ne<'a>(
1488 &'a self,
1489 _a: &'a GpuTensorHandle,
1490 _b: &'a GpuTensorHandle,
1491 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1492 unsupported_future("elem_ne not supported by provider")
1493 }
1494 fn logical_and(
1495 &self,
1496 _a: &GpuTensorHandle,
1497 _b: &GpuTensorHandle,
1498 ) -> anyhow::Result<GpuTensorHandle> {
1499 Err(anyhow::anyhow!("logical_and not supported by provider"))
1500 }
1501 fn logical_or(
1502 &self,
1503 _a: &GpuTensorHandle,
1504 _b: &GpuTensorHandle,
1505 ) -> anyhow::Result<GpuTensorHandle> {
1506 Err(anyhow::anyhow!("logical_or not supported by provider"))
1507 }
1508 fn logical_xor(
1509 &self,
1510 _a: &GpuTensorHandle,
1511 _b: &GpuTensorHandle,
1512 ) -> anyhow::Result<GpuTensorHandle> {
1513 Err(anyhow::anyhow!("logical_xor not supported by provider"))
1514 }
1515 fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1516 Err(anyhow::anyhow!("logical_not not supported by provider"))
1517 }
1518 fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1519 Ok(handle_is_logical(a))
1520 }
1521 fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1522 Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1523 }
1524 fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1525 Err(anyhow::anyhow!(
1526 "logical_isfinite not supported by provider"
1527 ))
1528 }
1529 fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1530 Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1531 }
1532 fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1533 Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1534 }
1535 fn elem_atan2<'a>(
1536 &'a self,
1537 _y: &'a GpuTensorHandle,
1538 _x: &'a GpuTensorHandle,
1539 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1540 unsupported_future("elem_atan2 not supported by provider")
1541 }
1542 fn unary_sin<'a>(
1544 &'a self,
1545 _a: &'a GpuTensorHandle,
1546 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1547 unsupported_future("unary_sin not supported by provider")
1548 }
1549 fn unary_sinc<'a>(
1550 &'a self,
1551 _a: &'a GpuTensorHandle,
1552 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1553 unsupported_future("unary_sinc not supported by provider")
1554 }
1555 fn unary_gamma<'a>(
1556 &'a self,
1557 _a: &'a GpuTensorHandle,
1558 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1559 unsupported_future("unary_gamma not supported by provider")
1560 }
1561 fn unary_factorial<'a>(
1562 &'a self,
1563 _a: &'a GpuTensorHandle,
1564 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1565 unsupported_future("unary_factorial not supported by provider")
1566 }
1567 fn unary_asinh<'a>(
1568 &'a self,
1569 _a: &'a GpuTensorHandle,
1570 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1571 unsupported_future("unary_asinh not supported by provider")
1572 }
1573 fn unary_sinh<'a>(
1574 &'a self,
1575 _a: &'a GpuTensorHandle,
1576 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1577 unsupported_future("unary_sinh not supported by provider")
1578 }
1579 fn unary_cosh<'a>(
1580 &'a self,
1581 _a: &'a GpuTensorHandle,
1582 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1583 unsupported_future("unary_cosh not supported by provider")
1584 }
1585 fn unary_asin<'a>(
1586 &'a self,
1587 _a: &'a GpuTensorHandle,
1588 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1589 unsupported_future("unary_asin not supported by provider")
1590 }
1591 fn unary_acos<'a>(
1592 &'a self,
1593 _a: &'a GpuTensorHandle,
1594 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1595 unsupported_future("unary_acos not supported by provider")
1596 }
1597 fn unary_acosh<'a>(
1598 &'a self,
1599 _a: &'a GpuTensorHandle,
1600 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1601 unsupported_future("unary_acosh not supported by provider")
1602 }
1603 fn unary_tan<'a>(
1604 &'a self,
1605 _a: &'a GpuTensorHandle,
1606 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1607 unsupported_future("unary_tan not supported by provider")
1608 }
1609 fn unary_tanh<'a>(
1610 &'a self,
1611 _a: &'a GpuTensorHandle,
1612 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1613 unsupported_future("unary_tanh not supported by provider")
1614 }
1615 fn unary_atan<'a>(
1616 &'a self,
1617 _a: &'a GpuTensorHandle,
1618 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1619 unsupported_future("unary_atan not supported by provider")
1620 }
1621 fn unary_atanh<'a>(
1622 &'a self,
1623 _a: &'a GpuTensorHandle,
1624 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1625 unsupported_future("unary_atanh not supported by provider")
1626 }
1627 fn unary_ceil<'a>(
1628 &'a self,
1629 _a: &'a GpuTensorHandle,
1630 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1631 unsupported_future("unary_ceil not supported by provider")
1632 }
1633 fn unary_floor<'a>(
1634 &'a self,
1635 _a: &'a GpuTensorHandle,
1636 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1637 unsupported_future("unary_floor not supported by provider")
1638 }
1639 fn unary_round<'a>(
1640 &'a self,
1641 _a: &'a GpuTensorHandle,
1642 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1643 unsupported_future("unary_round not supported by provider")
1644 }
1645 fn unary_fix<'a>(
1646 &'a self,
1647 _a: &'a GpuTensorHandle,
1648 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1649 unsupported_future("unary_fix not supported by provider")
1650 }
1651 fn unary_cos<'a>(
1652 &'a self,
1653 _a: &'a GpuTensorHandle,
1654 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1655 unsupported_future("unary_cos not supported by provider")
1656 }
1657 fn unary_angle<'a>(
1658 &'a self,
1659 _a: &'a GpuTensorHandle,
1660 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1661 unsupported_future("unary_angle not supported by provider")
1662 }
1663 fn unary_imag<'a>(
1664 &'a self,
1665 _a: &'a GpuTensorHandle,
1666 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1667 unsupported_future("unary_imag not supported by provider")
1668 }
1669 fn unary_real<'a>(
1670 &'a self,
1671 _a: &'a GpuTensorHandle,
1672 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1673 unsupported_future("unary_real not supported by provider")
1674 }
1675 fn unary_conj<'a>(
1676 &'a self,
1677 _a: &'a GpuTensorHandle,
1678 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1679 unsupported_future("unary_conj not supported by provider")
1680 }
1681 fn unary_abs<'a>(
1682 &'a self,
1683 _a: &'a GpuTensorHandle,
1684 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1685 unsupported_future("unary_abs not supported by provider")
1686 }
1687 fn unary_sign<'a>(
1688 &'a self,
1689 _a: &'a GpuTensorHandle,
1690 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1691 unsupported_future("unary_sign not supported by provider")
1692 }
1693 fn unary_heaviside<'a>(
1694 &'a self,
1695 _a: &'a GpuTensorHandle,
1696 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1697 unsupported_future("unary_heaviside not supported by provider")
1698 }
1699 fn unary_exp<'a>(
1700 &'a self,
1701 _a: &'a GpuTensorHandle,
1702 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1703 unsupported_future("unary_exp not supported by provider")
1704 }
1705 fn unary_expm1<'a>(
1706 &'a self,
1707 _a: &'a GpuTensorHandle,
1708 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1709 unsupported_future("unary_expm1 not supported by provider")
1710 }
1711 fn unary_log<'a>(
1712 &'a self,
1713 _a: &'a GpuTensorHandle,
1714 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1715 unsupported_future("unary_log not supported by provider")
1716 }
1717 fn unary_log2<'a>(
1718 &'a self,
1719 _a: &'a GpuTensorHandle,
1720 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1721 unsupported_future("unary_log2 not supported by provider")
1722 }
1723 fn unary_log10<'a>(
1724 &'a self,
1725 _a: &'a GpuTensorHandle,
1726 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1727 unsupported_future("unary_log10 not supported by provider")
1728 }
1729 fn unary_log1p<'a>(
1730 &'a self,
1731 _a: &'a GpuTensorHandle,
1732 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1733 unsupported_future("unary_log1p not supported by provider")
1734 }
1735 fn unary_sqrt<'a>(
1736 &'a self,
1737 _a: &'a GpuTensorHandle,
1738 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1739 unsupported_future("unary_sqrt not supported by provider")
1740 }
1741 fn unary_double<'a>(
1742 &'a self,
1743 _a: &'a GpuTensorHandle,
1744 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1745 unsupported_future("unary_double not supported by provider")
1746 }
1747 fn unary_single<'a>(
1748 &'a self,
1749 _a: &'a GpuTensorHandle,
1750 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1751 unsupported_future("unary_single not supported by provider")
1752 }
1753 fn unary_pow2<'a>(
1754 &'a self,
1755 _a: &'a GpuTensorHandle,
1756 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1757 unsupported_future("unary_pow2 not supported by provider")
1758 }
1759 fn unary_nextpow2<'a>(
1760 &'a self,
1761 _a: &'a GpuTensorHandle,
1762 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1763 unsupported_future("unary_nextpow2 not supported by provider")
1764 }
1765 fn pow2_scale(
1766 &self,
1767 _mantissa: &GpuTensorHandle,
1768 _exponent: &GpuTensorHandle,
1769 ) -> anyhow::Result<GpuTensorHandle> {
1770 Err(anyhow::anyhow!("pow2_scale not supported by provider"))
1771 }
1772 fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1774 Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
1775 }
1776 fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1777 Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
1778 }
1779 fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1781 Err(anyhow::anyhow!("scalar_add not supported by provider"))
1782 }
1783 fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1784 Err(anyhow::anyhow!("scalar_sub not supported by provider"))
1785 }
1786 fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1787 Err(anyhow::anyhow!("scalar_mul not supported by provider"))
1788 }
1789 fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1790 Err(anyhow::anyhow!("scalar_max not supported by provider"))
1791 }
1792 fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1793 Err(anyhow::anyhow!("scalar_min not supported by provider"))
1794 }
1795 fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1796 Err(anyhow::anyhow!("scalar_div not supported by provider"))
1797 }
1798 fn sort_dim<'a>(
1799 &'a self,
1800 _a: &'a GpuTensorHandle,
1801 _dim: usize,
1802 _order: SortOrder,
1803 _comparison: SortComparison,
1804 ) -> AccelProviderFuture<'a, SortResult> {
1805 unsupported_future("sort_dim not supported by provider")
1806 }
1807 fn sort_rows<'a>(
1808 &'a self,
1809 _a: &'a GpuTensorHandle,
1810 _columns: &'a [SortRowsColumnSpec],
1811 _comparison: SortComparison,
1812 ) -> AccelProviderFuture<'a, SortResult> {
1813 unsupported_future("sort_rows not supported by provider")
1814 }
1815 fn matmul<'a>(
1816 &'a self,
1817 _a: &'a GpuTensorHandle,
1818 _b: &'a GpuTensorHandle,
1819 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1820 unsupported_future("matmul not supported by provider")
1821 }
1822
1823 fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1824 Err(anyhow::anyhow!("syrk not supported by provider"))
1825 }
1826 fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
1827 Err(anyhow::anyhow!("pagefun not supported by provider"))
1828 }
1829
1830 fn matmul_epilogue<'a>(
1835 &'a self,
1836 a: &'a GpuTensorHandle,
1837 b: &'a GpuTensorHandle,
1838 epilogue: &'a MatmulEpilogue,
1839 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1840 Box::pin(async move {
1841 if epilogue.is_noop() {
1842 return self.matmul(a, b).await;
1843 }
1844 Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
1845 })
1846 }
1847 fn image_normalize<'a>(
1848 &'a self,
1849 _input: &'a GpuTensorHandle,
1850 _desc: &'a ImageNormalizeDescriptor,
1851 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1852 unsupported_future("image_normalize fusion not supported by provider")
1853 }
1854 fn matmul_power_step<'a>(
1855 &'a self,
1856 _lhs: &'a GpuTensorHandle,
1857 _rhs: &'a GpuTensorHandle,
1858 _epilogue: &'a PowerStepEpilogue,
1859 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1860 unsupported_future("matmul_power_step normalization not supported by provider")
1861 }
1862 fn linsolve<'a>(
1863 &'a self,
1864 _lhs: &'a GpuTensorHandle,
1865 _rhs: &'a GpuTensorHandle,
1866 _options: &'a ProviderLinsolveOptions,
1867 ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
1868 unsupported_future("linsolve not supported by provider")
1869 }
1870 fn inv<'a>(
1871 &'a self,
1872 _matrix: &'a GpuTensorHandle,
1873 _options: ProviderInvOptions,
1874 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1875 unsupported_future("inv not supported by provider")
1876 }
1877 fn pinv<'a>(
1878 &'a self,
1879 _matrix: &'a GpuTensorHandle,
1880 _options: ProviderPinvOptions,
1881 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1882 unsupported_future("pinv not supported by provider")
1883 }
1884 fn cond<'a>(
1885 &'a self,
1886 _matrix: &'a GpuTensorHandle,
1887 _norm: ProviderCondNorm,
1888 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1889 Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
1890 }
1891 fn norm<'a>(
1892 &'a self,
1893 _tensor: &'a GpuTensorHandle,
1894 _order: ProviderNormOrder,
1895 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1896 Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
1897 }
1898 fn rank<'a>(
1899 &'a self,
1900 _matrix: &'a GpuTensorHandle,
1901 _tolerance: Option<f64>,
1902 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1903 Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
1904 }
1905 fn rcond<'a>(
1906 &'a self,
1907 _matrix: &'a GpuTensorHandle,
1908 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1909 Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
1910 }
1911 fn mldivide<'a>(
1912 &'a self,
1913 _lhs: &'a GpuTensorHandle,
1914 _rhs: &'a GpuTensorHandle,
1915 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1916 Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
1917 }
1918 fn mrdivide<'a>(
1919 &'a self,
1920 _lhs: &'a GpuTensorHandle,
1921 _rhs: &'a GpuTensorHandle,
1922 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1923 Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
1924 }
1925 fn eig<'a>(
1926 &'a self,
1927 _a: &'a GpuTensorHandle,
1928 _compute_left: bool,
1929 ) -> AccelProviderFuture<'a, ProviderEigResult> {
1930 Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
1931 }
1932 fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
1933 Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
1934 }
1935
1936 fn chol<'a>(
1937 &'a self,
1938 _a: &'a GpuTensorHandle,
1939 _lower: bool,
1940 ) -> AccelProviderFuture<'a, ProviderCholResult> {
1941 Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
1942 }
1943 fn qr<'a>(
1944 &'a self,
1945 _a: &'a GpuTensorHandle,
1946 _options: ProviderQrOptions,
1947 ) -> AccelProviderFuture<'a, ProviderQrResult> {
1948 Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
1949 }
1950 fn take_matmul_sources(
1951 &self,
1952 _product: &GpuTensorHandle,
1953 ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
1954 None
1955 }
1956 fn qr_power_iter<'a>(
1957 &'a self,
1958 product: &'a GpuTensorHandle,
1959 _product_lhs: Option<&'a GpuTensorHandle>,
1960 q_handle: &'a GpuTensorHandle,
1961 options: &'a ProviderQrOptions,
1962 ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
1963 let _ = (product, q_handle, options);
1964 Box::pin(async move { Ok(None) })
1965 }
1966 fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1967 Err(anyhow::anyhow!("transpose not supported by provider"))
1968 }
1969 fn conv1d(
1970 &self,
1971 _signal: &GpuTensorHandle,
1972 _kernel: &GpuTensorHandle,
1973 _options: ProviderConv1dOptions,
1974 ) -> anyhow::Result<GpuTensorHandle> {
1975 Err(anyhow::anyhow!("conv1d not supported by provider"))
1976 }
1977 fn conv2d(
1978 &self,
1979 _signal: &GpuTensorHandle,
1980 _kernel: &GpuTensorHandle,
1981 _mode: ProviderConvMode,
1982 ) -> anyhow::Result<GpuTensorHandle> {
1983 Err(anyhow::anyhow!("conv2d not supported by provider"))
1984 }
1985 fn iir_filter<'a>(
1986 &'a self,
1987 _b: &'a GpuTensorHandle,
1988 _a: &'a GpuTensorHandle,
1989 _x: &'a GpuTensorHandle,
1990 _options: ProviderIirFilterOptions,
1991 ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
1992 Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
1993 }
1994 fn permute(
1996 &self,
1997 _handle: &GpuTensorHandle,
1998 _order: &[usize],
1999 ) -> anyhow::Result<GpuTensorHandle> {
2000 Err(anyhow::anyhow!("permute not supported by provider"))
2001 }
2002 fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
2003 Err(anyhow::anyhow!("flip not supported by provider"))
2004 }
2005 fn circshift(
2006 &self,
2007 _handle: &GpuTensorHandle,
2008 _shifts: &[isize],
2009 ) -> anyhow::Result<GpuTensorHandle> {
2010 Err(anyhow::anyhow!("circshift not supported by provider"))
2011 }
2012 fn diff_dim(
2013 &self,
2014 _handle: &GpuTensorHandle,
2015 _order: usize,
2016 _dim: usize,
2017 ) -> anyhow::Result<GpuTensorHandle> {
2018 Err(anyhow::anyhow!("diff_dim not supported by provider"))
2019 }
2020 fn gradient_dim(
2021 &self,
2022 _handle: &GpuTensorHandle,
2023 _dim: usize,
2024 _spacing: f64,
2025 ) -> anyhow::Result<GpuTensorHandle> {
2026 Err(anyhow::anyhow!("gradient_dim not supported by provider"))
2027 }
2028 fn fft_dim<'a>(
2030 &'a self,
2031 _handle: &'a GpuTensorHandle,
2032 _len: Option<usize>,
2033 _dim: usize,
2034 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2035 unsupported_future("fft_dim not supported by provider")
2036 }
2037 fn ifft_dim<'a>(
2038 &'a self,
2039 _handle: &'a GpuTensorHandle,
2040 _len: Option<usize>,
2041 _dim: usize,
2042 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2043 unsupported_future("ifft_dim not supported by provider")
2044 }
2045 fn fft_extract_real<'a>(
2046 &'a self,
2047 _handle: &'a GpuTensorHandle,
2048 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2049 unsupported_future("fft_extract_real not supported by provider")
2050 }
2051 fn unique<'a>(
2052 &'a self,
2053 _handle: &'a GpuTensorHandle,
2054 _options: &'a UniqueOptions,
2055 ) -> AccelProviderFuture<'a, UniqueResult> {
2056 Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
2057 }
2058 fn union<'a>(
2059 &'a self,
2060 _a: &'a GpuTensorHandle,
2061 _b: &'a GpuTensorHandle,
2062 _options: &'a UnionOptions,
2063 ) -> AccelProviderFuture<'a, UnionResult> {
2064 Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
2065 }
2066 fn setdiff<'a>(
2067 &'a self,
2068 _a: &'a GpuTensorHandle,
2069 _b: &'a GpuTensorHandle,
2070 _options: &'a SetdiffOptions,
2071 ) -> AccelProviderFuture<'a, SetdiffResult> {
2072 Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
2073 }
2074 fn ismember<'a>(
2075 &'a self,
2076 _a: &'a GpuTensorHandle,
2077 _b: &'a GpuTensorHandle,
2078 _options: &'a IsMemberOptions,
2079 ) -> AccelProviderFuture<'a, IsMemberResult> {
2080 Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
2081 }
2082 fn reshape(
2083 &self,
2084 handle: &GpuTensorHandle,
2085 new_shape: &[usize],
2086 ) -> anyhow::Result<GpuTensorHandle> {
2087 let mut updated = handle.clone();
2088 updated.shape = new_shape.to_vec();
2089 Ok(updated)
2090 }
2091 fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
2093 Err(anyhow::anyhow!("cat not supported by provider"))
2094 }
2095 fn repmat(
2096 &self,
2097 _handle: &GpuTensorHandle,
2098 _reps: &[usize],
2099 ) -> anyhow::Result<GpuTensorHandle> {
2100 Err(anyhow::anyhow!("repmat not supported by provider"))
2101 }
2102 fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2104 Err(anyhow::anyhow!("kron not supported by provider"))
2105 }
2106 fn cross(
2108 &self,
2109 _lhs: &GpuTensorHandle,
2110 _rhs: &GpuTensorHandle,
2111 _dim: Option<usize>,
2112 ) -> anyhow::Result<GpuTensorHandle> {
2113 Err(anyhow::anyhow!("cross not supported by provider"))
2114 }
2115 fn reduce_sum<'a>(
2116 &'a self,
2117 _a: &'a GpuTensorHandle,
2118 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2119 unsupported_future("reduce_sum not supported by provider")
2120 }
2121 fn reduce_sum_dim<'a>(
2122 &'a self,
2123 _a: &'a GpuTensorHandle,
2124 _dim: usize,
2125 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2126 unsupported_future("reduce_sum_dim not supported by provider")
2127 }
2128 fn dot<'a>(
2129 &'a self,
2130 _lhs: &'a GpuTensorHandle,
2131 _rhs: &'a GpuTensorHandle,
2132 _dim: Option<usize>,
2133 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2134 unsupported_future("dot not supported by provider")
2135 }
2136 fn reduce_nnz<'a>(
2137 &'a self,
2138 _a: &'a GpuTensorHandle,
2139 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2140 unsupported_future("reduce_nnz not supported by provider")
2141 }
2142 fn reduce_nnz_dim<'a>(
2143 &'a self,
2144 _a: &'a GpuTensorHandle,
2145 _dim: usize,
2146 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2147 unsupported_future("reduce_nnz_dim not supported by provider")
2148 }
2149 fn reduce_prod<'a>(
2150 &'a self,
2151 _a: &'a GpuTensorHandle,
2152 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2153 unsupported_future("reduce_prod not supported by provider")
2154 }
2155 fn reduce_prod_dim<'a>(
2156 &'a self,
2157 _a: &'a GpuTensorHandle,
2158 _dim: usize,
2159 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2160 unsupported_future("reduce_prod_dim not supported by provider")
2161 }
2162 fn reduce_mean<'a>(
2163 &'a self,
2164 _a: &'a GpuTensorHandle,
2165 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2166 unsupported_future("reduce_mean not supported by provider")
2167 }
2168 fn reduce_mean_nd<'a>(
2170 &'a self,
2171 _a: &'a GpuTensorHandle,
2172 _dims_zero_based: &'a [usize],
2173 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2174 unsupported_future("reduce_mean_nd not supported by provider")
2175 }
2176 fn reduce_moments_nd<'a>(
2179 &'a self,
2180 _a: &'a GpuTensorHandle,
2181 _dims_zero_based: &'a [usize],
2182 ) -> AccelProviderFuture<'a, ProviderMoments2> {
2183 unsupported_future("reduce_moments_nd not supported by provider")
2184 }
2185 fn reduce_mean_dim<'a>(
2186 &'a self,
2187 _a: &'a GpuTensorHandle,
2188 _dim: usize,
2189 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2190 unsupported_future("reduce_mean_dim not supported by provider")
2191 }
2192 fn reduce_std<'a>(
2193 &'a self,
2194 _a: &'a GpuTensorHandle,
2195 _normalization: ProviderStdNormalization,
2196 _nan_mode: ProviderNanMode,
2197 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2198 unsupported_future("reduce_std not supported by provider")
2199 }
2200 fn reduce_std_dim<'a>(
2201 &'a self,
2202 _a: &'a GpuTensorHandle,
2203 _dim: usize,
2204 _normalization: ProviderStdNormalization,
2205 _nan_mode: ProviderNanMode,
2206 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2207 unsupported_future("reduce_std_dim not supported by provider")
2208 }
2209 fn reduce_any<'a>(
2210 &'a self,
2211 _a: &'a GpuTensorHandle,
2212 _omit_nan: bool,
2213 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2214 unsupported_future("reduce_any not supported by provider")
2215 }
2216 fn reduce_any_dim<'a>(
2217 &'a self,
2218 _a: &'a GpuTensorHandle,
2219 _dim: usize,
2220 _omit_nan: bool,
2221 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2222 unsupported_future("reduce_any_dim not supported by provider")
2223 }
2224 fn reduce_all<'a>(
2225 &'a self,
2226 _a: &'a GpuTensorHandle,
2227 _omit_nan: bool,
2228 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2229 unsupported_future("reduce_all not supported by provider")
2230 }
2231 fn reduce_all_dim<'a>(
2232 &'a self,
2233 _a: &'a GpuTensorHandle,
2234 _dim: usize,
2235 _omit_nan: bool,
2236 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2237 unsupported_future("reduce_all_dim not supported by provider")
2238 }
2239 fn reduce_median<'a>(
2240 &'a self,
2241 _a: &'a GpuTensorHandle,
2242 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2243 unsupported_future("reduce_median not supported by provider")
2244 }
2245 fn reduce_median_dim<'a>(
2246 &'a self,
2247 _a: &'a GpuTensorHandle,
2248 _dim: usize,
2249 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2250 unsupported_future("reduce_median_dim not supported by provider")
2251 }
2252 fn reduce_min<'a>(
2253 &'a self,
2254 _a: &'a GpuTensorHandle,
2255 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2256 unsupported_future("reduce_min not supported by provider")
2257 }
2258 fn reduce_min_dim<'a>(
2259 &'a self,
2260 _a: &'a GpuTensorHandle,
2261 _dim: usize,
2262 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2263 unsupported_future("reduce_min_dim not supported by provider")
2264 }
2265 fn reduce_max<'a>(
2266 &'a self,
2267 _a: &'a GpuTensorHandle,
2268 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2269 unsupported_future("reduce_max not supported by provider")
2270 }
2271 fn reduce_max_dim<'a>(
2272 &'a self,
2273 _a: &'a GpuTensorHandle,
2274 _dim: usize,
2275 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2276 unsupported_future("reduce_max_dim not supported by provider")
2277 }
2278 fn cumsum_scan(
2279 &self,
2280 _input: &GpuTensorHandle,
2281 _dim: usize,
2282 _direction: ProviderScanDirection,
2283 _nan_mode: ProviderNanMode,
2284 ) -> anyhow::Result<GpuTensorHandle> {
2285 Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2286 }
2287 fn cumprod_scan(
2288 &self,
2289 _input: &GpuTensorHandle,
2290 _dim: usize,
2291 _direction: ProviderScanDirection,
2292 _nan_mode: ProviderNanMode,
2293 ) -> anyhow::Result<GpuTensorHandle> {
2294 Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2295 }
2296 fn cummin_scan(
2297 &self,
2298 _input: &GpuTensorHandle,
2299 _dim: usize,
2300 _direction: ProviderScanDirection,
2301 _nan_mode: ProviderNanMode,
2302 ) -> anyhow::Result<ProviderCumminResult> {
2303 Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2304 }
2305 fn cummax_scan(
2306 &self,
2307 _input: &GpuTensorHandle,
2308 _dim: usize,
2309 _direction: ProviderScanDirection,
2310 _nan_mode: ProviderNanMode,
2311 ) -> anyhow::Result<ProviderCummaxResult> {
2312 Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2313 }
2314
2315 fn find(
2316 &self,
2317 _a: &GpuTensorHandle,
2318 _limit: Option<usize>,
2319 _direction: FindDirection,
2320 ) -> anyhow::Result<ProviderFindResult> {
2321 Err(anyhow::anyhow!("find not supported by provider"))
2322 }
2323
2324 fn fused_elementwise(
2325 &self,
2326 _shader: &str,
2327 _inputs: &[GpuTensorHandle],
2328 _output_shape: &[usize],
2329 _len: usize,
2330 ) -> anyhow::Result<GpuTensorHandle> {
2331 Err(anyhow::anyhow!(
2332 "fused_elementwise not supported by provider"
2333 ))
2334 }
2335
2336 fn fused_elementwise_multi(
2345 &self,
2346 _shader: &str,
2347 _inputs: &[GpuTensorHandle],
2348 _output_shape: &[usize],
2349 _len: usize,
2350 _num_outputs: usize,
2351 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2352 Err(anyhow::anyhow!(
2353 "fused_elementwise_multi not supported by provider"
2354 ))
2355 }
2356
2357 fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2359 Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2360 }
2361
2362 fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2364 Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2365 }
2366
2367 #[allow(clippy::too_many_arguments)]
2374 fn fused_reduction(
2375 &self,
2376 _shader: &str,
2377 _inputs: &[GpuTensorHandle],
2378 _output_shape: &[usize],
2379 _reduce_len: usize,
2380 _num_slices: usize,
2381 _workgroup_size: u32,
2382 _flavor: ReductionFlavor,
2383 ) -> anyhow::Result<GpuTensorHandle> {
2384 Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2385 }
2386
2387 fn warmup(&self) {}
2389
2390 fn fused_cache_counters(&self) -> (u64, u64) {
2392 (0, 0)
2393 }
2394
2395 fn last_warmup_millis(&self) -> Option<u64> {
2397 None
2398 }
2399
2400 fn telemetry_snapshot(&self) -> ProviderTelemetry {
2402 let (hits, misses) = self.fused_cache_counters();
2403 ProviderTelemetry {
2404 fused_elementwise: ProviderDispatchStats::default(),
2405 fused_reduction: ProviderDispatchStats::default(),
2406 matmul: ProviderDispatchStats::default(),
2407 linsolve: ProviderDispatchStats::default(),
2408 mldivide: ProviderDispatchStats::default(),
2409 mrdivide: ProviderDispatchStats::default(),
2410 upload_bytes: 0,
2411 download_bytes: 0,
2412 solve_fallbacks: Vec::new(),
2413 fusion_cache_hits: hits,
2414 fusion_cache_misses: misses,
2415 bind_group_cache_hits: 0,
2416 bind_group_cache_misses: 0,
2417 bind_group_cache_by_layout: None,
2418 kernel_launches: Vec::new(),
2419 }
2420 }
2421
2422 fn reset_telemetry(&self) {}
2424
2425 fn default_reduction_workgroup_size(&self) -> u32 {
2427 256
2428 }
2429
2430 fn two_pass_threshold(&self) -> usize {
2432 1024
2433 }
2434
2435 fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2437 ReductionTwoPassMode::Auto
2438 }
2439
2440 fn scatter_column(
2443 &self,
2444 _matrix: &GpuTensorHandle,
2445 _col_index: usize,
2446 _values: &GpuTensorHandle,
2447 ) -> anyhow::Result<GpuTensorHandle> {
2448 Err(anyhow::anyhow!("scatter_column not supported by provider"))
2449 }
2450
2451 fn scatter_row(
2454 &self,
2455 _matrix: &GpuTensorHandle,
2456 _row_index: usize,
2457 _values: &GpuTensorHandle,
2458 ) -> anyhow::Result<GpuTensorHandle> {
2459 Err(anyhow::anyhow!("scatter_row not supported by provider"))
2460 }
2461
2462 fn sub2ind(
2463 &self,
2464 _dims: &[usize],
2465 _strides: &[usize],
2466 _inputs: &[&GpuTensorHandle],
2467 _scalar_mask: &[bool],
2468 _len: usize,
2469 _output_shape: &[usize],
2470 ) -> anyhow::Result<GpuTensorHandle> {
2471 Err(anyhow::anyhow!("sub2ind not supported by provider"))
2472 }
2473
2474 fn supports_ind2sub(&self) -> bool {
2476 false
2477 }
2478
2479 fn ind2sub(
2481 &self,
2482 _dims: &[usize],
2483 _strides: &[usize],
2484 _indices: &GpuTensorHandle,
2485 _total: usize,
2486 _len: usize,
2487 _output_shape: &[usize],
2488 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2489 Err(anyhow::anyhow!("ind2sub not supported by provider"))
2490 }
2491
2492 fn issymmetric(
2494 &self,
2495 _matrix: &GpuTensorHandle,
2496 _kind: ProviderSymmetryKind,
2497 _tolerance: f64,
2498 ) -> anyhow::Result<bool> {
2499 Err(anyhow::anyhow!(
2500 "issymmetric predicate not supported by provider"
2501 ))
2502 }
2503
2504 fn ishermitian<'a>(
2506 &'a self,
2507 _matrix: &'a GpuTensorHandle,
2508 _kind: ProviderHermitianKind,
2509 _tolerance: f64,
2510 ) -> AccelProviderFuture<'a, bool> {
2511 Box::pin(async move {
2512 Err(anyhow::anyhow!(
2513 "ishermitian predicate not supported by provider"
2514 ))
2515 })
2516 }
2517
2518 fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2520 Err(anyhow::anyhow!("bandwidth not supported by provider"))
2521 }
2522
2523 fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2528 Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2529 }
2530}
2531
2532static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2533 Lazy::new(|| RwLock::new(None));
2534static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2535 Lazy::new(|| RwLock::new(HashMap::new()));
2536static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2537
2538#[cfg(not(target_arch = "wasm32"))]
2539thread_local! {
2540 static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2541}
2542
2543#[cfg(target_arch = "wasm32")]
2544static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2545 Lazy::new(|| Mutex::new(None));
2546
2547#[cfg(not(target_arch = "wasm32"))]
2548fn replace_thread_provider(
2549 provider: Option<&'static dyn AccelProvider>,
2550) -> Option<&'static dyn AccelProvider> {
2551 THREAD_PROVIDER.with(|cell| {
2552 let prev = cell.get();
2553 cell.set(provider);
2554 prev
2555 })
2556}
2557
2558#[cfg(target_arch = "wasm32")]
2559fn replace_thread_provider(
2560 provider: Option<&'static dyn AccelProvider>,
2561) -> Option<&'static dyn AccelProvider> {
2562 let mut slot = WASM_THREAD_PROVIDER
2563 .lock()
2564 .expect("wasm provider mutex poisoned");
2565 let prev = *slot;
2566 *slot = provider;
2567 prev
2568}
2569
2570#[cfg(not(target_arch = "wasm32"))]
2571fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2572 THREAD_PROVIDER.with(|cell| cell.get())
2573}
2574
2575#[cfg(target_arch = "wasm32")]
2576fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2577 WASM_THREAD_PROVIDER
2578 .lock()
2579 .expect("wasm provider mutex poisoned")
2580 .as_ref()
2581 .copied()
2582}
2583
2584pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2592 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2593 *guard = Some(p);
2594 }
2595 register_provider_for_device(p.device_id(), p);
2596}
2597
2598unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2599 if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2600 guard.insert(device_id, provider);
2601 }
2602}
2603
2604pub fn provider() -> Option<&'static dyn AccelProvider> {
2605 if let Some(p) = current_thread_provider() {
2606 return Some(p);
2607 }
2608 GLOBAL_PROVIDER
2609 .read()
2610 .ok()
2611 .and_then(|guard| guard.as_ref().copied())
2612}
2613
2614pub fn clear_provider() {
2616 replace_thread_provider(None);
2617 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2618 *guard = None;
2619 }
2620 if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2621 map.clear();
2622 }
2623}
2624
2625pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2626 if let Some(registered) = PROVIDER_REGISTRY
2627 .read()
2628 .ok()
2629 .and_then(|guard| guard.get(&device_id).copied())
2630 {
2631 return Some(registered);
2632 }
2633 if let Some(thread_provider) = current_thread_provider() {
2634 if thread_provider.device_id() == device_id {
2635 return Some(thread_provider);
2636 }
2637 }
2638 GLOBAL_PROVIDER
2641 .read()
2642 .ok()
2643 .and_then(|guard| guard.as_ref().copied())
2644}
2645
2646pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2647 provider_for_device(handle.device_id)
2648}
2649
2650pub fn spawn_handle_concurrency_for(handle: &GpuTensorHandle) -> Option<SpawnHandleConcurrency> {
2651 provider_for_handle(handle).map(AccelProvider::spawn_handle_concurrency)
2652}
2653
2654pub fn next_device_id() -> u32 {
2655 DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2656}
2657
2658pub struct ThreadProviderGuard {
2659 prev: Option<&'static dyn AccelProvider>,
2660}
2661
2662impl ThreadProviderGuard {
2663 pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2664 let prev = replace_thread_provider(provider);
2665 ThreadProviderGuard { prev }
2666 }
2667}
2668
2669impl Drop for ThreadProviderGuard {
2670 fn drop(&mut self) {
2671 let prev = self.prev.take();
2672 replace_thread_provider(prev);
2673 }
2674}
2675
2676pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2677 replace_thread_provider(provider);
2678}
2679
2680pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2682 if let Some(p) = provider() {
2683 if let Ok(h) = p.elem_add(a, b).await {
2684 return Some(h);
2685 }
2686 }
2687 None
2688}
2689
2690pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2692 if let Some(p) = provider() {
2693 if let Ok(h) = p.elem_hypot(a, b).await {
2694 return Some(h);
2695 }
2696 }
2697 None
2698}
2699
2700pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2702 if let Some(p) = provider() {
2703 if let Ok(h) = p.elem_max(a, b).await {
2704 return Some(h);
2705 }
2706 }
2707 None
2708}
2709
2710pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2712 if let Some(p) = provider() {
2713 if let Ok(h) = p.elem_min(a, b).await {
2714 return Some(h);
2715 }
2716 }
2717 None
2718}
2719
2720pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2722 if let Some(p) = provider() {
2723 if let Ok(h) = p.elem_atan2(y, x).await {
2724 return Some(h);
2725 }
2726 }
2727 None
2728}
2729
2730#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2732pub struct HostTensorOwned {
2733 pub data: Vec<f64>,
2734 pub shape: Vec<usize>,
2735 pub storage: GpuTensorStorage,
2736}
2737
2738#[derive(Debug)]
2739pub struct HostTensorView<'a> {
2740 pub data: &'a [f64],
2741 pub shape: &'a [usize],
2742}
2743
2744#[derive(Debug)]
2746pub struct MeshgridAxisView<'a> {
2747 pub data: &'a [f64],
2748}
2749
2750#[derive(Debug, Clone)]
2752pub struct ProviderMeshgridResult {
2753 pub outputs: Vec<GpuTensorHandle>,
2754}
2755
2756#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
2762pub enum ScaleOp {
2763 Multiply,
2764 Divide,
2765}
2766
2767#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2768pub struct MatmulEpilogue {
2769 pub alpha: f64,
2771 pub beta: f64,
2773 pub row_scale: Option<GpuTensorHandle>,
2775 pub col_scale: Option<GpuTensorHandle>,
2777 pub row_op: ScaleOp,
2779 pub col_op: ScaleOp,
2781 #[serde(default)]
2783 pub clamp_min: Option<f64>,
2784 #[serde(default)]
2786 pub clamp_max: Option<f64>,
2787 #[serde(default)]
2789 pub pow_exponent: Option<f64>,
2790 #[serde(default)]
2792 pub diag_output: Option<GpuTensorHandle>,
2793}
2794
2795impl MatmulEpilogue {
2796 pub fn noop() -> Self {
2797 Self {
2798 alpha: 1.0,
2799 beta: 0.0,
2800 row_scale: None,
2801 col_scale: None,
2802 row_op: ScaleOp::Multiply,
2803 col_op: ScaleOp::Multiply,
2804 clamp_min: None,
2805 clamp_max: None,
2806 pow_exponent: None,
2807 diag_output: None,
2808 }
2809 }
2810 pub fn is_noop(&self) -> bool {
2811 self.alpha == 1.0
2812 && self.beta == 0.0
2813 && self.row_scale.is_none()
2814 && self.col_scale.is_none()
2815 && self.clamp_min.is_none()
2816 && self.clamp_max.is_none()
2817 && self.pow_exponent.is_none()
2818 && self.diag_output.is_none()
2819 }
2820}
2821
2822#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
2823pub struct PowerStepEpilogue {
2824 pub epsilon: f64,
2825}
2826
2827impl Default for PowerStepEpilogue {
2828 fn default() -> Self {
2829 Self { epsilon: 0.0 }
2830 }
2831}
2832
2833#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2834pub struct ImageNormalizeDescriptor {
2835 pub batch: usize,
2836 pub height: usize,
2837 pub width: usize,
2838 pub epsilon: f64,
2839 #[serde(default)]
2840 pub gain: Option<f64>,
2841 #[serde(default)]
2842 pub bias: Option<f64>,
2843 #[serde(default)]
2844 pub gamma: Option<f64>,
2845}
2846
2847#[cfg(test)]
2848mod tests {
2849 use super::*;
2850
2851 struct TestProvider {
2852 device_id: u32,
2853 name: &'static str,
2854 spawn_concurrency: SpawnHandleConcurrency,
2855 }
2856
2857 impl AccelProvider for TestProvider {
2858 fn upload(&self, _host: &HostTensorView) -> anyhow::Result<GpuTensorHandle> {
2859 Err(anyhow!("test provider upload should not be called"))
2860 }
2861
2862 fn download<'a>(&'a self, _h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a> {
2863 unsupported_future("test provider download should not be called")
2864 }
2865
2866 fn free(&self, _h: &GpuTensorHandle) -> anyhow::Result<()> {
2867 Err(anyhow!("test provider free should not be called"))
2868 }
2869
2870 fn device_info(&self) -> String {
2871 self.name.to_string()
2872 }
2873
2874 fn device_id(&self) -> u32 {
2875 self.device_id
2876 }
2877
2878 fn spawn_handle_concurrency(&self) -> SpawnHandleConcurrency {
2879 self.spawn_concurrency
2880 }
2881 }
2882
2883 static PROVIDER_TEST_LOCK: Lazy<std::sync::Mutex<()>> = Lazy::new(|| std::sync::Mutex::new(()));
2884 static PROVIDER_A: TestProvider = TestProvider {
2885 device_id: 101,
2886 name: "provider-a",
2887 spawn_concurrency: SpawnHandleConcurrency::ImmutableShare,
2888 };
2889 static PROVIDER_B: TestProvider = TestProvider {
2890 device_id: 202,
2891 name: "provider-b",
2892 spawn_concurrency: SpawnHandleConcurrency::Reject,
2893 };
2894 static PROVIDER_C: TestProvider = TestProvider {
2895 device_id: 303,
2896 name: "provider-c",
2897 spawn_concurrency: SpawnHandleConcurrency::CopyOnWrite,
2898 };
2899
2900 fn register_test_providers() {
2901 clear_provider();
2902 unsafe {
2903 register_provider(&PROVIDER_A);
2904 register_provider(&PROVIDER_B);
2905 }
2906 }
2907
2908 fn test_handle(device_id: u32) -> GpuTensorHandle {
2909 GpuTensorHandle {
2910 shape: vec![1],
2911 device_id,
2912 buffer_id: 42,
2913 }
2914 }
2915
2916 #[test]
2917 fn provider_for_device_prefers_registered_device_over_thread_provider() {
2918 let _lock = PROVIDER_TEST_LOCK
2919 .lock()
2920 .expect("provider test lock poisoned");
2921 register_test_providers();
2922 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
2923
2924 let provider = provider_for_device(PROVIDER_A.device_id()).expect("provider for device");
2925
2926 assert_eq!(provider.device_info(), PROVIDER_A.name);
2927 clear_provider();
2928 }
2929
2930 #[test]
2931 fn provider_for_handle_uses_handle_device_owner() {
2932 let _lock = PROVIDER_TEST_LOCK
2933 .lock()
2934 .expect("provider test lock poisoned");
2935 register_test_providers();
2936 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
2937
2938 let provider =
2939 provider_for_handle(&test_handle(PROVIDER_A.device_id())).expect("provider for handle");
2940
2941 assert_eq!(provider.device_info(), PROVIDER_A.name);
2942 clear_provider();
2943 }
2944
2945 #[test]
2946 fn spawn_handle_concurrency_for_uses_registered_owner() {
2947 let _lock = PROVIDER_TEST_LOCK
2948 .lock()
2949 .expect("provider test lock poisoned");
2950 register_test_providers();
2951 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_B));
2952
2953 let concurrency = spawn_handle_concurrency_for(&test_handle(PROVIDER_A.device_id()))
2954 .expect("spawn concurrency");
2955
2956 assert_eq!(concurrency, PROVIDER_A.spawn_concurrency);
2957 clear_provider();
2958 }
2959
2960 #[test]
2961 fn provider_keeps_thread_local_active_provider_semantics() {
2962 let _lock = PROVIDER_TEST_LOCK
2963 .lock()
2964 .expect("provider test lock poisoned");
2965 register_test_providers();
2966 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_A));
2967
2968 let active = provider().expect("active provider");
2969
2970 assert_eq!(active.device_info(), PROVIDER_A.name);
2971 clear_provider();
2972 }
2973
2974 #[test]
2975 fn unregistered_thread_provider_only_matches_own_device_before_global_fallback() {
2976 let _lock = PROVIDER_TEST_LOCK
2977 .lock()
2978 .expect("provider test lock poisoned");
2979 clear_provider();
2980 unsafe {
2981 register_provider(&PROVIDER_A);
2982 }
2983 let _thread_provider = ThreadProviderGuard::set(Some(&PROVIDER_C));
2984
2985 let own_device = provider_for_device(PROVIDER_C.device_id()).expect("own provider");
2986 let fallback = provider_for_device(404).expect("global fallback provider");
2987
2988 assert_eq!(own_device.device_info(), PROVIDER_C.name);
2989 assert_eq!(fallback.device_info(), PROVIDER_A.name);
2990 clear_provider();
2991 }
2992}