1use num_bigint::BigUint;
5use rayon::prelude::*;
6
7use crate::backend::ArithmeticBackend;
8use crate::batch::RnsBatch;
9use crate::rns::{garner_crt, gpu_mul_channel, RnsInt};
10use crate::RAYON_CHANNEL_THRESHOLD;
11
12pub struct CpuBackend {
14 pool: rayon::ThreadPool,
15}
16
17impl Default for CpuBackend {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl CpuBackend {
24 pub fn new() -> Self {
26 Self {
27 pool: rayon::ThreadPoolBuilder::new()
28 .num_threads(0) .thread_name(|i| format!("adele-ring-cpu-{i}"))
30 .build()
31 .expect("rayon pool init failed"),
32 }
33 }
34
35 pub fn rns_add_single(&self, a: &RnsInt, b: &RnsInt) -> RnsInt {
40 let moduli = a.channels.moduli();
41 let k = a.channels.len();
42 let residues: Vec<u64> = if k >= RAYON_CHANNEL_THRESHOLD {
43 self.pool.install(|| {
44 a.residues
45 .par_iter()
46 .zip(b.residues.par_iter())
47 .zip(moduli.par_iter())
48 .map(|((&av, &bv), &m)| (av + bv) % m)
49 .collect()
50 })
51 } else {
52 a.residues
53 .iter()
54 .zip(b.residues.iter())
55 .zip(moduli.iter())
56 .map(|((&av, &bv), &m)| (av + bv) % m)
57 .collect()
58 };
59 RnsInt::from_residues(residues, a.channels.clone())
60 }
61}
62
63impl ArithmeticBackend for CpuBackend {
64 fn batch_rns_add(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
65 let k = a.channels.len();
66 let moduli = a.channels.moduli();
67 let mut result = RnsBatch::zeros(a.batch_size, a.channels.clone());
68 self.pool.install(|| {
69 result
70 .data
71 .par_chunks_mut(k)
72 .enumerate()
73 .for_each(|(b_idx, out_row)| {
74 let base = b_idx * k;
75 for c in 0..k {
76 let av = a.data[base + c];
77 let bv = b.data[base + c];
78 out_row[c] = (av + bv) % moduli[c];
79 }
80 });
81 });
82 result
83 }
84
85 fn batch_rns_mul(&self, a: &RnsBatch, b: &RnsBatch) -> RnsBatch {
86 let k = a.channels.len();
87 let moduli = a.channels.moduli();
88 let mut result = RnsBatch::zeros(a.batch_size, a.channels.clone());
89 self.pool.install(|| {
90 result
91 .data
92 .par_chunks_mut(k)
93 .enumerate()
94 .for_each(|(b_idx, out_row)| {
95 let base = b_idx * k;
96 for c in 0..k {
97 let av = a.data[base + c];
98 let bv = b.data[base + c];
99 out_row[c] = gpu_mul_channel(av, bv, moduli[c]);
100 }
101 });
102 });
103 result
104 }
105
106 fn batch_crt(&self, batch: &RnsBatch) -> Vec<BigUint> {
107 let k = batch.channels.len();
108 let moduli = batch.channels.moduli();
109 self.pool.install(|| {
110 (0..batch.batch_size)
111 .into_par_iter()
112 .map(|b| {
113 let residues = &batch.data[b * k..(b + 1) * k];
114 garner_crt(residues, moduli)
115 })
116 .collect()
117 })
118 }
119
120 fn name(&self) -> &'static str {
121 "cpu-rayon"
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::rns::Channels;
129
130 #[test]
131 fn batch_add_matches_scalar() {
132 let ch = Channels::standard(32);
133 let a = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(123, ch.clone()); 64]);
134 let b = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(456, ch.clone()); 64]);
135 let cpu = CpuBackend::new();
136 let sum = cpu.batch_rns_add(&a, &b);
137 for item in sum.to_rns_ints() {
138 assert_eq!(item.to_bigint(), num_bigint::BigInt::from(579));
139 }
140 }
141
142 #[test]
143 fn batch_mul_matches_scalar() {
144 let ch = Channels::standard(32);
145 let a = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(123, ch.clone()); 64]);
146 let b = RnsBatch::from_rns_ints(&vec![RnsInt::from_i64(456, ch.clone()); 64]);
147 let cpu = CpuBackend::new();
148 let prod = cpu.batch_rns_mul(&a, &b);
149 for item in prod.to_rns_ints() {
150 assert_eq!(item.to_bigint(), num_bigint::BigInt::from(123 * 456));
151 }
152 }
153
154 #[test]
155 fn single_add() {
156 let ch = Channels::standard(32);
157 let cpu = CpuBackend::new();
158 let r = cpu.rns_add_single(&RnsInt::from_i64(10, ch.clone()), &RnsInt::from_i64(32, ch));
159 assert_eq!(r.to_bigint(), num_bigint::BigInt::from(42));
160 }
161}