dashu_int/modular/
mul.rs

1use crate::{
2    add,
3    arch::word::Word,
4    cmp, div,
5    div_const::ConstLargeDivisor,
6    error::panic_different_rings,
7    helper_macros::debug_assert_zero,
8    memory::{self, Memory, MemoryAllocation},
9    modular::repr::{Reduced, ReducedRepr},
10    mul,
11    primitive::{extend_word, locate_top_word_plus_one, split_dword},
12    shift, sqr,
13};
14use alloc::alloc::Layout;
15use core::ops::{Deref, Mul, MulAssign};
16use num_modular::Reducer;
17
18use super::repr::{ReducedDword, ReducedLarge, ReducedWord};
19
20impl<'a> Mul<Reduced<'a>> for Reduced<'a> {
21    type Output = Reduced<'a>;
22
23    #[inline]
24    fn mul(self, rhs: Reduced<'a>) -> Reduced<'a> {
25        self.mul(&rhs)
26    }
27}
28
29impl<'a> Mul<&Reduced<'a>> for Reduced<'a> {
30    type Output = Reduced<'a>;
31
32    #[inline]
33    fn mul(mut self, rhs: &Reduced<'a>) -> Reduced<'a> {
34        self.mul_assign(rhs);
35        self
36    }
37}
38
39impl<'a> Mul<Reduced<'a>> for &Reduced<'a> {
40    type Output = Reduced<'a>;
41
42    #[inline]
43    fn mul(self, rhs: Reduced<'a>) -> Reduced<'a> {
44        rhs.mul(self)
45    }
46}
47
48impl<'a> Mul<&Reduced<'a>> for &Reduced<'a> {
49    type Output = Reduced<'a>;
50
51    #[inline]
52    fn mul(self, rhs: &Reduced<'a>) -> Reduced<'a> {
53        self.clone().mul(rhs)
54    }
55}
56
57impl<'a> MulAssign<Reduced<'a>> for Reduced<'a> {
58    #[inline]
59    fn mul_assign(&mut self, rhs: Reduced<'a>) {
60        self.mul_assign(&rhs)
61    }
62}
63
64impl<'a> MulAssign<&Reduced<'a>> for Reduced<'a> {
65    #[inline]
66    fn mul_assign(&mut self, rhs: &Reduced<'a>) {
67        match (self.repr_mut(), rhs.repr()) {
68            (ReducedRepr::Single(raw0, ring), ReducedRepr::Single(raw1, ring1)) => {
69                Reduced::check_same_ring_single(ring, ring1);
70                ring.0.mul_in_place(&mut raw0.0, &raw1.0)
71            }
72            (ReducedRepr::Double(raw0, ring), ReducedRepr::Double(raw1, ring1)) => {
73                Reduced::check_same_ring_double(ring, ring1);
74                ring.0.mul_in_place(&mut raw0.0, &raw1.0)
75            }
76            (ReducedRepr::Large(raw0, ring), ReducedRepr::Large(raw1, ring1)) => {
77                Reduced::check_same_ring_large(ring, ring1);
78                let memory_requirement = mul_memory_requirement(ring);
79                let mut allocation = MemoryAllocation::new(memory_requirement);
80                mul_in_place(ring, raw0, raw1, &mut allocation.memory());
81            }
82            _ => panic_different_rings(),
83        }
84    }
85}
86
87impl<'a> Reduced<'a> {
88    /// Calculate target^2 mod m in reduced form
89    ///
90    /// # Examples
91    ///
92    /// ```
93    /// # use dashu_int::{fast_div::ConstDivisor, UBig};
94    /// let p = UBig::from(0x1234u16);
95    /// let ring = ConstDivisor::new(p.clone());
96    /// let a = ring.reduce(4000);
97    /// assert_eq!(a.sqr(), ring.reduce(4000 * 4000));
98    /// ```
99    pub fn sqr(&self) -> Self {
100        match self.repr() {
101            ReducedRepr::Single(raw, ring) => {
102                Reduced::from_single(ReducedWord(ring.0.sqr(raw.0)), ring)
103            }
104            ReducedRepr::Double(raw, ring) => {
105                Reduced::from_double(ReducedDword(ring.0.sqr(raw.0)), ring)
106            }
107            ReducedRepr::Large(raw, ring) => {
108                let mut result = raw.clone();
109                let memory_requirement = mul_memory_requirement(ring);
110                let mut allocation = MemoryAllocation::new(memory_requirement);
111                sqr_in_place(ring, &mut result, &mut allocation.memory());
112                Reduced::from_large(result, ring)
113            }
114        }
115    }
116}
117
118pub(crate) fn mul_memory_requirement(ring: &ConstLargeDivisor) -> Layout {
119    let n = ring.normalized_divisor.len();
120    memory::add_layout(
121        memory::array_layout::<Word>(2 * n),
122        memory::max_layout(
123            mul::memory_requirement_exact(2 * n, n),
124            div::memory_requirement_exact(2 * n, n),
125        ),
126    )
127}
128
129/// Returns a * b allocated in memory.
130pub(crate) fn mul_normalized<'a>(
131    ring: &ConstLargeDivisor,
132    a: &[Word],
133    b: &[Word],
134    memory: &'a mut Memory,
135) -> &'a [Word] {
136    let modulus = ring.normalized_divisor.deref();
137    let n = modulus.len();
138    debug_assert!(a.len() == n && b.len() == n);
139
140    // trim the leading zeros in a, b
141    let na = locate_top_word_plus_one(a);
142    let nb = locate_top_word_plus_one(b);
143
144    // product = a * b
145    let (product, mut memory) = memory.allocate_slice_fill::<Word>(n.max(na + nb), 0);
146    if na | nb == 0 {
147        return product;
148    } else if na == 1 && nb == 1 {
149        let (a0, b0) = (extend_word(a[0]), extend_word(b[0]));
150        let (lo, hi) = split_dword(a0 * b0);
151        product[0] = lo;
152        product[1] = hi;
153    } else {
154        mul::multiply(&mut product[..na + nb], &a[..na], &b[..nb], &mut memory);
155    }
156
157    // return (product >> shift) % normalized_modulus
158    debug_assert_zero!(shift::shr_in_place(product, ring.shift));
159    if na + nb > n {
160        let _overflow = div::div_rem_in_place(product, modulus, ring.fast_div_top, &mut memory);
161        &product[..n]
162    } else {
163        if cmp::cmp_same_len(product, modulus).is_ge() {
164            debug_assert_zero!(add::sub_same_len_in_place(product, modulus));
165        }
166        product
167    }
168}
169
170/// lhs *= rhs
171pub(crate) fn mul_in_place(
172    ring: &ConstLargeDivisor,
173    lhs: &mut ReducedLarge,
174    rhs: &ReducedLarge,
175    memory: &mut Memory,
176) {
177    if lhs.0 == rhs.0 {
178        // shortcut to squaring
179        let prod = sqr_normalized(ring, &lhs.0, memory);
180        lhs.0.copy_from_slice(prod)
181    } else {
182        let prod = mul_normalized(ring, &lhs.0, &rhs.0, memory);
183        lhs.0.copy_from_slice(prod)
184    }
185}
186
187/// Returns a^2 allocated in memory.
188pub(crate) fn sqr_normalized<'a>(
189    ring: &ConstLargeDivisor,
190    a: &[Word],
191    memory: &'a mut Memory,
192) -> &'a [Word] {
193    let modulus = ring.normalized_divisor.deref();
194    let n = modulus.len();
195    debug_assert!(a.len() == n);
196
197    // trim the leading zeros in a
198    let na = locate_top_word_plus_one(a);
199
200    // product = a * a
201    let (product, mut memory) = memory.allocate_slice_fill::<Word>(n.max(na * 2), 0);
202    if na == 0 {
203        return product;
204    } else if na == 1 {
205        let a0 = extend_word(a[0]);
206        let (lo, hi) = split_dword(a0 * a0);
207        product[0] = lo;
208        product[1] = hi;
209    } else {
210        sqr::sqr(&mut product[..na * 2], &a[..na], &mut memory);
211    }
212
213    // return (product >> shift) % normalized_modulus
214    debug_assert_zero!(shift::shr_in_place(product, ring.shift));
215    if na * 2 > n {
216        let _overflow = div::div_rem_in_place(product, modulus, ring.fast_div_top, &mut memory);
217        &product[..n]
218    } else {
219        if cmp::cmp_same_len(product, modulus).is_ge() {
220            debug_assert_zero!(add::sub_same_len_in_place(product, modulus));
221        }
222        product
223    }
224}
225
226/// raw = raw^2
227pub(crate) fn sqr_in_place(ring: &ConstLargeDivisor, raw: &mut ReducedLarge, memory: &mut Memory) {
228    let prod = sqr_normalized(ring, &raw.0, memory);
229    raw.0.copy_from_slice(prod)
230}