1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4#[cfg(not(target_arch = "wasm32"))]
5use std::cell::Cell;
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::atomic::{AtomicU32, Ordering};
10#[cfg(feature = "wgpu")]
11use std::sync::Arc;
12#[cfg(target_arch = "wasm32")]
13use std::sync::Mutex;
14use std::sync::RwLock;
15
16type ResidencyMarkFn = fn(&GpuTensorHandle);
17type ResidencyClearFn = fn(&GpuTensorHandle);
18type SequenceThresholdFn = fn() -> Option<usize>;
19type WorkgroupSizeHintFn = fn() -> Option<u32>;
20
21static RESIDENCY_MARK: OnceCell<ResidencyMarkFn> = OnceCell::new();
22static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
23static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
24static WORKGROUP_SIZE_HINT_PROVIDER: OnceCell<WorkgroupSizeHintFn> = OnceCell::new();
25
26static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
27static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
28 Lazy::new(|| RwLock::new(HashMap::new()));
29static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
30 Lazy::new(|| RwLock::new(HashMap::new()));
31
32static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
33 Lazy::new(|| RwLock::new(HashMap::new()));
34static HANDLE_STORAGES: Lazy<RwLock<HashMap<u64, GpuTensorStorage>>> =
35 Lazy::new(|| RwLock::new(HashMap::new()));
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct TransposeInfo {
39 pub base_rows: usize,
40 pub base_cols: usize,
41}
42
43pub fn register_residency_mark(handler: ResidencyMarkFn) {
46 let _ = RESIDENCY_MARK.set(handler);
47}
48
49pub fn mark_residency(handle: &GpuTensorHandle) {
52 if let Some(handler) = RESIDENCY_MARK.get() {
53 handler(handle);
54 }
55}
56
57pub fn register_residency_clear(handler: ResidencyClearFn) {
61 let _ = RESIDENCY_CLEAR.set(handler);
62}
63
64pub fn clear_residency(handle: &GpuTensorHandle) {
67 if let Some(handler) = RESIDENCY_CLEAR.get() {
68 handler(handle);
69 }
70}
71
72pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
76 let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
77}
78
79pub fn sequence_threshold_hint() -> Option<usize> {
81 SEQUENCE_THRESHOLD_PROVIDER
82 .get()
83 .and_then(|provider| provider())
84}
85
86pub fn register_workgroup_size_hint_provider(provider: WorkgroupSizeHintFn) {
90 let _ = WORKGROUP_SIZE_HINT_PROVIDER.set(provider);
91}
92
93pub fn workgroup_size_hint() -> Option<u32> {
95 WORKGROUP_SIZE_HINT_PROVIDER
96 .get()
97 .and_then(|provider| provider())
98}
99
100pub fn export_context(kind: AccelContextKind) -> Option<AccelContextHandle> {
103 provider().and_then(|p| p.export_context(kind))
104}
105
106#[cfg(feature = "wgpu")]
110pub fn export_wgpu_buffer(handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
111 provider().and_then(|p| p.export_wgpu_buffer(handle))
112}
113
114pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
117 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
118 guard.insert(handle.buffer_id, precision);
119 }
120}
121
122pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
124 HANDLE_PRECISIONS
125 .read()
126 .ok()
127 .and_then(|guard| guard.get(&handle.buffer_id).copied())
128}
129
130pub fn clear_handle_precision(handle: &GpuTensorHandle) {
132 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
133 guard.remove(&handle.buffer_id);
134 }
135}
136
137pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
140 if let Ok(mut guard) = LOGICAL_HANDLES.write() {
141 if logical {
142 guard.insert(handle.buffer_id);
143 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
144 *hits.entry(handle.buffer_id).or_insert(0) += 1;
145 }
146 } else {
147 guard.remove(&handle.buffer_id);
148 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
149 hits.remove(&handle.buffer_id);
150 }
151 }
152 }
153}
154
155pub fn clear_handle_logical(handle: &GpuTensorHandle) {
157 set_handle_logical(handle, false);
158}
159
160pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
162 LOGICAL_HANDLES
163 .read()
164 .map(|guard| guard.contains(&handle.buffer_id))
165 .unwrap_or(false)
166}
167
168pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
169 LOGICAL_HANDLE_HITS
170 .read()
171 .ok()
172 .and_then(|guard| guard.get(&buffer_id).copied())
173}
174
175pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
176 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
177 guard.insert(
178 handle.buffer_id,
179 TransposeInfo {
180 base_rows,
181 base_cols,
182 },
183 );
184 }
185}
186
187pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
188 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
189 guard.remove(&handle.buffer_id);
190 }
191}
192
193pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
194 TRANSPOSED_HANDLES
195 .read()
196 .ok()
197 .and_then(|guard| guard.get(&handle.buffer_id).copied())
198}
199
200pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
201 handle_transpose_info(handle).is_some()
202}
203
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum GpuTensorStorage {
206 Real,
207 ComplexInterleaved,
208}
209
210impl Default for GpuTensorStorage {
211 fn default() -> Self {
212 Self::Real
213 }
214}
215
216#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
217pub struct GpuTensorHandle {
218 pub shape: Vec<usize>,
219 pub device_id: u32,
220 pub buffer_id: u64,
221}
222
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct ApiDeviceInfo {
225 pub device_id: u32,
226 pub name: String,
227 pub vendor: String,
228 pub memory_bytes: Option<u64>,
229 pub backend: Option<String>,
230}
231
232#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
233pub struct ReduceDimResult {
234 pub values: GpuTensorHandle,
235 pub indices: GpuTensorHandle,
236}
237
238#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
239pub struct ProviderCumminResult {
240 pub values: GpuTensorHandle,
241 pub indices: GpuTensorHandle,
242}
243
244pub type ProviderCummaxResult = ProviderCumminResult;
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum AccelContextKind {
253 Plotting,
254}
255
256#[derive(Clone)]
258pub enum AccelContextHandle {
259 #[cfg(feature = "wgpu")]
260 Wgpu(WgpuContextHandle),
261}
262
263impl AccelContextHandle {
264 #[cfg(feature = "wgpu")]
266 pub fn as_wgpu(&self) -> Option<&WgpuContextHandle> {
267 match self {
268 AccelContextHandle::Wgpu(ctx) => Some(ctx),
269 }
270 }
271}
272
273#[cfg(feature = "wgpu")]
275#[derive(Clone)]
276pub struct WgpuContextHandle {
277 pub instance: Arc<wgpu::Instance>,
278 pub device: Arc<wgpu::Device>,
279 pub queue: Arc<wgpu::Queue>,
280 pub adapter: Arc<wgpu::Adapter>,
281 pub adapter_info: wgpu::AdapterInfo,
282 pub limits: wgpu::Limits,
283 pub features: wgpu::Features,
284}
285
286#[cfg(feature = "wgpu")]
288#[derive(Clone)]
289pub struct WgpuBufferRef {
290 pub buffer: Arc<wgpu::Buffer>,
291 pub len: usize,
292 pub shape: Vec<usize>,
293 pub element_size: usize,
294 pub precision: ProviderPrecision,
295}
296
297pub fn set_handle_storage(handle: &GpuTensorHandle, storage: GpuTensorStorage) {
298 if let Ok(mut guard) = HANDLE_STORAGES.write() {
299 guard.insert(handle.buffer_id, storage);
300 }
301}
302
303pub fn handle_storage(handle: &GpuTensorHandle) -> GpuTensorStorage {
304 HANDLE_STORAGES
305 .read()
306 .ok()
307 .and_then(|guard| guard.get(&handle.buffer_id).cloned())
308 .unwrap_or(GpuTensorStorage::Real)
309}
310
311pub fn clear_handle_storage(handle: &GpuTensorHandle) {
312 if let Ok(mut guard) = HANDLE_STORAGES.write() {
313 guard.remove(&handle.buffer_id);
314 }
315}
316
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
318pub enum PagefunOp {
319 Mtimes,
320}
321
322#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
323pub struct PagefunRequest {
324 pub op: PagefunOp,
325 pub inputs: Vec<GpuTensorHandle>,
326 pub output_shape: Vec<usize>,
327 pub page_dims: Vec<usize>,
328 pub input_page_dims: Vec<Vec<usize>>,
329}
330
331#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
332pub enum FindDirection {
333 First,
334 Last,
335}
336
337#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
338pub struct ProviderFindResult {
339 pub linear: GpuTensorHandle,
340 pub rows: GpuTensorHandle,
341 pub cols: GpuTensorHandle,
342 pub values: Option<GpuTensorHandle>,
343}
344
345#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
346pub struct ProviderBandwidth {
347 pub lower: u32,
348 pub upper: u32,
349}
350
351#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
352pub enum ProviderSymmetryKind {
353 Symmetric,
354 Skew,
355}
356
357#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
358pub enum ProviderHermitianKind {
359 Hermitian,
360 Skew,
361}
362
363#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
364pub struct ProviderLuResult {
365 pub combined: GpuTensorHandle,
366 pub lower: GpuTensorHandle,
367 pub upper: GpuTensorHandle,
368 pub perm_matrix: GpuTensorHandle,
369 pub perm_vector: GpuTensorHandle,
370}
371
372#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
373pub struct ProviderCholResult {
374 pub factor: GpuTensorHandle,
375 pub info: u32,
377}
378
379#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
380pub struct ProviderQrResult {
381 pub q: GpuTensorHandle,
382 pub r: GpuTensorHandle,
383 pub perm_matrix: GpuTensorHandle,
384 pub perm_vector: GpuTensorHandle,
385}
386
387#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
388pub struct ProviderQrPowerIterResult {
389 pub q: GpuTensorHandle,
390 pub r: GpuTensorHandle,
391 pub perm_matrix: GpuTensorHandle,
392 pub perm_vector: GpuTensorHandle,
393}
394
395#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
396pub struct ProviderLinsolveOptions {
397 pub lower: bool,
398 pub upper: bool,
399 pub rectangular: bool,
400 pub transposed: bool,
401 pub conjugate: bool,
402 pub symmetric: bool,
403 pub posdef: bool,
404 pub need_rcond: bool,
405 pub rcond: Option<f64>,
406}
407
408#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
409pub struct ProviderLinsolveResult {
410 pub solution: GpuTensorHandle,
411 pub reciprocal_condition: f64,
412}
413
414#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
415pub struct ProviderPinvOptions {
416 pub tolerance: Option<f64>,
417}
418
419#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
420pub struct ProviderPolyvalMu {
421 pub mean: f64,
422 pub scale: f64,
423}
424
425#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
426pub struct ProviderPolyvalOptions {
427 pub mu: Option<ProviderPolyvalMu>,
428}
429
430#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
431pub struct ProviderInvOptions {}
432
433#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
434pub struct ProviderPolyfitResult {
435 pub coefficients: Vec<f64>,
436 pub r_matrix: Vec<f64>,
437 pub normr: f64,
438 pub df: f64,
439 pub mu: [f64; 2],
440}
441
442#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
444pub struct ProviderPolyderQuotient {
445 pub numerator: GpuTensorHandle,
446 pub denominator: GpuTensorHandle,
447}
448
449#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
451pub enum ProviderCondNorm {
452 Two,
453 One,
454 Inf,
455 Fro,
456}
457
458#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
460pub enum ProviderNormOrder {
461 Two,
462 One,
463 Inf,
464 NegInf,
465 Zero,
466 Fro,
467 Nuc,
468 P(f64),
469}
470
471#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
472pub struct ProviderEigResult {
473 pub eigenvalues: GpuTensorHandle,
474 pub diagonal: GpuTensorHandle,
475 pub right: GpuTensorHandle,
476 pub left: Option<GpuTensorHandle>,
477}
478
479#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
480pub enum ProviderQrPivot {
481 Matrix,
482 Vector,
483}
484
485#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
486pub struct ProviderQrOptions {
487 pub economy: bool,
488 pub pivot: ProviderQrPivot,
489}
490
491impl Default for ProviderQrOptions {
492 fn default() -> Self {
493 Self {
494 economy: false,
495 pivot: ProviderQrPivot::Matrix,
496 }
497 }
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
501pub enum ProviderPrecision {
502 F32,
503 F64,
504}
505
506#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
507pub enum ReductionTwoPassMode {
508 Auto,
509 ForceOn,
510 ForceOff,
511}
512
513impl ReductionTwoPassMode {
514 pub fn as_str(self) -> &'static str {
515 match self {
516 ReductionTwoPassMode::Auto => "auto",
517 ReductionTwoPassMode::ForceOn => "force_on",
518 ReductionTwoPassMode::ForceOff => "force_off",
519 }
520 }
521}
522
523#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
524pub enum ReductionFlavor {
525 Sum,
526 Mean,
527 CustomScale(f64),
528}
529
530impl ReductionFlavor {
531 pub fn is_mean(self) -> bool {
532 matches!(self, ReductionFlavor::Mean)
533 }
534
535 pub fn scale(self, reduce_len: usize) -> f64 {
536 match self {
537 ReductionFlavor::Sum => 1.0,
538 ReductionFlavor::Mean => {
539 if reduce_len == 0 {
540 1.0
541 } else {
542 1.0 / reduce_len as f64
543 }
544 }
545 ReductionFlavor::CustomScale(scale) => scale,
546 }
547 }
548}
549
550#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
552pub enum CorrcoefNormalization {
553 Unbiased,
554 Biased,
555}
556
557#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
559pub enum CorrcoefRows {
560 All,
561 Complete,
562 Pairwise,
563}
564
565#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
567pub struct CorrcoefOptions {
568 pub normalization: CorrcoefNormalization,
569 pub rows: CorrcoefRows,
570}
571
572impl Default for CorrcoefOptions {
573 fn default() -> Self {
574 Self {
575 normalization: CorrcoefNormalization::Unbiased,
576 rows: CorrcoefRows::All,
577 }
578 }
579}
580
581#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
583pub enum CovNormalization {
584 Unbiased,
585 Biased,
586}
587
588#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
590pub enum CovRows {
591 All,
592 OmitRows,
593 PartialRows,
594}
595
596#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
598pub struct CovarianceOptions {
599 pub normalization: CovNormalization,
600 pub rows: CovRows,
601 pub has_weight_vector: bool,
602}
603
604impl Default for CovarianceOptions {
605 fn default() -> Self {
606 Self {
607 normalization: CovNormalization::Unbiased,
608 rows: CovRows::All,
609 has_weight_vector: false,
610 }
611 }
612}
613
614#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
616pub enum ProviderStdNormalization {
617 Sample,
618 Population,
619}
620
621#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
623pub enum ProviderNanMode {
624 Include,
625 Omit,
626}
627
628#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
630pub enum ProviderScanDirection {
631 Forward,
632 Reverse,
633}
634
635#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
637pub enum SortOrder {
638 Ascend,
639 Descend,
640}
641
642#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
644pub enum SortComparison {
645 Auto,
646 Real,
647 Abs,
648}
649
650#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
652pub struct SortResult {
653 pub values: HostTensorOwned,
654 pub indices: HostTensorOwned,
655}
656
657#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
658pub struct SortRowsColumnSpec {
659 pub index: usize,
660 pub order: SortOrder,
661}
662
663#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
665pub enum UniqueOrder {
666 Sorted,
667 Stable,
668}
669
670#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
672pub enum UniqueOccurrence {
673 First,
674 Last,
675}
676
677#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
679pub struct UniqueOptions {
680 pub rows: bool,
681 pub order: UniqueOrder,
682 pub occurrence: UniqueOccurrence,
683}
684
685#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
687pub struct UniqueResult {
688 pub values: HostTensorOwned,
689 pub ia: HostTensorOwned,
690 pub ic: HostTensorOwned,
691}
692
693#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
695pub enum UnionOrder {
696 Sorted,
697 Stable,
698}
699
700#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
702pub struct UnionOptions {
703 pub rows: bool,
704 pub order: UnionOrder,
705}
706
707#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
709pub struct UnionResult {
710 pub values: HostTensorOwned,
711 pub ia: HostTensorOwned,
712 pub ib: HostTensorOwned,
713}
714
715#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
717pub enum FspecialFilter {
718 Average {
719 rows: u32,
720 cols: u32,
721 },
722 Disk {
723 radius: f64,
724 size: u32,
725 },
726 Gaussian {
727 rows: u32,
728 cols: u32,
729 sigma: f64,
730 },
731 Laplacian {
732 alpha: f64,
733 },
734 Log {
735 rows: u32,
736 cols: u32,
737 sigma: f64,
738 },
739 Motion {
740 length: u32,
741 kernel_size: u32,
742 angle_degrees: f64,
743 oversample: u32,
744 },
745 Prewitt,
746 Sobel,
747 Unsharp {
748 alpha: f64,
749 },
750}
751
752#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
754pub struct FspecialRequest {
755 pub filter: FspecialFilter,
756}
757
758#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
760pub enum ImfilterPadding {
761 Constant,
762 Replicate,
763 Symmetric,
764 Circular,
765}
766
767#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
769pub enum ImfilterShape {
770 Same,
771 Full,
772 Valid,
773}
774
775#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
777pub enum ImfilterMode {
778 Correlation,
779 Convolution,
780}
781
782#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
784pub struct ImfilterOptions {
785 pub padding: ImfilterPadding,
786 pub constant_value: f64,
787 pub shape: ImfilterShape,
788 pub mode: ImfilterMode,
789}
790
791impl Default for ImfilterOptions {
792 fn default() -> Self {
793 Self {
794 padding: ImfilterPadding::Constant,
795 constant_value: 0.0,
796 shape: ImfilterShape::Same,
797 mode: ImfilterMode::Correlation,
798 }
799 }
800}
801
802#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
804pub enum SetdiffOrder {
805 Sorted,
806 Stable,
807}
808
809#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
811pub struct SetdiffOptions {
812 pub rows: bool,
813 pub order: SetdiffOrder,
814}
815
816#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
818pub struct SetdiffResult {
819 pub values: HostTensorOwned,
820 pub ia: HostTensorOwned,
821}
822
823#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
825pub struct IsMemberOptions {
826 pub rows: bool,
827}
828
829#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
831pub struct HostLogicalOwned {
832 pub data: Vec<u8>,
833 pub shape: Vec<usize>,
834}
835
836#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
838pub struct IsMemberResult {
839 pub mask: HostLogicalOwned,
840 pub loc: HostTensorOwned,
841}
842
843#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
844pub enum ProviderConvMode {
845 Full,
846 Same,
847 Valid,
848}
849
850#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
851pub enum ProviderConvOrientation {
852 Row,
853 Column,
854}
855
856#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
857pub struct ProviderConv1dOptions {
858 pub mode: ProviderConvMode,
859 pub orientation: ProviderConvOrientation,
860}
861
862#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
863pub struct ProviderIirFilterOptions {
864 pub dim: usize,
866 pub zi: Option<GpuTensorHandle>,
868}
869
870#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
871pub struct ProviderIirFilterResult {
872 pub output: GpuTensorHandle,
874 pub final_state: Option<GpuTensorHandle>,
876}
877
878#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
879pub struct ProviderMoments2 {
880 pub mean: GpuTensorHandle,
881 pub ex2: GpuTensorHandle,
882}
883
884#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
885pub struct ProviderDispatchStats {
886 pub count: u64,
888 pub total_wall_time_ns: u64,
890}
891
892#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
893pub struct ProviderFallbackStat {
894 pub reason: String,
895 pub count: u64,
896}
897
898#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
899pub struct ProviderTelemetry {
900 pub fused_elementwise: ProviderDispatchStats,
901 pub fused_reduction: ProviderDispatchStats,
902 pub matmul: ProviderDispatchStats,
903 pub linsolve: ProviderDispatchStats,
904 pub mldivide: ProviderDispatchStats,
905 pub mrdivide: ProviderDispatchStats,
906 pub upload_bytes: u64,
907 pub download_bytes: u64,
908 pub solve_fallbacks: Vec<ProviderFallbackStat>,
909 pub fusion_cache_hits: u64,
910 pub fusion_cache_misses: u64,
911 pub bind_group_cache_hits: u64,
912 pub bind_group_cache_misses: u64,
913 pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
915 pub kernel_launches: Vec<KernelLaunchTelemetry>,
917}
918
919#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
920pub struct BindGroupLayoutTelemetry {
921 pub tag: String,
922 pub hits: u64,
923 pub misses: u64,
924}
925
926#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
927pub struct KernelAttrTelemetry {
928 pub key: String,
929 pub value: u64,
930}
931
932#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
933pub struct KernelLaunchTelemetry {
934 pub kernel: String,
935 pub precision: Option<String>,
936 pub shape: Vec<KernelAttrTelemetry>,
937 pub tuning: Vec<KernelAttrTelemetry>,
938}
939
940pub type AccelProviderFuture<'a, T> = Pin<Box<dyn Future<Output = anyhow::Result<T>> + 'a>>;
941pub type AccelDownloadFuture<'a> = AccelProviderFuture<'a, crate::HostTensorOwned>;
942
943fn unsupported_future<T>(message: &'static str) -> AccelProviderFuture<'static, T> {
944 Box::pin(async move { Err(anyhow::anyhow!(message)) })
945}
946
947pub trait AccelProvider: Send + Sync {
949 fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
950 fn download<'a>(&'a self, h: &'a GpuTensorHandle) -> AccelDownloadFuture<'a>;
951 fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
952 fn device_info(&self) -> String;
953 fn device_id(&self) -> u32 {
954 0
955 }
956
957 fn export_context(&self, _kind: AccelContextKind) -> Option<AccelContextHandle> {
960 None
961 }
962
963 #[cfg(feature = "wgpu")]
965 fn export_wgpu_buffer(&self, _handle: &GpuTensorHandle) -> Option<WgpuBufferRef> {
966 let _ = _handle;
967 None
968 }
969
970 fn gather_linear(
973 &self,
974 _source: &GpuTensorHandle,
975 _indices: &[u32],
976 _output_shape: &[usize],
977 ) -> anyhow::Result<GpuTensorHandle> {
978 Err(anyhow::anyhow!("gather_linear not supported by provider"))
979 }
980
981 fn scatter_linear(
985 &self,
986 _target: &GpuTensorHandle,
987 _indices: &[u32],
988 _values: &GpuTensorHandle,
989 ) -> anyhow::Result<()> {
990 Err(anyhow::anyhow!("scatter_linear not supported by provider"))
991 }
992
993 fn device_info_struct(&self) -> ApiDeviceInfo {
995 ApiDeviceInfo {
996 device_id: 0,
997 name: self.device_info(),
998 vendor: String::new(),
999 memory_bytes: None,
1000 backend: None,
1001 }
1002 }
1003
1004 fn precision(&self) -> ProviderPrecision {
1005 ProviderPrecision::F64
1006 }
1007
1008 fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
1010 Err(anyhow::anyhow!("read_scalar not supported by provider"))
1011 }
1012
1013 fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1015 Err(anyhow::anyhow!("zeros not supported by provider"))
1016 }
1017
1018 fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1020 Err(anyhow::anyhow!("ones not supported by provider"))
1021 }
1022
1023 fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1025 self.zeros(&prototype.shape)
1026 }
1027
1028 fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
1030 if value == 0.0 {
1031 return self.zeros(shape);
1032 }
1033 if let Ok(base) = self.zeros(shape) {
1034 match self.scalar_add(&base, value) {
1035 Ok(out) => {
1036 let _ = self.free(&base);
1037 return Ok(out);
1038 }
1039 Err(_) => {
1040 let _ = self.free(&base);
1041 }
1042 }
1043 }
1044 let len: usize = shape.iter().copied().product();
1045 let data = vec![value; len];
1046 let view = HostTensorView { data: &data, shape };
1047 self.upload(&view)
1048 }
1049
1050 fn fill_like(
1052 &self,
1053 prototype: &GpuTensorHandle,
1054 value: f64,
1055 ) -> anyhow::Result<GpuTensorHandle> {
1056 if value == 0.0 {
1057 return self.zeros_like(prototype);
1058 }
1059 if let Ok(base) = self.zeros_like(prototype) {
1060 match self.scalar_add(&base, value) {
1061 Ok(out) => {
1062 let _ = self.free(&base);
1063 return Ok(out);
1064 }
1065 Err(_) => {
1066 let _ = self.free(&base);
1067 }
1068 }
1069 }
1070 self.fill(&prototype.shape, value)
1071 }
1072
1073 fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1075 self.ones(&prototype.shape)
1076 }
1077
1078 fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1080 Err(anyhow::anyhow!("eye not supported by provider"))
1081 }
1082
1083 fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1085 self.eye(&prototype.shape)
1086 }
1087
1088 fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
1090 Err(anyhow::anyhow!("meshgrid not supported by provider"))
1091 }
1092
1093 fn diag_from_vector(
1095 &self,
1096 _vector: &GpuTensorHandle,
1097 _offset: isize,
1098 ) -> anyhow::Result<GpuTensorHandle> {
1099 Err(anyhow::anyhow!(
1100 "diag_from_vector not supported by provider"
1101 ))
1102 }
1103
1104 fn diag_extract(
1106 &self,
1107 _matrix: &GpuTensorHandle,
1108 _offset: isize,
1109 ) -> anyhow::Result<GpuTensorHandle> {
1110 Err(anyhow::anyhow!("diag_extract not supported by provider"))
1111 }
1112
1113 fn tril<'a>(
1115 &'a self,
1116 _matrix: &'a GpuTensorHandle,
1117 _offset: isize,
1118 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1119 Box::pin(async move { Err(anyhow!("tril not supported by provider")) })
1120 }
1121
1122 fn triu<'a>(
1124 &'a self,
1125 _matrix: &'a GpuTensorHandle,
1126 _offset: isize,
1127 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1128 Box::pin(async move { Err(anyhow!("triu not supported by provider")) })
1129 }
1130
1131 fn polyval(
1133 &self,
1134 _coefficients: &GpuTensorHandle,
1135 _points: &GpuTensorHandle,
1136 _options: &ProviderPolyvalOptions,
1137 ) -> anyhow::Result<GpuTensorHandle> {
1138 Err(anyhow::anyhow!("polyval not supported by provider"))
1139 }
1140
1141 fn polyfit<'a>(
1143 &'a self,
1144 _x: &'a GpuTensorHandle,
1145 _y: &'a GpuTensorHandle,
1146 _degree: usize,
1147 _weights: Option<&'a GpuTensorHandle>,
1148 ) -> AccelProviderFuture<'a, ProviderPolyfitResult> {
1149 Box::pin(async move { Err(anyhow::anyhow!("polyfit not supported by provider")) })
1150 }
1151
1152 fn polyder_single<'a>(
1154 &'a self,
1155 _polynomial: &'a GpuTensorHandle,
1156 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1157 Box::pin(async move { Err(anyhow::anyhow!("polyder_single not supported by provider")) })
1158 }
1159
1160 fn polyder_product<'a>(
1162 &'a self,
1163 _p: &'a GpuTensorHandle,
1164 _q: &'a GpuTensorHandle,
1165 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1166 Box::pin(async move { Err(anyhow::anyhow!("polyder_product not supported by provider")) })
1167 }
1168
1169 fn polyder_quotient<'a>(
1171 &'a self,
1172 _u: &'a GpuTensorHandle,
1173 _v: &'a GpuTensorHandle,
1174 ) -> AccelProviderFuture<'a, ProviderPolyderQuotient> {
1175 Box::pin(async move {
1176 Err(anyhow::anyhow!(
1177 "polyder_quotient not supported by provider"
1178 ))
1179 })
1180 }
1181
1182 fn polyint(
1184 &self,
1185 _polynomial: &GpuTensorHandle,
1186 _constant: f64,
1187 ) -> anyhow::Result<GpuTensorHandle> {
1188 Err(anyhow::anyhow!("polyint not supported by provider"))
1189 }
1190
1191 fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1193 Err(anyhow::anyhow!("random_uniform not supported by provider"))
1194 }
1195
1196 fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1198 self.random_uniform(&prototype.shape)
1199 }
1200
1201 fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1203 Err(anyhow::anyhow!("random_normal not supported by provider"))
1204 }
1205
1206 fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1208 self.random_normal(&prototype.shape)
1209 }
1210
1211 fn random_exponential(&self, _mu: f64, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1213 Err(anyhow::anyhow!(
1214 "random_exponential not supported by provider"
1215 ))
1216 }
1217
1218 fn random_normrnd(
1220 &self,
1221 _mu: f64,
1222 _sigma: f64,
1223 _shape: &[usize],
1224 ) -> anyhow::Result<GpuTensorHandle> {
1225 Err(anyhow::anyhow!("random_normrnd not supported by provider"))
1226 }
1227
1228 fn random_unifrnd(
1230 &self,
1231 _a: f64,
1232 _b: f64,
1233 _shape: &[usize],
1234 ) -> anyhow::Result<GpuTensorHandle> {
1235 Err(anyhow::anyhow!("random_unifrnd not supported by provider"))
1236 }
1237
1238 fn stochastic_evolution(
1239 &self,
1240 _state: &GpuTensorHandle,
1241 _drift: f64,
1242 _scale: f64,
1243 _steps: u32,
1244 ) -> anyhow::Result<GpuTensorHandle> {
1245 Err(anyhow::anyhow!(
1246 "stochastic_evolution not supported by provider"
1247 ))
1248 }
1249
1250 fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1252 Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1253 }
1254
1255 fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1257 Err(anyhow::anyhow!("fspecial not supported by provider"))
1258 }
1259
1260 fn peaks(&self, _n: usize) -> anyhow::Result<GpuTensorHandle> {
1263 Err(anyhow::anyhow!("peaks not supported by provider"))
1264 }
1265
1266 fn peaks_xy(
1269 &self,
1270 _x: &GpuTensorHandle,
1271 _y: &GpuTensorHandle,
1272 ) -> anyhow::Result<GpuTensorHandle> {
1273 Err(anyhow::anyhow!("peaks_xy not supported by provider"))
1274 }
1275
1276 fn hann_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1277 Err(anyhow::anyhow!("hann_window not supported by provider"))
1278 }
1279
1280 fn hamming_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1281 Err(anyhow::anyhow!("hamming_window not supported by provider"))
1282 }
1283
1284 fn blackman_window(&self, _len: usize, _periodic: bool) -> anyhow::Result<GpuTensorHandle> {
1285 Err(anyhow::anyhow!("blackman_window not supported by provider"))
1286 }
1287
1288 fn imfilter<'a>(
1290 &'a self,
1291 _image: &'a GpuTensorHandle,
1292 _kernel: &'a GpuTensorHandle,
1293 _options: &'a ImfilterOptions,
1294 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1295 unsupported_future("imfilter not supported by provider")
1296 }
1297
1298 fn random_integer_range(
1300 &self,
1301 _lower: i64,
1302 _upper: i64,
1303 _shape: &[usize],
1304 ) -> anyhow::Result<GpuTensorHandle> {
1305 Err(anyhow::anyhow!(
1306 "random_integer_range not supported by provider"
1307 ))
1308 }
1309
1310 fn random_integer_like(
1312 &self,
1313 prototype: &GpuTensorHandle,
1314 lower: i64,
1315 upper: i64,
1316 ) -> anyhow::Result<GpuTensorHandle> {
1317 self.random_integer_range(lower, upper, &prototype.shape)
1318 }
1319
1320 fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1322 Err(anyhow!("random_permutation not supported by provider"))
1323 }
1324
1325 fn random_permutation_like(
1327 &self,
1328 _prototype: &GpuTensorHandle,
1329 n: usize,
1330 k: usize,
1331 ) -> anyhow::Result<GpuTensorHandle> {
1332 self.random_permutation(n, k)
1333 }
1334
1335 fn covariance<'a>(
1337 &'a self,
1338 _matrix: &'a GpuTensorHandle,
1339 _second: Option<&'a GpuTensorHandle>,
1340 _weights: Option<&'a GpuTensorHandle>,
1341 _options: &'a CovarianceOptions,
1342 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1343 unsupported_future("covariance not supported by provider")
1344 }
1345
1346 fn corrcoef<'a>(
1348 &'a self,
1349 _matrix: &'a GpuTensorHandle,
1350 _options: &'a CorrcoefOptions,
1351 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1352 unsupported_future("corrcoef not supported by provider")
1353 }
1354
1355 fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1357 Err(anyhow::anyhow!("linspace not supported by provider"))
1358 }
1359 fn elem_add<'a>(
1360 &'a self,
1361 _a: &'a GpuTensorHandle,
1362 _b: &'a GpuTensorHandle,
1363 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1364 unsupported_future("elem_add not supported by provider")
1365 }
1366 fn elem_mul<'a>(
1367 &'a self,
1368 _a: &'a GpuTensorHandle,
1369 _b: &'a GpuTensorHandle,
1370 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1371 unsupported_future("elem_mul not supported by provider")
1372 }
1373 fn elem_max<'a>(
1374 &'a self,
1375 _a: &'a GpuTensorHandle,
1376 _b: &'a GpuTensorHandle,
1377 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1378 unsupported_future("elem_max not supported by provider")
1379 }
1380 fn elem_min<'a>(
1381 &'a self,
1382 _a: &'a GpuTensorHandle,
1383 _b: &'a GpuTensorHandle,
1384 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1385 unsupported_future("elem_min not supported by provider")
1386 }
1387 fn elem_sub<'a>(
1388 &'a self,
1389 _a: &'a GpuTensorHandle,
1390 _b: &'a GpuTensorHandle,
1391 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1392 unsupported_future("elem_sub not supported by provider")
1393 }
1394 fn elem_div<'a>(
1395 &'a self,
1396 _a: &'a GpuTensorHandle,
1397 _b: &'a GpuTensorHandle,
1398 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1399 unsupported_future("elem_div not supported by provider")
1400 }
1401 fn elem_pow<'a>(
1402 &'a self,
1403 _a: &'a GpuTensorHandle,
1404 _b: &'a GpuTensorHandle,
1405 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1406 unsupported_future("elem_pow not supported by provider")
1407 }
1408
1409 fn elem_hypot<'a>(
1410 &'a self,
1411 _a: &'a GpuTensorHandle,
1412 _b: &'a GpuTensorHandle,
1413 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1414 unsupported_future("elem_hypot not supported by provider")
1415 }
1416 fn elem_ge<'a>(
1417 &'a self,
1418 _a: &'a GpuTensorHandle,
1419 _b: &'a GpuTensorHandle,
1420 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1421 unsupported_future("elem_ge not supported by provider")
1422 }
1423 fn elem_le<'a>(
1424 &'a self,
1425 _a: &'a GpuTensorHandle,
1426 _b: &'a GpuTensorHandle,
1427 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1428 unsupported_future("elem_le not supported by provider")
1429 }
1430 fn elem_lt<'a>(
1431 &'a self,
1432 _a: &'a GpuTensorHandle,
1433 _b: &'a GpuTensorHandle,
1434 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1435 unsupported_future("elem_lt not supported by provider")
1436 }
1437 fn elem_gt<'a>(
1438 &'a self,
1439 _a: &'a GpuTensorHandle,
1440 _b: &'a GpuTensorHandle,
1441 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1442 unsupported_future("elem_gt not supported by provider")
1443 }
1444 fn elem_eq<'a>(
1445 &'a self,
1446 _a: &'a GpuTensorHandle,
1447 _b: &'a GpuTensorHandle,
1448 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1449 unsupported_future("elem_eq not supported by provider")
1450 }
1451 fn elem_ne<'a>(
1452 &'a self,
1453 _a: &'a GpuTensorHandle,
1454 _b: &'a GpuTensorHandle,
1455 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1456 unsupported_future("elem_ne not supported by provider")
1457 }
1458 fn logical_and(
1459 &self,
1460 _a: &GpuTensorHandle,
1461 _b: &GpuTensorHandle,
1462 ) -> anyhow::Result<GpuTensorHandle> {
1463 Err(anyhow::anyhow!("logical_and not supported by provider"))
1464 }
1465 fn logical_or(
1466 &self,
1467 _a: &GpuTensorHandle,
1468 _b: &GpuTensorHandle,
1469 ) -> anyhow::Result<GpuTensorHandle> {
1470 Err(anyhow::anyhow!("logical_or not supported by provider"))
1471 }
1472 fn logical_xor(
1473 &self,
1474 _a: &GpuTensorHandle,
1475 _b: &GpuTensorHandle,
1476 ) -> anyhow::Result<GpuTensorHandle> {
1477 Err(anyhow::anyhow!("logical_xor not supported by provider"))
1478 }
1479 fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1480 Err(anyhow::anyhow!("logical_not not supported by provider"))
1481 }
1482 fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1483 Ok(handle_is_logical(a))
1484 }
1485 fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1486 Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1487 }
1488 fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1489 Err(anyhow::anyhow!(
1490 "logical_isfinite not supported by provider"
1491 ))
1492 }
1493 fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1494 Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1495 }
1496 fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1497 Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1498 }
1499 fn elem_atan2<'a>(
1500 &'a self,
1501 _y: &'a GpuTensorHandle,
1502 _x: &'a GpuTensorHandle,
1503 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1504 unsupported_future("elem_atan2 not supported by provider")
1505 }
1506 fn unary_sin<'a>(
1508 &'a self,
1509 _a: &'a GpuTensorHandle,
1510 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1511 unsupported_future("unary_sin not supported by provider")
1512 }
1513 fn unary_sinc<'a>(
1514 &'a self,
1515 _a: &'a GpuTensorHandle,
1516 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1517 unsupported_future("unary_sinc not supported by provider")
1518 }
1519 fn unary_gamma<'a>(
1520 &'a self,
1521 _a: &'a GpuTensorHandle,
1522 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1523 unsupported_future("unary_gamma not supported by provider")
1524 }
1525 fn unary_factorial<'a>(
1526 &'a self,
1527 _a: &'a GpuTensorHandle,
1528 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1529 unsupported_future("unary_factorial not supported by provider")
1530 }
1531 fn unary_asinh<'a>(
1532 &'a self,
1533 _a: &'a GpuTensorHandle,
1534 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1535 unsupported_future("unary_asinh not supported by provider")
1536 }
1537 fn unary_sinh<'a>(
1538 &'a self,
1539 _a: &'a GpuTensorHandle,
1540 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1541 unsupported_future("unary_sinh not supported by provider")
1542 }
1543 fn unary_cosh<'a>(
1544 &'a self,
1545 _a: &'a GpuTensorHandle,
1546 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1547 unsupported_future("unary_cosh not supported by provider")
1548 }
1549 fn unary_asin<'a>(
1550 &'a self,
1551 _a: &'a GpuTensorHandle,
1552 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1553 unsupported_future("unary_asin not supported by provider")
1554 }
1555 fn unary_acos<'a>(
1556 &'a self,
1557 _a: &'a GpuTensorHandle,
1558 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1559 unsupported_future("unary_acos not supported by provider")
1560 }
1561 fn unary_acosh<'a>(
1562 &'a self,
1563 _a: &'a GpuTensorHandle,
1564 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1565 unsupported_future("unary_acosh not supported by provider")
1566 }
1567 fn unary_tan<'a>(
1568 &'a self,
1569 _a: &'a GpuTensorHandle,
1570 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1571 unsupported_future("unary_tan not supported by provider")
1572 }
1573 fn unary_tanh<'a>(
1574 &'a self,
1575 _a: &'a GpuTensorHandle,
1576 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1577 unsupported_future("unary_tanh not supported by provider")
1578 }
1579 fn unary_atan<'a>(
1580 &'a self,
1581 _a: &'a GpuTensorHandle,
1582 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1583 unsupported_future("unary_atan not supported by provider")
1584 }
1585 fn unary_atanh<'a>(
1586 &'a self,
1587 _a: &'a GpuTensorHandle,
1588 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1589 unsupported_future("unary_atanh not supported by provider")
1590 }
1591 fn unary_ceil<'a>(
1592 &'a self,
1593 _a: &'a GpuTensorHandle,
1594 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1595 unsupported_future("unary_ceil not supported by provider")
1596 }
1597 fn unary_floor<'a>(
1598 &'a self,
1599 _a: &'a GpuTensorHandle,
1600 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1601 unsupported_future("unary_floor not supported by provider")
1602 }
1603 fn unary_round<'a>(
1604 &'a self,
1605 _a: &'a GpuTensorHandle,
1606 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1607 unsupported_future("unary_round not supported by provider")
1608 }
1609 fn unary_fix<'a>(
1610 &'a self,
1611 _a: &'a GpuTensorHandle,
1612 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1613 unsupported_future("unary_fix not supported by provider")
1614 }
1615 fn unary_cos<'a>(
1616 &'a self,
1617 _a: &'a GpuTensorHandle,
1618 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1619 unsupported_future("unary_cos not supported by provider")
1620 }
1621 fn unary_angle<'a>(
1622 &'a self,
1623 _a: &'a GpuTensorHandle,
1624 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1625 unsupported_future("unary_angle not supported by provider")
1626 }
1627 fn unary_imag<'a>(
1628 &'a self,
1629 _a: &'a GpuTensorHandle,
1630 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1631 unsupported_future("unary_imag not supported by provider")
1632 }
1633 fn unary_real<'a>(
1634 &'a self,
1635 _a: &'a GpuTensorHandle,
1636 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1637 unsupported_future("unary_real not supported by provider")
1638 }
1639 fn unary_conj<'a>(
1640 &'a self,
1641 _a: &'a GpuTensorHandle,
1642 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1643 unsupported_future("unary_conj not supported by provider")
1644 }
1645 fn unary_abs<'a>(
1646 &'a self,
1647 _a: &'a GpuTensorHandle,
1648 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1649 unsupported_future("unary_abs not supported by provider")
1650 }
1651 fn unary_sign<'a>(
1652 &'a self,
1653 _a: &'a GpuTensorHandle,
1654 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1655 unsupported_future("unary_sign not supported by provider")
1656 }
1657 fn unary_exp<'a>(
1658 &'a self,
1659 _a: &'a GpuTensorHandle,
1660 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1661 unsupported_future("unary_exp not supported by provider")
1662 }
1663 fn unary_expm1<'a>(
1664 &'a self,
1665 _a: &'a GpuTensorHandle,
1666 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1667 unsupported_future("unary_expm1 not supported by provider")
1668 }
1669 fn unary_log<'a>(
1670 &'a self,
1671 _a: &'a GpuTensorHandle,
1672 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1673 unsupported_future("unary_log not supported by provider")
1674 }
1675 fn unary_log2<'a>(
1676 &'a self,
1677 _a: &'a GpuTensorHandle,
1678 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1679 unsupported_future("unary_log2 not supported by provider")
1680 }
1681 fn unary_log10<'a>(
1682 &'a self,
1683 _a: &'a GpuTensorHandle,
1684 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1685 unsupported_future("unary_log10 not supported by provider")
1686 }
1687 fn unary_log1p<'a>(
1688 &'a self,
1689 _a: &'a GpuTensorHandle,
1690 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1691 unsupported_future("unary_log1p not supported by provider")
1692 }
1693 fn unary_sqrt<'a>(
1694 &'a self,
1695 _a: &'a GpuTensorHandle,
1696 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1697 unsupported_future("unary_sqrt not supported by provider")
1698 }
1699 fn unary_double<'a>(
1700 &'a self,
1701 _a: &'a GpuTensorHandle,
1702 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1703 unsupported_future("unary_double not supported by provider")
1704 }
1705 fn unary_single<'a>(
1706 &'a self,
1707 _a: &'a GpuTensorHandle,
1708 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1709 unsupported_future("unary_single not supported by provider")
1710 }
1711 fn unary_pow2<'a>(
1712 &'a self,
1713 _a: &'a GpuTensorHandle,
1714 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1715 unsupported_future("unary_pow2 not supported by provider")
1716 }
1717 fn unary_nextpow2<'a>(
1718 &'a self,
1719 _a: &'a GpuTensorHandle,
1720 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1721 unsupported_future("unary_nextpow2 not supported by provider")
1722 }
1723 fn pow2_scale(
1724 &self,
1725 _mantissa: &GpuTensorHandle,
1726 _exponent: &GpuTensorHandle,
1727 ) -> anyhow::Result<GpuTensorHandle> {
1728 Err(anyhow::anyhow!("pow2_scale not supported by provider"))
1729 }
1730 fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1732 Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
1733 }
1734 fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1735 Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
1736 }
1737 fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1739 Err(anyhow::anyhow!("scalar_add not supported by provider"))
1740 }
1741 fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1742 Err(anyhow::anyhow!("scalar_sub not supported by provider"))
1743 }
1744 fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1745 Err(anyhow::anyhow!("scalar_mul not supported by provider"))
1746 }
1747 fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1748 Err(anyhow::anyhow!("scalar_max not supported by provider"))
1749 }
1750 fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1751 Err(anyhow::anyhow!("scalar_min not supported by provider"))
1752 }
1753 fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1754 Err(anyhow::anyhow!("scalar_div not supported by provider"))
1755 }
1756 fn sort_dim<'a>(
1757 &'a self,
1758 _a: &'a GpuTensorHandle,
1759 _dim: usize,
1760 _order: SortOrder,
1761 _comparison: SortComparison,
1762 ) -> AccelProviderFuture<'a, SortResult> {
1763 unsupported_future("sort_dim not supported by provider")
1764 }
1765 fn sort_rows<'a>(
1766 &'a self,
1767 _a: &'a GpuTensorHandle,
1768 _columns: &'a [SortRowsColumnSpec],
1769 _comparison: SortComparison,
1770 ) -> AccelProviderFuture<'a, SortResult> {
1771 unsupported_future("sort_rows not supported by provider")
1772 }
1773 fn matmul<'a>(
1774 &'a self,
1775 _a: &'a GpuTensorHandle,
1776 _b: &'a GpuTensorHandle,
1777 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1778 unsupported_future("matmul not supported by provider")
1779 }
1780
1781 fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1782 Err(anyhow::anyhow!("syrk not supported by provider"))
1783 }
1784 fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
1785 Err(anyhow::anyhow!("pagefun not supported by provider"))
1786 }
1787
1788 fn matmul_epilogue<'a>(
1793 &'a self,
1794 a: &'a GpuTensorHandle,
1795 b: &'a GpuTensorHandle,
1796 epilogue: &'a MatmulEpilogue,
1797 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1798 Box::pin(async move {
1799 if epilogue.is_noop() {
1800 return self.matmul(a, b).await;
1801 }
1802 Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
1803 })
1804 }
1805 fn image_normalize<'a>(
1806 &'a self,
1807 _input: &'a GpuTensorHandle,
1808 _desc: &'a ImageNormalizeDescriptor,
1809 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1810 unsupported_future("image_normalize fusion not supported by provider")
1811 }
1812 fn matmul_power_step<'a>(
1813 &'a self,
1814 _lhs: &'a GpuTensorHandle,
1815 _rhs: &'a GpuTensorHandle,
1816 _epilogue: &'a PowerStepEpilogue,
1817 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1818 unsupported_future("matmul_power_step normalization not supported by provider")
1819 }
1820 fn linsolve<'a>(
1821 &'a self,
1822 _lhs: &'a GpuTensorHandle,
1823 _rhs: &'a GpuTensorHandle,
1824 _options: &'a ProviderLinsolveOptions,
1825 ) -> AccelProviderFuture<'a, ProviderLinsolveResult> {
1826 unsupported_future("linsolve not supported by provider")
1827 }
1828 fn inv<'a>(
1829 &'a self,
1830 _matrix: &'a GpuTensorHandle,
1831 _options: ProviderInvOptions,
1832 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1833 unsupported_future("inv not supported by provider")
1834 }
1835 fn pinv<'a>(
1836 &'a self,
1837 _matrix: &'a GpuTensorHandle,
1838 _options: ProviderPinvOptions,
1839 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1840 unsupported_future("pinv not supported by provider")
1841 }
1842 fn cond<'a>(
1843 &'a self,
1844 _matrix: &'a GpuTensorHandle,
1845 _norm: ProviderCondNorm,
1846 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1847 Box::pin(async move { Err(anyhow::anyhow!("cond not supported by provider")) })
1848 }
1849 fn norm<'a>(
1850 &'a self,
1851 _tensor: &'a GpuTensorHandle,
1852 _order: ProviderNormOrder,
1853 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1854 Box::pin(async move { Err(anyhow::anyhow!("norm not supported by provider")) })
1855 }
1856 fn rank<'a>(
1857 &'a self,
1858 _matrix: &'a GpuTensorHandle,
1859 _tolerance: Option<f64>,
1860 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1861 Box::pin(async move { Err(anyhow::anyhow!("rank not supported by provider")) })
1862 }
1863 fn rcond<'a>(
1864 &'a self,
1865 _matrix: &'a GpuTensorHandle,
1866 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1867 Box::pin(async move { Err(anyhow::anyhow!("rcond not supported by provider")) })
1868 }
1869 fn mldivide<'a>(
1870 &'a self,
1871 _lhs: &'a GpuTensorHandle,
1872 _rhs: &'a GpuTensorHandle,
1873 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1874 Box::pin(async move { Err(anyhow::anyhow!("mldivide not supported by provider")) })
1875 }
1876 fn mrdivide<'a>(
1877 &'a self,
1878 _lhs: &'a GpuTensorHandle,
1879 _rhs: &'a GpuTensorHandle,
1880 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1881 Box::pin(async move { Err(anyhow::anyhow!("mrdivide not supported by provider")) })
1882 }
1883 fn eig<'a>(
1884 &'a self,
1885 _a: &'a GpuTensorHandle,
1886 _compute_left: bool,
1887 ) -> AccelProviderFuture<'a, ProviderEigResult> {
1888 Box::pin(async move { Err(anyhow::anyhow!("eig not supported by provider")) })
1889 }
1890 fn lu<'a>(&'a self, _a: &'a GpuTensorHandle) -> AccelProviderFuture<'a, ProviderLuResult> {
1891 Box::pin(async move { Err(anyhow::anyhow!("lu not supported by provider")) })
1892 }
1893
1894 fn chol<'a>(
1895 &'a self,
1896 _a: &'a GpuTensorHandle,
1897 _lower: bool,
1898 ) -> AccelProviderFuture<'a, ProviderCholResult> {
1899 Box::pin(async move { Err(anyhow::anyhow!("chol not supported by provider")) })
1900 }
1901 fn qr<'a>(
1902 &'a self,
1903 _a: &'a GpuTensorHandle,
1904 _options: ProviderQrOptions,
1905 ) -> AccelProviderFuture<'a, ProviderQrResult> {
1906 Box::pin(async move { Err(anyhow::anyhow!("qr not supported by provider")) })
1907 }
1908 fn take_matmul_sources(
1909 &self,
1910 _product: &GpuTensorHandle,
1911 ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
1912 None
1913 }
1914 fn qr_power_iter<'a>(
1915 &'a self,
1916 product: &'a GpuTensorHandle,
1917 _product_lhs: Option<&'a GpuTensorHandle>,
1918 q_handle: &'a GpuTensorHandle,
1919 options: &'a ProviderQrOptions,
1920 ) -> AccelProviderFuture<'a, Option<ProviderQrPowerIterResult>> {
1921 let _ = (product, q_handle, options);
1922 Box::pin(async move { Ok(None) })
1923 }
1924 fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1925 Err(anyhow::anyhow!("transpose not supported by provider"))
1926 }
1927 fn conv1d(
1928 &self,
1929 _signal: &GpuTensorHandle,
1930 _kernel: &GpuTensorHandle,
1931 _options: ProviderConv1dOptions,
1932 ) -> anyhow::Result<GpuTensorHandle> {
1933 Err(anyhow::anyhow!("conv1d not supported by provider"))
1934 }
1935 fn conv2d(
1936 &self,
1937 _signal: &GpuTensorHandle,
1938 _kernel: &GpuTensorHandle,
1939 _mode: ProviderConvMode,
1940 ) -> anyhow::Result<GpuTensorHandle> {
1941 Err(anyhow::anyhow!("conv2d not supported by provider"))
1942 }
1943 fn iir_filter<'a>(
1944 &'a self,
1945 _b: &'a GpuTensorHandle,
1946 _a: &'a GpuTensorHandle,
1947 _x: &'a GpuTensorHandle,
1948 _options: ProviderIirFilterOptions,
1949 ) -> AccelProviderFuture<'a, ProviderIirFilterResult> {
1950 Box::pin(async move { Err(anyhow::anyhow!("iir_filter not supported by provider")) })
1951 }
1952 fn permute(
1954 &self,
1955 _handle: &GpuTensorHandle,
1956 _order: &[usize],
1957 ) -> anyhow::Result<GpuTensorHandle> {
1958 Err(anyhow::anyhow!("permute not supported by provider"))
1959 }
1960 fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1961 Err(anyhow::anyhow!("flip not supported by provider"))
1962 }
1963 fn circshift(
1964 &self,
1965 _handle: &GpuTensorHandle,
1966 _shifts: &[isize],
1967 ) -> anyhow::Result<GpuTensorHandle> {
1968 Err(anyhow::anyhow!("circshift not supported by provider"))
1969 }
1970 fn diff_dim(
1971 &self,
1972 _handle: &GpuTensorHandle,
1973 _order: usize,
1974 _dim: usize,
1975 ) -> anyhow::Result<GpuTensorHandle> {
1976 Err(anyhow::anyhow!("diff_dim not supported by provider"))
1977 }
1978 fn gradient_dim(
1979 &self,
1980 _handle: &GpuTensorHandle,
1981 _dim: usize,
1982 _spacing: f64,
1983 ) -> anyhow::Result<GpuTensorHandle> {
1984 Err(anyhow::anyhow!("gradient_dim not supported by provider"))
1985 }
1986 fn fft_dim<'a>(
1988 &'a self,
1989 _handle: &'a GpuTensorHandle,
1990 _len: Option<usize>,
1991 _dim: usize,
1992 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
1993 unsupported_future("fft_dim not supported by provider")
1994 }
1995 fn ifft_dim<'a>(
1996 &'a self,
1997 _handle: &'a GpuTensorHandle,
1998 _len: Option<usize>,
1999 _dim: usize,
2000 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2001 unsupported_future("ifft_dim not supported by provider")
2002 }
2003 fn fft_extract_real<'a>(
2004 &'a self,
2005 _handle: &'a GpuTensorHandle,
2006 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2007 unsupported_future("fft_extract_real not supported by provider")
2008 }
2009 fn unique<'a>(
2010 &'a self,
2011 _handle: &'a GpuTensorHandle,
2012 _options: &'a UniqueOptions,
2013 ) -> AccelProviderFuture<'a, UniqueResult> {
2014 Box::pin(async move { Err(anyhow::anyhow!("unique not supported by provider")) })
2015 }
2016 fn union<'a>(
2017 &'a self,
2018 _a: &'a GpuTensorHandle,
2019 _b: &'a GpuTensorHandle,
2020 _options: &'a UnionOptions,
2021 ) -> AccelProviderFuture<'a, UnionResult> {
2022 Box::pin(async move { Err(anyhow::anyhow!("union not supported by provider")) })
2023 }
2024 fn setdiff<'a>(
2025 &'a self,
2026 _a: &'a GpuTensorHandle,
2027 _b: &'a GpuTensorHandle,
2028 _options: &'a SetdiffOptions,
2029 ) -> AccelProviderFuture<'a, SetdiffResult> {
2030 Box::pin(async move { Err(anyhow::anyhow!("setdiff not supported by provider")) })
2031 }
2032 fn ismember<'a>(
2033 &'a self,
2034 _a: &'a GpuTensorHandle,
2035 _b: &'a GpuTensorHandle,
2036 _options: &'a IsMemberOptions,
2037 ) -> AccelProviderFuture<'a, IsMemberResult> {
2038 Box::pin(async move { Err(anyhow::anyhow!("ismember not supported by provider")) })
2039 }
2040 fn reshape(
2041 &self,
2042 handle: &GpuTensorHandle,
2043 new_shape: &[usize],
2044 ) -> anyhow::Result<GpuTensorHandle> {
2045 let mut updated = handle.clone();
2046 updated.shape = new_shape.to_vec();
2047 Ok(updated)
2048 }
2049 fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
2051 Err(anyhow::anyhow!("cat not supported by provider"))
2052 }
2053 fn repmat(
2054 &self,
2055 _handle: &GpuTensorHandle,
2056 _reps: &[usize],
2057 ) -> anyhow::Result<GpuTensorHandle> {
2058 Err(anyhow::anyhow!("repmat not supported by provider"))
2059 }
2060 fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2062 Err(anyhow::anyhow!("kron not supported by provider"))
2063 }
2064 fn cross(
2066 &self,
2067 _lhs: &GpuTensorHandle,
2068 _rhs: &GpuTensorHandle,
2069 _dim: Option<usize>,
2070 ) -> anyhow::Result<GpuTensorHandle> {
2071 Err(anyhow::anyhow!("cross not supported by provider"))
2072 }
2073 fn reduce_sum<'a>(
2074 &'a self,
2075 _a: &'a GpuTensorHandle,
2076 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2077 unsupported_future("reduce_sum not supported by provider")
2078 }
2079 fn reduce_sum_dim<'a>(
2080 &'a self,
2081 _a: &'a GpuTensorHandle,
2082 _dim: usize,
2083 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2084 unsupported_future("reduce_sum_dim not supported by provider")
2085 }
2086 fn dot<'a>(
2087 &'a self,
2088 _lhs: &'a GpuTensorHandle,
2089 _rhs: &'a GpuTensorHandle,
2090 _dim: Option<usize>,
2091 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2092 unsupported_future("dot not supported by provider")
2093 }
2094 fn reduce_nnz<'a>(
2095 &'a self,
2096 _a: &'a GpuTensorHandle,
2097 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2098 unsupported_future("reduce_nnz not supported by provider")
2099 }
2100 fn reduce_nnz_dim<'a>(
2101 &'a self,
2102 _a: &'a GpuTensorHandle,
2103 _dim: usize,
2104 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2105 unsupported_future("reduce_nnz_dim not supported by provider")
2106 }
2107 fn reduce_prod<'a>(
2108 &'a self,
2109 _a: &'a GpuTensorHandle,
2110 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2111 unsupported_future("reduce_prod not supported by provider")
2112 }
2113 fn reduce_prod_dim<'a>(
2114 &'a self,
2115 _a: &'a GpuTensorHandle,
2116 _dim: usize,
2117 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2118 unsupported_future("reduce_prod_dim not supported by provider")
2119 }
2120 fn reduce_mean<'a>(
2121 &'a self,
2122 _a: &'a GpuTensorHandle,
2123 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2124 unsupported_future("reduce_mean not supported by provider")
2125 }
2126 fn reduce_mean_nd<'a>(
2128 &'a self,
2129 _a: &'a GpuTensorHandle,
2130 _dims_zero_based: &'a [usize],
2131 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2132 unsupported_future("reduce_mean_nd not supported by provider")
2133 }
2134 fn reduce_moments_nd<'a>(
2137 &'a self,
2138 _a: &'a GpuTensorHandle,
2139 _dims_zero_based: &'a [usize],
2140 ) -> AccelProviderFuture<'a, ProviderMoments2> {
2141 unsupported_future("reduce_moments_nd not supported by provider")
2142 }
2143 fn reduce_mean_dim<'a>(
2144 &'a self,
2145 _a: &'a GpuTensorHandle,
2146 _dim: usize,
2147 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2148 unsupported_future("reduce_mean_dim not supported by provider")
2149 }
2150 fn reduce_std<'a>(
2151 &'a self,
2152 _a: &'a GpuTensorHandle,
2153 _normalization: ProviderStdNormalization,
2154 _nan_mode: ProviderNanMode,
2155 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2156 unsupported_future("reduce_std not supported by provider")
2157 }
2158 fn reduce_std_dim<'a>(
2159 &'a self,
2160 _a: &'a GpuTensorHandle,
2161 _dim: usize,
2162 _normalization: ProviderStdNormalization,
2163 _nan_mode: ProviderNanMode,
2164 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2165 unsupported_future("reduce_std_dim not supported by provider")
2166 }
2167 fn reduce_any<'a>(
2168 &'a self,
2169 _a: &'a GpuTensorHandle,
2170 _omit_nan: bool,
2171 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2172 unsupported_future("reduce_any not supported by provider")
2173 }
2174 fn reduce_any_dim<'a>(
2175 &'a self,
2176 _a: &'a GpuTensorHandle,
2177 _dim: usize,
2178 _omit_nan: bool,
2179 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2180 unsupported_future("reduce_any_dim not supported by provider")
2181 }
2182 fn reduce_all<'a>(
2183 &'a self,
2184 _a: &'a GpuTensorHandle,
2185 _omit_nan: bool,
2186 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2187 unsupported_future("reduce_all not supported by provider")
2188 }
2189 fn reduce_all_dim<'a>(
2190 &'a self,
2191 _a: &'a GpuTensorHandle,
2192 _dim: usize,
2193 _omit_nan: bool,
2194 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2195 unsupported_future("reduce_all_dim not supported by provider")
2196 }
2197 fn reduce_median<'a>(
2198 &'a self,
2199 _a: &'a GpuTensorHandle,
2200 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2201 unsupported_future("reduce_median not supported by provider")
2202 }
2203 fn reduce_median_dim<'a>(
2204 &'a self,
2205 _a: &'a GpuTensorHandle,
2206 _dim: usize,
2207 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2208 unsupported_future("reduce_median_dim not supported by provider")
2209 }
2210 fn reduce_min<'a>(
2211 &'a self,
2212 _a: &'a GpuTensorHandle,
2213 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2214 unsupported_future("reduce_min not supported by provider")
2215 }
2216 fn reduce_min_dim<'a>(
2217 &'a self,
2218 _a: &'a GpuTensorHandle,
2219 _dim: usize,
2220 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2221 unsupported_future("reduce_min_dim not supported by provider")
2222 }
2223 fn reduce_max<'a>(
2224 &'a self,
2225 _a: &'a GpuTensorHandle,
2226 ) -> AccelProviderFuture<'a, GpuTensorHandle> {
2227 unsupported_future("reduce_max not supported by provider")
2228 }
2229 fn reduce_max_dim<'a>(
2230 &'a self,
2231 _a: &'a GpuTensorHandle,
2232 _dim: usize,
2233 ) -> AccelProviderFuture<'a, ReduceDimResult> {
2234 unsupported_future("reduce_max_dim not supported by provider")
2235 }
2236 fn cumsum_scan(
2237 &self,
2238 _input: &GpuTensorHandle,
2239 _dim: usize,
2240 _direction: ProviderScanDirection,
2241 _nan_mode: ProviderNanMode,
2242 ) -> anyhow::Result<GpuTensorHandle> {
2243 Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
2244 }
2245 fn cumprod_scan(
2246 &self,
2247 _input: &GpuTensorHandle,
2248 _dim: usize,
2249 _direction: ProviderScanDirection,
2250 _nan_mode: ProviderNanMode,
2251 ) -> anyhow::Result<GpuTensorHandle> {
2252 Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
2253 }
2254 fn cummin_scan(
2255 &self,
2256 _input: &GpuTensorHandle,
2257 _dim: usize,
2258 _direction: ProviderScanDirection,
2259 _nan_mode: ProviderNanMode,
2260 ) -> anyhow::Result<ProviderCumminResult> {
2261 Err(anyhow::anyhow!("cummin_scan not supported by provider"))
2262 }
2263 fn cummax_scan(
2264 &self,
2265 _input: &GpuTensorHandle,
2266 _dim: usize,
2267 _direction: ProviderScanDirection,
2268 _nan_mode: ProviderNanMode,
2269 ) -> anyhow::Result<ProviderCummaxResult> {
2270 Err(anyhow::anyhow!("cummax_scan not supported by provider"))
2271 }
2272
2273 fn find(
2274 &self,
2275 _a: &GpuTensorHandle,
2276 _limit: Option<usize>,
2277 _direction: FindDirection,
2278 ) -> anyhow::Result<ProviderFindResult> {
2279 Err(anyhow::anyhow!("find not supported by provider"))
2280 }
2281
2282 fn fused_elementwise(
2283 &self,
2284 _shader: &str,
2285 _inputs: &[GpuTensorHandle],
2286 _output_shape: &[usize],
2287 _len: usize,
2288 ) -> anyhow::Result<GpuTensorHandle> {
2289 Err(anyhow::anyhow!(
2290 "fused_elementwise not supported by provider"
2291 ))
2292 }
2293
2294 fn fused_elementwise_multi(
2303 &self,
2304 _shader: &str,
2305 _inputs: &[GpuTensorHandle],
2306 _output_shape: &[usize],
2307 _len: usize,
2308 _num_outputs: usize,
2309 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2310 Err(anyhow::anyhow!(
2311 "fused_elementwise_multi not supported by provider"
2312 ))
2313 }
2314
2315 fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2317 Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
2318 }
2319
2320 fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
2322 Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
2323 }
2324
2325 #[allow(clippy::too_many_arguments)]
2332 fn fused_reduction(
2333 &self,
2334 _shader: &str,
2335 _inputs: &[GpuTensorHandle],
2336 _output_shape: &[usize],
2337 _reduce_len: usize,
2338 _num_slices: usize,
2339 _workgroup_size: u32,
2340 _flavor: ReductionFlavor,
2341 ) -> anyhow::Result<GpuTensorHandle> {
2342 Err(anyhow::anyhow!("fused_reduction not supported by provider"))
2343 }
2344
2345 fn warmup(&self) {}
2347
2348 fn fused_cache_counters(&self) -> (u64, u64) {
2350 (0, 0)
2351 }
2352
2353 fn last_warmup_millis(&self) -> Option<u64> {
2355 None
2356 }
2357
2358 fn telemetry_snapshot(&self) -> ProviderTelemetry {
2360 let (hits, misses) = self.fused_cache_counters();
2361 ProviderTelemetry {
2362 fused_elementwise: ProviderDispatchStats::default(),
2363 fused_reduction: ProviderDispatchStats::default(),
2364 matmul: ProviderDispatchStats::default(),
2365 linsolve: ProviderDispatchStats::default(),
2366 mldivide: ProviderDispatchStats::default(),
2367 mrdivide: ProviderDispatchStats::default(),
2368 upload_bytes: 0,
2369 download_bytes: 0,
2370 solve_fallbacks: Vec::new(),
2371 fusion_cache_hits: hits,
2372 fusion_cache_misses: misses,
2373 bind_group_cache_hits: 0,
2374 bind_group_cache_misses: 0,
2375 bind_group_cache_by_layout: None,
2376 kernel_launches: Vec::new(),
2377 }
2378 }
2379
2380 fn reset_telemetry(&self) {}
2382
2383 fn default_reduction_workgroup_size(&self) -> u32 {
2385 256
2386 }
2387
2388 fn two_pass_threshold(&self) -> usize {
2390 1024
2391 }
2392
2393 fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
2395 ReductionTwoPassMode::Auto
2396 }
2397
2398 fn scatter_column(
2401 &self,
2402 _matrix: &GpuTensorHandle,
2403 _col_index: usize,
2404 _values: &GpuTensorHandle,
2405 ) -> anyhow::Result<GpuTensorHandle> {
2406 Err(anyhow::anyhow!("scatter_column not supported by provider"))
2407 }
2408
2409 fn scatter_row(
2412 &self,
2413 _matrix: &GpuTensorHandle,
2414 _row_index: usize,
2415 _values: &GpuTensorHandle,
2416 ) -> anyhow::Result<GpuTensorHandle> {
2417 Err(anyhow::anyhow!("scatter_row not supported by provider"))
2418 }
2419
2420 fn sub2ind(
2421 &self,
2422 _dims: &[usize],
2423 _strides: &[usize],
2424 _inputs: &[&GpuTensorHandle],
2425 _scalar_mask: &[bool],
2426 _len: usize,
2427 _output_shape: &[usize],
2428 ) -> anyhow::Result<GpuTensorHandle> {
2429 Err(anyhow::anyhow!("sub2ind not supported by provider"))
2430 }
2431
2432 fn supports_ind2sub(&self) -> bool {
2434 false
2435 }
2436
2437 fn ind2sub(
2439 &self,
2440 _dims: &[usize],
2441 _strides: &[usize],
2442 _indices: &GpuTensorHandle,
2443 _total: usize,
2444 _len: usize,
2445 _output_shape: &[usize],
2446 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2447 Err(anyhow::anyhow!("ind2sub not supported by provider"))
2448 }
2449
2450 fn issymmetric(
2452 &self,
2453 _matrix: &GpuTensorHandle,
2454 _kind: ProviderSymmetryKind,
2455 _tolerance: f64,
2456 ) -> anyhow::Result<bool> {
2457 Err(anyhow::anyhow!(
2458 "issymmetric predicate not supported by provider"
2459 ))
2460 }
2461
2462 fn ishermitian<'a>(
2464 &'a self,
2465 _matrix: &'a GpuTensorHandle,
2466 _kind: ProviderHermitianKind,
2467 _tolerance: f64,
2468 ) -> AccelProviderFuture<'a, bool> {
2469 Box::pin(async move {
2470 Err(anyhow::anyhow!(
2471 "ishermitian predicate not supported by provider"
2472 ))
2473 })
2474 }
2475
2476 fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2478 Err(anyhow::anyhow!("bandwidth not supported by provider"))
2479 }
2480
2481 fn sym_rcm<'a>(&'a self, _matrix: &'a GpuTensorHandle) -> AccelProviderFuture<'a, Vec<usize>> {
2486 Box::pin(async move { Err(anyhow::anyhow!("sym_rcm not supported by provider")) })
2487 }
2488}
2489
2490static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2491 Lazy::new(|| RwLock::new(None));
2492static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2493 Lazy::new(|| RwLock::new(HashMap::new()));
2494static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2495
2496#[cfg(not(target_arch = "wasm32"))]
2497thread_local! {
2498 static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2499}
2500
2501#[cfg(target_arch = "wasm32")]
2502static WASM_THREAD_PROVIDER: Lazy<Mutex<Option<&'static dyn AccelProvider>>> =
2503 Lazy::new(|| Mutex::new(None));
2504
2505#[cfg(not(target_arch = "wasm32"))]
2506fn replace_thread_provider(
2507 provider: Option<&'static dyn AccelProvider>,
2508) -> Option<&'static dyn AccelProvider> {
2509 THREAD_PROVIDER.with(|cell| {
2510 let prev = cell.get();
2511 cell.set(provider);
2512 prev
2513 })
2514}
2515
2516#[cfg(target_arch = "wasm32")]
2517fn replace_thread_provider(
2518 provider: Option<&'static dyn AccelProvider>,
2519) -> Option<&'static dyn AccelProvider> {
2520 let mut slot = WASM_THREAD_PROVIDER
2521 .lock()
2522 .expect("wasm provider mutex poisoned");
2523 let prev = *slot;
2524 *slot = provider;
2525 prev
2526}
2527
2528#[cfg(not(target_arch = "wasm32"))]
2529fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2530 THREAD_PROVIDER.with(|cell| cell.get())
2531}
2532
2533#[cfg(target_arch = "wasm32")]
2534fn current_thread_provider() -> Option<&'static dyn AccelProvider> {
2535 WASM_THREAD_PROVIDER
2536 .lock()
2537 .expect("wasm provider mutex poisoned")
2538 .as_ref()
2539 .copied()
2540}
2541
2542pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2550 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2551 *guard = Some(p);
2552 }
2553 register_provider_for_device(p.device_id(), p);
2554}
2555
2556unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2557 if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2558 guard.insert(device_id, provider);
2559 }
2560}
2561
2562pub fn provider() -> Option<&'static dyn AccelProvider> {
2563 if let Some(p) = current_thread_provider() {
2564 return Some(p);
2565 }
2566 GLOBAL_PROVIDER
2567 .read()
2568 .ok()
2569 .and_then(|guard| guard.as_ref().copied())
2570}
2571
2572pub fn clear_provider() {
2574 replace_thread_provider(None);
2575 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2576 *guard = None;
2577 }
2578 if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2579 map.clear();
2580 }
2581}
2582
2583pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2584 PROVIDER_REGISTRY
2585 .read()
2586 .ok()
2587 .and_then(|guard| guard.get(&device_id).copied())
2588 .or_else(|| provider())
2589}
2590
2591pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2592 provider_for_device(handle.device_id)
2593}
2594
2595pub fn next_device_id() -> u32 {
2596 DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2597}
2598
2599pub struct ThreadProviderGuard {
2600 prev: Option<&'static dyn AccelProvider>,
2601}
2602
2603impl ThreadProviderGuard {
2604 pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2605 let prev = replace_thread_provider(provider);
2606 ThreadProviderGuard { prev }
2607 }
2608}
2609
2610impl Drop for ThreadProviderGuard {
2611 fn drop(&mut self) {
2612 let prev = self.prev.take();
2613 replace_thread_provider(prev);
2614 }
2615}
2616
2617pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2618 replace_thread_provider(provider);
2619}
2620
2621pub async fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2623 if let Some(p) = provider() {
2624 if let Ok(h) = p.elem_add(a, b).await {
2625 return Some(h);
2626 }
2627 }
2628 None
2629}
2630
2631pub async fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2633 if let Some(p) = provider() {
2634 if let Ok(h) = p.elem_hypot(a, b).await {
2635 return Some(h);
2636 }
2637 }
2638 None
2639}
2640
2641pub async fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2643 if let Some(p) = provider() {
2644 if let Ok(h) = p.elem_max(a, b).await {
2645 return Some(h);
2646 }
2647 }
2648 None
2649}
2650
2651pub async fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2653 if let Some(p) = provider() {
2654 if let Ok(h) = p.elem_min(a, b).await {
2655 return Some(h);
2656 }
2657 }
2658 None
2659}
2660
2661pub async fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2663 if let Some(p) = provider() {
2664 if let Ok(h) = p.elem_atan2(y, x).await {
2665 return Some(h);
2666 }
2667 }
2668 None
2669}
2670
2671#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2673pub struct HostTensorOwned {
2674 pub data: Vec<f64>,
2675 pub shape: Vec<usize>,
2676 pub storage: GpuTensorStorage,
2677}
2678
2679#[derive(Debug)]
2680pub struct HostTensorView<'a> {
2681 pub data: &'a [f64],
2682 pub shape: &'a [usize],
2683}
2684
2685#[derive(Debug)]
2687pub struct MeshgridAxisView<'a> {
2688 pub data: &'a [f64],
2689}
2690
2691#[derive(Debug, Clone)]
2693pub struct ProviderMeshgridResult {
2694 pub outputs: Vec<GpuTensorHandle>,
2695}
2696
2697#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
2703pub enum ScaleOp {
2704 Multiply,
2705 Divide,
2706}
2707
2708#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2709pub struct MatmulEpilogue {
2710 pub alpha: f64,
2712 pub beta: f64,
2714 pub row_scale: Option<GpuTensorHandle>,
2716 pub col_scale: Option<GpuTensorHandle>,
2718 pub row_op: ScaleOp,
2720 pub col_op: ScaleOp,
2722 #[serde(default)]
2724 pub clamp_min: Option<f64>,
2725 #[serde(default)]
2727 pub clamp_max: Option<f64>,
2728 #[serde(default)]
2730 pub pow_exponent: Option<f64>,
2731 #[serde(default)]
2733 pub diag_output: Option<GpuTensorHandle>,
2734}
2735
2736impl MatmulEpilogue {
2737 pub fn noop() -> Self {
2738 Self {
2739 alpha: 1.0,
2740 beta: 0.0,
2741 row_scale: None,
2742 col_scale: None,
2743 row_op: ScaleOp::Multiply,
2744 col_op: ScaleOp::Multiply,
2745 clamp_min: None,
2746 clamp_max: None,
2747 pow_exponent: None,
2748 diag_output: None,
2749 }
2750 }
2751 pub fn is_noop(&self) -> bool {
2752 self.alpha == 1.0
2753 && self.beta == 0.0
2754 && self.row_scale.is_none()
2755 && self.col_scale.is_none()
2756 && self.clamp_min.is_none()
2757 && self.clamp_max.is_none()
2758 && self.pow_exponent.is_none()
2759 && self.diag_output.is_none()
2760 }
2761}
2762
2763#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
2764pub struct PowerStepEpilogue {
2765 pub epsilon: f64,
2766}
2767
2768impl Default for PowerStepEpilogue {
2769 fn default() -> Self {
2770 Self { epsilon: 0.0 }
2771 }
2772}
2773
2774#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2775pub struct ImageNormalizeDescriptor {
2776 pub batch: usize,
2777 pub height: usize,
2778 pub width: usize,
2779 pub epsilon: f64,
2780 #[serde(default)]
2781 pub gain: Option<f64>,
2782 #[serde(default)]
2783 pub bias: Option<f64>,
2784 #[serde(default)]
2785 pub gamma: Option<f64>,
2786}