1use crate::ntt64::arith::Ntt64Arith;
31use crate::ntt64::context::Ntt64Context;
32use crate::poly::Poly64;
33use alloc::vec::Vec;
34
35pub struct RnsContext {
44 pub moduli: Vec<u64>,
46 pub ariths: Vec<Ntt64Arith>,
48 pub ntt_ctxs: Vec<Ntt64Context>,
50 pub poly_degree: usize,
52}
53
54impl RnsContext {
55 pub fn new(poly_degree: usize, moduli: Vec<u64>) -> Self {
65 assert!(
66 poly_degree.is_power_of_two(),
67 "poly_degree must be a power of 2"
68 );
69 assert!(!moduli.is_empty(), "at least one modulus is required");
70
71 let ariths: Vec<Ntt64Arith> = moduli.iter().map(|&q| Ntt64Arith::new(q)).collect();
72
73 let ntt_ctxs: Vec<Ntt64Context> = ariths
74 .iter()
75 .map(|arith| Ntt64Context::new(poly_degree, arith.clone()))
76 .collect();
77
78 Self {
79 moduli,
80 ariths,
81 ntt_ctxs,
82 poly_degree,
83 }
84 }
85
86 #[inline]
88 pub fn num_moduli(&self) -> usize {
89 self.moduli.len()
90 }
91}
92
93#[derive(Clone, Debug)]
105pub struct RnsPoly {
106 pub components: Vec<Poly64>,
108 pub level: usize,
110}
111
112impl RnsPoly {
113 pub fn from_coefficients(coeffs: &[i64], ctx: &RnsContext) -> Self {
123 let n = ctx.poly_degree;
124 assert!(
125 coeffs.len() <= n,
126 "too many coefficients: {} > {}",
127 coeffs.len(),
128 n
129 );
130
131 let level = ctx.num_moduli();
132 let mut components = Vec::with_capacity(level);
133
134 for i in 0..level {
135 let q = ctx.moduli[i];
136 let q_i64 = q as i64;
137
138 let mut poly = Poly64::new_zero(n);
139 for (j, &c) in coeffs.iter().enumerate() {
140 let r = c % q_i64;
141 poly.data[j] = if r < 0 { (r + q_i64) as u64 } else { r as u64 };
142 }
143
144 poly.forward_ntt(&ctx.ntt_ctxs[i]);
145 components.push(poly);
146 }
147
148 Self { components, level }
149 }
150
151 pub fn add(&self, other: &RnsPoly, ctx: &RnsContext) -> RnsPoly {
155 assert_eq!(
156 self.level, other.level,
157 "levels must match: {} != {}",
158 self.level, other.level
159 );
160
161 let mut result = self.clone();
162 for i in 0..self.level {
163 result.components[i].add_assign(&other.components[i], &ctx.ariths[i]);
164 }
165 result
166 }
167
168 pub fn sub(&self, other: &RnsPoly, ctx: &RnsContext) -> RnsPoly {
170 assert_eq!(self.level, other.level, "levels must match");
171
172 let mut result = self.clone();
173 for i in 0..self.level {
174 result.components[i].sub_assign(&other.components[i], &ctx.ariths[i]);
175 }
176 result
177 }
178
179 pub fn mul(&self, other: &RnsPoly, ctx: &RnsContext) -> RnsPoly {
183 assert_eq!(self.level, other.level, "levels must match");
184
185 let mut result = self.clone();
186 for i in 0..self.level {
187 result.components[i].mul_assign(&other.components[i], &ctx.ariths[i]);
188 }
189 result
190 }
191
192 pub fn drop_last_modulus(&mut self) {
200 assert!(self.level > 1, "cannot reduce level below 1");
201 self.components.pop();
202 self.level -= 1;
203 }
204
205 pub fn forward_all(&mut self, ctx: &RnsContext) {
207 for i in 0..self.level {
208 if !self.components[i].is_ntt {
209 self.components[i].forward_ntt(&ctx.ntt_ctxs[i]);
210 }
211 }
212 }
213
214 pub fn inverse_all(&mut self, ctx: &RnsContext) {
216 for i in 0..self.level {
217 if self.components[i].is_ntt {
218 self.components[i].inverse_ntt(&ctx.ntt_ctxs[i]);
219 }
220 }
221 }
222}
223
224#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::ntt64::prime::is_prime;
232 use alloc::vec;
233 use alloc::vec::Vec;
234
235 const TEST_N: usize = 256;
236 const TEST_Q1: u64 = 7681; const TEST_Q2: u64 = 12289; fn test_rns_ctx() -> RnsContext {
240 RnsContext::new(TEST_N, vec![TEST_Q1, TEST_Q2])
241 }
242
243 #[test]
244 fn test_rns_encode_decode() {
245 let ctx = test_rns_ctx();
246 let coeffs = vec![5i64, -3, 0, 7];
247 let mut rns_poly = RnsPoly::from_coefficients(&coeffs, &ctx);
248
249 rns_poly.inverse_all(&ctx);
250
251 assert_eq!(rns_poly.components[0].data[0], 5);
252 assert_eq!(rns_poly.components[0].data[1], TEST_Q1 - 3);
253 assert_eq!(rns_poly.components[0].data[2], 0);
254 assert_eq!(rns_poly.components[0].data[3], 7);
255
256 assert_eq!(rns_poly.components[1].data[0], 5);
257 assert_eq!(rns_poly.components[1].data[1], TEST_Q2 - 3);
258 assert_eq!(rns_poly.components[1].data[2], 0);
259 assert_eq!(rns_poly.components[1].data[3], 7);
260 }
261
262 #[test]
263 fn test_rns_add_mul_distributivity() {
264 let ctx = test_rns_ctx();
265
266 let a_coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| i % 100).collect();
267 let b_coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| (i * 3 + 7) % 100).collect();
268 let c_coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| (i * 2 + 1) % 50).collect();
269
270 let a = RnsPoly::from_coefficients(&a_coeffs, &ctx);
271 let b = RnsPoly::from_coefficients(&b_coeffs, &ctx);
272 let c = RnsPoly::from_coefficients(&c_coeffs, &ctx);
273
274 let ab = a.add(&b, &ctx);
276 let mut lhs = ab.mul(&c, &ctx);
277
278 let ac = a.mul(&c, &ctx);
280 let bc = b.mul(&c, &ctx);
281 let mut rhs = ac.add(&bc, &ctx);
282
283 lhs.inverse_all(&ctx);
284 rhs.inverse_all(&ctx);
285
286 for i in 0..ctx.num_moduli() {
287 for j in 0..TEST_N {
288 assert_eq!(
289 lhs.components[i].data[j], rhs.components[i].data[j],
290 "(a+b)*c != a*c+b*c — modulus {}, coeff {}",
291 ctx.moduli[i], j
292 );
293 }
294 }
295 }
296
297 #[test]
298 fn test_rns_drop_last_modulus() {
299 let ctx = test_rns_ctx();
300 let coeffs = vec![1i64, 2, 3];
301 let mut poly = RnsPoly::from_coefficients(&coeffs, &ctx);
302
303 assert_eq!(poly.level, 2);
304 assert_eq!(poly.components.len(), 2);
305
306 poly.drop_last_modulus();
307
308 assert_eq!(poly.level, 1);
309 assert_eq!(poly.components.len(), 1);
310 }
311
312 #[test]
313 #[should_panic(expected = "cannot reduce")]
314 fn test_rns_drop_last_modulus_panics_at_level_1() {
315 let ctx = RnsContext::new(TEST_N, vec![TEST_Q1]);
316 let coeffs = vec![1i64];
317 let mut poly = RnsPoly::from_coefficients(&coeffs, &ctx);
318 poly.drop_last_modulus();
319 }
320
321 #[test]
322 fn test_rns_sub() {
323 let ctx = test_rns_ctx();
324 let coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| i % 1000 - 500).collect();
325 let a = RnsPoly::from_coefficients(&coeffs, &ctx);
326
327 let mut zero = a.sub(&a, &ctx);
328 zero.inverse_all(&ctx);
329
330 for i in 0..ctx.num_moduli() {
331 for j in 0..TEST_N {
332 assert_eq!(
333 zero.components[i].data[j], 0,
334 "a - a != 0 — modulus {}, coeff {}",
335 ctx.moduli[i], j
336 );
337 }
338 }
339 }
340
341 #[test]
342 fn test_ntt_friendly_primes_are_valid() {
343 assert!(is_prime(TEST_Q1), "q1 = {TEST_Q1} should be prime");
344 assert!(is_prime(TEST_Q2), "q2 = {TEST_Q2} should be prime");
345
346 let two_n = 2 * TEST_N as u64;
347 assert_eq!((TEST_Q1 - 1) % two_n, 0);
348 assert_eq!((TEST_Q2 - 1) % two_n, 0);
349 }
350}