fhe_math/rq/
context.rs

1use itertools::Itertools;
2use num_bigint::BigUint;
3use std::{fmt::Debug, sync::Arc};
4
5use crate::{ntt::NttOperator, rns::RnsContext, zq::Modulus, Error, Result};
6
7/// Struct that holds the context associated with elements in rq.
8#[derive(Default, Clone, PartialEq, Eq)]
9pub struct Context {
10    pub(crate) moduli: Box<[u64]>,
11    pub(crate) q: Box<[Modulus]>,
12    pub(crate) rns: Arc<RnsContext>,
13    pub(crate) ops: Box<[NttOperator]>,
14    pub(crate) degree: usize,
15    pub(crate) bitrev: Box<[usize]>,
16    pub(crate) inv_last_qi_mod_qj: Box<[u64]>,
17    pub(crate) inv_last_qi_mod_qj_shoup: Box<[u64]>,
18    pub(crate) next_context: Option<Arc<Context>>,
19}
20
21impl Debug for Context {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("Context")
24            .field("moduli", &self.moduli)
25            // .field("q", &self.q)
26            // .field("rns", &self.rns)
27            // .field("ops", &self.ops)
28            // .field("degree", &self.degree)
29            // .field("bitrev", &self.bitrev)
30            // .field("inv_last_qi_mod_qj", &self.inv_last_qi_mod_qj)
31            // .field("inv_last_qi_mod_qj_shoup", &self.inv_last_qi_mod_qj_shoup)
32            .field("next_context", &self.next_context)
33            .finish()
34    }
35}
36
37impl Context {
38    /// Creates a context from a list of moduli and a polynomial degree.
39    ///
40    /// Returns an error if the moduli are not primes less than 62 bits which
41    /// supports the NTT of size `degree`.
42    pub fn new(moduli: &[u64], degree: usize) -> Result<Self> {
43        if !degree.is_power_of_two() || degree < 8 {
44            Err(Error::Default(
45                "The degree is not a power of two larger or equal to 8".to_string(),
46            ))
47        } else {
48            let rns = Arc::new(RnsContext::new(moduli)?);
49            let (q, ops): (Vec<Modulus>, Vec<NttOperator>) = moduli
50                .iter()
51                .map(|modulus| {
52                    let qi = Modulus::new(*modulus)?;
53                    NttOperator::new(&qi, degree)
54                        .ok_or_else(|| {
55                            Error::Default("Impossible to construct a Ntt operator".to_string())
56                        })
57                        .map(|op| (qi, op))
58                })
59                .collect::<Result<Vec<(Modulus, NttOperator)>>>()?
60                .into_iter()
61                .unzip();
62            let bitrev = (0..degree)
63                .map(|j| j.reverse_bits() >> (degree.leading_zeros() + 1))
64                .collect_vec();
65
66            let mut inv_last_qi_mod_qj = vec![];
67            let mut inv_last_qi_mod_qj_shoup = vec![];
68            let q_last = moduli.last().unwrap();
69            for qi in &q[..q.len() - 1] {
70                let inv = qi.inv(qi.reduce(*q_last)).unwrap();
71                inv_last_qi_mod_qj.push(inv);
72                inv_last_qi_mod_qj_shoup.push(qi.shoup(inv));
73            }
74
75            let next_context = if moduli.len() >= 2 {
76                Some(Arc::new(Context::new(&moduli[..moduli.len() - 1], degree)?))
77            } else {
78                None
79            };
80
81            Ok(Self {
82                moduli: moduli.to_owned().into_boxed_slice(),
83                q: q.into_boxed_slice(),
84                rns,
85                ops: ops.into_boxed_slice(),
86                degree,
87                bitrev: bitrev.into_boxed_slice(),
88                inv_last_qi_mod_qj: inv_last_qi_mod_qj.into_boxed_slice(),
89                inv_last_qi_mod_qj_shoup: inv_last_qi_mod_qj_shoup.into_boxed_slice(),
90                next_context,
91            })
92        }
93    }
94
95    /// Creates a context in an `Arc`.
96    pub fn new_arc(moduli: &[u64], degree: usize) -> Result<Arc<Self>> {
97        Self::new(moduli, degree).map(Arc::new)
98    }
99
100    /// Returns the modulus as a BigUint.
101    pub fn modulus(&self) -> &BigUint {
102        self.rns.modulus()
103    }
104
105    /// Returns a reference to the moduli in this context.
106    pub fn moduli(&self) -> &[u64] {
107        &self.moduli
108    }
109
110    /// Returns a reference to the moduli as Modulus in this context.
111    pub fn moduli_operators(&self) -> &[Modulus] {
112        &self.q
113    }
114
115    /// Returns the number of iterations to switch to a children context.
116    /// Returns an error if the context provided is not a child context.
117    pub fn niterations_to(&self, context: &Arc<Context>) -> Result<usize> {
118        if context.as_ref() == self {
119            return Ok(0);
120        }
121
122        let mut niterations = 0;
123        let mut found = false;
124        let mut current_ctx = Arc::new(self.clone());
125        while current_ctx.next_context.is_some() {
126            niterations += 1;
127            current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
128            if &current_ctx == context {
129                found = true;
130                break;
131            }
132        }
133        if found {
134            Ok(niterations)
135        } else {
136            Err(Error::InvalidContext)
137        }
138    }
139
140    /// Returns the context after `i` iterations.
141    pub fn context_at_level(&self, i: usize) -> Result<Arc<Self>> {
142        if i >= self.moduli.len() {
143            Err(Error::Default(
144                "No context at the specified level".to_string(),
145            ))
146        } else {
147            let mut current_ctx = Arc::new(self.clone());
148            for _ in 0..i {
149                current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
150            }
151            Ok(current_ctx)
152        }
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use std::{error::Error, sync::Arc};
159
160    use crate::ntt::supports_ntt;
161    use crate::rq::Context;
162
163    const MODULI: &[u64; 5] = &[
164        1153,
165        4611686018326724609,
166        4611686018309947393,
167        4611686018232352769,
168        4611686018171535361,
169    ];
170
171    #[test]
172    fn context_constructor() {
173        for modulus in MODULI {
174            // modulus is = 1 modulo 2 * 8
175            assert!(Context::new(&[*modulus], 16).is_ok());
176
177            if supports_ntt(*modulus, 128) {
178                assert!(Context::new(&[*modulus], 128).is_ok());
179            } else {
180                assert!(Context::new(&[*modulus], 128).is_err());
181            }
182        }
183
184        // All moduli in MODULI are = 1 modulo 2 * 8
185        assert!(Context::new(MODULI, 16).is_ok());
186
187        // This should fail since 1153 != 1 moduli 2 * 128
188        assert!(Context::new(MODULI, 128).is_err());
189    }
190
191    #[test]
192    fn next_context() -> Result<(), Box<dyn Error>> {
193        // A context should have a children pointing to a context with one less modulus.
194        let context = Arc::new(Context::new(MODULI, 16)?);
195        assert_eq!(
196            context.next_context,
197            Some(Arc::new(Context::new(&MODULI[..MODULI.len() - 1], 16)?))
198        );
199
200        // We can go down the chain of the MODULI.len() - 1 context's.
201        let mut number_of_children = 0;
202        let mut current = context;
203        while current.next_context.is_some() {
204            number_of_children += 1;
205            current = current.next_context.as_ref().unwrap().clone();
206        }
207        assert_eq!(number_of_children, MODULI.len() - 1);
208
209        Ok(())
210    }
211
212    #[test]
213    fn niterations_to() -> Result<(), Box<dyn Error>> {
214        // A context should have a children pointing to a context with one less modulus.
215        let context = Arc::new(Context::new(MODULI, 16)?);
216
217        assert_eq!(context.niterations_to(&context).ok(), Some(0));
218
219        assert_eq!(
220            context
221                .niterations_to(&Arc::new(Context::new(&MODULI[1..], 16)?))
222                .err(),
223            Some(crate::Error::InvalidContext)
224        );
225
226        for i in 1..MODULI.len() {
227            assert_eq!(
228                context
229                    .niterations_to(&Arc::new(Context::new(&MODULI[..MODULI.len() - i], 16)?))
230                    .ok(),
231                Some(i)
232            );
233        }
234
235        Ok(())
236    }
237}