Skip to main content

ariadnetor_tensor/
capability.rs

1//! Backend-capability scaffolding.
2//!
3//! [`OpsFor<St>`] is the compile-time half of capability dispatch: a
4//! backend implements it for each storage flavor whose operations it
5//! actually supports — the Kokkos `SpaceAccessibility` analogue. It is
6//! deliberately not sealed, so out-of-tree backends (e.g. a future GPU
7//! backend) can declare their own capability by implementing it.
8//!
9//! [`Host`] aliases the default host backend, so signatures can name the
10//! substrate through one stable alias instead of spelling the concrete
11//! backend type; repointing the substrate is then a one-line change.
12
13use ariadnetor_core::Scalar;
14use ariadnetor_core::backend::ComputeBackend;
15use ariadnetor_native::NativeBackend;
16
17use crate::{BlockSparseStorage, DenseStorage};
18
19/// Compile-time marker: backend `Self` supports operations on storage
20/// flavor `St`. Implemented selectively per (backend, storage) pair, so a
21/// backend that cannot operate on a given storage simply omits the impl.
22pub trait OpsFor<St>: ComputeBackend {}
23
24impl<T: Scalar> OpsFor<DenseStorage<T>> for NativeBackend {}
25impl<T: Scalar> OpsFor<BlockSparseStorage<T>> for NativeBackend {}
26
27/// The default host compute substrate, aliased so signatures name it
28/// through one stable alias rather than spelling the concrete backend
29/// type; repointing the substrate is then a one-line change.
30///
31/// The `pluggability-litmus` feature exercises exactly that one-line
32/// repoint: it swaps the substrate to `AltHostBackend`, a distinct
33/// stateful backend, so the whole host-pinned surface is proven to hold
34/// against a substrate that is not the concrete native type. The litmus
35/// build is a standalone check (`cargo make litmus`), not part of the
36/// default gate.
37#[cfg(not(feature = "pluggability-litmus"))]
38pub type Host = NativeBackend;
39
40/// Litmus substrate: under `pluggability-litmus`, `Host` resolves to the
41/// alternate backend instead of `NativeBackend`, proving the substrate is
42/// swappable in one line.
43#[cfg(feature = "pluggability-litmus")]
44pub type Host = alt_host::AltHostBackend;
45
46/// Pluggability-litmus alternate host backend.
47///
48/// A distinct, stateful (non-zero-sized) backend that delegates every
49/// kernel to an inner [`NativeBackend`] while counting dispatches. Made
50/// the [`Host`] substrate by the feature-gated alias above, it proves
51/// the call-site-backend design holds against a substrate that is not
52/// the concrete native type: every `Host::shared()` call site, every
53/// `host_order()` constructor, and every `&Host` / `Arc<Host>` /
54/// `Host: OpsFor<…>` use type-checks and runs against this type. The
55/// dispatch counter lets a routing test observe that host-ergonomic
56/// paths actually reach the aliased substrate rather than a hard-coded
57/// native handle.
58#[cfg(feature = "pluggability-litmus")]
59mod alt_host {
60    use std::sync::atomic::{AtomicUsize, Ordering};
61    use std::sync::{Arc, OnceLock};
62
63    use ariadnetor_core::Scalar;
64    use ariadnetor_core::backend::{
65        BackendError, ComputeBackend, DeviceType, EigDescriptor, EighDescriptor, ExecPolicy,
66        GemmDescriptor, LqDescriptor, MemoryOrder, QrDescriptor, SolveDescriptor, SvdDescriptor,
67        TransposeDescriptor,
68    };
69    use ariadnetor_native::NativeBackend;
70
71    use crate::{BlockSparseStorage, DenseStorage, OpsFor};
72
73    /// See the module docs: a stateful native delegate that counts kernel
74    /// dispatches, used as the `Host` substrate under the litmus feature.
75    pub struct AltHostBackend {
76        inner: NativeBackend,
77        kernel_calls: AtomicUsize,
78    }
79
80    impl AltHostBackend {
81        fn new() -> Self {
82            Self {
83                inner: NativeBackend::new(),
84                kernel_calls: AtomicUsize::new(0),
85            }
86        }
87
88        /// Shared singleton, mirroring [`NativeBackend::shared`] so the
89        /// `Host::shared()` call sites resolve unchanged under the alias.
90        pub fn shared() -> Arc<AltHostBackend> {
91            static INSTANCE: OnceLock<Arc<AltHostBackend>> = OnceLock::new();
92            INSTANCE
93                .get_or_init(|| Arc::new(AltHostBackend::new()))
94                .clone()
95        }
96
97        /// Number of kernel dispatches routed through this backend.
98        pub fn count(&self) -> usize {
99            self.kernel_calls.load(Ordering::SeqCst)
100        }
101
102        fn bump(&self) {
103            self.kernel_calls.fetch_add(1, Ordering::SeqCst);
104        }
105    }
106
107    impl ComputeBackend for AltHostBackend {
108        fn name(&self) -> &'static str {
109            "alt-host"
110        }
111
112        fn device_type(&self) -> DeviceType {
113            self.inner.device_type()
114        }
115
116        fn preferred_order(&self) -> MemoryOrder {
117            self.inner.preferred_order()
118        }
119
120        fn gemm<T: Scalar>(&self, desc: GemmDescriptor<'_, T>) -> Result<(), BackendError> {
121            self.bump();
122            self.inner.gemm(desc)
123        }
124
125        fn transpose<T: Scalar>(
126            &self,
127            desc: TransposeDescriptor<'_, T>,
128        ) -> Result<(), BackendError> {
129            self.bump();
130            self.inner.transpose(desc)
131        }
132
133        fn svd<T: Scalar>(&self, desc: SvdDescriptor<'_, T>) -> Result<(), BackendError> {
134            self.bump();
135            self.inner.svd(desc)
136        }
137
138        fn qr<T: Scalar>(&self, desc: QrDescriptor<'_, T>) -> Result<(), BackendError> {
139            self.bump();
140            self.inner.qr(desc)
141        }
142
143        fn lq<T: Scalar>(&self, desc: LqDescriptor<'_, T>) -> Result<(), BackendError> {
144            self.bump();
145            self.inner.lq(desc)
146        }
147
148        fn eigh<T: Scalar>(&self, desc: EighDescriptor<'_, T>) -> Result<(), BackendError> {
149            self.bump();
150            self.inner.eigh(desc)
151        }
152
153        fn eig<T: Scalar>(&self, desc: EigDescriptor<'_, T>) -> Result<(), BackendError> {
154            self.bump();
155            self.inner.eig(desc)
156        }
157
158        fn solve<T: Scalar>(&self, desc: SolveDescriptor<'_, T>) -> Result<(), BackendError> {
159            self.bump();
160            self.inner.solve(desc)
161        }
162
163        // Policy hooks carry no compute authority but must mirror the native
164        // substrate's hardware-aware thresholds, so delegate all eight.
165        fn par_for_svd(&self, m: usize, n: usize) -> ExecPolicy {
166            self.inner.par_for_svd(m, n)
167        }
168
169        fn par_for_qr(&self, m: usize, n: usize) -> ExecPolicy {
170            self.inner.par_for_qr(m, n)
171        }
172
173        fn par_for_lq(&self, m: usize, n: usize) -> ExecPolicy {
174            self.inner.par_for_lq(m, n)
175        }
176
177        fn par_for_eigh(&self, n: usize) -> ExecPolicy {
178            self.inner.par_for_eigh(n)
179        }
180
181        fn par_for_eig(&self, n: usize) -> ExecPolicy {
182            self.inner.par_for_eig(n)
183        }
184
185        fn par_for_gemm(&self, m: usize, n: usize, k: usize) -> ExecPolicy {
186            self.inner.par_for_gemm(m, n, k)
187        }
188
189        fn par_for_solve(&self, n: usize, nrhs: usize) -> ExecPolicy {
190            self.inner.par_for_solve(n, nrhs)
191        }
192
193        fn par_for_transpose(&self, shape: &[usize]) -> ExecPolicy {
194            self.inner.par_for_transpose(shape)
195        }
196    }
197
198    // Delegating every kernel to a full `NativeBackend`, the litmus backend
199    // genuinely supports both storage flavors and declares the capability
200    // exactly as an out-of-tree backend would (`OpsFor` is unsealed), so it
201    // can stand in for `Host` on the `OpsFor`-gated public surface.
202    impl<T: Scalar> OpsFor<DenseStorage<T>> for AltHostBackend {}
203    impl<T: Scalar> OpsFor<BlockSparseStorage<T>> for AltHostBackend {}
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    /// `OpsFor` is a marker, so the meaningful assertion is a compile-time
211    /// bound check: this fails to compile if an impl is missing or if
212    /// `Host` stops resolving to a backend that satisfies it.
213    #[test]
214    fn native_and_host_declare_ops_for_both_storage_flavors() {
215        fn assert_ops_for<St, B: OpsFor<St>>() {}
216
217        assert_ops_for::<DenseStorage<f64>, NativeBackend>();
218        assert_ops_for::<BlockSparseStorage<f64>, NativeBackend>();
219        assert_ops_for::<DenseStorage<f64>, Host>();
220        assert_ops_for::<BlockSparseStorage<f64>, Host>();
221    }
222}