Skip to main content

adele_ring/
backend.rs

1//! The [`ArithmeticBackend`] trait and the [`Executor`] that selects between CPU
2//! and GPU at runtime. All batch math in the crate flows through the Executor —
3//! it never hard-codes a backend.
4
5use std::sync::{Arc, OnceLock};
6
7use num_bigint::BigUint;
8
9use crate::batch::RnsBatch;
10use crate::cpu::CpuBackend;
11use crate::gpu::GpuBackend;
12
13/// A backend that can perform elementwise RNS arithmetic over a [`RnsBatch`].
14pub trait ArithmeticBackend: Send + Sync {
15    /// Elementwise add: `result[b][c] = (a[b][c] + b[b][c]) % m[c]`.
16    fn batch_rns_add(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch;
17
18    /// Elementwise multiply: `result[b][c] = (a[b][c] * b[b][c]) % m[c]`.
19    fn batch_rns_mul(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch;
20
21    /// CRT-reconstruct every item in the batch.
22    fn batch_crt(&self, batch: &RnsBatch) -> Vec<BigUint>;
23
24    /// Backend name for diagnostics.
25    fn name(&self) -> &'static str;
26}
27
28/// Runtime dispatcher between the CPU and (optional) GPU backends.
29pub struct Executor {
30    cpu: Arc<CpuBackend>,
31    gpu: Option<Arc<GpuBackend>>,
32    /// Batches smaller than this use the CPU even when a GPU is present, because
33    /// the GPU's upload/dispatch/download round-trip (~100µs) dominates for small
34    /// inputs. Public so callers can tune it for their hardware.
35    pub gpu_threshold: usize,
36}
37
38impl Executor {
39    /// Probe for a GPU and build the executor. This is the only constructor.
40    pub fn init() -> Self {
41        let cpu = Arc::new(CpuBackend::new());
42        let gpu = GpuBackend::try_init().ok().map(Arc::new);
43        match &gpu {
44            Some(g) => log::info!("adele-ring: GPU backend active ({})", g.adapter_name()),
45            None => log::info!("adele-ring: no GPU found, using CPU backend"),
46        }
47        Self {
48            cpu,
49            gpu,
50            gpu_threshold: 128,
51        }
52    }
53
54    /// Whether a GPU backend is available.
55    pub fn has_gpu(&self) -> bool {
56        self.gpu.is_some()
57    }
58
59    /// Borrow the CPU backend directly (used by benchmarks).
60    pub fn cpu(&self) -> &CpuBackend {
61        &self.cpu
62    }
63
64    /// Borrow the GPU backend if present (used by benchmarks).
65    pub fn gpu(&self) -> Option<&GpuBackend> {
66        self.gpu.as_deref()
67    }
68
69    /// Pick the best backend for a given batch size.
70    fn select(&self, batch_size: usize) -> &dyn ArithmeticBackend {
71        match &self.gpu {
72            Some(g) if batch_size >= self.gpu_threshold => g.as_ref(),
73            _ => self.cpu.as_ref(),
74        }
75    }
76
77    /// Elementwise batch addition.
78    pub fn add(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
79        self.select(a.batch_size).batch_rns_add(a, b)
80    }
81
82    /// Elementwise batch multiplication.
83    pub fn mul(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
84        self.select(a.batch_size).batch_rns_mul(a, b)
85    }
86
87    /// CRT reconstruction — always CPU-side (Garner's algorithm is sequential).
88    pub fn crt(&self, batch: &RnsBatch) -> Vec<BigUint> {
89        self.cpu.batch_crt(batch)
90    }
91}
92
93impl Default for Executor {
94    fn default() -> Self {
95        Self::init()
96    }
97}
98
99static EXECUTOR: OnceLock<Executor> = OnceLock::new();
100
101/// The lazily-initialized, crate-wide executor.
102pub fn executor() -> &'static Executor {
103    EXECUTOR.get_or_init(Executor::init)
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::rns::{Channels, RnsInt};
110
111    #[test]
112    fn executor_adds_batches() {
113        let ch = Channels::standard(32);
114        let a = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(7, ch.clone()); 200]);
115        let b = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(35, ch.clone()); 200]);
116        let sum = executor().add(&a, &b);
117        for item in sum.to_rns_ints() {
118            assert_eq!(item.to_bigint(), num_bigint::BigInt::from(42));
119        }
120    }
121
122    #[test]
123    fn cpu_gpu_identical() {
124        let exec = Executor::init();
125        if !exec.has_gpu() {
126            return;
127        }
128        let ch = Channels::standard(32);
129        let a = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(123, ch.clone()); 256]);
130        let b = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(456, ch.clone()); 256]);
131        let cpu = exec.cpu().batch_rns_add(&a, &b);
132        let gpu = exec.gpu().unwrap().batch_rns_add(&a, &b);
133        assert_eq!(cpu.data, gpu.data);
134    }
135}