Skip to main content

malachite_nz/natural/arithmetic/
mod_power_of_2_square.rs

1// Copyright © 2026 Mikhail Hogrefe
2//
3// Uses code adopted from the GNU MP Library.
4//
5//      Copyright © 1991-1994, 1996, 1997, 2000-2005, 2008, 2009, 2010, 2011, 2012, 2015 Free
6//      Software Foundation, Inc.
7//
8// This file is part of Malachite.
9//
10// Malachite is free software: you can redistribute it and/or modify it under the terms of the GNU
11// Lesser General Public License (LGPL) as published by the Free Software Foundation; either version
12// 3 of the License, or (at your option) any later version. See <https://www.gnu.org/licenses/>.
13
14use crate::natural::InnerNatural::{Large, Small};
15use crate::natural::arithmetic::add::limbs_slice_add_same_length_in_place_left;
16use crate::natural::arithmetic::add_mul::limbs_slice_add_mul_limb_same_length_in_place_left;
17use crate::natural::arithmetic::mod_power_of_2::limbs_vec_mod_power_of_2_in_place;
18use crate::natural::arithmetic::mul::limb::limbs_mul_limb_to_out;
19use crate::natural::arithmetic::mul::limbs_mul_greater_to_out_basecase;
20use crate::natural::arithmetic::mul::mul_low::{
21    limbs_mul_low_same_length, limbs_mul_low_same_length_basecase,
22};
23use crate::natural::arithmetic::mul::toom::{TUNE_PROGRAM_BUILD, WANT_FAT_BINARY};
24use crate::natural::arithmetic::shl::{limbs_shl_to_out, limbs_slice_shl_in_place};
25use crate::natural::arithmetic::square::{
26    limbs_square, limbs_square_diagonal, limbs_square_to_out, limbs_square_to_out_basecase,
27    limbs_square_to_out_scratch_len,
28};
29use crate::natural::{Natural, bit_to_limb_count_ceiling};
30use crate::platform::{
31    DoubleLimb, Limb, MULLO_BASECASE_THRESHOLD, MULLO_DC_THRESHOLD, SQR_TOOM2_THRESHOLD,
32    SQR_TOOM3_THRESHOLD, SQR_TOOM4_THRESHOLD, SQR_TOOM8_THRESHOLD, SQRLO_DC_THRESHOLD,
33};
34use alloc::vec::Vec;
35use malachite_base::num::arithmetic::traits::{
36    ModPowerOf2Square, ModPowerOf2SquareAssign, Parity, Square, WrappingSquare,
37};
38use malachite_base::num::basic::integers::PrimitiveInt;
39use malachite_base::num::basic::traits::Zero;
40use malachite_base::num::conversion::traits::SplitInHalf;
41use malachite_base::num::logic::traits::SignificantBits;
42
43// # Worst-case complexity
44// $T(n) = O(n)$
45//
46// $M(n) = O(1)$
47//
48// where $T$ is time, $M$ is additional memory, and $n$ is `xs.len()`.
49//
50// This is equivalent to `MPN_SQRLO_DIAGONAL` from `mpn/generic/sqrlo_basecase.c`, GMP 6.2.1.
51fn limbs_square_low_diagonal(out: &mut [Limb], xs: &[Limb]) {
52    let n = xs.len();
53    let half_n = n >> 1;
54    limbs_square_diagonal(out, &xs[..half_n]);
55    if n.odd() {
56        out[n - 1] = xs[half_n].wrapping_square();
57    }
58}
59
60// # Worst-case complexity
61// $T(n) = O(n)$
62//
63// $M(n) = O(1)$
64//
65// where $T$ is time, $M$ is additional memory, and $n$ is `xs.len()`.
66//
67// This is equivalent to `MPN_SQRLO_DIAG_ADDLSH1` from `mpn/generic/sqrlo_basecase.c`, GMP 6.2.1.
68pub_test! {limbs_square_diagonal_shl_add(out: &mut [Limb], scratch: &mut [Limb], xs: &[Limb]) {
69    let n = xs.len();
70    assert_eq!(scratch.len(), n - 1);
71    assert_eq!(out.len(), n);
72    limbs_square_low_diagonal(out, xs);
73    limbs_slice_shl_in_place(scratch, 1);
74    limbs_slice_add_same_length_in_place_left(&mut out[1..], scratch);
75}}
76
77// TODO tune
78#[cfg(feature = "test_build")]
79pub const SQRLO_DC_THRESHOLD_LIMIT: usize = 500;
80
81#[cfg(not(feature = "test_build"))]
82const SQRLO_DC_THRESHOLD_LIMIT: usize = 500;
83
84// TODO tune
85const SQRLO_BASECASE_ALLOC: usize = if SQRLO_DC_THRESHOLD_LIMIT < 2 {
86    1
87} else {
88    SQRLO_DC_THRESHOLD_LIMIT - 1
89};
90
91// # Worst-case complexity
92// $T(n) = O(n^2)$
93//
94// $M(n) = O(n)$
95//
96// where $T$ is time, $M$ is additional memory, and $n$ is `xs.len()`.
97//
98// This is equivalent to `mpn_sqrlo_basecase` from `mpn/generic/sqrlo_basecase.c`, GMP 6.2.1.
99pub_test! {limbs_square_low_basecase(out: &mut [Limb], xs: &[Limb]) {
100    let n = xs.len();
101    let out = &mut out[..n];
102    assert_ne!(n, 0);
103    let xs_0 = xs[0];
104    match n {
105        1 => out[0] = xs_0.wrapping_square(),
106        2 => {
107            let p_hi;
108            (p_hi, out[0]) = DoubleLimb::from(xs_0).square().split_in_half();
109            out[1] = (xs_0.wrapping_mul(xs[1]) << 1).wrapping_add(p_hi);
110        }
111        _ => {
112            let scratch = &mut [0; SQRLO_BASECASE_ALLOC];
113            // must fit n - 1 limbs in scratch
114            assert!(n <= SQRLO_DC_THRESHOLD_LIMIT);
115            let scratch = &mut scratch[..n - 1];
116            limbs_mul_limb_to_out::<DoubleLimb, Limb>(scratch, &xs[1..], xs_0);
117            for i in 1.. {
118                let two_i = i << 1;
119                if two_i >= n - 1 {
120                    break;
121                }
122                limbs_slice_add_mul_limb_same_length_in_place_left(
123                    &mut scratch[two_i..],
124                    &xs[i + 1..n - i],
125                    xs[i],
126                );
127            }
128            limbs_square_diagonal_shl_add(out, scratch, xs);
129        }
130    }
131}}
132
133// TODO tune
134const SQRLO_BASECASE_THRESHOLD: usize = 8;
135
136// TODO tune
137/// This is equivalent to `MAYBE_range_basecase` from `mpn/generic/sqrlo.c`, GMP 6.2.1. Investigate
138/// changes from 6.1.2?
139const MAYBE_RANGE_BASECASE_MOD_SQUARE: bool = TUNE_PROGRAM_BUILD
140    || WANT_FAT_BINARY
141    || (if SQRLO_DC_THRESHOLD == 0 {
142        SQRLO_BASECASE_THRESHOLD
143    } else {
144        SQRLO_DC_THRESHOLD
145    }) < SQR_TOOM2_THRESHOLD * 36 / (36 - 11);
146
147// TODO tune
148/// This is equivalent to `MAYBE_range_toom22` from `mpn/generic/sqrlo.c`, GMP 6.2.1. Investigate
149/// changes from 6.1.2?
150const MAYBE_RANGE_TOOM22_MOD_SQUARE: bool = TUNE_PROGRAM_BUILD
151    || WANT_FAT_BINARY
152    || (if SQRLO_DC_THRESHOLD == 0 {
153        SQRLO_BASECASE_THRESHOLD
154    } else {
155        SQRLO_DC_THRESHOLD
156    }) < SQR_TOOM3_THRESHOLD * 36 / (36 - 11);
157
158// # Worst-case complexity
159// Constant time and additional memory.
160//
161// This is equivalent to `mpn_sqrlo_itch` from `mpn/generic/sqrlo.c`, GMP 6.2.1. Investigate changes
162// from 6.1.2?
163pub_const_test! {limbs_square_low_scratch_len(len: usize) -> usize {
164    len << 1
165}}
166
167// Requires a scratch space of 2 * `xs.len()` limbs at `scratch`.
168//
169// # Worst-case complexity
170// $T(n) = O(n \log n \log\log n)$
171//
172// $M(n) = O(n \log n)$
173//
174// where $T$ is time, $M$ is additional memory, and $n$ is `xs.len()`.
175//
176// This is equivalent to `mpn_dc_sqrlo` from `mpn/generic/sqrlo.c`, GMP 6.2.1. Investigate changes
177// from 6.1.2?
178pub_test! {
179#[allow(clippy::absurd_extreme_comparisons)]
180limbs_square_low_divide_and_conquer(
181    out: &mut [Limb],
182    xs: &[Limb],
183    scratch: &mut [Limb]
184) {
185    let len = xs.len();
186    let out = &mut out[..len];
187    assert!(len > 1);
188    // We need a fractional approximation of the value 0 < a <= 1/2, giving the minimum in the
189    // function k = (1 - a) ^ e / (1 - 2 * a ^ e).
190    let len_small = if MAYBE_RANGE_BASECASE_MOD_SQUARE && len < SQR_TOOM2_THRESHOLD * 36 / (36 - 11)
191    {
192        len >> 1
193    } else if MAYBE_RANGE_TOOM22_MOD_SQUARE && len < SQR_TOOM3_THRESHOLD * 36 / (36 - 11) {
194        len * 11 / 36 // n1 ~= n*(1-.694...)
195    } else if len < SQR_TOOM4_THRESHOLD * 40 / (40 - 9) {
196        len * 9 / 40 // n1 ~= n*(1-.775...)
197    } else if len < SQR_TOOM8_THRESHOLD * 10 / 9 {
198        len * 7 / 39 // n1 ~= n*(1-.821...)
199    } else {
200        len / 10 // n1 ~= n*(1-.899...) [TOOM88]
201    };
202    let len_big = len - len_small;
203    // x0 ^ 2
204    let (xs_lo, xs_hi) = xs.split_at(len_big);
205    let mut square_scratch = vec![0; limbs_square_to_out_scratch_len(xs_lo.len())];
206    limbs_square_to_out(scratch, xs_lo, &mut square_scratch);
207    let xs_lo = &xs_lo[..len_small];
208    let (out_lo, out_hi) = out.split_at_mut(len_big);
209    let (scratch_lo, scratch_hi) = scratch.split_at_mut(len);
210    out_lo.copy_from_slice(&scratch_lo[..len_big]);
211    // x1 * x0 * 2^(n2 Limb::WIDTH)
212    if len_small < MULLO_BASECASE_THRESHOLD {
213        limbs_mul_greater_to_out_basecase(scratch_hi, xs_hi, xs_lo);
214    } else if len_small < MULLO_DC_THRESHOLD {
215        limbs_mul_low_same_length_basecase(scratch_hi, xs_hi, xs_lo);
216    } else {
217        limbs_mul_low_same_length(scratch_hi, xs_hi, xs_lo);
218    }
219    limbs_shl_to_out(out_hi, &scratch_hi[..len_small], 1);
220    limbs_slice_add_same_length_in_place_left(out_hi, &scratch_lo[len_big..]);
221}}
222
223// TODO tune
224
225// must be at least SQRLO_BASECASE_THRESHOLD
226const SQRLO_BASECASE_THRESHOLD_LIMIT: usize = 8;
227
228// TODO tune
229const SQRLO_SQR_THRESHOLD: usize = 6440;
230
231// TODO tune
232const SQR_BASECASE_ALLOC: usize = if SQRLO_BASECASE_THRESHOLD_LIMIT == 0 {
233    1
234} else {
235    SQRLO_BASECASE_THRESHOLD_LIMIT << 1
236};
237
238// Square an n-limb number and return the lowest n limbs of the result.
239//
240// # Worst-case complexity
241// $T(n) = O(n \log n \log\log n)$
242//
243// $M(n) = O(n \log n)$
244//
245// where $T$ is time, $M$ is additional memory, and $n$ is `xs.len()`.
246//
247// This is equivalent to `mpn_sqrlo` from `mpn/generic/sqrlo.c`, GMP 6.2.1. Investigate changes from
248// 6.1.2?
249pub_crate_test! {limbs_square_low(out: &mut [Limb], xs: &[Limb]) {
250    assert!(SQRLO_BASECASE_THRESHOLD_LIMIT >= SQRLO_BASECASE_THRESHOLD);
251    let len = xs.len();
252    assert_ne!(len, 0);
253    let out = &mut out[..len];
254    if len < SQRLO_BASECASE_THRESHOLD {
255        // Allocate workspace of fixed size on stack: fast!
256        let scratch = &mut [0; SQR_BASECASE_ALLOC];
257        limbs_square_to_out_basecase(scratch, xs);
258        out.copy_from_slice(&scratch[..len]);
259    } else if len < SQRLO_DC_THRESHOLD {
260        limbs_square_low_basecase(out, xs);
261    } else {
262        let mut scratch = vec![0; limbs_square_low_scratch_len(len)];
263        if len < SQRLO_SQR_THRESHOLD {
264            limbs_square_low_divide_and_conquer(out, xs, &mut scratch);
265        } else {
266            // For really large operands, use plain mpn_mul_n but throw away upper n limbs of the
267            // result.
268            let mut square_scratch = vec![0; limbs_square_to_out_scratch_len(xs.len())];
269            limbs_square_to_out(&mut scratch, xs, &mut square_scratch);
270            out.copy_from_slice(&scratch[..len]);
271        }
272    }
273}}
274
275// Interpreting a `Vec<Limb>` as the limbs (in ascending order) of a `Natural`, returns a `Vec` of
276// the limbs of the square of the `Natural` mod `2 ^ pow`. Assumes the input is already reduced mod
277// `2 ^ pow`. The input `Vec` may be mutated. The input may not be empty or have trailing zeros.
278//
279// # Worst-case complexity
280// $T(n) = O(n \log n \log\log n)$
281//
282// $M(n) = O(n \log n)$
283//
284// where $T$ is time, $M$ is additional memory, and $n$ is `pow`.
285//
286// # Panics
287// Panics if the input is empty. May panic if the input has trailing zeros.
288pub_crate_test! {limbs_mod_power_of_2_square(xs: &mut Vec<Limb>, pow: u64) -> Vec<Limb> {
289    let len = xs.len();
290    assert_ne!(len, 0);
291    let max_len = bit_to_limb_count_ceiling(pow);
292    if max_len > len << 1 {
293        return limbs_square(xs);
294    }
295    // Should really be max_len / sqrt(2); 0.75 * max_len is close enough
296    let limit = max_len.checked_mul(3).unwrap() >> 2;
297    let mut square = if len >= limit {
298        if len != max_len {
299            xs.resize(max_len, 0);
300        }
301        let mut square_limbs = vec![0; max_len];
302        limbs_square_low(&mut square_limbs, xs);
303        square_limbs
304    } else {
305        limbs_square(xs)
306    };
307    limbs_vec_mod_power_of_2_in_place(&mut square, pow);
308    square
309}}
310
311// Interpreting a slice of `Limb` as the limbs (in ascending order) of a `Natural`, returns a `Vec`
312// of the limbs of the square of the `Natural` mod `2 ^ pow`. Assumes the input is already reduced
313// mod `2 ^ pow`. The input may not be empty or have trailing zeros.
314//
315// # Worst-case complexity
316// $T(n) = O(n \log n \log\log n)$
317//
318// $M(n) = O(n \log n)$
319//
320// where $T$ is time, $M$ is additional memory, and $n$ is `pow`.
321//
322// # Panics
323// Panics if the input is empty. May panic if the input has trailing zeros.
324pub_crate_test! {limbs_mod_power_of_2_square_ref(xs: &[Limb], pow: u64) -> Vec<Limb> {
325    let len = xs.len();
326    assert_ne!(len, 0);
327    let max_len = bit_to_limb_count_ceiling(pow);
328    if max_len > len << 1 {
329        return limbs_square(xs);
330    }
331    // Should really be max_len / sqrt(2); 0.75 * max_len is close enough
332    let limit = max_len.checked_mul(3).unwrap() >> 2;
333    let mut square = if len >= limit {
334        let mut xs_adjusted_vec;
335        let xs_adjusted = if len == max_len {
336            xs
337        } else {
338            xs_adjusted_vec = vec![0; max_len];
339            xs_adjusted_vec[..len].copy_from_slice(xs);
340            &xs_adjusted_vec
341        };
342        let mut square = vec![0; max_len];
343        limbs_square_low(&mut square, xs_adjusted);
344        square
345    } else {
346        limbs_square(xs)
347    };
348    limbs_vec_mod_power_of_2_in_place(&mut square, pow);
349    square
350}}
351
352impl ModPowerOf2Square for Natural {
353    type Output = Self;
354
355    /// Squares a [`Natural`] modulo $2^k$. The input must be already reduced modulo $2^k$. The
356    /// [`Natural`] is taken by value.
357    ///
358    /// $f(x, k) = y$, where $x, y < 2^k$ and $x^2 \equiv y \mod 2^k$.
359    ///
360    /// # Worst-case complexity
361    /// $T(n) = O(n \log n \log\log n)$
362    ///
363    /// $M(n) = O(n \log n)$
364    ///
365    /// where $T$ is time, $M$ is additional memory, and $n$ is `pow`.
366    ///
367    /// # Panics
368    /// Panics if `self` is greater than or equal to $2^k$.
369    ///
370    /// # Examples
371    /// ```
372    /// use core::str::FromStr;
373    /// use malachite_base::num::arithmetic::traits::ModPowerOf2Square;
374    /// use malachite_base::num::basic::traits::Zero;
375    /// use malachite_nz::natural::Natural;
376    ///
377    /// assert_eq!(Natural::ZERO.mod_power_of_2_square(2), 0);
378    /// assert_eq!(Natural::from(5u32).mod_power_of_2_square(3), 1);
379    /// assert_eq!(
380    ///     Natural::from_str("12345678987654321")
381    ///         .unwrap()
382    ///         .mod_power_of_2_square(64)
383    ///         .to_string(),
384    ///     "16556040056090124897"
385    /// );
386    /// ```
387    #[inline]
388    fn mod_power_of_2_square(mut self, pow: u64) -> Self {
389        self.mod_power_of_2_square_assign(pow);
390        self
391    }
392}
393
394impl ModPowerOf2Square for &Natural {
395    type Output = Natural;
396
397    /// Squares a [`Natural`] modulo $2^k$. The input must be already reduced modulo $2^k$. The
398    /// [`Natural`] is taken by reference.
399    ///
400    /// $f(x, k) = y$, where $x, y < 2^k$ and $x^2 \equiv y \mod 2^k$.
401    ///
402    /// # Worst-case complexity
403    /// $T(n) = O(n \log n \log\log n)$
404    ///
405    /// $M(n) = O(n \log n)$
406    ///
407    /// where $T$ is time, $M$ is additional memory, and $n$ is `pow`.
408    ///
409    /// # Panics
410    /// Panics if `self` is greater than or equal to $2^k$.
411    ///
412    /// # Examples
413    /// ```
414    /// use core::str::FromStr;
415    /// use malachite_base::num::arithmetic::traits::ModPowerOf2Square;
416    /// use malachite_base::num::basic::traits::Zero;
417    /// use malachite_nz::natural::Natural;
418    ///
419    /// assert_eq!((&Natural::ZERO).mod_power_of_2_square(2), 0);
420    /// assert_eq!((&Natural::from(5u32)).mod_power_of_2_square(3), 1);
421    /// assert_eq!(
422    ///     (&Natural::from_str("12345678987654321").unwrap())
423    ///         .mod_power_of_2_square(64)
424    ///         .to_string(),
425    ///     "16556040056090124897"
426    /// );
427    /// ```
428    #[inline]
429    fn mod_power_of_2_square(self, pow: u64) -> Natural {
430        assert!(
431            self.significant_bits() <= pow,
432            "self must be reduced mod 2^pow, but {self} >= 2^{pow}"
433        );
434        match self {
435            &Natural::ZERO => Natural::ZERO,
436            Natural(Small(x)) if pow <= Limb::WIDTH => Natural(Small(x.mod_power_of_2_square(pow))),
437            Natural(Small(x)) => {
438                let x_double = DoubleLimb::from(*x);
439                Natural::from(if pow <= Limb::WIDTH << 1 {
440                    x_double.mod_power_of_2_square(pow)
441                } else {
442                    x_double.square()
443                })
444            }
445            Natural(Large(xs)) => {
446                Natural::from_owned_limbs_asc(limbs_mod_power_of_2_square_ref(xs, pow))
447            }
448        }
449    }
450}
451
452impl ModPowerOf2SquareAssign for Natural {
453    /// Squares a [`Natural`] modulo $2^k$, in place. The input must be already reduced modulo
454    /// $2^k$.
455    ///
456    /// $x \gets y$, where $x, y < 2^k$ and $x^2 \equiv y \mod 2^k$.
457    ///
458    /// # Worst-case complexity
459    /// $T(n) = O(n \log n \log\log n)$
460    ///
461    /// $M(n) = O(n \log n)$
462    ///
463    /// where $T$ is time, $M$ is additional memory, and $n$ is `pow`.
464    ///
465    /// # Panics
466    /// Panics if `self` is greater than or equal to $2^k$.
467    ///
468    /// # Examples
469    /// ```
470    /// use core::str::FromStr;
471    /// use malachite_base::num::arithmetic::traits::ModPowerOf2SquareAssign;
472    /// use malachite_base::num::basic::traits::Zero;
473    /// use malachite_nz::natural::Natural;
474    ///
475    /// let mut n = Natural::ZERO;
476    /// n.mod_power_of_2_square_assign(2);
477    /// assert_eq!(n, 0);
478    ///
479    /// let mut n = Natural::from(5u32);
480    /// n.mod_power_of_2_square_assign(3);
481    /// assert_eq!(n, 1);
482    ///
483    /// let mut n = Natural::from_str("12345678987654321").unwrap();
484    /// n.mod_power_of_2_square_assign(64);
485    /// assert_eq!(n.to_string(), "16556040056090124897");
486    /// ```
487    #[inline]
488    fn mod_power_of_2_square_assign(&mut self, pow: u64) {
489        assert!(
490            self.significant_bits() <= pow,
491            "self must be reduced mod 2^pow, but {self} >= 2^{pow}"
492        );
493        match self {
494            &mut Self::ZERO => {}
495            Self(Small(x)) if pow <= Limb::WIDTH => x.mod_power_of_2_square_assign(pow),
496            Self(Small(x)) => {
497                let x_double = DoubleLimb::from(*x);
498                *self = Self::from(if pow <= Limb::WIDTH << 1 {
499                    x_double.mod_power_of_2_square(pow)
500                } else {
501                    x_double.square()
502                });
503            }
504            Self(Large(xs)) => {
505                *xs = limbs_mod_power_of_2_square(xs, pow);
506                self.trim();
507            }
508        }
509    }
510}