1use std::sync::{Arc, OnceLock};
6
7use num_bigint::BigUint;
8
9use crate::batch::RnsBatch;
10use crate::cpu::CpuBackend;
11use crate::gpu::GpuBackend;
12
13pub trait ArithmeticBackend: Send + Sync {
15 fn batch_rns_add(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch;
17
18 fn batch_rns_mul(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch;
20
21 fn batch_crt(&self, batch: &RnsBatch) -> Vec<BigUint>;
23
24 fn name(&self) -> &'static str;
26}
27
28pub struct Executor {
30 cpu: Arc<CpuBackend>,
31 gpu: Option<Arc<GpuBackend>>,
32 pub gpu_threshold: usize,
36}
37
38impl Executor {
39 pub fn init() -> Self {
41 let cpu = Arc::new(CpuBackend::new());
42 let gpu = GpuBackend::try_init().ok().map(Arc::new);
43 match &gpu {
44 Some(g) => log::info!("adele-ring: GPU backend active ({})", g.adapter_name()),
45 None => log::info!("adele-ring: no GPU found, using CPU backend"),
46 }
47 Self {
48 cpu,
49 gpu,
50 gpu_threshold: 128,
51 }
52 }
53
54 pub fn has_gpu(&self) -> bool {
56 self.gpu.is_some()
57 }
58
59 pub fn cpu(&self) -> &CpuBackend {
61 &self.cpu
62 }
63
64 pub fn gpu(&self) -> Option<&GpuBackend> {
66 self.gpu.as_deref()
67 }
68
69 fn select(&self, batch_size: usize) -> &dyn ArithmeticBackend {
71 match &self.gpu {
72 Some(g) if batch_size >= self.gpu_threshold => g.as_ref(),
73 _ => self.cpu.as_ref(),
74 }
75 }
76
77 pub fn add(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
79 self.select(a.batch_size).batch_rns_add(a, b)
80 }
81
82 pub fn mul(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
84 self.select(a.batch_size).batch_rns_mul(a, b)
85 }
86
87 pub fn crt(&self, batch: &RnsBatch) -> Vec<BigUint> {
89 self.cpu.batch_crt(batch)
90 }
91}
92
93impl Default for Executor {
94 fn default() -> Self {
95 Self::init()
96 }
97}
98
99static EXECUTOR: OnceLock<Executor> = OnceLock::new();
100
101pub fn executor() -> &'static Executor {
103 EXECUTOR.get_or_init(Executor::init)
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::rns::{Channels, RnsInt};
110
111 #[test]
112 fn executor_adds_batches() {
113 let ch = Channels::standard(32);
114 let a = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(7, ch.clone()); 200]);
115 let b = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(35, ch.clone()); 200]);
116 let sum = executor().add(&a, &b);
117 for item in sum.to_rns_ints() {
118 assert_eq!(item.to_bigint(), num_bigint::BigInt::from(42));
119 }
120 }
121
122 #[test]
123 fn cpu_gpu_identical() {
124 let exec = Executor::init();
125 if !exec.has_gpu() {
126 return;
127 }
128 let ch = Channels::standard(32);
129 let a = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(123, ch.clone()); 256]);
130 let b = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(456, ch.clone()); 256]);
131 let cpu = exec.cpu().batch_rns_add(&a, &b);
132 let gpu = exec.gpu().unwrap().batch_rns_add(&a, &b);
133 assert_eq!(cpu.data, gpu.data);
134 }
135}