1use feanor_math::algorithms::matmul::strassen::strassen_mem_size;
2use feanor_math::integer::*;
3use feanor_math::matrix::*;
4use feanor_math::homomorphism::*;
5use feanor_math::seq::*;
6use feanor_math::rings::zn::*;
7use feanor_math::rings::zn::zn_64::*;
8use feanor_math::divisibility::DivisibilityRingStore;
9use feanor_math::primitive_int::*;
10use feanor_math::ring::*;
11use feanor_math::ordered::OrderedRingStore;
12use tracing::instrument;
13
14use std::alloc::Allocator;
15use std::alloc::Global;
16
17use crate::{ZZbig, ZZi64, ZZi128};
18use super::RNSOperation;
19
20pub struct AlmostExactMatrixBaseConversion<A = Global>
37 where A: Allocator + Clone
38{
39 from_summands: Vec<Zn>,
40 to_summands: Vec<Zn>,
41 q_over_Q: Vec<ZnEl>,
43 Q_over_q_mod_and_downscaled: OwnedMatrix<i128>,
46 gamma: i128,
47 Q_mod_q: Vec<ZnEl>,
49 allocator: A
50}
51
52const ZN_ANY_LIFT_FACTOR: i64 = 6;
54
55const BLOCK_SIZE_LOG2: usize = 4;
56
57fn pad_to_block(len: usize) -> usize {
58 ((len - 1) / (1 << BLOCK_SIZE_LOG2) + 1) * (1 << BLOCK_SIZE_LOG2)
59}
60
61impl<A> AlmostExactMatrixBaseConversion<A>
62 where A: Allocator + Clone
63{
64 #[instrument(skip_all)]
69 pub fn new_with(in_rings: Vec<Zn>, out_rings: Vec<Zn>, allocator: A) -> Self {
70
71 let Q = ZZbig.prod((0..in_rings.len()).map(|i| int_cast(*in_rings.at(i).modulus(), ZZbig, ZZi64)));
72
73 let max = |l, r| if ZZbig.is_geq(&l, &r) { l } else { r };
74 let max_computation_result = ZZbig.prod([
75 in_rings.iter().map(|ring| int_cast(*ring.modulus() * ZN_ANY_LIFT_FACTOR, ZZbig, ZZi64)).reduce(max).unwrap(),
76 out_rings.iter().map(|ring| int_cast(*ring.modulus(), ZZbig, ZZi64)).reduce(max).unwrap(),
77 ZZbig.int_hom().map(in_rings.len() as i32)
78 ].into_iter());
79 assert!(ZZbig.is_lt(&max_computation_result, &ZZbig.power_of_two(i128::BITS as usize - 1)), "temporarily unreduced modular lift sum will overflow");
80
81 let log2_r = ZZi64.abs_log2_ceil(&(in_rings.len() as i64)).unwrap();
83 let log2_qmax = ZZi64.abs_log2_ceil(&(0..in_rings.len()).map(|i| *in_rings.at(i).modulus()).max().unwrap()).unwrap();
84 let log2_any_lift_factor = ZZi64.abs_log2_ceil(&ZN_ANY_LIFT_FACTOR).unwrap();
85 let gamma = ZZbig.power_of_two(log2_r + log2_qmax + log2_any_lift_factor + 2);
86 assert!(ZZbig.abs_log2_ceil(&gamma).unwrap() + log2_r + log2_any_lift_factor + 1 < ZZi128.get_ring().representable_bits().unwrap(), "correction computation will overflow");
88 let gamma_log2 = ZZbig.abs_log2_ceil(&gamma).unwrap();
89 assert!(gamma_log2 == ZZbig.abs_log2_floor(&gamma).unwrap());
90
91 let Q_over_q = OwnedMatrix::from_fn_in(pad_to_block(out_rings.len() + 1), pad_to_block(in_rings.len()), |i, j| {
92 if i < out_rings.len() && j < in_rings.len() {
93 let ring = out_rings.at(i);
94 ring.smallest_lift(ring.coerce(&ZZbig, ZZbig.checked_div(&Q, &int_cast(*in_rings.at(j).modulus(), ZZbig, ZZi64)).unwrap())) as i128
95 } else if i == out_rings.len() && j < in_rings.len() {
96 int_cast(ZZbig.rounded_div(ZZbig.clone_el(&gamma), &int_cast(*in_rings.at(j).modulus(), ZZbig, ZZi64)), ZZi128, ZZbig)
97 } else {
98 0
99 }
100 }, Global);
101 let q_over_Q = (0..(in_rings.len())).map(|i|
102 in_rings.at(i).invert(&in_rings.at(i).coerce(&ZZbig, ZZbig.checked_div(&Q, &int_cast(*in_rings.at(i).modulus(), ZZbig, ZZi64)).unwrap())).unwrap()
103 ).collect();
104
105 Self {
106 Q_over_q_mod_and_downscaled: Q_over_q,
107 q_over_Q: q_over_Q,
108 Q_mod_q: (0..out_rings.len()).map(|i| out_rings.at(i).coerce(&ZZbig, ZZbig.clone_el(&Q))).collect(),
109 gamma: ZZi128.power_of_two(gamma_log2),
110 allocator: allocator.clone(),
111 from_summands: in_rings,
112 to_summands: out_rings
113 }
114 }
115}
116
117impl<A> RNSOperation for AlmostExactMatrixBaseConversion<A>
118 where A: Allocator + Clone
119{
120 type Ring = Zn;
121
122 type RingType = ZnBase;
123
124 fn input_rings<'a>(&'a self) -> &'a [Zn] {
125 &self.from_summands
126 }
127
128 fn output_rings<'a>(&'a self) -> &'a [Zn] {
129 &self.to_summands
130 }
131
132 #[instrument(skip_all)]
143 fn apply<V1, V2>(&self, input: Submatrix<V1, El<Self::Ring>>, mut output: SubmatrixMut<V2, El<Self::Ring>>)
144 where V1: AsPointerToSlice<El<Self::Ring>>,
145 V2: AsPointerToSlice<El<Self::Ring>>
146 {
147 {
148 assert_eq!(input.row_count(), self.input_rings().len());
149 assert_eq!(output.row_count(), self.output_rings().len());
150 assert_eq!(input.col_count(), output.col_count());
151
152 let in_len = input.row_count();
153 let out_len = output.row_count();
154 let col_count = input.col_count();
155
156 let int_to_homs = (0..self.output_rings().len()).map(|k| self.output_rings().at(k).can_hom(&ZZi128).unwrap()).collect::<Vec<_>>();
157
158 let mut lifts = OwnedMatrix::from_fn_in(pad_to_block(in_len), pad_to_block(col_count), |_, _| 0, self.allocator.clone());
159 let mut lifts = lifts.data_mut();
160
161 for i in 0..in_len {
162 for j in 0..col_count {
163 *lifts.at_mut(i, j) = self.from_summands[i].any_lift(self.from_summands[i].mul_ref(input.at(i, j), self.q_over_Q.at(i))) as i128;
166 debug_assert!(*lifts.at(i, 0) >= 0 && *lifts.at(i, 0) <= ZN_ANY_LIFT_FACTOR as i128 * *self.from_summands[i].modulus() as i128);
167 }
168 }
169
170 let mut output_unreduced = OwnedMatrix::from_fn_in(pad_to_block(out_len + 1), pad_to_block(col_count), |_, _| 0, self.allocator.clone());
171 let mut output_unreduced = output_unreduced.data_mut();
172
173 const STRASSEN_THRESHOLD_LOG2: usize = 3;
176 let mem_size = strassen_mem_size(pad_to_block(in_len) > (1 << BLOCK_SIZE_LOG2), BLOCK_SIZE_LOG2, STRASSEN_THRESHOLD_LOG2);
177 let mut memory = Vec::with_capacity_in(mem_size, self.allocator.clone());
178 memory.resize(mem_size, 0);
179
180 {
181 for i in 0..(pad_to_block(out_len + 1) / (1 << BLOCK_SIZE_LOG2)) {
182 for k in 0..(pad_to_block(in_len) / (1 << BLOCK_SIZE_LOG2)) {
183 for j in 0..(pad_to_block(col_count) / (1 << BLOCK_SIZE_LOG2)) {
184 let rows = (i << BLOCK_SIZE_LOG2)..((i + 1) << BLOCK_SIZE_LOG2);
185 let cols = (j << BLOCK_SIZE_LOG2)..((j + 1) << BLOCK_SIZE_LOG2);
186 let ks = (k << BLOCK_SIZE_LOG2)..((k + 1) << BLOCK_SIZE_LOG2);
187 if k == 0 {
188 feanor_math::algorithms::matmul::strassen::dispatch_strassen_impl::<_, _, _, _, false, false, false, false>(
189 BLOCK_SIZE_LOG2,
190 STRASSEN_THRESHOLD_LOG2,
191 TransposableSubmatrix::from(self.Q_over_q_mod_and_downscaled.data().submatrix(rows.clone(), ks.clone())),
192 TransposableSubmatrix::from(lifts.as_const().submatrix(ks, cols.clone())),
193 TransposableSubmatrixMut::from(output_unreduced.reborrow().submatrix(rows, cols)),
194 StaticRing::<i128>::RING,
195 &mut memory
196 );
197 } else {
198 feanor_math::algorithms::matmul::strassen::dispatch_strassen_impl::<_, _, _, _, true, false, false, false>(
199 BLOCK_SIZE_LOG2,
200 STRASSEN_THRESHOLD_LOG2,
201 TransposableSubmatrix::from(self.Q_over_q_mod_and_downscaled.data().submatrix(rows.clone(), ks.clone())),
202 TransposableSubmatrix::from(lifts.as_const().submatrix(ks, cols.clone())),
203 TransposableSubmatrixMut::from(output_unreduced.reborrow().submatrix(rows, cols)),
204 StaticRing::<i128>::RING,
205 &mut memory
206 );
207 }
208 }
209 }
210 }
211 }
212
213 for j in 0..col_count {
214 let mut correction = *output_unreduced.at(out_len, j);
215 correction = ZZi128.rounded_div(correction, &self.gamma);
216
217 for i in 0..out_len {
218 *output.at_mut(i, j) = self.to_summands[i].sub(
219 int_to_homs.at(i).map_ref(output_unreduced.at(i, j)),
220 self.to_summands[i].mul_ref_snd(int_to_homs[i].map(correction), &self.Q_mod_q[i])
221 );
222 }
223 }
224 }
225 }
226}
227
228#[cfg(test)]
229use feanor_math::assert_el_eq;
230#[cfg(test)]
231use test::Bencher;
232#[cfg(test)]
233use feanor_math::algorithms::miller_rabin::is_prime;
234#[cfg(test)]
235use feanor_math::rings::finite::FiniteRingStore;
236
237#[test]
238fn test_rns_base_conversion() {
239 let from = vec![Zn::new(17), Zn::new(97)];
240 let to = vec![Zn::new(17), Zn::new(97), Zn::new(113), Zn::new(257)];
241
242 let table = AlmostExactMatrixBaseConversion::new_with(from.clone(), to.clone(), Global);
243
244 for k in -(17 * 97 / 4)..=(17 * 97 / 4) {
246 let input = from.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
247 let expected = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
248 let mut actual = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
249
250 table.apply(
251 Submatrix::from_1d(&input, 2, 1),
252 SubmatrixMut::from_1d(&mut actual, 4, 1)
253 );
254
255 for j in 0..to.len() {
256 assert_el_eq!(to.at(j), expected.at(j), actual.at(j));
257 }
258 }
259
260 for k in (-17 * 97 / 2)..=(17 * 97 / 2) {
261 let input = from.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
262 let expected = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
263 let mut actual = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
264
265 table.apply(
266 Submatrix::from_1d(&input, 2, 1),
267 SubmatrixMut::from_1d(&mut actual, 4, 1)
268 );
269
270 for j in 0..to.len() {
271 assert!(
272 to.at(j).eq_el(expected.at(j), actual.at(j)) ||
273 to.at(j).eq_el(&to.at(j).add_ref_fst(expected.at(j), to.at(j).int_hom().map(17 * 97)), actual.at(j)) ||
274 to.at(j).eq_el(&to.at(j).sub_ref_fst(expected.at(j), to.at(j).int_hom().map(17 * 97)), actual.at(j))
275 );
276 }
277 }
278}
279
280#[test]
281fn test_rns_base_conversion_small() {
282 let from = vec![Zn::new(3), Zn::new(97)];
283 let to = vec![Zn::new(17)];
284 let table = AlmostExactMatrixBaseConversion::new_with(from.clone(), to.clone(), Global);
285
286 for k in -(97 * 3 / 2)..(97 * 3 / 2) {
287 let mut actual = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
288 table.apply(
289 Submatrix::from_1d(&[from[0].int_hom().map(k), from[1].int_hom().map(k)], 2, 1),
290 SubmatrixMut::from_1d(&mut actual, 1, 1)
291 );
292
293 assert!(
294 to[0].eq_el(&to[0].int_hom().map(k), actual.at(0)) ||
295 to[0].eq_el(&to[0].int_hom().map(k + 97 * 3), actual.at(0)) ||
296 to[0].eq_el(&to[0].int_hom().map(k - 97 * 3), actual.at(0))
297 );
298 }
299}
300
301#[test]
302fn test_rns_base_conversion_not_coprime() {
303 let from = vec![Zn::new(17), Zn::new(97), Zn::new(113)];
304 let to = vec![Zn::new(17), Zn::new(97), Zn::new(113), Zn::new(257)];
305 let table = AlmostExactMatrixBaseConversion::new_with(from.clone(), to.clone(), Global);
306
307 for k in -(17 * 97 * 113 / 4)..=(17 * 97 * 113 / 4) {
308 let x = from.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
309 let y = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
310 let mut actual = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
311
312 table.apply(
313 Submatrix::from_1d(&x, 3, 1),
314 SubmatrixMut::from_1d(&mut actual, 4, 1)
315 );
316
317 for i in 0..y.len() {
318 assert!(to[i].eq_el(&y[i], actual.at(i)));
319 }
320 }
321}
322
323#[test]
324fn test_rns_base_conversion_not_coprime_permuted() {
325 let from = vec![Zn::new(113), Zn::new(17), Zn::new(97)];
326 let to = vec![Zn::new(17), Zn::new(97), Zn::new(113), Zn::new(257)];
327 let table = AlmostExactMatrixBaseConversion::new_with(from.clone(), to.clone(), Global);
328
329 for k in -(17 * 97 * 113 / 4)..=(17 * 97 * 113 / 4) {
330 let x = from.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
331 let y = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
332 let mut actual = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
333
334 table.apply(
335 Submatrix::from_1d(&x, 3, 1),
336 SubmatrixMut::from_1d(&mut actual, 4, 1)
337 );
338
339 for i in 0..y.len() {
340 assert!(to[i].eq_el(&y[i], actual.at(i)));
341 }
342 }
343}
344
345#[test]
346fn test_rns_base_conversion_coprime() {
347 let from = vec![Zn::new(17), Zn::new(97), Zn::new(113)];
348 let to = vec![Zn::new(19), Zn::new(23), Zn::new(257)];
349 let table = AlmostExactMatrixBaseConversion::new_with(from.clone(), to.clone(), Global);
350
351 for k in -(17 * 97 * 113 / 4)..=(17 * 97 * 113 / 4) {
352 let x = from.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
353 let y = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
354 let mut actual = to.iter().map(|Zn| Zn.int_hom().map(k)).collect::<Vec<_>>();
355
356 table.apply(
357 Submatrix::from_1d(&x, 3, 1),
358 SubmatrixMut::from_1d(&mut actual, 3, 1)
359 );
360
361 for i in 0..y.len() {
362 assert!(to[i].eq_el(&y[i], actual.at(i)));
363 }
364 }
365}
366
367#[bench]
368fn bench_rns_base_conversion(bencher: &mut Bencher) {
369 let in_moduli_count = 20;
370 let out_moduli_count = 40;
371 let cols = 1000;
372 let mut primes = ((1 << 30)..).map(|k| (1 << 10) * k + 1).filter(|p| is_prime(&StaticRing::<i64>::RING, p, 10)).map(|p| Zn::new(p as u64));
373 let in_moduli = primes.by_ref().take(in_moduli_count).collect::<Vec<_>>();
374 let out_moduli = primes.take(out_moduli_count).collect::<Vec<_>>();
375 let conv = AlmostExactMatrixBaseConversion::new_with(in_moduli.clone(), out_moduli.clone(), Global);
376
377 let mut rng = oorandom::Rand64::new(1);
378 let mut in_data = (0..(in_moduli_count * cols)).map(|idx| in_moduli[idx / cols].zero()).collect::<Vec<_>>();
379 let mut in_matrix = SubmatrixMut::from_1d(&mut in_data, in_moduli_count, cols);
380 let mut out_data = (0..(out_moduli_count * cols)).map(|idx| out_moduli[idx / cols].zero()).collect::<Vec<_>>();
381 let mut out_matrix = SubmatrixMut::from_1d(&mut out_data, out_moduli_count, cols);
382
383 bencher.iter(|| {
384 for i in 0..in_moduli_count {
385 for j in 0..cols {
386 *in_matrix.at_mut(i, j) = in_moduli[i].random_element(|| rng.rand_u64());
387 }
388 }
389 conv.apply(in_matrix.as_const(), out_matrix.reborrow());
390 for i in 0..out_moduli_count {
391 for j in 0..cols {
392 std::hint::black_box(out_matrix.at(i, j));
393 }
394 }
395 });
396}
397
398#[test]
399fn test_base_conversion_large() {
400 let primes: [i64; 34] = [
401 72057594040066049,
402 288230376150870017,
403 288230376150876161,
404 288230376150878209,
405 288230376150890497,
406 288230376150945793,
407 288230376150956033,
408 288230376151062529,
409 288230376151123969,
410 288230376151130113,
411 288230376151191553,
412 288230376151388161,
413 288230376151422977,
414 288230376151529473,
415 288230376151545857,
416 288230376151554049,
417 288230376151601153,
418 288230376151625729,
419 288230376151683073,
420 288230376151748609,
421 288230376151760897,
422 288230376151779329,
423 288230376151812097,
424 288230376151902209,
425 288230376151951361,
426 288230376151994369,
427 288230376152027137,
428 288230376152061953,
429 288230376152137729,
430 288230376152154113,
431 288230376152156161,
432 288230376152205313,
433 288230376152227841,
434 288230376152340481,
435 ];
436 let in_len = 17;
437 let from = &primes[..in_len];
438 let from_prod = ZZbig.prod(from.iter().map(|p| int_cast(*p, ZZbig, StaticRing::<i64>::RING)));
439 let to = &primes[in_len..];
440 let number = ZZbig.get_ring().parse("156545561910861509258548850310120795193837265771491906959215072510998373539323526014165281634346450795208120921520265422129013635769405993324585707811035953253906720513250161495607960734366886366296007741500531044904559075687514262946086011957808717474666493477109586105297965072817051127737667010", 10).unwrap();
441 assert!(ZZbig.is_lt(&number, &from_prod));
442
443 let from = from.iter().map(|p| Zn::new(*p as u64)).collect::<Vec<_>>();
444 let to = to.iter().map(|p| Zn::new(*p as u64)).collect::<Vec<_>>();
445 let conversion = AlmostExactMatrixBaseConversion::new_with(from, to, Global);
446
447 let input = (0..in_len).map(|i| conversion.input_rings().at(i).coerce(&ZZbig, ZZbig.clone_el(&number))).collect::<Vec<_>>();
448 let expected = (0..(primes.len() - in_len)).map(|i| conversion.output_rings().at(i).coerce(&ZZbig, ZZbig.clone_el(&number))).collect::<Vec<_>>();
449 let mut output = (0..(primes.len() - in_len)).map(|i| conversion.output_rings().at(i).zero()).collect::<Vec<_>>();
450 conversion.apply(Submatrix::from_1d(&input, in_len, 1), SubmatrixMut::from_1d(&mut output, primes.len() - in_len, 1));
451
452 assert!(
453 expected.iter().zip(output.iter()).enumerate().all(|(i, (e, a))| conversion.output_rings().at(i).eq_el(e, a)) ||
454 expected.iter().zip(output.iter()).enumerate().all(|(i, (e, a))| conversion.output_rings().at(i).eq_el(e, &conversion.output_rings().at(i).add_ref_fst(a, conversion.output_rings().at(i).coerce(&ZZbig, ZZbig.clone_el(&from_prod))))) ||
455 expected.iter().zip(output.iter()).enumerate().all(|(i, (e, a))| conversion.output_rings().at(i).eq_el(e, &conversion.output_rings().at(i).sub_ref_fst(a, conversion.output_rings().at(i).coerce(&ZZbig, ZZbig.clone_el(&from_prod)))))
456 );
457}