concrete_shortint/engine/
mod.rs

1use crate::ServerKey;
2use concrete_core::prelude::*;
3use std::cell::RefCell;
4use std::collections::BTreeMap;
5use std::fmt::Debug;
6
7mod client_side;
8mod server_side;
9mod wopbs;
10
11thread_local! {
12    static LOCAL_ENGINE: RefCell<ShortintEngine> = RefCell::new(ShortintEngine::new());
13}
14
15fn new_seeder() -> Box<dyn Seeder> {
16    let seeder: Box<dyn Seeder>;
17    #[cfg(target_arch = "x86_64")]
18    {
19        if RdseedSeeder::is_available() {
20            seeder = Box::new(RdseedSeeder);
21        } else {
22            seeder = Box::new(UnixSeeder::new(0));
23        }
24    }
25    #[cfg(not(target_arch = "x86_64"))]
26    {
27        seeder = Box::new(UnixSeeder::new(0));
28    }
29
30    seeder
31}
32
33/// Stores buffers associated to a ServerKey
34pub struct Buffers {
35    pub(crate) accumulator: GlweCiphertext64,
36    pub(crate) buffer_lwe_after_ks: LweCiphertext64,
37}
38
39/// This allows to store and retrieve the `Buffers`
40/// corresponding to a `ServerKey` in a `BTreeMap`
41#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)]
42struct KeyId {
43    accumulator_dim: GlweSize,
44    lwe_dim_after_pbs: usize,
45    glwe_size: GlweSize,
46    poly_size: PolynomialSize,
47}
48
49impl ServerKey {
50    #[inline]
51    fn key_id(&self) -> KeyId {
52        KeyId {
53            accumulator_dim: self.bootstrapping_key.glwe_dimension().to_glwe_size(),
54            lwe_dim_after_pbs: self.bootstrapping_key.output_lwe_dimension().0,
55            glwe_size: self.bootstrapping_key.glwe_dimension().to_glwe_size(),
56            poly_size: self.bootstrapping_key.polynomial_size(),
57        }
58    }
59}
60
61/// Simple wrapper around `std::error::Error` to be able to
62/// forward all the possible `EngineError` type from `concrete-core`
63#[allow(dead_code)]
64#[derive(Debug)]
65pub struct EngineError {
66    error: Box<dyn std::error::Error>,
67}
68
69impl<T> From<T> for EngineError
70where
71    T: std::error::Error + 'static,
72{
73    fn from(error: T) -> Self {
74        Self {
75            error: Box::new(error),
76        }
77    }
78}
79
80pub(crate) type EngineResult<T> = Result<T, EngineError>;
81
82/// ShortintEngine
83///
84/// This 'engine' holds the necessary engines from `concrete-core`
85/// as well as the buffers that we want to keep around to save processing time.
86///
87/// This structs actually implements the logics into its methods.
88pub struct ShortintEngine {
89    pub(crate) engine: DefaultEngine,
90    pub(crate) fft_engine: FftEngine,
91    pub(crate) par_engine: DefaultParallelEngine,
92    buffers: BTreeMap<KeyId, Buffers>,
93}
94
95impl ShortintEngine {
96    /// Safely gives access to the `thead_local` shortint engine
97    /// to call one (or many) of its method.
98    #[inline]
99    pub fn with_thread_local_mut<F, R>(func: F) -> R
100    where
101        F: FnOnce(&mut Self) -> R,
102    {
103        LOCAL_ENGINE.with(|engine_cell| func(&mut engine_cell.borrow_mut()))
104    }
105
106    /// Creates a new shortint engine
107    ///
108    /// Creating a `ShortintEngine` should not be needed, as each
109    /// rust thread gets its own `thread_local` engine created automatically,
110    /// see [ShortintEngine::with_thread_local_mut]
111    ///
112    ///
113    /// # Panics
114    ///
115    /// This will panic if the `CoreEngine` failed to create.
116    fn new() -> Self {
117        let engine = DefaultEngine::new(new_seeder()).expect("Failed to create a DefaultEngine");
118        let par_engine = DefaultParallelEngine::new(new_seeder())
119            .expect("Failed to create a DefaultParallelEngine");
120        let fft_engine = FftEngine::new(()).unwrap();
121        Self {
122            engine,
123            // fftw_engine,
124            fft_engine,
125            par_engine,
126            buffers: Default::default(),
127        }
128    }
129
130    fn generate_accumulator_with_engine<F>(
131        engine: &mut DefaultEngine,
132        server_key: &ServerKey,
133        f: F,
134    ) -> EngineResult<GlweCiphertext64>
135    where
136        F: Fn(u64) -> u64,
137    {
138        // Modulus of the msg contained in the msg bits and operations buffer
139        let modulus_sup = server_key.message_modulus.0 * server_key.carry_modulus.0;
140
141        // N/(p/2) = size of each block
142        let box_size = server_key.bootstrapping_key.polynomial_size().0 / modulus_sup;
143
144        // Value of the shift we multiply our messages by
145        let delta =
146            (1_u64 << 63) / (server_key.message_modulus.0 * server_key.carry_modulus.0) as u64;
147
148        // Create the accumulator
149        let mut accumulator_u64 = vec![0_u64; server_key.bootstrapping_key.polynomial_size().0];
150
151        // This accumulator extracts the carry bits
152        for i in 0..modulus_sup {
153            let index = i as usize * box_size;
154            accumulator_u64[index..index + box_size]
155                .iter_mut()
156                .for_each(|a| *a = f(i as u64) * delta);
157        }
158
159        let half_box_size = box_size / 2;
160
161        // Negate the first half_box_size coefficients
162        for a_i in accumulator_u64[0..half_box_size].iter_mut() {
163            *a_i = (*a_i).wrapping_neg();
164        }
165
166        // Rotate the accumulator
167        accumulator_u64.rotate_left(half_box_size);
168
169        // Everywhere
170        let accumulator_plaintext = engine.create_plaintext_vector_from(&accumulator_u64)?;
171
172        let accumulator = engine.trivially_encrypt_glwe_ciphertext(
173            server_key.bootstrapping_key.glwe_dimension().to_glwe_size(),
174            &accumulator_plaintext,
175        )?;
176
177        Ok(accumulator)
178    }
179
180    /// Returns the `Buffers` for the given `ServerKey`
181    ///
182    /// Takes care creating the buffers if they do not exists for the given key
183    ///
184    /// This also `&mut CoreEngine` to simply borrow checking for the caller
185    /// (since returned buffers are borrowed from `self`, using the `self.engine`
186    /// wouldn't be possible after calling `buffers_for_key`)
187    pub fn buffers_for_key(
188        &mut self,
189        server_key: &ServerKey,
190    ) -> (&mut Buffers, &mut DefaultEngine, &mut FftEngine) {
191        let key = server_key.key_id();
192        // To make borrow checker happy
193        let engine = &mut self.engine;
194        let buffers_map = &mut self.buffers;
195        let buffers = buffers_map.entry(key).or_insert_with(|| {
196            let accumulator = Self::generate_accumulator_with_engine(engine, server_key, |n| {
197                n % server_key.message_modulus.0 as u64
198            })
199            .unwrap();
200
201            // Allocate the buffer for the output of the PBS
202            let zero_plaintext = engine.create_plaintext_from(&0_u64).unwrap();
203            let buffer_lwe_after_pbs = engine
204                .trivially_encrypt_lwe_ciphertext(
205                    server_key
206                        .key_switching_key
207                        .output_lwe_dimension()
208                        .to_lwe_size(),
209                    &zero_plaintext,
210                )
211                .unwrap();
212
213            Buffers {
214                accumulator,
215                buffer_lwe_after_ks: buffer_lwe_after_pbs,
216            }
217        });
218
219        (buffers, engine, &mut self.fft_engine)
220    }
221}