1use itertools::Itertools;
2use num_bigint::BigUint;
3use std::{fmt::Debug, sync::Arc};
4
5use crate::{ntt::NttOperator, rns::RnsContext, zq::Modulus, Error, Result};
6
7#[derive(Default, Clone, PartialEq, Eq)]
9pub struct Context {
10 pub(crate) moduli: Box<[u64]>,
11 pub(crate) q: Box<[Modulus]>,
12 pub(crate) rns: Arc<RnsContext>,
13 pub(crate) ops: Box<[NttOperator]>,
14 pub(crate) degree: usize,
15 pub(crate) bitrev: Box<[usize]>,
16 pub(crate) inv_last_qi_mod_qj: Box<[u64]>,
17 pub(crate) inv_last_qi_mod_qj_shoup: Box<[u64]>,
18 pub(crate) next_context: Option<Arc<Context>>,
19}
20
21impl Debug for Context {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("Context")
24 .field("moduli", &self.moduli)
25 .field("next_context", &self.next_context)
33 .finish()
34 }
35}
36
37impl Context {
38 pub fn new(moduli: &[u64], degree: usize) -> Result<Self> {
43 if !degree.is_power_of_two() || degree < 8 {
44 Err(Error::Default(
45 "The degree is not a power of two larger or equal to 8".to_string(),
46 ))
47 } else {
48 let rns = Arc::new(RnsContext::new(moduli)?);
49 let (q, ops): (Vec<Modulus>, Vec<NttOperator>) = moduli
50 .iter()
51 .map(|modulus| {
52 let qi = Modulus::new(*modulus)?;
53 NttOperator::new(&qi, degree)
54 .ok_or_else(|| {
55 Error::Default("Impossible to construct a Ntt operator".to_string())
56 })
57 .map(|op| (qi, op))
58 })
59 .collect::<Result<Vec<(Modulus, NttOperator)>>>()?
60 .into_iter()
61 .unzip();
62 let bitrev = (0..degree)
63 .map(|j| j.reverse_bits() >> (degree.leading_zeros() + 1))
64 .collect_vec();
65
66 let mut inv_last_qi_mod_qj = vec![];
67 let mut inv_last_qi_mod_qj_shoup = vec![];
68 let q_last = moduli.last().unwrap();
69 for qi in &q[..q.len() - 1] {
70 let inv = qi.inv(qi.reduce(*q_last)).unwrap();
71 inv_last_qi_mod_qj.push(inv);
72 inv_last_qi_mod_qj_shoup.push(qi.shoup(inv));
73 }
74
75 let next_context = if moduli.len() >= 2 {
76 Some(Arc::new(Context::new(&moduli[..moduli.len() - 1], degree)?))
77 } else {
78 None
79 };
80
81 Ok(Self {
82 moduli: moduli.to_owned().into_boxed_slice(),
83 q: q.into_boxed_slice(),
84 rns,
85 ops: ops.into_boxed_slice(),
86 degree,
87 bitrev: bitrev.into_boxed_slice(),
88 inv_last_qi_mod_qj: inv_last_qi_mod_qj.into_boxed_slice(),
89 inv_last_qi_mod_qj_shoup: inv_last_qi_mod_qj_shoup.into_boxed_slice(),
90 next_context,
91 })
92 }
93 }
94
95 pub fn new_arc(moduli: &[u64], degree: usize) -> Result<Arc<Self>> {
97 Self::new(moduli, degree).map(Arc::new)
98 }
99
100 pub fn modulus(&self) -> &BigUint {
102 self.rns.modulus()
103 }
104
105 pub fn moduli(&self) -> &[u64] {
107 &self.moduli
108 }
109
110 pub fn moduli_operators(&self) -> &[Modulus] {
112 &self.q
113 }
114
115 pub fn niterations_to(&self, context: &Arc<Context>) -> Result<usize> {
118 if context.as_ref() == self {
119 return Ok(0);
120 }
121
122 let mut niterations = 0;
123 let mut found = false;
124 let mut current_ctx = Arc::new(self.clone());
125 while current_ctx.next_context.is_some() {
126 niterations += 1;
127 current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
128 if ¤t_ctx == context {
129 found = true;
130 break;
131 }
132 }
133 if found {
134 Ok(niterations)
135 } else {
136 Err(Error::InvalidContext)
137 }
138 }
139
140 pub fn context_at_level(&self, i: usize) -> Result<Arc<Self>> {
142 if i >= self.moduli.len() {
143 Err(Error::Default(
144 "No context at the specified level".to_string(),
145 ))
146 } else {
147 let mut current_ctx = Arc::new(self.clone());
148 for _ in 0..i {
149 current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
150 }
151 Ok(current_ctx)
152 }
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use std::{error::Error, sync::Arc};
159
160 use crate::ntt::supports_ntt;
161 use crate::rq::Context;
162
163 const MODULI: &[u64; 5] = &[
164 1153,
165 4611686018326724609,
166 4611686018309947393,
167 4611686018232352769,
168 4611686018171535361,
169 ];
170
171 #[test]
172 fn context_constructor() {
173 for modulus in MODULI {
174 assert!(Context::new(&[*modulus], 16).is_ok());
176
177 if supports_ntt(*modulus, 128) {
178 assert!(Context::new(&[*modulus], 128).is_ok());
179 } else {
180 assert!(Context::new(&[*modulus], 128).is_err());
181 }
182 }
183
184 assert!(Context::new(MODULI, 16).is_ok());
186
187 assert!(Context::new(MODULI, 128).is_err());
189 }
190
191 #[test]
192 fn next_context() -> Result<(), Box<dyn Error>> {
193 let context = Arc::new(Context::new(MODULI, 16)?);
195 assert_eq!(
196 context.next_context,
197 Some(Arc::new(Context::new(&MODULI[..MODULI.len() - 1], 16)?))
198 );
199
200 let mut number_of_children = 0;
202 let mut current = context;
203 while current.next_context.is_some() {
204 number_of_children += 1;
205 current = current.next_context.as_ref().unwrap().clone();
206 }
207 assert_eq!(number_of_children, MODULI.len() - 1);
208
209 Ok(())
210 }
211
212 #[test]
213 fn niterations_to() -> Result<(), Box<dyn Error>> {
214 let context = Arc::new(Context::new(MODULI, 16)?);
216
217 assert_eq!(context.niterations_to(&context).ok(), Some(0));
218
219 assert_eq!(
220 context
221 .niterations_to(&Arc::new(Context::new(&MODULI[1..], 16)?))
222 .err(),
223 Some(crate::Error::InvalidContext)
224 );
225
226 for i in 1..MODULI.len() {
227 assert_eq!(
228 context
229 .niterations_to(&Arc::new(Context::new(&MODULI[..MODULI.len() - i], 16)?))
230 .ok(),
231 Some(i)
232 );
233 }
234
235 Ok(())
236 }
237}