1use anyhow::anyhow;
2use once_cell::sync::{Lazy, OnceCell};
3use serde::{Deserialize, Serialize};
4use std::cell::Cell;
5use std::collections::{HashMap, HashSet};
6use std::sync::atomic::{AtomicU32, Ordering};
7use std::sync::RwLock;
8
9type ResidencyClearFn = fn(&GpuTensorHandle);
10type SequenceThresholdFn = fn() -> Option<usize>;
11
12static RESIDENCY_CLEAR: OnceCell<ResidencyClearFn> = OnceCell::new();
13static SEQUENCE_THRESHOLD_PROVIDER: OnceCell<SequenceThresholdFn> = OnceCell::new();
14
15static LOGICAL_HANDLES: Lazy<RwLock<HashSet<u64>>> = Lazy::new(|| RwLock::new(HashSet::new()));
16static LOGICAL_HANDLE_HITS: Lazy<RwLock<HashMap<u64, u64>>> =
17 Lazy::new(|| RwLock::new(HashMap::new()));
18static TRANSPOSED_HANDLES: Lazy<RwLock<HashMap<u64, TransposeInfo>>> =
19 Lazy::new(|| RwLock::new(HashMap::new()));
20
21static HANDLE_PRECISIONS: Lazy<RwLock<HashMap<u64, ProviderPrecision>>> =
22 Lazy::new(|| RwLock::new(HashMap::new()));
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub struct TransposeInfo {
26 pub base_rows: usize,
27 pub base_cols: usize,
28}
29
30pub fn register_residency_clear(handler: ResidencyClearFn) {
34 let _ = RESIDENCY_CLEAR.set(handler);
35}
36
37pub fn clear_residency(handle: &GpuTensorHandle) {
40 if let Some(handler) = RESIDENCY_CLEAR.get() {
41 handler(handle);
42 }
43}
44
45pub fn register_sequence_threshold_provider(provider: SequenceThresholdFn) {
49 let _ = SEQUENCE_THRESHOLD_PROVIDER.set(provider);
50}
51
52pub fn sequence_threshold_hint() -> Option<usize> {
54 SEQUENCE_THRESHOLD_PROVIDER
55 .get()
56 .and_then(|provider| provider())
57}
58
59pub fn set_handle_precision(handle: &GpuTensorHandle, precision: ProviderPrecision) {
62 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
63 guard.insert(handle.buffer_id, precision);
64 }
65}
66
67pub fn handle_precision(handle: &GpuTensorHandle) -> Option<ProviderPrecision> {
69 HANDLE_PRECISIONS
70 .read()
71 .ok()
72 .and_then(|guard| guard.get(&handle.buffer_id).copied())
73}
74
75pub fn clear_handle_precision(handle: &GpuTensorHandle) {
77 if let Ok(mut guard) = HANDLE_PRECISIONS.write() {
78 guard.remove(&handle.buffer_id);
79 }
80}
81
82pub fn set_handle_logical(handle: &GpuTensorHandle, logical: bool) {
85 if let Ok(mut guard) = LOGICAL_HANDLES.write() {
86 if logical {
87 guard.insert(handle.buffer_id);
88 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
89 *hits.entry(handle.buffer_id).or_insert(0) += 1;
90 }
91 } else {
92 guard.remove(&handle.buffer_id);
93 if let Ok(mut hits) = LOGICAL_HANDLE_HITS.write() {
94 hits.remove(&handle.buffer_id);
95 }
96 }
97 }
98}
99
100pub fn clear_handle_logical(handle: &GpuTensorHandle) {
102 set_handle_logical(handle, false);
103}
104
105pub fn handle_is_logical(handle: &GpuTensorHandle) -> bool {
107 LOGICAL_HANDLES
108 .read()
109 .map(|guard| guard.contains(&handle.buffer_id))
110 .unwrap_or(false)
111}
112
113pub fn handle_logical_hits(buffer_id: u64) -> Option<u64> {
114 LOGICAL_HANDLE_HITS
115 .read()
116 .ok()
117 .and_then(|guard| guard.get(&buffer_id).copied())
118}
119
120pub fn record_handle_transpose(handle: &GpuTensorHandle, base_rows: usize, base_cols: usize) {
121 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
122 guard.insert(
123 handle.buffer_id,
124 TransposeInfo {
125 base_rows,
126 base_cols,
127 },
128 );
129 }
130}
131
132pub fn clear_handle_transpose(handle: &GpuTensorHandle) {
133 if let Ok(mut guard) = TRANSPOSED_HANDLES.write() {
134 guard.remove(&handle.buffer_id);
135 }
136}
137
138pub fn handle_transpose_info(handle: &GpuTensorHandle) -> Option<TransposeInfo> {
139 TRANSPOSED_HANDLES
140 .read()
141 .ok()
142 .and_then(|guard| guard.get(&handle.buffer_id).copied())
143}
144
145pub fn handle_is_transposed(handle: &GpuTensorHandle) -> bool {
146 handle_transpose_info(handle).is_some()
147}
148
149#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
150pub struct GpuTensorHandle {
151 pub shape: Vec<usize>,
152 pub device_id: u32,
153 pub buffer_id: u64,
154}
155
156#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
157pub struct ApiDeviceInfo {
158 pub device_id: u32,
159 pub name: String,
160 pub vendor: String,
161 pub memory_bytes: Option<u64>,
162 pub backend: Option<String>,
163}
164
165#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166pub struct ReduceDimResult {
167 pub values: GpuTensorHandle,
168 pub indices: GpuTensorHandle,
169}
170
171#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
172pub struct ProviderCumminResult {
173 pub values: GpuTensorHandle,
174 pub indices: GpuTensorHandle,
175}
176
177pub type ProviderCummaxResult = ProviderCumminResult;
182
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
184pub enum PagefunOp {
185 Mtimes,
186}
187
188#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
189pub struct PagefunRequest {
190 pub op: PagefunOp,
191 pub inputs: Vec<GpuTensorHandle>,
192 pub output_shape: Vec<usize>,
193 pub page_dims: Vec<usize>,
194 pub input_page_dims: Vec<Vec<usize>>,
195}
196
197#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
198pub enum FindDirection {
199 First,
200 Last,
201}
202
203#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
204pub struct ProviderFindResult {
205 pub linear: GpuTensorHandle,
206 pub rows: GpuTensorHandle,
207 pub cols: GpuTensorHandle,
208 pub values: Option<GpuTensorHandle>,
209}
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
212pub struct ProviderBandwidth {
213 pub lower: u32,
214 pub upper: u32,
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
218pub enum ProviderSymmetryKind {
219 Symmetric,
220 Skew,
221}
222
223#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
224pub enum ProviderHermitianKind {
225 Hermitian,
226 Skew,
227}
228
229#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
230pub struct ProviderLuResult {
231 pub combined: GpuTensorHandle,
232 pub lower: GpuTensorHandle,
233 pub upper: GpuTensorHandle,
234 pub perm_matrix: GpuTensorHandle,
235 pub perm_vector: GpuTensorHandle,
236}
237
238#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
239pub struct ProviderCholResult {
240 pub factor: GpuTensorHandle,
241 pub info: u32,
243}
244
245#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
246pub struct ProviderQrResult {
247 pub q: GpuTensorHandle,
248 pub r: GpuTensorHandle,
249 pub perm_matrix: GpuTensorHandle,
250 pub perm_vector: GpuTensorHandle,
251}
252
253#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
254pub struct ProviderQrPowerIterResult {
255 pub q: GpuTensorHandle,
256 pub r: GpuTensorHandle,
257 pub perm_matrix: GpuTensorHandle,
258 pub perm_vector: GpuTensorHandle,
259}
260
261#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
262pub struct ProviderLinsolveOptions {
263 pub lower: bool,
264 pub upper: bool,
265 pub rectangular: bool,
266 pub transposed: bool,
267 pub conjugate: bool,
268 pub symmetric: bool,
269 pub posdef: bool,
270 pub rcond: Option<f64>,
271}
272
273#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
274pub struct ProviderLinsolveResult {
275 pub solution: GpuTensorHandle,
276 pub reciprocal_condition: f64,
277}
278
279#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
280pub struct ProviderPinvOptions {
281 pub tolerance: Option<f64>,
282}
283
284#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
285pub struct ProviderPolyvalMu {
286 pub mean: f64,
287 pub scale: f64,
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
291pub struct ProviderPolyvalOptions {
292 pub mu: Option<ProviderPolyvalMu>,
293}
294
295#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
296pub struct ProviderInvOptions {}
297
298#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
299pub struct ProviderPolyfitResult {
300 pub coefficients: Vec<f64>,
301 pub r_matrix: Vec<f64>,
302 pub normr: f64,
303 pub df: f64,
304 pub mu: [f64; 2],
305}
306
307#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
309pub struct ProviderPolyderQuotient {
310 pub numerator: GpuTensorHandle,
311 pub denominator: GpuTensorHandle,
312}
313
314#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
316pub enum ProviderCondNorm {
317 Two,
318 One,
319 Inf,
320 Fro,
321}
322
323#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
325pub enum ProviderNormOrder {
326 Two,
327 One,
328 Inf,
329 NegInf,
330 Zero,
331 Fro,
332 Nuc,
333 P(f64),
334}
335
336#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
337pub struct ProviderEigResult {
338 pub eigenvalues: GpuTensorHandle,
339 pub diagonal: GpuTensorHandle,
340 pub right: GpuTensorHandle,
341 pub left: Option<GpuTensorHandle>,
342}
343
344#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
345pub enum ProviderQrPivot {
346 Matrix,
347 Vector,
348}
349
350#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
351pub struct ProviderQrOptions {
352 pub economy: bool,
353 pub pivot: ProviderQrPivot,
354}
355
356impl Default for ProviderQrOptions {
357 fn default() -> Self {
358 Self {
359 economy: false,
360 pivot: ProviderQrPivot::Matrix,
361 }
362 }
363}
364
365#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
366pub enum ProviderPrecision {
367 F32,
368 F64,
369}
370
371#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
372pub enum ReductionTwoPassMode {
373 Auto,
374 ForceOn,
375 ForceOff,
376}
377
378impl ReductionTwoPassMode {
379 pub fn as_str(self) -> &'static str {
380 match self {
381 ReductionTwoPassMode::Auto => "auto",
382 ReductionTwoPassMode::ForceOn => "force_on",
383 ReductionTwoPassMode::ForceOff => "force_off",
384 }
385 }
386}
387
388#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
389pub enum ReductionFlavor {
390 Sum,
391 Mean,
392 CustomScale(f64),
393}
394
395impl ReductionFlavor {
396 pub fn is_mean(self) -> bool {
397 matches!(self, ReductionFlavor::Mean)
398 }
399
400 pub fn scale(self, reduce_len: usize) -> f64 {
401 match self {
402 ReductionFlavor::Sum => 1.0,
403 ReductionFlavor::Mean => {
404 if reduce_len == 0 {
405 1.0
406 } else {
407 1.0 / reduce_len as f64
408 }
409 }
410 ReductionFlavor::CustomScale(scale) => scale,
411 }
412 }
413}
414
415#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
417pub enum CorrcoefNormalization {
418 Unbiased,
419 Biased,
420}
421
422#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
424pub enum CorrcoefRows {
425 All,
426 Complete,
427 Pairwise,
428}
429
430#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
432pub struct CorrcoefOptions {
433 pub normalization: CorrcoefNormalization,
434 pub rows: CorrcoefRows,
435}
436
437impl Default for CorrcoefOptions {
438 fn default() -> Self {
439 Self {
440 normalization: CorrcoefNormalization::Unbiased,
441 rows: CorrcoefRows::All,
442 }
443 }
444}
445
446#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
448pub enum CovNormalization {
449 Unbiased,
450 Biased,
451}
452
453#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
455pub enum CovRows {
456 All,
457 OmitRows,
458 PartialRows,
459}
460
461#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
463pub struct CovarianceOptions {
464 pub normalization: CovNormalization,
465 pub rows: CovRows,
466 pub has_weight_vector: bool,
467}
468
469impl Default for CovarianceOptions {
470 fn default() -> Self {
471 Self {
472 normalization: CovNormalization::Unbiased,
473 rows: CovRows::All,
474 has_weight_vector: false,
475 }
476 }
477}
478
479#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
481pub enum ProviderStdNormalization {
482 Sample,
483 Population,
484}
485
486#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
488pub enum ProviderNanMode {
489 Include,
490 Omit,
491}
492
493#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
495pub enum ProviderScanDirection {
496 Forward,
497 Reverse,
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
502pub enum SortOrder {
503 Ascend,
504 Descend,
505}
506
507#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
509pub enum SortComparison {
510 Auto,
511 Real,
512 Abs,
513}
514
515#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
517pub struct SortResult {
518 pub values: HostTensorOwned,
519 pub indices: HostTensorOwned,
520}
521
522#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
523pub struct SortRowsColumnSpec {
524 pub index: usize,
525 pub order: SortOrder,
526}
527
528#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
530pub enum UniqueOrder {
531 Sorted,
532 Stable,
533}
534
535#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
537pub enum UniqueOccurrence {
538 First,
539 Last,
540}
541
542#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
544pub struct UniqueOptions {
545 pub rows: bool,
546 pub order: UniqueOrder,
547 pub occurrence: UniqueOccurrence,
548}
549
550#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
552pub struct UniqueResult {
553 pub values: HostTensorOwned,
554 pub ia: HostTensorOwned,
555 pub ic: HostTensorOwned,
556}
557
558#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
560pub enum UnionOrder {
561 Sorted,
562 Stable,
563}
564
565#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
567pub struct UnionOptions {
568 pub rows: bool,
569 pub order: UnionOrder,
570}
571
572#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
574pub struct UnionResult {
575 pub values: HostTensorOwned,
576 pub ia: HostTensorOwned,
577 pub ib: HostTensorOwned,
578}
579
580#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
582pub enum FspecialFilter {
583 Average {
584 rows: u32,
585 cols: u32,
586 },
587 Disk {
588 radius: f64,
589 size: u32,
590 },
591 Gaussian {
592 rows: u32,
593 cols: u32,
594 sigma: f64,
595 },
596 Laplacian {
597 alpha: f64,
598 },
599 Log {
600 rows: u32,
601 cols: u32,
602 sigma: f64,
603 },
604 Motion {
605 length: u32,
606 kernel_size: u32,
607 angle_degrees: f64,
608 oversample: u32,
609 },
610 Prewitt,
611 Sobel,
612 Unsharp {
613 alpha: f64,
614 },
615}
616
617#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
619pub struct FspecialRequest {
620 pub filter: FspecialFilter,
621}
622
623#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
625pub enum ImfilterPadding {
626 Constant,
627 Replicate,
628 Symmetric,
629 Circular,
630}
631
632#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
634pub enum ImfilterShape {
635 Same,
636 Full,
637 Valid,
638}
639
640#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
642pub enum ImfilterMode {
643 Correlation,
644 Convolution,
645}
646
647#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
649pub struct ImfilterOptions {
650 pub padding: ImfilterPadding,
651 pub constant_value: f64,
652 pub shape: ImfilterShape,
653 pub mode: ImfilterMode,
654}
655
656impl Default for ImfilterOptions {
657 fn default() -> Self {
658 Self {
659 padding: ImfilterPadding::Constant,
660 constant_value: 0.0,
661 shape: ImfilterShape::Same,
662 mode: ImfilterMode::Correlation,
663 }
664 }
665}
666
667#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
669pub enum SetdiffOrder {
670 Sorted,
671 Stable,
672}
673
674#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
676pub struct SetdiffOptions {
677 pub rows: bool,
678 pub order: SetdiffOrder,
679}
680
681#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
683pub struct SetdiffResult {
684 pub values: HostTensorOwned,
685 pub ia: HostTensorOwned,
686}
687
688#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
690pub struct IsMemberOptions {
691 pub rows: bool,
692}
693
694#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
696pub struct HostLogicalOwned {
697 pub data: Vec<u8>,
698 pub shape: Vec<usize>,
699}
700
701#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
703pub struct IsMemberResult {
704 pub mask: HostLogicalOwned,
705 pub loc: HostTensorOwned,
706}
707
708#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
709pub enum ProviderConvMode {
710 Full,
711 Same,
712 Valid,
713}
714
715#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
716pub enum ProviderConvOrientation {
717 Row,
718 Column,
719}
720
721#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
722pub struct ProviderConv1dOptions {
723 pub mode: ProviderConvMode,
724 pub orientation: ProviderConvOrientation,
725}
726
727#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
728pub struct ProviderIirFilterOptions {
729 pub dim: usize,
731 pub zi: Option<GpuTensorHandle>,
733}
734
735#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
736pub struct ProviderIirFilterResult {
737 pub output: GpuTensorHandle,
739 pub final_state: Option<GpuTensorHandle>,
741}
742
743#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
744pub struct ProviderMoments2 {
745 pub mean: GpuTensorHandle,
746 pub ex2: GpuTensorHandle,
747}
748
749#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
750pub struct ProviderDispatchStats {
751 pub count: u64,
753 pub total_wall_time_ns: u64,
755}
756
757#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
758pub struct ProviderTelemetry {
759 pub fused_elementwise: ProviderDispatchStats,
760 pub fused_reduction: ProviderDispatchStats,
761 pub matmul: ProviderDispatchStats,
762 pub upload_bytes: u64,
763 pub download_bytes: u64,
764 pub fusion_cache_hits: u64,
765 pub fusion_cache_misses: u64,
766 pub bind_group_cache_hits: u64,
767 pub bind_group_cache_misses: u64,
768 pub bind_group_cache_by_layout: Option<Vec<BindGroupLayoutTelemetry>>,
770 pub kernel_launches: Vec<KernelLaunchTelemetry>,
772}
773
774#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
775pub struct BindGroupLayoutTelemetry {
776 pub tag: String,
777 pub hits: u64,
778 pub misses: u64,
779}
780
781#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
782pub struct KernelAttrTelemetry {
783 pub key: String,
784 pub value: u64,
785}
786
787#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
788pub struct KernelLaunchTelemetry {
789 pub kernel: String,
790 pub precision: Option<String>,
791 pub shape: Vec<KernelAttrTelemetry>,
792 pub tuning: Vec<KernelAttrTelemetry>,
793}
794
795pub trait AccelProvider: Send + Sync {
797 fn upload(&self, host: &crate::HostTensorView) -> anyhow::Result<GpuTensorHandle>;
798 fn download(&self, h: &GpuTensorHandle) -> anyhow::Result<crate::HostTensorOwned>;
799 fn free(&self, h: &GpuTensorHandle) -> anyhow::Result<()>;
800 fn device_info(&self) -> String;
801 fn device_id(&self) -> u32 {
802 0
803 }
804
805 fn gather_linear(
808 &self,
809 _source: &GpuTensorHandle,
810 _indices: &[u32],
811 _output_shape: &[usize],
812 ) -> anyhow::Result<GpuTensorHandle> {
813 Err(anyhow::anyhow!("gather_linear not supported by provider"))
814 }
815
816 fn scatter_linear(
820 &self,
821 _target: &GpuTensorHandle,
822 _indices: &[u32],
823 _values: &GpuTensorHandle,
824 ) -> anyhow::Result<()> {
825 Err(anyhow::anyhow!("scatter_linear not supported by provider"))
826 }
827
828 fn device_info_struct(&self) -> ApiDeviceInfo {
830 ApiDeviceInfo {
831 device_id: 0,
832 name: self.device_info(),
833 vendor: String::new(),
834 memory_bytes: None,
835 backend: None,
836 }
837 }
838
839 fn precision(&self) -> ProviderPrecision {
840 ProviderPrecision::F64
841 }
842
843 fn read_scalar(&self, _h: &GpuTensorHandle, _linear_index: usize) -> anyhow::Result<f64> {
845 Err(anyhow::anyhow!("read_scalar not supported by provider"))
846 }
847
848 fn zeros(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
850 Err(anyhow::anyhow!("zeros not supported by provider"))
851 }
852
853 fn ones(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
855 Err(anyhow::anyhow!("ones not supported by provider"))
856 }
857
858 fn zeros_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
860 self.zeros(&prototype.shape)
861 }
862
863 fn fill(&self, shape: &[usize], value: f64) -> anyhow::Result<GpuTensorHandle> {
865 if value == 0.0 {
866 return self.zeros(shape);
867 }
868 if let Ok(base) = self.zeros(shape) {
869 match self.scalar_add(&base, value) {
870 Ok(out) => {
871 let _ = self.free(&base);
872 return Ok(out);
873 }
874 Err(_) => {
875 let _ = self.free(&base);
876 }
877 }
878 }
879 let len: usize = shape.iter().copied().product();
880 let data = vec![value; len];
881 let view = HostTensorView { data: &data, shape };
882 self.upload(&view)
883 }
884
885 fn fill_like(
887 &self,
888 prototype: &GpuTensorHandle,
889 value: f64,
890 ) -> anyhow::Result<GpuTensorHandle> {
891 if value == 0.0 {
892 return self.zeros_like(prototype);
893 }
894 if let Ok(base) = self.zeros_like(prototype) {
895 match self.scalar_add(&base, value) {
896 Ok(out) => {
897 let _ = self.free(&base);
898 return Ok(out);
899 }
900 Err(_) => {
901 let _ = self.free(&base);
902 }
903 }
904 }
905 self.fill(&prototype.shape, value)
906 }
907
908 fn ones_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
910 self.ones(&prototype.shape)
911 }
912
913 fn eye(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
915 Err(anyhow::anyhow!("eye not supported by provider"))
916 }
917
918 fn eye_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
920 self.eye(&prototype.shape)
921 }
922
923 fn meshgrid(&self, _axes: &[MeshgridAxisView<'_>]) -> anyhow::Result<ProviderMeshgridResult> {
925 Err(anyhow::anyhow!("meshgrid not supported by provider"))
926 }
927
928 fn diag_from_vector(
930 &self,
931 _vector: &GpuTensorHandle,
932 _offset: isize,
933 ) -> anyhow::Result<GpuTensorHandle> {
934 Err(anyhow::anyhow!(
935 "diag_from_vector not supported by provider"
936 ))
937 }
938
939 fn diag_extract(
941 &self,
942 _matrix: &GpuTensorHandle,
943 _offset: isize,
944 ) -> anyhow::Result<GpuTensorHandle> {
945 Err(anyhow::anyhow!("diag_extract not supported by provider"))
946 }
947
948 fn tril(&self, _matrix: &GpuTensorHandle, _offset: isize) -> anyhow::Result<GpuTensorHandle> {
950 Err(anyhow!("tril not supported by provider"))
951 }
952
953 fn triu(&self, _matrix: &GpuTensorHandle, _offset: isize) -> anyhow::Result<GpuTensorHandle> {
955 Err(anyhow!("triu not supported by provider"))
956 }
957
958 fn polyval(
960 &self,
961 _coefficients: &GpuTensorHandle,
962 _points: &GpuTensorHandle,
963 _options: &ProviderPolyvalOptions,
964 ) -> anyhow::Result<GpuTensorHandle> {
965 Err(anyhow::anyhow!("polyval not supported by provider"))
966 }
967
968 fn polyfit(
970 &self,
971 _x: &GpuTensorHandle,
972 _y: &GpuTensorHandle,
973 _degree: usize,
974 _weights: Option<&GpuTensorHandle>,
975 ) -> anyhow::Result<ProviderPolyfitResult> {
976 Err(anyhow::anyhow!("polyfit not supported by provider"))
977 }
978
979 fn polyder_single(&self, _polynomial: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
981 Err(anyhow::anyhow!("polyder_single not supported by provider"))
982 }
983
984 fn polyder_product(
986 &self,
987 _p: &GpuTensorHandle,
988 _q: &GpuTensorHandle,
989 ) -> anyhow::Result<GpuTensorHandle> {
990 Err(anyhow::anyhow!("polyder_product not supported by provider"))
991 }
992
993 fn polyder_quotient(
995 &self,
996 _u: &GpuTensorHandle,
997 _v: &GpuTensorHandle,
998 ) -> anyhow::Result<ProviderPolyderQuotient> {
999 Err(anyhow::anyhow!(
1000 "polyder_quotient not supported by provider"
1001 ))
1002 }
1003
1004 fn polyint(
1006 &self,
1007 _polynomial: &GpuTensorHandle,
1008 _constant: f64,
1009 ) -> anyhow::Result<GpuTensorHandle> {
1010 Err(anyhow::anyhow!("polyint not supported by provider"))
1011 }
1012
1013 fn random_uniform(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1015 Err(anyhow::anyhow!("random_uniform not supported by provider"))
1016 }
1017
1018 fn random_uniform_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1020 self.random_uniform(&prototype.shape)
1021 }
1022
1023 fn random_normal(&self, _shape: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1025 Err(anyhow::anyhow!("random_normal not supported by provider"))
1026 }
1027
1028 fn random_normal_like(&self, prototype: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1030 self.random_normal(&prototype.shape)
1031 }
1032
1033 fn stochastic_evolution(
1034 &self,
1035 _state: &GpuTensorHandle,
1036 _drift: f64,
1037 _scale: f64,
1038 _steps: u32,
1039 ) -> anyhow::Result<GpuTensorHandle> {
1040 Err(anyhow::anyhow!(
1041 "stochastic_evolution not supported by provider"
1042 ))
1043 }
1044
1045 fn set_rng_state(&self, _state: u64) -> anyhow::Result<()> {
1047 Err(anyhow::anyhow!("set_rng_state not supported by provider"))
1048 }
1049
1050 fn fspecial(&self, _request: &FspecialRequest) -> anyhow::Result<GpuTensorHandle> {
1052 Err(anyhow::anyhow!("fspecial not supported by provider"))
1053 }
1054
1055 fn imfilter(
1057 &self,
1058 _image: &GpuTensorHandle,
1059 _kernel: &GpuTensorHandle,
1060 _options: &ImfilterOptions,
1061 ) -> anyhow::Result<GpuTensorHandle> {
1062 Err(anyhow::anyhow!("imfilter not supported by provider"))
1063 }
1064
1065 fn random_integer_range(
1067 &self,
1068 _lower: i64,
1069 _upper: i64,
1070 _shape: &[usize],
1071 ) -> anyhow::Result<GpuTensorHandle> {
1072 Err(anyhow::anyhow!(
1073 "random_integer_range not supported by provider"
1074 ))
1075 }
1076
1077 fn random_integer_like(
1079 &self,
1080 prototype: &GpuTensorHandle,
1081 lower: i64,
1082 upper: i64,
1083 ) -> anyhow::Result<GpuTensorHandle> {
1084 self.random_integer_range(lower, upper, &prototype.shape)
1085 }
1086
1087 fn random_permutation(&self, _n: usize, _k: usize) -> anyhow::Result<GpuTensorHandle> {
1089 Err(anyhow!("random_permutation not supported by provider"))
1090 }
1091
1092 fn random_permutation_like(
1094 &self,
1095 _prototype: &GpuTensorHandle,
1096 n: usize,
1097 k: usize,
1098 ) -> anyhow::Result<GpuTensorHandle> {
1099 self.random_permutation(n, k)
1100 }
1101
1102 fn covariance(
1104 &self,
1105 _matrix: &GpuTensorHandle,
1106 _second: Option<&GpuTensorHandle>,
1107 _weights: Option<&GpuTensorHandle>,
1108 _options: &CovarianceOptions,
1109 ) -> anyhow::Result<GpuTensorHandle> {
1110 Err(anyhow::anyhow!("covariance not supported by provider"))
1111 }
1112
1113 fn corrcoef(
1115 &self,
1116 _matrix: &GpuTensorHandle,
1117 _options: &CorrcoefOptions,
1118 ) -> anyhow::Result<GpuTensorHandle> {
1119 Err(anyhow::anyhow!("corrcoef not supported by provider"))
1120 }
1121
1122 fn linspace(&self, _start: f64, _stop: f64, _count: usize) -> anyhow::Result<GpuTensorHandle> {
1124 Err(anyhow::anyhow!("linspace not supported by provider"))
1125 }
1126 fn elem_add(
1127 &self,
1128 _a: &GpuTensorHandle,
1129 _b: &GpuTensorHandle,
1130 ) -> anyhow::Result<GpuTensorHandle> {
1131 Err(anyhow::anyhow!("elem_add not supported by provider"))
1132 }
1133 fn elem_mul(
1134 &self,
1135 _a: &GpuTensorHandle,
1136 _b: &GpuTensorHandle,
1137 ) -> anyhow::Result<GpuTensorHandle> {
1138 Err(anyhow::anyhow!("elem_mul not supported by provider"))
1139 }
1140 fn elem_max(
1141 &self,
1142 _a: &GpuTensorHandle,
1143 _b: &GpuTensorHandle,
1144 ) -> anyhow::Result<GpuTensorHandle> {
1145 Err(anyhow::anyhow!("elem_max not supported by provider"))
1146 }
1147 fn elem_min(
1148 &self,
1149 _a: &GpuTensorHandle,
1150 _b: &GpuTensorHandle,
1151 ) -> anyhow::Result<GpuTensorHandle> {
1152 Err(anyhow::anyhow!("elem_min not supported by provider"))
1153 }
1154 fn elem_sub(
1155 &self,
1156 _a: &GpuTensorHandle,
1157 _b: &GpuTensorHandle,
1158 ) -> anyhow::Result<GpuTensorHandle> {
1159 Err(anyhow::anyhow!("elem_sub not supported by provider"))
1160 }
1161 fn elem_div(
1162 &self,
1163 _a: &GpuTensorHandle,
1164 _b: &GpuTensorHandle,
1165 ) -> anyhow::Result<GpuTensorHandle> {
1166 Err(anyhow::anyhow!("elem_div not supported by provider"))
1167 }
1168 fn elem_pow(
1169 &self,
1170 _a: &GpuTensorHandle,
1171 _b: &GpuTensorHandle,
1172 ) -> anyhow::Result<GpuTensorHandle> {
1173 Err(anyhow::anyhow!("elem_pow not supported by provider"))
1174 }
1175
1176 fn elem_hypot(
1177 &self,
1178 _a: &GpuTensorHandle,
1179 _b: &GpuTensorHandle,
1180 ) -> anyhow::Result<GpuTensorHandle> {
1181 Err(anyhow::anyhow!("elem_hypot not supported by provider"))
1182 }
1183 fn elem_ge(
1184 &self,
1185 _a: &GpuTensorHandle,
1186 _b: &GpuTensorHandle,
1187 ) -> anyhow::Result<GpuTensorHandle> {
1188 Err(anyhow::anyhow!("elem_ge not supported by provider"))
1189 }
1190 fn elem_le(
1191 &self,
1192 _a: &GpuTensorHandle,
1193 _b: &GpuTensorHandle,
1194 ) -> anyhow::Result<GpuTensorHandle> {
1195 Err(anyhow::anyhow!("elem_le not supported by provider"))
1196 }
1197 fn elem_lt(
1198 &self,
1199 _a: &GpuTensorHandle,
1200 _b: &GpuTensorHandle,
1201 ) -> anyhow::Result<GpuTensorHandle> {
1202 Err(anyhow::anyhow!("elem_lt not supported by provider"))
1203 }
1204 fn elem_gt(
1205 &self,
1206 _a: &GpuTensorHandle,
1207 _b: &GpuTensorHandle,
1208 ) -> anyhow::Result<GpuTensorHandle> {
1209 Err(anyhow::anyhow!("elem_gt not supported by provider"))
1210 }
1211 fn elem_eq(
1212 &self,
1213 _a: &GpuTensorHandle,
1214 _b: &GpuTensorHandle,
1215 ) -> anyhow::Result<GpuTensorHandle> {
1216 Err(anyhow::anyhow!("elem_eq not supported by provider"))
1217 }
1218 fn elem_ne(
1219 &self,
1220 _a: &GpuTensorHandle,
1221 _b: &GpuTensorHandle,
1222 ) -> anyhow::Result<GpuTensorHandle> {
1223 Err(anyhow::anyhow!("elem_ne not supported by provider"))
1224 }
1225 fn logical_and(
1226 &self,
1227 _a: &GpuTensorHandle,
1228 _b: &GpuTensorHandle,
1229 ) -> anyhow::Result<GpuTensorHandle> {
1230 Err(anyhow::anyhow!("logical_and not supported by provider"))
1231 }
1232 fn logical_or(
1233 &self,
1234 _a: &GpuTensorHandle,
1235 _b: &GpuTensorHandle,
1236 ) -> anyhow::Result<GpuTensorHandle> {
1237 Err(anyhow::anyhow!("logical_or not supported by provider"))
1238 }
1239 fn logical_xor(
1240 &self,
1241 _a: &GpuTensorHandle,
1242 _b: &GpuTensorHandle,
1243 ) -> anyhow::Result<GpuTensorHandle> {
1244 Err(anyhow::anyhow!("logical_xor not supported by provider"))
1245 }
1246 fn logical_not(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1247 Err(anyhow::anyhow!("logical_not not supported by provider"))
1248 }
1249 fn logical_islogical(&self, a: &GpuTensorHandle) -> anyhow::Result<bool> {
1250 Ok(handle_is_logical(a))
1251 }
1252 fn logical_isreal(&self, _a: &GpuTensorHandle) -> anyhow::Result<bool> {
1253 Err(anyhow::anyhow!("logical_isreal not supported by provider"))
1254 }
1255 fn logical_isfinite(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1256 Err(anyhow::anyhow!(
1257 "logical_isfinite not supported by provider"
1258 ))
1259 }
1260 fn logical_isnan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1261 Err(anyhow::anyhow!("logical_isnan not supported by provider"))
1262 }
1263 fn logical_isinf(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1264 Err(anyhow::anyhow!("logical_isinf not supported by provider"))
1265 }
1266 fn elem_atan2(
1267 &self,
1268 _y: &GpuTensorHandle,
1269 _x: &GpuTensorHandle,
1270 ) -> anyhow::Result<GpuTensorHandle> {
1271 Err(anyhow::anyhow!("elem_atan2 not supported by provider"))
1272 }
1273 fn unary_sin(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1275 Err(anyhow::anyhow!("unary_sin not supported by provider"))
1276 }
1277 fn unary_gamma(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1278 Err(anyhow::anyhow!("unary_gamma not supported by provider"))
1279 }
1280 fn unary_factorial(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1281 Err(anyhow::anyhow!("unary_factorial not supported by provider"))
1282 }
1283 fn unary_asinh(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1284 Err(anyhow::anyhow!("unary_asinh not supported by provider"))
1285 }
1286 fn unary_sinh(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1287 Err(anyhow::anyhow!("unary_sinh not supported by provider"))
1288 }
1289 fn unary_cosh(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1290 Err(anyhow::anyhow!("unary_cosh not supported by provider"))
1291 }
1292 fn unary_asin(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1293 Err(anyhow::anyhow!("unary_asin not supported by provider"))
1294 }
1295 fn unary_acos(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1296 Err(anyhow::anyhow!("unary_acos not supported by provider"))
1297 }
1298 fn unary_acosh(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1299 Err(anyhow::anyhow!("unary_acosh not supported by provider"))
1300 }
1301 fn unary_tan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1302 Err(anyhow::anyhow!("unary_tan not supported by provider"))
1303 }
1304 fn unary_tanh(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1305 Err(anyhow::anyhow!("unary_tanh not supported by provider"))
1306 }
1307 fn unary_atan(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1308 Err(anyhow::anyhow!("unary_atan not supported by provider"))
1309 }
1310 fn unary_atanh(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1311 Err(anyhow::anyhow!("unary_atanh not supported by provider"))
1312 }
1313 fn unary_ceil(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1314 Err(anyhow::anyhow!("unary_ceil not supported by provider"))
1315 }
1316 fn unary_floor(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1317 Err(anyhow::anyhow!("unary_floor not supported by provider"))
1318 }
1319 fn unary_round(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1320 Err(anyhow::anyhow!("unary_round not supported by provider"))
1321 }
1322 fn unary_fix(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1323 Err(anyhow::anyhow!("unary_fix not supported by provider"))
1324 }
1325 fn unary_cos(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1326 Err(anyhow::anyhow!("unary_cos not supported by provider"))
1327 }
1328 fn unary_angle(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1329 Err(anyhow::anyhow!("unary_angle not supported by provider"))
1330 }
1331 fn unary_imag(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1332 Err(anyhow::anyhow!("unary_imag not supported by provider"))
1333 }
1334 fn unary_real(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1335 Err(anyhow::anyhow!("unary_real not supported by provider"))
1336 }
1337 fn unary_conj(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1338 Err(anyhow::anyhow!("unary_conj not supported by provider"))
1339 }
1340 fn unary_abs(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1341 Err(anyhow::anyhow!("unary_abs not supported by provider"))
1342 }
1343 fn unary_sign(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1344 Err(anyhow::anyhow!("unary_sign not supported by provider"))
1345 }
1346 fn unary_exp(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1347 Err(anyhow::anyhow!("unary_exp not supported by provider"))
1348 }
1349 fn unary_expm1(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1350 Err(anyhow::anyhow!("unary_expm1 not supported by provider"))
1351 }
1352 fn unary_log(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1353 Err(anyhow::anyhow!("unary_log not supported by provider"))
1354 }
1355 fn unary_log2(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1356 Err(anyhow::anyhow!("unary_log2 not supported by provider"))
1357 }
1358 fn unary_log10(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1359 Err(anyhow::anyhow!("unary_log10 not supported by provider"))
1360 }
1361 fn unary_log1p(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1362 Err(anyhow::anyhow!("unary_log1p not supported by provider"))
1363 }
1364 fn unary_sqrt(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1365 Err(anyhow::anyhow!("unary_sqrt not supported by provider"))
1366 }
1367 fn unary_double(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1368 Err(anyhow::anyhow!("unary_double not supported by provider"))
1369 }
1370 fn unary_single(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1371 Err(anyhow::anyhow!("unary_single not supported by provider"))
1372 }
1373 fn unary_pow2(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1374 Err(anyhow::anyhow!("unary_pow2 not supported by provider"))
1375 }
1376 fn pow2_scale(
1377 &self,
1378 _mantissa: &GpuTensorHandle,
1379 _exponent: &GpuTensorHandle,
1380 ) -> anyhow::Result<GpuTensorHandle> {
1381 Err(anyhow::anyhow!("pow2_scale not supported by provider"))
1382 }
1383 fn scalar_rsub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1385 Err(anyhow::anyhow!("scalar_rsub not supported by provider"))
1386 }
1387 fn scalar_rdiv(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1388 Err(anyhow::anyhow!("scalar_rdiv not supported by provider"))
1389 }
1390 fn scalar_add(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1392 Err(anyhow::anyhow!("scalar_add not supported by provider"))
1393 }
1394 fn scalar_sub(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1395 Err(anyhow::anyhow!("scalar_sub not supported by provider"))
1396 }
1397 fn scalar_mul(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1398 Err(anyhow::anyhow!("scalar_mul not supported by provider"))
1399 }
1400 fn scalar_max(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1401 Err(anyhow::anyhow!("scalar_max not supported by provider"))
1402 }
1403 fn scalar_min(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1404 Err(anyhow::anyhow!("scalar_min not supported by provider"))
1405 }
1406 fn scalar_div(&self, _a: &GpuTensorHandle, _scalar: f64) -> anyhow::Result<GpuTensorHandle> {
1407 Err(anyhow::anyhow!("scalar_div not supported by provider"))
1408 }
1409 fn sort_dim(
1410 &self,
1411 _a: &GpuTensorHandle,
1412 _dim: usize,
1413 _order: SortOrder,
1414 _comparison: SortComparison,
1415 ) -> anyhow::Result<SortResult> {
1416 Err(anyhow::anyhow!("sort_dim not supported by provider"))
1417 }
1418 fn sort_rows(
1419 &self,
1420 _a: &GpuTensorHandle,
1421 _columns: &[SortRowsColumnSpec],
1422 _comparison: SortComparison,
1423 ) -> anyhow::Result<SortResult> {
1424 Err(anyhow::anyhow!("sort_rows not supported by provider"))
1425 }
1426 fn matmul(
1427 &self,
1428 _a: &GpuTensorHandle,
1429 _b: &GpuTensorHandle,
1430 ) -> anyhow::Result<GpuTensorHandle> {
1431 Err(anyhow::anyhow!("matmul not supported by provider"))
1432 }
1433
1434 fn syrk(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1435 Err(anyhow::anyhow!("syrk not supported by provider"))
1436 }
1437 fn pagefun(&self, _request: &PagefunRequest) -> anyhow::Result<GpuTensorHandle> {
1438 Err(anyhow::anyhow!("pagefun not supported by provider"))
1439 }
1440
1441 fn matmul_epilogue(
1446 &self,
1447 a: &GpuTensorHandle,
1448 b: &GpuTensorHandle,
1449 epilogue: &MatmulEpilogue,
1450 ) -> anyhow::Result<GpuTensorHandle> {
1451 if epilogue.is_noop() {
1452 return self.matmul(a, b);
1453 }
1454 Err(anyhow::anyhow!("matmul_epilogue not supported by provider"))
1455 }
1456 fn image_normalize(
1457 &self,
1458 _input: &GpuTensorHandle,
1459 _desc: &ImageNormalizeDescriptor,
1460 ) -> anyhow::Result<GpuTensorHandle> {
1461 Err(anyhow::anyhow!(
1462 "image_normalize fusion not supported by provider"
1463 ))
1464 }
1465 fn matmul_power_step(
1466 &self,
1467 _lhs: &GpuTensorHandle,
1468 _rhs: &GpuTensorHandle,
1469 _epilogue: &PowerStepEpilogue,
1470 ) -> anyhow::Result<GpuTensorHandle> {
1471 Err(anyhow::anyhow!(
1472 "matmul_power_step normalization not supported by provider"
1473 ))
1474 }
1475 fn linsolve(
1476 &self,
1477 _lhs: &GpuTensorHandle,
1478 _rhs: &GpuTensorHandle,
1479 _options: &ProviderLinsolveOptions,
1480 ) -> anyhow::Result<ProviderLinsolveResult> {
1481 Err(anyhow::anyhow!("linsolve not supported by provider"))
1482 }
1483 fn inv(
1484 &self,
1485 _matrix: &GpuTensorHandle,
1486 _options: ProviderInvOptions,
1487 ) -> anyhow::Result<GpuTensorHandle> {
1488 Err(anyhow::anyhow!("inv not supported by provider"))
1489 }
1490 fn pinv(
1491 &self,
1492 _matrix: &GpuTensorHandle,
1493 _options: ProviderPinvOptions,
1494 ) -> anyhow::Result<GpuTensorHandle> {
1495 Err(anyhow::anyhow!("pinv not supported by provider"))
1496 }
1497 fn cond(
1498 &self,
1499 _matrix: &GpuTensorHandle,
1500 _norm: ProviderCondNorm,
1501 ) -> anyhow::Result<GpuTensorHandle> {
1502 Err(anyhow::anyhow!("cond not supported by provider"))
1503 }
1504 fn norm(
1505 &self,
1506 _tensor: &GpuTensorHandle,
1507 _order: ProviderNormOrder,
1508 ) -> anyhow::Result<GpuTensorHandle> {
1509 Err(anyhow::anyhow!("norm not supported by provider"))
1510 }
1511 fn rank(
1512 &self,
1513 _matrix: &GpuTensorHandle,
1514 _tolerance: Option<f64>,
1515 ) -> anyhow::Result<GpuTensorHandle> {
1516 Err(anyhow::anyhow!("rank not supported by provider"))
1517 }
1518 fn rcond(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1519 Err(anyhow::anyhow!("rcond not supported by provider"))
1520 }
1521 fn mldivide(
1522 &self,
1523 _lhs: &GpuTensorHandle,
1524 _rhs: &GpuTensorHandle,
1525 ) -> anyhow::Result<GpuTensorHandle> {
1526 Err(anyhow::anyhow!("mldivide not supported by provider"))
1527 }
1528 fn mrdivide(
1529 &self,
1530 _lhs: &GpuTensorHandle,
1531 _rhs: &GpuTensorHandle,
1532 ) -> anyhow::Result<GpuTensorHandle> {
1533 Err(anyhow::anyhow!("mrdivide not supported by provider"))
1534 }
1535 fn eig(&self, _a: &GpuTensorHandle, _compute_left: bool) -> anyhow::Result<ProviderEigResult> {
1536 Err(anyhow::anyhow!("eig not supported by provider"))
1537 }
1538 fn lu(&self, _a: &GpuTensorHandle) -> anyhow::Result<ProviderLuResult> {
1539 Err(anyhow::anyhow!("lu not supported by provider"))
1540 }
1541
1542 fn chol(&self, _a: &GpuTensorHandle, _lower: bool) -> anyhow::Result<ProviderCholResult> {
1543 Err(anyhow::anyhow!("chol not supported by provider"))
1544 }
1545 fn qr(
1546 &self,
1547 _a: &GpuTensorHandle,
1548 _options: ProviderQrOptions,
1549 ) -> anyhow::Result<ProviderQrResult> {
1550 Err(anyhow::anyhow!("qr not supported by provider"))
1551 }
1552 fn take_matmul_sources(
1553 &self,
1554 _product: &GpuTensorHandle,
1555 ) -> Option<(GpuTensorHandle, GpuTensorHandle)> {
1556 None
1557 }
1558 fn qr_power_iter(
1559 &self,
1560 product: &GpuTensorHandle,
1561 _product_lhs: Option<&GpuTensorHandle>,
1562 q_handle: &GpuTensorHandle,
1563 options: &ProviderQrOptions,
1564 ) -> anyhow::Result<Option<ProviderQrPowerIterResult>> {
1565 let _ = (product, q_handle, options);
1566 Ok(None)
1567 }
1568 fn transpose(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1569 Err(anyhow::anyhow!("transpose not supported by provider"))
1570 }
1571 fn conv1d(
1572 &self,
1573 _signal: &GpuTensorHandle,
1574 _kernel: &GpuTensorHandle,
1575 _options: ProviderConv1dOptions,
1576 ) -> anyhow::Result<GpuTensorHandle> {
1577 Err(anyhow::anyhow!("conv1d not supported by provider"))
1578 }
1579 fn conv2d(
1580 &self,
1581 _signal: &GpuTensorHandle,
1582 _kernel: &GpuTensorHandle,
1583 _mode: ProviderConvMode,
1584 ) -> anyhow::Result<GpuTensorHandle> {
1585 Err(anyhow::anyhow!("conv2d not supported by provider"))
1586 }
1587 fn iir_filter(
1588 &self,
1589 _b: &GpuTensorHandle,
1590 _a: &GpuTensorHandle,
1591 _x: &GpuTensorHandle,
1592 _options: ProviderIirFilterOptions,
1593 ) -> anyhow::Result<ProviderIirFilterResult> {
1594 Err(anyhow::anyhow!("iir_filter not supported by provider"))
1595 }
1596 fn permute(
1598 &self,
1599 _handle: &GpuTensorHandle,
1600 _order: &[usize],
1601 ) -> anyhow::Result<GpuTensorHandle> {
1602 Err(anyhow::anyhow!("permute not supported by provider"))
1603 }
1604 fn flip(&self, _handle: &GpuTensorHandle, _axes: &[usize]) -> anyhow::Result<GpuTensorHandle> {
1605 Err(anyhow::anyhow!("flip not supported by provider"))
1606 }
1607 fn circshift(
1608 &self,
1609 _handle: &GpuTensorHandle,
1610 _shifts: &[isize],
1611 ) -> anyhow::Result<GpuTensorHandle> {
1612 Err(anyhow::anyhow!("circshift not supported by provider"))
1613 }
1614 fn diff_dim(
1615 &self,
1616 _handle: &GpuTensorHandle,
1617 _order: usize,
1618 _dim: usize,
1619 ) -> anyhow::Result<GpuTensorHandle> {
1620 Err(anyhow::anyhow!("diff_dim not supported by provider"))
1621 }
1622 fn fft_dim(
1624 &self,
1625 _handle: &GpuTensorHandle,
1626 _len: Option<usize>,
1627 _dim: usize,
1628 ) -> anyhow::Result<GpuTensorHandle> {
1629 Err(anyhow::anyhow!("fft_dim not supported by provider"))
1630 }
1631 fn ifft_dim(
1632 &self,
1633 _handle: &GpuTensorHandle,
1634 _len: Option<usize>,
1635 _dim: usize,
1636 ) -> anyhow::Result<GpuTensorHandle> {
1637 Err(anyhow::anyhow!("ifft_dim not supported by provider"))
1638 }
1639 fn unique(
1640 &self,
1641 _handle: &GpuTensorHandle,
1642 _options: &UniqueOptions,
1643 ) -> anyhow::Result<UniqueResult> {
1644 Err(anyhow::anyhow!("unique not supported by provider"))
1645 }
1646 fn union(
1647 &self,
1648 _a: &GpuTensorHandle,
1649 _b: &GpuTensorHandle,
1650 _options: &UnionOptions,
1651 ) -> anyhow::Result<UnionResult> {
1652 Err(anyhow::anyhow!("union not supported by provider"))
1653 }
1654 fn setdiff(
1655 &self,
1656 _a: &GpuTensorHandle,
1657 _b: &GpuTensorHandle,
1658 _options: &SetdiffOptions,
1659 ) -> anyhow::Result<SetdiffResult> {
1660 Err(anyhow::anyhow!("setdiff not supported by provider"))
1661 }
1662 fn ismember(
1663 &self,
1664 _a: &GpuTensorHandle,
1665 _b: &GpuTensorHandle,
1666 _options: &IsMemberOptions,
1667 ) -> anyhow::Result<IsMemberResult> {
1668 Err(anyhow::anyhow!("ismember not supported by provider"))
1669 }
1670 fn reshape(
1671 &self,
1672 handle: &GpuTensorHandle,
1673 new_shape: &[usize],
1674 ) -> anyhow::Result<GpuTensorHandle> {
1675 let mut updated = handle.clone();
1676 updated.shape = new_shape.to_vec();
1677 Ok(updated)
1678 }
1679 fn cat(&self, _dim: usize, _inputs: &[GpuTensorHandle]) -> anyhow::Result<GpuTensorHandle> {
1681 Err(anyhow::anyhow!("cat not supported by provider"))
1682 }
1683 fn repmat(
1684 &self,
1685 _handle: &GpuTensorHandle,
1686 _reps: &[usize],
1687 ) -> anyhow::Result<GpuTensorHandle> {
1688 Err(anyhow::anyhow!("repmat not supported by provider"))
1689 }
1690 fn kron(&self, _a: &GpuTensorHandle, _b: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1692 Err(anyhow::anyhow!("kron not supported by provider"))
1693 }
1694 fn reduce_sum(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1695 Err(anyhow::anyhow!("reduce_sum not supported by provider"))
1696 }
1697 fn reduce_sum_dim(&self, _a: &GpuTensorHandle, _dim: usize) -> anyhow::Result<GpuTensorHandle> {
1698 Err(anyhow::anyhow!("reduce_sum_dim not supported by provider"))
1699 }
1700 fn dot(
1701 &self,
1702 _lhs: &GpuTensorHandle,
1703 _rhs: &GpuTensorHandle,
1704 _dim: Option<usize>,
1705 ) -> anyhow::Result<GpuTensorHandle> {
1706 Err(anyhow::anyhow!("dot not supported by provider"))
1707 }
1708 fn reduce_nnz(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1709 Err(anyhow::anyhow!("reduce_nnz not supported by provider"))
1710 }
1711 fn reduce_nnz_dim(&self, _a: &GpuTensorHandle, _dim: usize) -> anyhow::Result<GpuTensorHandle> {
1712 Err(anyhow::anyhow!("reduce_nnz_dim not supported by provider"))
1713 }
1714 fn reduce_prod(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1715 Err(anyhow::anyhow!("reduce_prod not supported by provider"))
1716 }
1717 fn reduce_prod_dim(
1718 &self,
1719 _a: &GpuTensorHandle,
1720 _dim: usize,
1721 ) -> anyhow::Result<GpuTensorHandle> {
1722 Err(anyhow::anyhow!("reduce_prod_dim not supported by provider"))
1723 }
1724 fn reduce_mean(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1725 Err(anyhow::anyhow!("reduce_mean not supported by provider"))
1726 }
1727 fn reduce_mean_nd(
1729 &self,
1730 _a: &GpuTensorHandle,
1731 _dims_zero_based: &[usize],
1732 ) -> anyhow::Result<GpuTensorHandle> {
1733 Err(anyhow::anyhow!("reduce_mean_nd not supported by provider"))
1734 }
1735 fn reduce_moments_nd(
1738 &self,
1739 _a: &GpuTensorHandle,
1740 _dims_zero_based: &[usize],
1741 ) -> anyhow::Result<ProviderMoments2> {
1742 Err(anyhow::anyhow!(
1743 "reduce_moments_nd not supported by provider"
1744 ))
1745 }
1746 fn reduce_mean_dim(
1747 &self,
1748 _a: &GpuTensorHandle,
1749 _dim: usize,
1750 ) -> anyhow::Result<GpuTensorHandle> {
1751 Err(anyhow::anyhow!("reduce_mean_dim not supported by provider"))
1752 }
1753 fn reduce_std(
1754 &self,
1755 _a: &GpuTensorHandle,
1756 _normalization: ProviderStdNormalization,
1757 _nan_mode: ProviderNanMode,
1758 ) -> anyhow::Result<GpuTensorHandle> {
1759 Err(anyhow::anyhow!("reduce_std not supported by provider"))
1760 }
1761 fn reduce_std_dim(
1762 &self,
1763 _a: &GpuTensorHandle,
1764 _dim: usize,
1765 _normalization: ProviderStdNormalization,
1766 _nan_mode: ProviderNanMode,
1767 ) -> anyhow::Result<GpuTensorHandle> {
1768 Err(anyhow::anyhow!("reduce_std_dim not supported by provider"))
1769 }
1770 fn reduce_any(&self, _a: &GpuTensorHandle, _omit_nan: bool) -> anyhow::Result<GpuTensorHandle> {
1771 Err(anyhow::anyhow!("reduce_any not supported by provider"))
1772 }
1773 fn reduce_any_dim(
1774 &self,
1775 _a: &GpuTensorHandle,
1776 _dim: usize,
1777 _omit_nan: bool,
1778 ) -> anyhow::Result<GpuTensorHandle> {
1779 Err(anyhow::anyhow!("reduce_any_dim not supported by provider"))
1780 }
1781 fn reduce_all(&self, _a: &GpuTensorHandle, _omit_nan: bool) -> anyhow::Result<GpuTensorHandle> {
1782 Err(anyhow::anyhow!("reduce_all not supported by provider"))
1783 }
1784 fn reduce_all_dim(
1785 &self,
1786 _a: &GpuTensorHandle,
1787 _dim: usize,
1788 _omit_nan: bool,
1789 ) -> anyhow::Result<GpuTensorHandle> {
1790 Err(anyhow::anyhow!("reduce_all_dim not supported by provider"))
1791 }
1792 fn reduce_median(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1793 Err(anyhow::anyhow!("reduce_median not supported by provider"))
1794 }
1795 fn reduce_median_dim(
1796 &self,
1797 _a: &GpuTensorHandle,
1798 _dim: usize,
1799 ) -> anyhow::Result<GpuTensorHandle> {
1800 Err(anyhow::anyhow!(
1801 "reduce_median_dim not supported by provider"
1802 ))
1803 }
1804 fn reduce_min(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1805 Err(anyhow::anyhow!("reduce_min not supported by provider"))
1806 }
1807 fn reduce_min_dim(&self, _a: &GpuTensorHandle, _dim: usize) -> anyhow::Result<ReduceDimResult> {
1808 Err(anyhow::anyhow!("reduce_min_dim not supported by provider"))
1809 }
1810 fn reduce_max(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1811 Err(anyhow::anyhow!("reduce_max not supported by provider"))
1812 }
1813 fn reduce_max_dim(&self, _a: &GpuTensorHandle, _dim: usize) -> anyhow::Result<ReduceDimResult> {
1814 Err(anyhow::anyhow!("reduce_max_dim not supported by provider"))
1815 }
1816 fn cumsum_scan(
1817 &self,
1818 _input: &GpuTensorHandle,
1819 _dim: usize,
1820 _direction: ProviderScanDirection,
1821 _nan_mode: ProviderNanMode,
1822 ) -> anyhow::Result<GpuTensorHandle> {
1823 Err(anyhow::anyhow!("cumsum_scan not supported by provider"))
1824 }
1825 fn cumprod_scan(
1826 &self,
1827 _input: &GpuTensorHandle,
1828 _dim: usize,
1829 _direction: ProviderScanDirection,
1830 _nan_mode: ProviderNanMode,
1831 ) -> anyhow::Result<GpuTensorHandle> {
1832 Err(anyhow::anyhow!("cumprod_scan not supported by provider"))
1833 }
1834 fn cummin_scan(
1835 &self,
1836 _input: &GpuTensorHandle,
1837 _dim: usize,
1838 _direction: ProviderScanDirection,
1839 _nan_mode: ProviderNanMode,
1840 ) -> anyhow::Result<ProviderCumminResult> {
1841 Err(anyhow::anyhow!("cummin_scan not supported by provider"))
1842 }
1843 fn cummax_scan(
1844 &self,
1845 _input: &GpuTensorHandle,
1846 _dim: usize,
1847 _direction: ProviderScanDirection,
1848 _nan_mode: ProviderNanMode,
1849 ) -> anyhow::Result<ProviderCummaxResult> {
1850 Err(anyhow::anyhow!("cummax_scan not supported by provider"))
1851 }
1852
1853 fn find(
1854 &self,
1855 _a: &GpuTensorHandle,
1856 _limit: Option<usize>,
1857 _direction: FindDirection,
1858 ) -> anyhow::Result<ProviderFindResult> {
1859 Err(anyhow::anyhow!("find not supported by provider"))
1860 }
1861
1862 fn fused_elementwise(
1863 &self,
1864 _shader: &str,
1865 _inputs: &[GpuTensorHandle],
1866 _output_shape: &[usize],
1867 _len: usize,
1868 ) -> anyhow::Result<GpuTensorHandle> {
1869 Err(anyhow::anyhow!(
1870 "fused_elementwise not supported by provider"
1871 ))
1872 }
1873
1874 fn map_nan_to_zero(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1876 Err(anyhow::anyhow!("map_nan_to_zero not supported by provider"))
1877 }
1878
1879 fn not_nan_mask(&self, _a: &GpuTensorHandle) -> anyhow::Result<GpuTensorHandle> {
1881 Err(anyhow::anyhow!("not_nan_mask not supported by provider"))
1882 }
1883
1884 #[allow(clippy::too_many_arguments)]
1891 fn fused_reduction(
1892 &self,
1893 _shader: &str,
1894 _inputs: &[GpuTensorHandle],
1895 _output_shape: &[usize],
1896 _reduce_len: usize,
1897 _num_slices: usize,
1898 _workgroup_size: u32,
1899 _flavor: ReductionFlavor,
1900 ) -> anyhow::Result<GpuTensorHandle> {
1901 Err(anyhow::anyhow!("fused_reduction not supported by provider"))
1902 }
1903
1904 fn warmup(&self) {}
1906
1907 fn fused_cache_counters(&self) -> (u64, u64) {
1909 (0, 0)
1910 }
1911
1912 fn last_warmup_millis(&self) -> Option<u64> {
1914 None
1915 }
1916
1917 fn telemetry_snapshot(&self) -> ProviderTelemetry {
1919 let (hits, misses) = self.fused_cache_counters();
1920 ProviderTelemetry {
1921 fused_elementwise: ProviderDispatchStats::default(),
1922 fused_reduction: ProviderDispatchStats::default(),
1923 matmul: ProviderDispatchStats::default(),
1924 upload_bytes: 0,
1925 download_bytes: 0,
1926 fusion_cache_hits: hits,
1927 fusion_cache_misses: misses,
1928 bind_group_cache_hits: 0,
1929 bind_group_cache_misses: 0,
1930 bind_group_cache_by_layout: None,
1931 kernel_launches: Vec::new(),
1932 }
1933 }
1934
1935 fn reset_telemetry(&self) {}
1937
1938 fn default_reduction_workgroup_size(&self) -> u32 {
1940 256
1941 }
1942
1943 fn two_pass_threshold(&self) -> usize {
1945 1024
1946 }
1947
1948 fn reduction_two_pass_mode(&self) -> ReductionTwoPassMode {
1950 ReductionTwoPassMode::Auto
1951 }
1952
1953 fn scatter_column(
1956 &self,
1957 _matrix: &GpuTensorHandle,
1958 _col_index: usize,
1959 _values: &GpuTensorHandle,
1960 ) -> anyhow::Result<GpuTensorHandle> {
1961 Err(anyhow::anyhow!("scatter_column not supported by provider"))
1962 }
1963
1964 fn scatter_row(
1967 &self,
1968 _matrix: &GpuTensorHandle,
1969 _row_index: usize,
1970 _values: &GpuTensorHandle,
1971 ) -> anyhow::Result<GpuTensorHandle> {
1972 Err(anyhow::anyhow!("scatter_row not supported by provider"))
1973 }
1974
1975 fn sub2ind(
1976 &self,
1977 _dims: &[usize],
1978 _strides: &[usize],
1979 _inputs: &[&GpuTensorHandle],
1980 _scalar_mask: &[bool],
1981 _len: usize,
1982 _output_shape: &[usize],
1983 ) -> anyhow::Result<GpuTensorHandle> {
1984 Err(anyhow::anyhow!("sub2ind not supported by provider"))
1985 }
1986
1987 fn supports_ind2sub(&self) -> bool {
1989 false
1990 }
1991
1992 fn ind2sub(
1994 &self,
1995 _dims: &[usize],
1996 _strides: &[usize],
1997 _indices: &GpuTensorHandle,
1998 _total: usize,
1999 _len: usize,
2000 _output_shape: &[usize],
2001 ) -> anyhow::Result<Vec<GpuTensorHandle>> {
2002 Err(anyhow::anyhow!("ind2sub not supported by provider"))
2003 }
2004
2005 fn issymmetric(
2007 &self,
2008 _matrix: &GpuTensorHandle,
2009 _kind: ProviderSymmetryKind,
2010 _tolerance: f64,
2011 ) -> anyhow::Result<bool> {
2012 Err(anyhow::anyhow!(
2013 "issymmetric predicate not supported by provider"
2014 ))
2015 }
2016
2017 fn ishermitian(
2019 &self,
2020 _matrix: &GpuTensorHandle,
2021 _kind: ProviderHermitianKind,
2022 _tolerance: f64,
2023 ) -> anyhow::Result<bool> {
2024 Err(anyhow::anyhow!(
2025 "ishermitian predicate not supported by provider"
2026 ))
2027 }
2028
2029 fn bandwidth(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<ProviderBandwidth> {
2031 Err(anyhow::anyhow!("bandwidth not supported by provider"))
2032 }
2033
2034 fn sym_rcm(&self, _matrix: &GpuTensorHandle) -> anyhow::Result<Vec<usize>> {
2039 Err(anyhow::anyhow!("sym_rcm not supported by provider"))
2040 }
2041}
2042
2043static GLOBAL_PROVIDER: Lazy<RwLock<Option<&'static dyn AccelProvider>>> =
2044 Lazy::new(|| RwLock::new(None));
2045static PROVIDER_REGISTRY: Lazy<RwLock<HashMap<u32, &'static dyn AccelProvider>>> =
2046 Lazy::new(|| RwLock::new(HashMap::new()));
2047static DEVICE_ID_COUNTER: AtomicU32 = AtomicU32::new(1);
2048thread_local! {
2049 static THREAD_PROVIDER: Cell<Option<&'static dyn AccelProvider>> = Cell::new(None);
2050}
2051
2052pub unsafe fn register_provider(p: &'static dyn AccelProvider) {
2060 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2061 *guard = Some(p);
2062 }
2063 register_provider_for_device(p.device_id(), p);
2064}
2065
2066unsafe fn register_provider_for_device(device_id: u32, provider: &'static dyn AccelProvider) {
2067 if let Ok(mut guard) = PROVIDER_REGISTRY.write() {
2068 guard.insert(device_id, provider);
2069 }
2070}
2071
2072pub fn provider() -> Option<&'static dyn AccelProvider> {
2073 if let Some(p) = THREAD_PROVIDER.with(|cell| cell.get()) {
2074 return Some(p);
2075 }
2076 GLOBAL_PROVIDER
2077 .read()
2078 .ok()
2079 .and_then(|guard| guard.as_ref().copied())
2080}
2081
2082pub fn clear_provider() {
2084 if let Ok(mut guard) = GLOBAL_PROVIDER.write() {
2085 *guard = None;
2086 }
2087 if let Ok(mut map) = PROVIDER_REGISTRY.write() {
2088 map.clear();
2089 }
2090}
2091
2092pub fn provider_for_device(device_id: u32) -> Option<&'static dyn AccelProvider> {
2093 PROVIDER_REGISTRY
2094 .read()
2095 .ok()
2096 .and_then(|guard| guard.get(&device_id).copied())
2097 .or_else(|| provider())
2098}
2099
2100pub fn provider_for_handle(handle: &GpuTensorHandle) -> Option<&'static dyn AccelProvider> {
2101 provider_for_device(handle.device_id)
2102}
2103
2104pub fn next_device_id() -> u32 {
2105 DEVICE_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
2106}
2107
2108pub struct ThreadProviderGuard {
2109 prev: Option<&'static dyn AccelProvider>,
2110}
2111
2112impl ThreadProviderGuard {
2113 pub fn set(provider: Option<&'static dyn AccelProvider>) -> Self {
2114 let prev = THREAD_PROVIDER.with(|cell| {
2115 let old = cell.get();
2116 cell.set(provider);
2117 old
2118 });
2119 ThreadProviderGuard { prev }
2120 }
2121}
2122
2123impl Drop for ThreadProviderGuard {
2124 fn drop(&mut self) {
2125 let prev = self.prev.take();
2126 THREAD_PROVIDER.with(|cell| cell.set(prev));
2127 }
2128}
2129
2130pub fn set_thread_provider(provider: Option<&'static dyn AccelProvider>) {
2131 THREAD_PROVIDER.with(|cell| cell.set(provider));
2132}
2133
2134pub fn try_elem_add(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2136 if let Some(p) = provider() {
2137 if let Ok(h) = p.elem_add(a, b) {
2138 return Some(h);
2139 }
2140 }
2141 None
2142}
2143
2144pub fn try_elem_hypot(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2146 if let Some(p) = provider() {
2147 if let Ok(h) = p.elem_hypot(a, b) {
2148 return Some(h);
2149 }
2150 }
2151 None
2152}
2153
2154pub fn try_elem_max(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2156 if let Some(p) = provider() {
2157 if let Ok(h) = p.elem_max(a, b) {
2158 return Some(h);
2159 }
2160 }
2161 None
2162}
2163
2164pub fn try_elem_min(a: &GpuTensorHandle, b: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2166 if let Some(p) = provider() {
2167 if let Ok(h) = p.elem_min(a, b) {
2168 return Some(h);
2169 }
2170 }
2171 None
2172}
2173
2174pub fn try_elem_atan2(y: &GpuTensorHandle, x: &GpuTensorHandle) -> Option<GpuTensorHandle> {
2176 if let Some(p) = provider() {
2177 if let Ok(h) = p.elem_atan2(y, x) {
2178 return Some(h);
2179 }
2180 }
2181 None
2182}
2183
2184#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2186pub struct HostTensorOwned {
2187 pub data: Vec<f64>,
2188 pub shape: Vec<usize>,
2189}
2190
2191#[derive(Debug)]
2192pub struct HostTensorView<'a> {
2193 pub data: &'a [f64],
2194 pub shape: &'a [usize],
2195}
2196
2197#[derive(Debug)]
2199pub struct MeshgridAxisView<'a> {
2200 pub data: &'a [f64],
2201}
2202
2203#[derive(Debug, Clone)]
2205pub struct ProviderMeshgridResult {
2206 pub outputs: Vec<GpuTensorHandle>,
2207}
2208
2209#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
2215pub enum ScaleOp {
2216 Multiply,
2217 Divide,
2218}
2219
2220#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2221pub struct MatmulEpilogue {
2222 pub alpha: f64,
2224 pub beta: f64,
2226 pub row_scale: Option<GpuTensorHandle>,
2228 pub col_scale: Option<GpuTensorHandle>,
2230 pub row_op: ScaleOp,
2232 pub col_op: ScaleOp,
2234 #[serde(default)]
2236 pub clamp_min: Option<f64>,
2237 #[serde(default)]
2239 pub clamp_max: Option<f64>,
2240 #[serde(default)]
2242 pub pow_exponent: Option<f64>,
2243 #[serde(default)]
2245 pub diag_output: Option<GpuTensorHandle>,
2246}
2247
2248impl MatmulEpilogue {
2249 pub fn noop() -> Self {
2250 Self {
2251 alpha: 1.0,
2252 beta: 0.0,
2253 row_scale: None,
2254 col_scale: None,
2255 row_op: ScaleOp::Multiply,
2256 col_op: ScaleOp::Multiply,
2257 clamp_min: None,
2258 clamp_max: None,
2259 pow_exponent: None,
2260 diag_output: None,
2261 }
2262 }
2263 pub fn is_noop(&self) -> bool {
2264 self.alpha == 1.0
2265 && self.beta == 0.0
2266 && self.row_scale.is_none()
2267 && self.col_scale.is_none()
2268 && self.clamp_min.is_none()
2269 && self.clamp_max.is_none()
2270 && self.pow_exponent.is_none()
2271 && self.diag_output.is_none()
2272 }
2273}
2274
2275#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
2276pub struct PowerStepEpilogue {
2277 pub epsilon: f64,
2278}
2279
2280impl Default for PowerStepEpilogue {
2281 fn default() -> Self {
2282 Self { epsilon: 0.0 }
2283 }
2284}
2285
2286#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2287pub struct ImageNormalizeDescriptor {
2288 pub batch: usize,
2289 pub height: usize,
2290 pub width: usize,
2291 pub epsilon: f64,
2292 #[serde(default)]
2293 pub gain: Option<f64>,
2294 #[serde(default)]
2295 pub bias: Option<f64>,
2296 #[serde(default)]
2297 pub gamma: Option<f64>,
2298}