Skip to main content

adele_ring/
cpu.rs

1//! `CpuBackend` — the always-available backend, using rayon to parallelize over
2//! batch items (and, for single values, over channels above a threshold).
3
4use num_bigint::BigUint;
5use rayon::prelude::*;
6
7use crate::backend::ArithmeticBackend;
8use crate::batch::RnsBatch;
9use crate::rns::{garner_crt, gpu_mul_channel, RnsInt};
10use crate::RAYON_CHANNEL_THRESHOLD;
11
12/// CPU backend backed by a dedicated rayon thread pool.
13pub struct CpuBackend {
14    pool: rayon::ThreadPool,
15}
16
17impl Default for CpuBackend {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl CpuBackend {
24    /// Build a backend with a thread pool spanning all logical cores.
25    pub fn new() -> Self {
26        Self {
27            pool: rayon::ThreadPoolBuilder::new()
28                .num_threads(0) // 0 = use all logical cores
29                .thread_name(|i| format!("adele-ring-cpu-{i}"))
30                .build()
31                .expect("rayon pool init failed"),
32        }
33    }
34
35    /// Single-value addition: parallel over channels only when `k` is large.
36    ///
37    /// Below [`RAYON_CHANNEL_THRESHOLD`] channels, rayon's per-task overhead
38    /// (~50ns) exceeds the cost of a channel op (~1ns), so we stay sequential.
39    pub fn rns_add_single(&self, a: &RnsInt, b: &RnsInt) -> RnsInt {
40        let moduli = a.channels.moduli();
41        let k = a.channels.len();
42        let residues: Vec<u64> = if k >= RAYON_CHANNEL_THRESHOLD {
43            self.pool.install(|| {
44                a.residues
45                    .par_iter()
46                    .zip(b.residues.par_iter())
47                    .zip(moduli.par_iter())
48                    .map(|((&av, &bv), &m)| (av + bv) % m)
49                    .collect()
50            })
51        } else {
52            a.residues
53                .iter()
54                .zip(b.residues.iter())
55                .zip(moduli.iter())
56                .map(|((&av, &bv), &m)| (av + bv) % m)
57                .collect()
58        };
59        RnsInt::from_residues(residues, a.channels.clone())
60    }
61}
62
63impl ArithmeticBackend for CpuBackend {
64    fn batch_rns_add(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
65        let k = a.channels.len();
66        let moduli = a.channels.moduli();
67        let mut result = RnsBatch::zeros(a.batch_size, a.channels.clone());
68        self.pool.install(|| {
69            result
70                .data
71                .par_chunks_mut(k)
72                .enumerate()
73                .for_each(|(b_idx, out_row)| {
74                    let base = b_idx * k;
75                    for c in 0..k {
76                        let av = a.data[base + c];
77                        let bv = b.data[base + c];
78                        out_row[c] = (av + bv) % moduli[c];
79                    }
80                });
81        });
82        result
83    }
84
85    fn batch_rns_mul(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
86        let k = a.channels.len();
87        let moduli = a.channels.moduli();
88        let mut result = RnsBatch::zeros(a.batch_size, a.channels.clone());
89        self.pool.install(|| {
90            result
91                .data
92                .par_chunks_mut(k)
93                .enumerate()
94                .for_each(|(b_idx, out_row)| {
95                    let base = b_idx * k;
96                    for c in 0..k {
97                        let av = a.data[base + c];
98                        let bv = b.data[base + c];
99                        out_row[c] = gpu_mul_channel(av, bv, moduli[c]);
100                    }
101                });
102        });
103        result
104    }
105
106    fn batch_crt(&self, batch: &RnsBatch) -> Vec<BigUint> {
107        let k = batch.channels.len();
108        let moduli = batch.channels.moduli();
109        self.pool.install(|| {
110            (0..batch.batch_size)
111                .into_par_iter()
112                .map(|b| {
113                    let residues = &batch.data[b * k..(b + 1) * k];
114                    garner_crt(residues, moduli)
115                })
116                .collect()
117        })
118    }
119
120    fn name(&self) -> &'static str {
121        "cpu-rayon"
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use crate::rns::Channels;
129
130    #[test]
131    fn batch_add_matches_scalar() {
132        let ch = Channels::standard(32);
133        let a = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(123, ch.clone()); 64]);
134        let b = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(456, ch.clone()); 64]);
135        let cpu = CpuBackend::new();
136        let sum = cpu.batch_rns_add(&a, &b);
137        for item in sum.to_rns_ints() {
138            assert_eq!(item.to_bigint(), num_bigint::BigInt::from(579));
139        }
140    }
141
142    #[test]
143    fn batch_mul_matches_scalar() {
144        let ch = Channels::standard(32);
145        let a = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(123, ch.clone()); 64]);
146        let b = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(456, ch.clone()); 64]);
147        let cpu = CpuBackend::new();
148        let prod = cpu.batch_rns_mul(&a, &b);
149        for item in prod.to_rns_ints() {
150            assert_eq!(item.to_bigint(), num_bigint::BigInt::from(123 * 456));
151        }
152    }
153
154    #[test]
155    fn single_add() {
156        let ch = Channels::standard(32);
157        let cpu = CpuBackend::new();
158        let r = cpu.rns_add_single(&RnsInt::from_i64(10, ch.clone()), &RnsInt::from_i64(32, ch));
159        assert_eq!(r.to_bigint(), num_bigint::BigInt::from(42));
160    }
161}