1use crate::{
2 add,
3 arch::word::Word,
4 cmp, div,
5 div_const::ConstLargeDivisor,
6 error::panic_different_rings,
7 helper_macros::debug_assert_zero,
8 memory::{self, Memory, MemoryAllocation},
9 modular::repr::{Reduced, ReducedRepr},
10 mul,
11 primitive::{extend_word, locate_top_word_plus_one, split_dword},
12 shift, sqr,
13};
14use alloc::alloc::Layout;
15use core::ops::{Deref, Mul, MulAssign};
16use num_modular::Reducer;
17
18use super::repr::{ReducedDword, ReducedLarge, ReducedWord};
19
20impl<'a> Mul<Reduced<'a>> for Reduced<'a> {
21 type Output = Reduced<'a>;
22
23 #[inline]
24 fn mul(self, rhs: Reduced<'a>) -> Reduced<'a> {
25 self.mul(&rhs)
26 }
27}
28
29impl<'a> Mul<&Reduced<'a>> for Reduced<'a> {
30 type Output = Reduced<'a>;
31
32 #[inline]
33 fn mul(mut self, rhs: &Reduced<'a>) -> Reduced<'a> {
34 self.mul_assign(rhs);
35 self
36 }
37}
38
39impl<'a> Mul<Reduced<'a>> for &Reduced<'a> {
40 type Output = Reduced<'a>;
41
42 #[inline]
43 fn mul(self, rhs: Reduced<'a>) -> Reduced<'a> {
44 rhs.mul(self)
45 }
46}
47
48impl<'a> Mul<&Reduced<'a>> for &Reduced<'a> {
49 type Output = Reduced<'a>;
50
51 #[inline]
52 fn mul(self, rhs: &Reduced<'a>) -> Reduced<'a> {
53 self.clone().mul(rhs)
54 }
55}
56
57impl<'a> MulAssign<Reduced<'a>> for Reduced<'a> {
58 #[inline]
59 fn mul_assign(&mut self, rhs: Reduced<'a>) {
60 self.mul_assign(&rhs)
61 }
62}
63
64impl<'a> MulAssign<&Reduced<'a>> for Reduced<'a> {
65 #[inline]
66 fn mul_assign(&mut self, rhs: &Reduced<'a>) {
67 match (self.repr_mut(), rhs.repr()) {
68 (ReducedRepr::Single(raw0, ring), ReducedRepr::Single(raw1, ring1)) => {
69 Reduced::check_same_ring_single(ring, ring1);
70 ring.0.mul_in_place(&mut raw0.0, &raw1.0)
71 }
72 (ReducedRepr::Double(raw0, ring), ReducedRepr::Double(raw1, ring1)) => {
73 Reduced::check_same_ring_double(ring, ring1);
74 ring.0.mul_in_place(&mut raw0.0, &raw1.0)
75 }
76 (ReducedRepr::Large(raw0, ring), ReducedRepr::Large(raw1, ring1)) => {
77 Reduced::check_same_ring_large(ring, ring1);
78 let memory_requirement = mul_memory_requirement(ring);
79 let mut allocation = MemoryAllocation::new(memory_requirement);
80 mul_in_place(ring, raw0, raw1, &mut allocation.memory());
81 }
82 _ => panic_different_rings(),
83 }
84 }
85}
86
87impl<'a> Reduced<'a> {
88 pub fn sqr(&self) -> Self {
100 match self.repr() {
101 ReducedRepr::Single(raw, ring) => {
102 Reduced::from_single(ReducedWord(ring.0.sqr(raw.0)), ring)
103 }
104 ReducedRepr::Double(raw, ring) => {
105 Reduced::from_double(ReducedDword(ring.0.sqr(raw.0)), ring)
106 }
107 ReducedRepr::Large(raw, ring) => {
108 let mut result = raw.clone();
109 let memory_requirement = mul_memory_requirement(ring);
110 let mut allocation = MemoryAllocation::new(memory_requirement);
111 sqr_in_place(ring, &mut result, &mut allocation.memory());
112 Reduced::from_large(result, ring)
113 }
114 }
115 }
116}
117
118pub(crate) fn mul_memory_requirement(ring: &ConstLargeDivisor) -> Layout {
119 let n = ring.normalized_divisor.len();
120 memory::add_layout(
121 memory::array_layout::<Word>(2 * n),
122 memory::max_layout(
123 mul::memory_requirement_exact(2 * n, n),
124 div::memory_requirement_exact(2 * n, n),
125 ),
126 )
127}
128
129pub(crate) fn mul_normalized<'a>(
131 ring: &ConstLargeDivisor,
132 a: &[Word],
133 b: &[Word],
134 memory: &'a mut Memory,
135) -> &'a [Word] {
136 let modulus = ring.normalized_divisor.deref();
137 let n = modulus.len();
138 debug_assert!(a.len() == n && b.len() == n);
139
140 let na = locate_top_word_plus_one(a);
142 let nb = locate_top_word_plus_one(b);
143
144 let (product, mut memory) = memory.allocate_slice_fill::<Word>(n.max(na + nb), 0);
146 if na | nb == 0 {
147 return product;
148 } else if na == 1 && nb == 1 {
149 let (a0, b0) = (extend_word(a[0]), extend_word(b[0]));
150 let (lo, hi) = split_dword(a0 * b0);
151 product[0] = lo;
152 product[1] = hi;
153 } else {
154 mul::multiply(&mut product[..na + nb], &a[..na], &b[..nb], &mut memory);
155 }
156
157 debug_assert_zero!(shift::shr_in_place(product, ring.shift));
159 if na + nb > n {
160 let _overflow = div::div_rem_in_place(product, modulus, ring.fast_div_top, &mut memory);
161 &product[..n]
162 } else {
163 if cmp::cmp_same_len(product, modulus).is_ge() {
164 debug_assert_zero!(add::sub_same_len_in_place(product, modulus));
165 }
166 product
167 }
168}
169
170pub(crate) fn mul_in_place(
172 ring: &ConstLargeDivisor,
173 lhs: &mut ReducedLarge,
174 rhs: &ReducedLarge,
175 memory: &mut Memory,
176) {
177 if lhs.0 == rhs.0 {
178 let prod = sqr_normalized(ring, &lhs.0, memory);
180 lhs.0.copy_from_slice(prod)
181 } else {
182 let prod = mul_normalized(ring, &lhs.0, &rhs.0, memory);
183 lhs.0.copy_from_slice(prod)
184 }
185}
186
187pub(crate) fn sqr_normalized<'a>(
189 ring: &ConstLargeDivisor,
190 a: &[Word],
191 memory: &'a mut Memory,
192) -> &'a [Word] {
193 let modulus = ring.normalized_divisor.deref();
194 let n = modulus.len();
195 debug_assert!(a.len() == n);
196
197 let na = locate_top_word_plus_one(a);
199
200 let (product, mut memory) = memory.allocate_slice_fill::<Word>(n.max(na * 2), 0);
202 if na == 0 {
203 return product;
204 } else if na == 1 {
205 let a0 = extend_word(a[0]);
206 let (lo, hi) = split_dword(a0 * a0);
207 product[0] = lo;
208 product[1] = hi;
209 } else {
210 sqr::sqr(&mut product[..na * 2], &a[..na], &mut memory);
211 }
212
213 debug_assert_zero!(shift::shr_in_place(product, ring.shift));
215 if na * 2 > n {
216 let _overflow = div::div_rem_in_place(product, modulus, ring.fast_div_top, &mut memory);
217 &product[..n]
218 } else {
219 if cmp::cmp_same_len(product, modulus).is_ge() {
220 debug_assert_zero!(add::sub_same_len_in_place(product, modulus));
221 }
222 product
223 }
224}
225
226pub(crate) fn sqr_in_place(ring: &ConstLargeDivisor, raw: &mut ReducedLarge, memory: &mut Memory) {
228 let prod = sqr_normalized(ring, &raw.0, memory);
229 raw.0.copy_from_slice(prod)
230}