use std::sync::{Arc, OnceLock};
use num_bigint::BigUint;
use crate::batch::RnsBatch;
use crate::cpu::CpuBackend;
use crate::gpu::GpuBackend;
pub trait ArithmeticBackend: Send + Sync {
fn batch_rns_add(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch;
fn batch_rns_mul(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch;
fn batch_crt(&self, batch: &RnsBatch) -> Vec<BigUint>;
fn name(&self) -> &'static str;
}
pub struct Executor {
cpu: Arc<CpuBackend>,
gpu: Option<Arc<GpuBackend>>,
pub gpu_threshold: usize,
}
impl Executor {
pub fn init() -> Self {
let cpu = Arc::new(CpuBackend::new());
let gpu = GpuBackend::try_init().ok().map(Arc::new);
match &gpu {
Some(g) => log::info!("adele-ring: GPU backend active ({})", g.adapter_name()),
None => log::info!("adele-ring: no GPU found, using CPU backend"),
}
Self {
cpu,
gpu,
gpu_threshold: 128,
}
}
pub fn has_gpu(&self) -> bool {
self.gpu.is_some()
}
pub fn cpu(&self) -> &CpuBackend {
&self.cpu
}
pub fn gpu(&self) -> Option<&GpuBackend> {
self.gpu.as_deref()
}
fn select(&self, batch_size: usize) -> &dyn ArithmeticBackend {
match &self.gpu {
Some(g) if batch_size >= self.gpu_threshold => g.as_ref(),
_ => self.cpu.as_ref(),
}
}
pub fn add(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
self.select(a.batch_size).batch_rns_add(a, b)
}
pub fn mul(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
self.select(a.batch_size).batch_rns_mul(a, b)
}
pub fn crt(&self, batch: &RnsBatch) -> Vec<BigUint> {
self.cpu.batch_crt(batch)
}
}
impl Default for Executor {
fn default() -> Self {
Self::init()
}
}
static EXECUTOR: OnceLock<Executor> = OnceLock::new();
pub fn executor() -> &'static Executor {
EXECUTOR.get_or_init(Executor::init)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rns::{Channels, RnsInt};
#[test]
fn executor_adds_batches() {
let ch = Channels::standard(32);
let a = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(7, ch.clone()); 200]);
let b = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(35, ch.clone()); 200]);
let sum = executor().add(&a, &b);
for item in sum.to_rns_ints() {
assert_eq!(item.to_bigint(), num_bigint::BigInt::from(42));
}
}
#[test]
fn cpu_gpu_identical() {
let exec = Executor::init();
if !exec.has_gpu() {
return;
}
let ch = Channels::standard(32);
let a = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(123, ch.clone()); 256]);
let b = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(456, ch.clone()); 256]);
let cpu = exec.cpu().batch_rns_add(&a, &b);
let gpu = exec.gpu().unwrap().batch_rns_add(&a, &b);
assert_eq!(cpu.data, gpu.data);
}
}