Skip to main content

vaea_ntt/
rns.rs

1// Copyright (C) 2024-2026 Vaea SAS
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//
4// This file is part of VaeaNTT.
5//
6// VaeaNTT is free software: you can redistribute it and/or modify it under
7// the terms of the GNU Affero General Public License as published by the
8// Free Software Foundation, either version 3 of the License, or (at your
9// option) any later version.
10//
11// VaeaNTT is distributed in the hope that it will be useful, but WITHOUT
12// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13// FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
14// License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with VaeaNTT. If not, see <https://www.gnu.org/licenses/>.
18
19
20//! # Residue Number System (RNS) — Multi-Moduli Decomposition
21//!
22//! RNS allows working with large integers by decomposing them into residues
23//! modulo several small coprime moduli. Each component can be processed
24//! independently, which is perfect for parallelism and avoids
25//! multi-precision arithmetic.
26//!
27//! For CKKS, the product Q = q₁·q₂·…·q_L defines the precision level.
28//! Rescaling removes one modulus per level.
29
30use crate::ntt64::arith::Ntt64Arith;
31use crate::ntt64::context::Ntt64Context;
32use crate::poly::Poly64;
33use alloc::vec::Vec;
34
35// ---------------------------------------------------------------------------
36// RnsContext — RNS context
37// ---------------------------------------------------------------------------
38
39/// RNS context: a set of coprime moduli.
40///
41/// Precomputes modular arithmetic and NTT contexts for each modulus,
42/// enabling efficient component-wise polynomial operations.
43pub struct RnsContext {
44    /// The moduli q₁, q₂, …, q_L.
45    pub moduli: Vec<u64>,
46    /// Modular arithmetic contexts for each modulus (Barrett, Montgomery).
47    pub ariths: Vec<Ntt64Arith>,
48    /// NTT contexts for each modulus.
49    pub ntt_ctxs: Vec<Ntt64Context>,
50    /// Polynomial degree N.
51    pub poly_degree: usize,
52}
53
54impl RnsContext {
55    /// Creates an RNS context with the given moduli.
56    ///
57    /// Precomputes all modular arithmetic and NTT contexts.
58    /// Each modulus must be NTT-friendly for the given polynomial degree.
59    ///
60    /// # Panics
61    /// - If `poly_degree` is not a power of 2
62    /// - If `moduli` is empty
63    /// - If any modulus is not NTT-friendly for the given degree
64    pub fn new(poly_degree: usize, moduli: Vec<u64>) -> Self {
65        assert!(
66            poly_degree.is_power_of_two(),
67            "poly_degree must be a power of 2"
68        );
69        assert!(!moduli.is_empty(), "at least one modulus is required");
70
71        let ariths: Vec<Ntt64Arith> = moduli.iter().map(|&q| Ntt64Arith::new(q)).collect();
72
73        let ntt_ctxs: Vec<Ntt64Context> = ariths
74            .iter()
75            .map(|arith| Ntt64Context::new(poly_degree, arith.clone()))
76            .collect();
77
78        Self {
79            moduli,
80            ariths,
81            ntt_ctxs,
82            poly_degree,
83        }
84    }
85
86    /// Number of moduli (= total number of levels).
87    #[inline]
88    pub fn num_moduli(&self) -> usize {
89        self.moduli.len()
90    }
91}
92
93// ---------------------------------------------------------------------------
94// RnsPoly — polynomial in RNS representation
95// ---------------------------------------------------------------------------
96
97/// Polynomial in RNS representation: one component per modulus.
98///
99/// Each component `components[i]` is a polynomial in Z_{q_i}\[X\]/(X^N+1),
100/// stored in NTT domain by default for performance.
101///
102/// The `level` indicates the number of active moduli. CKKS rescaling reduces
103/// the level by removing the last modulus.
104#[derive(Clone, Debug)]
105pub struct RnsPoly {
106    /// `components[i]` = polynomial modulo `moduli[i]`.
107    pub components: Vec<Poly64>,
108    /// Current level (number of active moduli).
109    pub level: usize,
110}
111
112impl RnsPoly {
113    /// Encodes a signed-integer polynomial into RNS representation.
114    ///
115    /// For each modulus q_i:
116    /// 1. Reduces each coefficient mod q_i (handles negatives)
117    /// 2. Converts to NTT domain
118    ///
119    /// # Arguments
120    /// * `coeffs` — polynomial coefficients in Z (signed, coefficient domain)
121    /// * `ctx` — RNS context
122    pub fn from_coefficients(coeffs: &[i64], ctx: &RnsContext) -> Self {
123        let n = ctx.poly_degree;
124        assert!(
125            coeffs.len() <= n,
126            "too many coefficients: {} > {}",
127            coeffs.len(),
128            n
129        );
130
131        let level = ctx.num_moduli();
132        let mut components = Vec::with_capacity(level);
133
134        for i in 0..level {
135            let q = ctx.moduli[i];
136            let q_i64 = q as i64;
137
138            let mut poly = Poly64::new_zero(n);
139            for (j, &c) in coeffs.iter().enumerate() {
140                let r = c % q_i64;
141                poly.data[j] = if r < 0 { (r + q_i64) as u64 } else { r as u64 };
142            }
143
144            poly.forward_ntt(&ctx.ntt_ctxs[i]);
145            components.push(poly);
146        }
147
148        Self { components, level }
149    }
150
151    /// Component-wise addition in RNS.
152    ///
153    /// Both polynomials must have the same level.
154    pub fn add(&self, other: &RnsPoly, ctx: &RnsContext) -> RnsPoly {
155        assert_eq!(
156            self.level, other.level,
157            "levels must match: {} != {}",
158            self.level, other.level
159        );
160
161        let mut result = self.clone();
162        for i in 0..self.level {
163            result.components[i].add_assign(&other.components[i], &ctx.ariths[i]);
164        }
165        result
166    }
167
168    /// Component-wise subtraction in RNS.
169    pub fn sub(&self, other: &RnsPoly, ctx: &RnsContext) -> RnsPoly {
170        assert_eq!(self.level, other.level, "levels must match");
171
172        let mut result = self.clone();
173        for i in 0..self.level {
174            result.components[i].sub_assign(&other.components[i], &ctx.ariths[i]);
175        }
176        result
177    }
178
179    /// Component-wise multiplication in RNS (NTT domain).
180    ///
181    /// All components must be in NTT domain.
182    pub fn mul(&self, other: &RnsPoly, ctx: &RnsContext) -> RnsPoly {
183        assert_eq!(self.level, other.level, "levels must match");
184
185        let mut result = self.clone();
186        for i in 0..self.level {
187            result.components[i].mul_assign(&other.components[i], &ctx.ariths[i]);
188        }
189        result
190    }
191
192    /// Drops the last modulus (CKKS rescaling).
193    ///
194    /// After this operation, the level decreases by 1 and the last component
195    /// is removed. The scale factor Δ is implicitly divided by q_L.
196    ///
197    /// # Panics
198    /// Panics if the level is already 1.
199    pub fn drop_last_modulus(&mut self) {
200        assert!(self.level > 1, "cannot reduce level below 1");
201        self.components.pop();
202        self.level -= 1;
203    }
204
205    /// Converts all components from NTT domain to coefficient domain.
206    pub fn forward_all(&mut self, ctx: &RnsContext) {
207        for i in 0..self.level {
208            if !self.components[i].is_ntt {
209                self.components[i].forward_ntt(&ctx.ntt_ctxs[i]);
210            }
211        }
212    }
213
214    /// Converts all components from NTT domain to coefficient domain.
215    pub fn inverse_all(&mut self, ctx: &RnsContext) {
216        for i in 0..self.level {
217            if self.components[i].is_ntt {
218                self.components[i].inverse_ntt(&ctx.ntt_ctxs[i]);
219            }
220        }
221    }
222}
223
224// ---------------------------------------------------------------------------
225// Tests
226// ---------------------------------------------------------------------------
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::ntt64::prime::is_prime;
232    use alloc::vec;
233    use alloc::vec::Vec;
234
235    const TEST_N: usize = 256;
236    const TEST_Q1: u64 = 7681; // 15·512+1
237    const TEST_Q2: u64 = 12289; // 24·512+1
238
239    fn test_rns_ctx() -> RnsContext {
240        RnsContext::new(TEST_N, vec![TEST_Q1, TEST_Q2])
241    }
242
243    #[test]
244    fn test_rns_encode_decode() {
245        let ctx = test_rns_ctx();
246        let coeffs = vec![5i64, -3, 0, 7];
247        let mut rns_poly = RnsPoly::from_coefficients(&coeffs, &ctx);
248
249        rns_poly.inverse_all(&ctx);
250
251        assert_eq!(rns_poly.components[0].data[0], 5);
252        assert_eq!(rns_poly.components[0].data[1], TEST_Q1 - 3);
253        assert_eq!(rns_poly.components[0].data[2], 0);
254        assert_eq!(rns_poly.components[0].data[3], 7);
255
256        assert_eq!(rns_poly.components[1].data[0], 5);
257        assert_eq!(rns_poly.components[1].data[1], TEST_Q2 - 3);
258        assert_eq!(rns_poly.components[1].data[2], 0);
259        assert_eq!(rns_poly.components[1].data[3], 7);
260    }
261
262    #[test]
263    fn test_rns_add_mul_distributivity() {
264        let ctx = test_rns_ctx();
265
266        let a_coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| i % 100).collect();
267        let b_coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| (i * 3 + 7) % 100).collect();
268        let c_coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| (i * 2 + 1) % 50).collect();
269
270        let a = RnsPoly::from_coefficients(&a_coeffs, &ctx);
271        let b = RnsPoly::from_coefficients(&b_coeffs, &ctx);
272        let c = RnsPoly::from_coefficients(&c_coeffs, &ctx);
273
274        // (a + b) * c
275        let ab = a.add(&b, &ctx);
276        let mut lhs = ab.mul(&c, &ctx);
277
278        // a*c + b*c
279        let ac = a.mul(&c, &ctx);
280        let bc = b.mul(&c, &ctx);
281        let mut rhs = ac.add(&bc, &ctx);
282
283        lhs.inverse_all(&ctx);
284        rhs.inverse_all(&ctx);
285
286        for i in 0..ctx.num_moduli() {
287            for j in 0..TEST_N {
288                assert_eq!(
289                    lhs.components[i].data[j], rhs.components[i].data[j],
290                    "(a+b)*c != a*c+b*c — modulus {}, coeff {}",
291                    ctx.moduli[i], j
292                );
293            }
294        }
295    }
296
297    #[test]
298    fn test_rns_drop_last_modulus() {
299        let ctx = test_rns_ctx();
300        let coeffs = vec![1i64, 2, 3];
301        let mut poly = RnsPoly::from_coefficients(&coeffs, &ctx);
302
303        assert_eq!(poly.level, 2);
304        assert_eq!(poly.components.len(), 2);
305
306        poly.drop_last_modulus();
307
308        assert_eq!(poly.level, 1);
309        assert_eq!(poly.components.len(), 1);
310    }
311
312    #[test]
313    #[should_panic(expected = "cannot reduce")]
314    fn test_rns_drop_last_modulus_panics_at_level_1() {
315        let ctx = RnsContext::new(TEST_N, vec![TEST_Q1]);
316        let coeffs = vec![1i64];
317        let mut poly = RnsPoly::from_coefficients(&coeffs, &ctx);
318        poly.drop_last_modulus();
319    }
320
321    #[test]
322    fn test_rns_sub() {
323        let ctx = test_rns_ctx();
324        let coeffs: Vec<i64> = (0..TEST_N as i64).map(|i| i % 1000 - 500).collect();
325        let a = RnsPoly::from_coefficients(&coeffs, &ctx);
326
327        let mut zero = a.sub(&a, &ctx);
328        zero.inverse_all(&ctx);
329
330        for i in 0..ctx.num_moduli() {
331            for j in 0..TEST_N {
332                assert_eq!(
333                    zero.components[i].data[j], 0,
334                    "a - a != 0 — modulus {}, coeff {}",
335                    ctx.moduli[i], j
336                );
337            }
338        }
339    }
340
341    #[test]
342    fn test_ntt_friendly_primes_are_valid() {
343        assert!(is_prime(TEST_Q1), "q1 = {TEST_Q1} should be prime");
344        assert!(is_prime(TEST_Q2), "q2 = {TEST_Q2} should be prime");
345
346        let two_n = 2 * TEST_N as u64;
347        assert_eq!((TEST_Q1 - 1) % two_n, 0);
348        assert_eq!((TEST_Q2 - 1) % two_n, 0);
349    }
350}