ariadnetor_tensor/
capability.rs1use ariadnetor_core::Scalar;
14use ariadnetor_core::backend::ComputeBackend;
15use ariadnetor_native::NativeBackend;
16
17use crate::{BlockSparseStorage, DenseStorage};
18
19pub trait OpsFor<St>: ComputeBackend {}
23
24impl<T: Scalar> OpsFor<DenseStorage<T>> for NativeBackend {}
25impl<T: Scalar> OpsFor<BlockSparseStorage<T>> for NativeBackend {}
26
27#[cfg(not(feature = "pluggability-litmus"))]
38pub type Host = NativeBackend;
39
40#[cfg(feature = "pluggability-litmus")]
44pub type Host = alt_host::AltHostBackend;
45
46#[cfg(feature = "pluggability-litmus")]
59mod alt_host {
60 use std::sync::atomic::{AtomicUsize, Ordering};
61 use std::sync::{Arc, OnceLock};
62
63 use ariadnetor_core::Scalar;
64 use ariadnetor_core::backend::{
65 BackendError, ComputeBackend, DeviceType, EigDescriptor, EighDescriptor, ExecPolicy,
66 GemmDescriptor, LqDescriptor, MemoryOrder, QrDescriptor, SolveDescriptor, SvdDescriptor,
67 TransposeDescriptor,
68 };
69 use ariadnetor_native::NativeBackend;
70
71 use crate::{BlockSparseStorage, DenseStorage, OpsFor};
72
73 pub struct AltHostBackend {
76 inner: NativeBackend,
77 kernel_calls: AtomicUsize,
78 }
79
80 impl AltHostBackend {
81 fn new() -> Self {
82 Self {
83 inner: NativeBackend::new(),
84 kernel_calls: AtomicUsize::new(0),
85 }
86 }
87
88 pub fn shared() -> Arc<AltHostBackend> {
91 static INSTANCE: OnceLock<Arc<AltHostBackend>> = OnceLock::new();
92 INSTANCE
93 .get_or_init(|| Arc::new(AltHostBackend::new()))
94 .clone()
95 }
96
97 pub fn count(&self) -> usize {
99 self.kernel_calls.load(Ordering::SeqCst)
100 }
101
102 fn bump(&self) {
103 self.kernel_calls.fetch_add(1, Ordering::SeqCst);
104 }
105 }
106
107 impl ComputeBackend for AltHostBackend {
108 fn name(&self) -> &'static str {
109 "alt-host"
110 }
111
112 fn device_type(&self) -> DeviceType {
113 self.inner.device_type()
114 }
115
116 fn preferred_order(&self) -> MemoryOrder {
117 self.inner.preferred_order()
118 }
119
120 fn gemm<T: Scalar>(&self, desc: GemmDescriptor<'_, T>) -> Result<(), BackendError> {
121 self.bump();
122 self.inner.gemm(desc)
123 }
124
125 fn transpose<T: Scalar>(
126 &self,
127 desc: TransposeDescriptor<'_, T>,
128 ) -> Result<(), BackendError> {
129 self.bump();
130 self.inner.transpose(desc)
131 }
132
133 fn svd<T: Scalar>(&self, desc: SvdDescriptor<'_, T>) -> Result<(), BackendError> {
134 self.bump();
135 self.inner.svd(desc)
136 }
137
138 fn qr<T: Scalar>(&self, desc: QrDescriptor<'_, T>) -> Result<(), BackendError> {
139 self.bump();
140 self.inner.qr(desc)
141 }
142
143 fn lq<T: Scalar>(&self, desc: LqDescriptor<'_, T>) -> Result<(), BackendError> {
144 self.bump();
145 self.inner.lq(desc)
146 }
147
148 fn eigh<T: Scalar>(&self, desc: EighDescriptor<'_, T>) -> Result<(), BackendError> {
149 self.bump();
150 self.inner.eigh(desc)
151 }
152
153 fn eig<T: Scalar>(&self, desc: EigDescriptor<'_, T>) -> Result<(), BackendError> {
154 self.bump();
155 self.inner.eig(desc)
156 }
157
158 fn solve<T: Scalar>(&self, desc: SolveDescriptor<'_, T>) -> Result<(), BackendError> {
159 self.bump();
160 self.inner.solve(desc)
161 }
162
163 fn par_for_svd(&self, m: usize, n: usize) -> ExecPolicy {
166 self.inner.par_for_svd(m, n)
167 }
168
169 fn par_for_qr(&self, m: usize, n: usize) -> ExecPolicy {
170 self.inner.par_for_qr(m, n)
171 }
172
173 fn par_for_lq(&self, m: usize, n: usize) -> ExecPolicy {
174 self.inner.par_for_lq(m, n)
175 }
176
177 fn par_for_eigh(&self, n: usize) -> ExecPolicy {
178 self.inner.par_for_eigh(n)
179 }
180
181 fn par_for_eig(&self, n: usize) -> ExecPolicy {
182 self.inner.par_for_eig(n)
183 }
184
185 fn par_for_gemm(&self, m: usize, n: usize, k: usize) -> ExecPolicy {
186 self.inner.par_for_gemm(m, n, k)
187 }
188
189 fn par_for_solve(&self, n: usize, nrhs: usize) -> ExecPolicy {
190 self.inner.par_for_solve(n, nrhs)
191 }
192
193 fn par_for_transpose(&self, shape: &[usize]) -> ExecPolicy {
194 self.inner.par_for_transpose(shape)
195 }
196 }
197
198 impl<T: Scalar> OpsFor<DenseStorage<T>> for AltHostBackend {}
203 impl<T: Scalar> OpsFor<BlockSparseStorage<T>> for AltHostBackend {}
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
214 fn native_and_host_declare_ops_for_both_storage_flavors() {
215 fn assert_ops_for<St, B: OpsFor<St>>() {}
216
217 assert_ops_for::<DenseStorage<f64>, NativeBackend>();
218 assert_ops_for::<BlockSparseStorage<f64>, NativeBackend>();
219 assert_ops_for::<DenseStorage<f64>, Host>();
220 assert_ops_for::<BlockSparseStorage<f64>, Host>();
221 }
222}