he_ring/rnsconv/
matrix_lift.rs

1use feanor_math::algorithms::matmul::strassen::strassen_mem_size;
2use feanor_math::integer::*;
3use feanor_math::matrix::*;
4use feanor_math::homomorphism::*;
5use feanor_math::seq::*;
6use feanor_math::rings::zn::*;
7use feanor_math::rings::zn::zn_64::*;
8use feanor_math::divisibility::DivisibilityRingStore;
9use feanor_math::primitive_int::*;
10use feanor_math::ring::*;
11use feanor_math::ordered::OrderedRingStore;
12use tracing::instrument;
13
14use std::alloc::Allocator;
15use std::alloc::Global;
16
17use crate::{ZZbig, ZZi64, ZZi128};
18use super::RNSOperation;
19
20///
21/// Stores values for an almost exact conversion between RNS bases.
22/// A complete conversion refers to the function
23/// ```text
24///   Z/QZ -> Z/Q'Z, x -> [lift(x)]
25/// ```
26/// In our case, the output of the function is allowed to have an error of `{ -Q, 0, Q }`,
27/// unless the shortest lift of the input is bounded by `Q/4`, in which case the result
28/// is always correct.
29/// 
30/// # Implementation
31/// 
32/// Similar to [`super::lift::AlmostExactBaseConversion`], but this
33/// implementation makes some assumptions on the sizes of the moduli, which allows
34/// to use a matrix multiplication for the performance-critical section.
35/// 
36pub struct AlmostExactMatrixBaseConversion<A = Global>
37    where A: Allocator + Clone
38{
39    from_summands: Vec<Zn>,
40    to_summands: Vec<Zn>,
41    /// the values `q/Q mod q` for each RNS factor q dividing Q (ordered as `from_summands`)
42    q_over_Q: Vec<ZnEl>,
43    /// shortest lifts of the values `Q/q mod q'` for each RNS factor q dividing Q (ordered as `from_summands_ordered`) and q' dividing Q';
44    /// finally, the last row are the values `gamma/q'` for each RNS factor q dividing Q (ordered as `from_summands_ordered`)
45    Q_over_q_mod_and_downscaled: OwnedMatrix<i128>,
46    gamma: i128,
47    /// `Q mod q'` for every `q'` dividing `Q'`
48    Q_mod_q: Vec<ZnEl>,
49    allocator: A
50}
51
52// we currently use `any_lift()`; I haven't yet documented it anywhere, but in fact the largest output of `zn_64::Zn::any_lift()` is currently `6 * modulus()`
53const ZN_ANY_LIFT_FACTOR: i64 = 6;
54
55const BLOCK_SIZE_LOG2: usize = 4;
56
57fn pad_to_block(len: usize) -> usize {
58    ((len - 1) / (1 << BLOCK_SIZE_LOG2) + 1) * (1 << BLOCK_SIZE_LOG2)
59}
60
61impl<A> AlmostExactMatrixBaseConversion<A> 
62    where A: Allocator + Clone
63{
64    ///
65    /// Creates a new [`AlmostExactMatrixBaseConversion`] from `q` to `q'`. The moduli belonging to `q'`
66    /// are expected to be sorted.
67    /// 
68    #[instrument(skip_all)]
69    pub fn new_with(in_rings: Vec<Zn>, out_rings: Vec<Zn>, allocator: A) -> Self {
70        
71        let Q = ZZbig.prod((0..in_rings.len()).map(|i| int_cast(*in_rings.at(i).modulus(), ZZbig, ZZi64)));
72
73        let max = |l, r| if ZZbig.is_geq(&l, &r) { l } else { r };
74        let max_computation_result = ZZbig.prod([
75            in_rings.iter().map(|ring| int_cast(*ring.modulus() * ZN_ANY_LIFT_FACTOR, ZZbig, ZZi64)).reduce(max).unwrap(),
76            out_rings.iter().map(|ring| int_cast(*ring.modulus(), ZZbig, ZZi64)).reduce(max).unwrap(),
77            ZZbig.int_hom().map(in_rings.len() as i32)
78        ].into_iter());
79        assert!(ZZbig.is_lt(&max_computation_result, &ZZbig.power_of_two(i128::BITS as usize - 1)), "temporarily unreduced modular lift sum will overflow");
80
81        // When computing the approximate lifted value, we can work with `gamma` in place of `Q`, where `gamma >= 4 r max(q)` (`q` runs through the input factors)
82        let log2_r = ZZi64.abs_log2_ceil(&(in_rings.len() as i64)).unwrap();
83        let log2_qmax = ZZi64.abs_log2_ceil(&(0..in_rings.len()).map(|i| *in_rings.at(i).modulus()).max().unwrap()).unwrap();
84        let log2_any_lift_factor = ZZi64.abs_log2_ceil(&ZN_ANY_LIFT_FACTOR).unwrap();
85        let gamma = ZZbig.power_of_two(log2_r + log2_qmax + log2_any_lift_factor + 2);
86        // we compute a sum of `r` summands, each being a product of a lifted value (mod `q`, `q | Q`) and `gamma/q`; this must not overflow
87        assert!(ZZbig.abs_log2_ceil(&gamma).unwrap() + log2_r + log2_any_lift_factor + 1 < ZZi128.get_ring().representable_bits().unwrap(), "correction computation will overflow");
88        let gamma_log2 = ZZbig.abs_log2_ceil(&gamma).unwrap();
89        assert!(gamma_log2 == ZZbig.abs_log2_floor(&gamma).unwrap());
90
91        let Q_over_q = OwnedMatrix::from_fn_in(pad_to_block(out_rings.len() + 1), pad_to_block(in_rings.len()), |i, j| {
92            if i < out_rings.len() && j < in_rings.len() {
93                let ring = out_rings.at(i);
94                ring.smallest_lift(ring.coerce(&ZZbig, ZZbig.checked_div(&Q, &int_cast(*in_rings.at(j).modulus(), ZZbig, ZZi64)).unwrap())) as i128
95            } else if i == out_rings.len() && j < in_rings.len() {
96                int_cast(ZZbig.rounded_div(ZZbig.clone_el(&gamma), &int_cast(*in_rings.at(j).modulus(), ZZbig, ZZi64)), ZZi128, ZZbig)
97            } else {
98                0
99            }
100        }, Global);
101        let q_over_Q = (0..(in_rings.len())).map(|i| 
102            in_rings.at(i).invert(&in_rings.at(i).coerce(&ZZbig, ZZbig.checked_div(&Q, &int_cast(*in_rings.at(i).modulus(), ZZbig, ZZi64)).unwrap())).unwrap()
103        ).collect();
104
105        Self {
106            Q_over_q_mod_and_downscaled: Q_over_q,
107            q_over_Q: q_over_Q,
108            Q_mod_q: (0..out_rings.len()).map(|i| out_rings.at(i).coerce(&ZZbig, ZZbig.clone_el(&Q))).collect(),
109            gamma: ZZi128.power_of_two(gamma_log2),
110            allocator: allocator.clone(),
111            from_summands: in_rings,
112            to_summands: out_rings
113        }
114    }
115}
116
117impl<A> RNSOperation for AlmostExactMatrixBaseConversion<A> 
118    where A: Allocator + Clone
119{
120    type Ring = Zn;
121
122    type RingType = ZnBase;
123
124    fn input_rings<'a>(&'a self) -> &'a [Zn] {
125        &self.from_summands
126    }
127
128    fn output_rings<'a>(&'a self) -> &'a [Zn] {
129        &self.to_summands
130    }
131
132    ///
133    /// Performs the (almost) exact RNS base conversion
134    /// ```text
135    ///   Z/QZ -> Z/Q'Z, x -> smallest_lift(x) + kQ mod Q''
136    /// ```
137    /// where `k in { -1, 0, 1 }`.
138    /// 
139    /// Furthermore, if the shortest lift of the input is bounded by `Q/4`,
140    /// then the result is guaranteed to be exact.
141    /// 
142    #[instrument(skip_all)]
143    fn apply<V1, V2>(&self, input: Submatrix<V1, El<Self::Ring>>, mut output: SubmatrixMut<V2, El<Self::Ring>>)
144        where V1: AsPointerToSlice<El<Self::Ring>>,
145            V2: AsPointerToSlice<El<Self::Ring>>
146    {
147        {
148            assert_eq!(input.row_count(), self.input_rings().len());
149            assert_eq!(output.row_count(), self.output_rings().len());
150            assert_eq!(input.col_count(), output.col_count());
151
152            let in_len = input.row_count();
153            let out_len = output.row_count();
154            let col_count = input.col_count();
155
156            let int_to_homs = (0..self.output_rings().len()).map(|k| self.output_rings().at(k).can_hom(&ZZi128).unwrap()).collect::<Vec<_>>();
157
158            let mut lifts = OwnedMatrix::from_fn_in(pad_to_block(in_len), pad_to_block(col_count), |_, _| 0, self.allocator.clone());
159            let mut lifts = lifts.data_mut();
160
161            for i in 0..in_len {
162                for j in 0..col_count {
163                    // using `any_lift()` here is slightly dangerous, as I haven't documented anywhere that `zn_64::Zn::any_lift()` returns values `<= 6 * modulus()`, but
164                    // it currently does, so this is currently fine
165                    *lifts.at_mut(i, j) = self.from_summands[i].any_lift(self.from_summands[i].mul_ref(input.at(i, j), self.q_over_Q.at(i))) as i128;
166                    debug_assert!(*lifts.at(i, 0) >= 0 && *lifts.at(i, 0) <= ZN_ANY_LIFT_FACTOR as i128 * *self.from_summands[i].modulus() as i128);
167                }
168            }
169
170            let mut output_unreduced = OwnedMatrix::from_fn_in(pad_to_block(out_len + 1), pad_to_block(col_count), |_, _| 0, self.allocator.clone());
171            let mut output_unreduced = output_unreduced.data_mut();
172
173            // actually using Strassen's algorithm here doesn't make much of a difference, it is basically as fast as without for normal
174            // parameters; however, this way we can claim superior asymptotic performance :)
175            const STRASSEN_THRESHOLD_LOG2: usize = 3;
176            let mem_size = strassen_mem_size(pad_to_block(in_len) > (1 << BLOCK_SIZE_LOG2), BLOCK_SIZE_LOG2, STRASSEN_THRESHOLD_LOG2);
177            let mut memory = Vec::with_capacity_in(mem_size, self.allocator.clone());
178            memory.resize(mem_size, 0);
179
180            {
181                for i in 0..(pad_to_block(out_len + 1) / (1 << BLOCK_SIZE_LOG2)) {
182                    for k in 0..(pad_to_block(in_len) / (1 << BLOCK_SIZE_LOG2)) {
183                        for j in 0..(pad_to_block(col_count) / (1 << BLOCK_SIZE_LOG2)) {
184                            let rows = (i << BLOCK_SIZE_LOG2)..((i + 1) << BLOCK_SIZE_LOG2);
185                            let cols = (j << BLOCK_SIZE_LOG2)..((j + 1) << BLOCK_SIZE_LOG2);
186                            let ks = (k << BLOCK_SIZE_LOG2)..((k + 1) << BLOCK_SIZE_LOG2);
187                            if k == 0 {
188                                feanor_math::algorithms::matmul::strassen::dispatch_strassen_impl::<_, _, _, _, false, false, false, false>(
189                                    BLOCK_SIZE_LOG2, 
190                                    STRASSEN_THRESHOLD_LOG2, 
191                                    TransposableSubmatrix::from(self.Q_over_q_mod_and_downscaled.data().submatrix(rows.clone(), ks.clone())), 
192                                    TransposableSubmatrix::from(lifts.as_const().submatrix(ks, cols.clone())), 
193                                    TransposableSubmatrixMut::from(output_unreduced.reborrow().submatrix(rows, cols)), 
194                                    StaticRing::<i128>::RING, 
195                                    &mut memory
196                                );
197                            } else {   
198                                feanor_math::algorithms::matmul::strassen::dispatch_strassen_impl::<_, _, _, _, true, false, false, false>(
199                                    BLOCK_SIZE_LOG2, 
200                                    STRASSEN_THRESHOLD_LOG2, 
201                                    TransposableSubmatrix::from(self.Q_over_q_mod_and_downscaled.data().submatrix(rows.clone(), ks.clone())), 
202                                    TransposableSubmatrix::from(lifts.as_const().submatrix(ks, cols.clone())), 
203                                    TransposableSubmatrixMut::from(output_unreduced.reborrow().submatrix(rows, cols)), 
204                                    StaticRing::<i128>::RING, 
205                                    &mut memory
206                                );
207                            }
208                        }
209                    }
210                }
211            }
212
213            for j in 0..col_count {
214                let mut correction = *output_unreduced.at(out_len, j);
215                correction = ZZi128.rounded_div(correction, &self.gamma);
216
217                for i in 0..out_len {
218                    *output.at_mut(i, j) = self.to_summands[i].sub(
219                        int_to_homs.at(i).map_ref(output_unreduced.at(i, j)), 
220                        self.to_summands[i].mul_ref_snd(int_to_homs[i].map(correction), &self.Q_mod_q[i])
221                    );
222                }
223            }
224        }
225    }
226}
227
228#[cfg(test)]
229use feanor_math::assert_el_eq;
230#[cfg(test)]
231use test::Bencher;
232#[cfg(test)]
233use feanor_math::algorithms::miller_rabin::is_prime;
234#[cfg(test)]
235use feanor_math::rings::finite::FiniteRingStore;
236
237#[test]
238fn test_rns_base_conversion() {
239    let from = vec![Zn::new(17), Zn::new(97)];
240    let to = vec![Zn::new(17), Zn::new(97), Zn::new(113), Zn::new(257)];
241
242    let table = AlmostExactMatrixBaseConversion::new_with(from.clone(), to.clone(), Global);
243
244    // within this area, we guarantee that no error occurs
245    for k in -(17 * 97 / 4)..=(17 * 97 / 4) {
246        let input = from.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
247        let expected = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
248        let mut actual = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
249
250        table.apply(
251            Submatrix::from_1d(&input, 2, 1), 
252            SubmatrixMut::from_1d(&mut actual, 4, 1)
253        );
254
255        for j in 0..to.len() {
256            assert_el_eq!(to.at(j), expected.at(j), actual.at(j));
257        }
258    }
259
260    for k in (-17 * 97 / 2)..=(17 * 97 / 2) {
261        let input = from.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
262        let expected = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
263        let mut actual = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
264
265        table.apply(
266            Submatrix::from_1d(&input, 2, 1), 
267            SubmatrixMut::from_1d(&mut actual, 4, 1)
268        );
269
270        for j in 0..to.len() {
271            assert!(
272                to.at(j).eq_el(expected.at(j), actual.at(j)) ||
273                to.at(j).eq_el(&to.at(j).add_ref_fst(expected.at(j), to.at(j).int_hom().map(17 * 97)), actual.at(j)) ||
274                to.at(j).eq_el(&to.at(j).sub_ref_fst(expected.at(j), to.at(j).int_hom().map(17 * 97)), actual.at(j))
275            );
276        }
277    }
278}
279
280#[test]
281fn test_rns_base_conversion_small() {
282    let from = vec![Zn::new(3), Zn::new(97)];
283    let to = vec![Zn::new(17)];
284    let table = AlmostExactMatrixBaseConversion::new_with(from.clone(), to.clone(), Global);
285    
286    for k in -(97 * 3 / 2)..(97 * 3 / 2) {
287        let mut actual = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
288        table.apply(
289            Submatrix::from_1d(&[from[0].int_hom().map(k), from[1].int_hom().map(k)], 2, 1), 
290            SubmatrixMut::from_1d(&mut actual, 1, 1)
291        );
292
293        assert!(
294            to[0].eq_el(&to[0].int_hom().map(k), actual.at(0)) ||
295            to[0].eq_el(&to[0].int_hom().map(k + 97 * 3), actual.at(0)) ||
296            to[0].eq_el(&to[0].int_hom().map(k - 97 * 3), actual.at(0))
297        );
298    }
299}
300
301#[test]
302fn test_rns_base_conversion_not_coprime() {
303    let from = vec![Zn::new(17), Zn::new(97), Zn::new(113)];
304    let to = vec![Zn::new(17), Zn::new(97), Zn::new(113), Zn::new(257)];
305    let table = AlmostExactMatrixBaseConversion::new_with(from.clone(), to.clone(), Global);
306
307    for k in -(17 * 97 * 113 / 4)..=(17 * 97 * 113 / 4) {
308        let x = from.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
309        let y = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
310        let mut actual = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
311
312        table.apply(
313            Submatrix::from_1d(&x, 3, 1), 
314            SubmatrixMut::from_1d(&mut actual, 4, 1)
315        );
316        
317        for i in 0..y.len() {
318            assert!(to[i].eq_el(&y[i], actual.at(i)));
319        }
320    }
321}
322
323#[test]
324fn test_rns_base_conversion_not_coprime_permuted() {
325    let from = vec![Zn::new(113), Zn::new(17), Zn::new(97)];
326    let to = vec![Zn::new(17), Zn::new(97), Zn::new(113), Zn::new(257)];
327    let table = AlmostExactMatrixBaseConversion::new_with(from.clone(), to.clone(), Global);
328
329    for k in -(17 * 97 * 113 / 4)..=(17 * 97 * 113 / 4) {
330        let x = from.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
331        let y = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
332        let mut actual = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
333
334        table.apply(
335            Submatrix::from_1d(&x, 3, 1), 
336            SubmatrixMut::from_1d(&mut actual, 4, 1)
337        );
338        
339        for i in 0..y.len() {
340            assert!(to[i].eq_el(&y[i], actual.at(i)));
341        }
342    }
343}
344
345#[test]
346fn test_rns_base_conversion_coprime() {
347    let from = vec![Zn::new(17), Zn::new(97), Zn::new(113)];
348    let to = vec![Zn::new(19), Zn::new(23), Zn::new(257)];
349    let table = AlmostExactMatrixBaseConversion::new_with(from.clone(), to.clone(), Global);
350
351    for k in -(17 * 97 * 113 / 4)..=(17 * 97 * 113 / 4) {
352        let x = from.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
353        let y = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
354        let mut actual = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
355
356        table.apply(
357            Submatrix::from_1d(&x, 3, 1), 
358            SubmatrixMut::from_1d(&mut actual, 3, 1)
359        );
360        
361        for i in 0..y.len() {
362            assert!(to[i].eq_el(&y[i], actual.at(i)));
363        }
364    }
365}
366
367#[bench]
368fn bench_rns_base_conversion(bencher: &mut Bencher) {
369    let in_moduli_count = 20;
370    let out_moduli_count = 40;
371    let cols = 1000;
372    let mut primes = ((1 << 30)..).map(|k| (1 << 10) * k + 1).filter(|p| is_prime(&StaticRing::<i64>::RING, p, 10)).map(|p| Zn::new(p as u64));
373    let in_moduli = primes.by_ref().take(in_moduli_count).collect::<Vec<_>>();
374    let out_moduli = primes.take(out_moduli_count).collect::<Vec<_>>();
375    let conv = AlmostExactMatrixBaseConversion::new_with(in_moduli.clone(), out_moduli.clone(), Global);
376    
377    let mut rng = oorandom::Rand64::new(1);
378    let mut in_data = (0..(in_moduli_count * cols)).map(|idx| in_moduli[idx / cols].zero()).collect::<Vec<_>>();
379    let mut in_matrix = SubmatrixMut::from_1d(&mut in_data, in_moduli_count, cols);
380    let mut out_data = (0..(out_moduli_count * cols)).map(|idx| out_moduli[idx / cols].zero()).collect::<Vec<_>>();
381    let mut out_matrix = SubmatrixMut::from_1d(&mut out_data, out_moduli_count, cols);
382
383    bencher.iter(|| {
384        for i in 0..in_moduli_count {
385            for j in 0..cols {
386                *in_matrix.at_mut(i, j) = in_moduli[i].random_element(|| rng.rand_u64());
387            }
388        }
389        conv.apply(in_matrix.as_const(), out_matrix.reborrow());
390        for i in 0..out_moduli_count {
391            for j in 0..cols {
392                std::hint::black_box(out_matrix.at(i, j));
393            }
394        }
395    });
396}
397
398#[test]
399fn test_base_conversion_large() {
400    let primes: [i64; 34] = [
401        72057594040066049,
402        288230376150870017,
403        288230376150876161,
404        288230376150878209,
405        288230376150890497,
406        288230376150945793,
407        288230376150956033,
408        288230376151062529,
409        288230376151123969,
410        288230376151130113,
411        288230376151191553,
412        288230376151388161,
413        288230376151422977,
414        288230376151529473,
415        288230376151545857,
416        288230376151554049,
417        288230376151601153,
418        288230376151625729,
419        288230376151683073,
420        288230376151748609,
421        288230376151760897,
422        288230376151779329,
423        288230376151812097,
424        288230376151902209,
425        288230376151951361,
426        288230376151994369,
427        288230376152027137,
428        288230376152061953,
429        288230376152137729,
430        288230376152154113,
431        288230376152156161,
432        288230376152205313,
433        288230376152227841,
434        288230376152340481,
435    ];
436    let in_len = 17;
437    let from = &primes[..in_len];
438    let from_prod = ZZbig.prod(from.iter().map(|p| int_cast(*p, ZZbig, StaticRing::<i64>::RING)));
439    let to = &primes[in_len..];
440    let number = ZZbig.get_ring().parse("156545561910861509258548850310120795193837265771491906959215072510998373539323526014165281634346450795208120921520265422129013635769405993324585707811035953253906720513250161495607960734366886366296007741500531044904559075687514262946086011957808717474666493477109586105297965072817051127737667010", 10).unwrap();
441    assert!(ZZbig.is_lt(&number, &from_prod));
442    
443    let from = from.iter().map(|p| Zn::new(*p as u64)).collect::<Vec<_>>();
444    let to = to.iter().map(|p| Zn::new(*p as u64)).collect::<Vec<_>>();
445    let conversion = AlmostExactMatrixBaseConversion::new_with(from, to, Global);
446
447    let input = (0..in_len).map(|i| conversion.input_rings().at(i).coerce(&ZZbig, ZZbig.clone_el(&number))).collect::<Vec<_>>();
448    let expected = (0..(primes.len() - in_len)).map(|i| conversion.output_rings().at(i).coerce(&ZZbig, ZZbig.clone_el(&number))).collect::<Vec<_>>();
449    let mut output = (0..(primes.len() - in_len)).map(|i| conversion.output_rings().at(i).zero()).collect::<Vec<_>>();
450    conversion.apply(Submatrix::from_1d(&input, in_len, 1), SubmatrixMut::from_1d(&mut output, primes.len() - in_len, 1));
451
452    assert!(
453        expected.iter().zip(output.iter()).enumerate().all(|(i, (e, a))| conversion.output_rings().at(i).eq_el(e, a)) ||
454        expected.iter().zip(output.iter()).enumerate().all(|(i, (e, a))| conversion.output_rings().at(i).eq_el(e, &conversion.output_rings().at(i).add_ref_fst(a, conversion.output_rings().at(i).coerce(&ZZbig, ZZbig.clone_el(&from_prod))))) ||
455        expected.iter().zip(output.iter()).enumerate().all(|(i, (e, a))| conversion.output_rings().at(i).eq_el(e, &conversion.output_rings().at(i).sub_ref_fst(a, conversion.output_rings().at(i).coerce(&ZZbig, ZZbig.clone_el(&from_prod)))))
456    );
457}