concrete_shortint/engine/
mod.rs1use crate::ServerKey;
2use concrete_core::prelude::*;
3use std::cell::RefCell;
4use std::collections::BTreeMap;
5use std::fmt::Debug;
6
7mod client_side;
8mod server_side;
9mod wopbs;
10
11thread_local! {
12 static LOCAL_ENGINE: RefCell<ShortintEngine> = RefCell::new(ShortintEngine::new());
13}
14
15fn new_seeder() -> Box<dyn Seeder> {
16 let seeder: Box<dyn Seeder>;
17 #[cfg(target_arch = "x86_64")]
18 {
19 if RdseedSeeder::is_available() {
20 seeder = Box::new(RdseedSeeder);
21 } else {
22 seeder = Box::new(UnixSeeder::new(0));
23 }
24 }
25 #[cfg(not(target_arch = "x86_64"))]
26 {
27 seeder = Box::new(UnixSeeder::new(0));
28 }
29
30 seeder
31}
32
33pub struct Buffers {
35 pub(crate) accumulator: GlweCiphertext64,
36 pub(crate) buffer_lwe_after_ks: LweCiphertext64,
37}
38
39#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)]
42struct KeyId {
43 accumulator_dim: GlweSize,
44 lwe_dim_after_pbs: usize,
45 glwe_size: GlweSize,
46 poly_size: PolynomialSize,
47}
48
49impl ServerKey {
50 #[inline]
51 fn key_id(&self) -> KeyId {
52 KeyId {
53 accumulator_dim: self.bootstrapping_key.glwe_dimension().to_glwe_size(),
54 lwe_dim_after_pbs: self.bootstrapping_key.output_lwe_dimension().0,
55 glwe_size: self.bootstrapping_key.glwe_dimension().to_glwe_size(),
56 poly_size: self.bootstrapping_key.polynomial_size(),
57 }
58 }
59}
60
61#[allow(dead_code)]
64#[derive(Debug)]
65pub struct EngineError {
66 error: Box<dyn std::error::Error>,
67}
68
69impl<T> From<T> for EngineError
70where
71 T: std::error::Error + 'static,
72{
73 fn from(error: T) -> Self {
74 Self {
75 error: Box::new(error),
76 }
77 }
78}
79
80pub(crate) type EngineResult<T> = Result<T, EngineError>;
81
82pub struct ShortintEngine {
89 pub(crate) engine: DefaultEngine,
90 pub(crate) fft_engine: FftEngine,
91 pub(crate) par_engine: DefaultParallelEngine,
92 buffers: BTreeMap<KeyId, Buffers>,
93}
94
95impl ShortintEngine {
96 #[inline]
99 pub fn with_thread_local_mut<F, R>(func: F) -> R
100 where
101 F: FnOnce(&mut Self) -> R,
102 {
103 LOCAL_ENGINE.with(|engine_cell| func(&mut engine_cell.borrow_mut()))
104 }
105
106 fn new() -> Self {
117 let engine = DefaultEngine::new(new_seeder()).expect("Failed to create a DefaultEngine");
118 let par_engine = DefaultParallelEngine::new(new_seeder())
119 .expect("Failed to create a DefaultParallelEngine");
120 let fft_engine = FftEngine::new(()).unwrap();
121 Self {
122 engine,
123 fft_engine,
125 par_engine,
126 buffers: Default::default(),
127 }
128 }
129
130 fn generate_accumulator_with_engine<F>(
131 engine: &mut DefaultEngine,
132 server_key: &ServerKey,
133 f: F,
134 ) -> EngineResult<GlweCiphertext64>
135 where
136 F: Fn(u64) -> u64,
137 {
138 let modulus_sup = server_key.message_modulus.0 * server_key.carry_modulus.0;
140
141 let box_size = server_key.bootstrapping_key.polynomial_size().0 / modulus_sup;
143
144 let delta =
146 (1_u64 << 63) / (server_key.message_modulus.0 * server_key.carry_modulus.0) as u64;
147
148 let mut accumulator_u64 = vec![0_u64; server_key.bootstrapping_key.polynomial_size().0];
150
151 for i in 0..modulus_sup {
153 let index = i as usize * box_size;
154 accumulator_u64[index..index + box_size]
155 .iter_mut()
156 .for_each(|a| *a = f(i as u64) * delta);
157 }
158
159 let half_box_size = box_size / 2;
160
161 for a_i in accumulator_u64[0..half_box_size].iter_mut() {
163 *a_i = (*a_i).wrapping_neg();
164 }
165
166 accumulator_u64.rotate_left(half_box_size);
168
169 let accumulator_plaintext = engine.create_plaintext_vector_from(&accumulator_u64)?;
171
172 let accumulator = engine.trivially_encrypt_glwe_ciphertext(
173 server_key.bootstrapping_key.glwe_dimension().to_glwe_size(),
174 &accumulator_plaintext,
175 )?;
176
177 Ok(accumulator)
178 }
179
180 pub fn buffers_for_key(
188 &mut self,
189 server_key: &ServerKey,
190 ) -> (&mut Buffers, &mut DefaultEngine, &mut FftEngine) {
191 let key = server_key.key_id();
192 let engine = &mut self.engine;
194 let buffers_map = &mut self.buffers;
195 let buffers = buffers_map.entry(key).or_insert_with(|| {
196 let accumulator = Self::generate_accumulator_with_engine(engine, server_key, |n| {
197 n % server_key.message_modulus.0 as u64
198 })
199 .unwrap();
200
201 let zero_plaintext = engine.create_plaintext_from(&0_u64).unwrap();
203 let buffer_lwe_after_pbs = engine
204 .trivially_encrypt_lwe_ciphertext(
205 server_key
206 .key_switching_key
207 .output_lwe_dimension()
208 .to_lwe_size(),
209 &zero_plaintext,
210 )
211 .unwrap();
212
213 Buffers {
214 accumulator,
215 buffer_lwe_after_ks: buffer_lwe_after_pbs,
216 }
217 });
218
219 (buffers, engine, &mut self.fft_engine)
220 }
221}