1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#![allow(clippy::module_name_repetitions)]

/// ⚠️ Computes `result += lhs * rhs` and checks for overflow.
///
/// **Warning.** This function is not part of the stable API.
///
/// Arrays are in little-endian order. All arrays can be arbitrary sized.
///
/// # Algorithm
///
/// Uses the schoolbook multiplication algorithm.
///
/// # Examples
///
/// ```
/// # use ruint::algorithms::mul;
/// let mut result = [0];
/// let overflow = mul(&[3], &[4], &mut result);
/// assert_eq!(overflow, false);
/// assert_eq!(result, [12]);
/// ```
pub fn mul(lhs: &[u64], rhs: &[u64], result: &mut [u64]) -> bool {
    mul_inline(lhs, rhs, result)
}

/// ⚠️ Same as [`mul`], but will always inline.
///
/// **Warning.** This function is not part of the stable API.
#[allow(clippy::inline_always)] // We want to decide at the call site.
#[inline(always)]
#[allow(clippy::cast_possible_truncation)] // Intentional truncation.
pub fn mul_inline(lhs: &[u64], rhs: &[u64], result: &mut [u64]) -> bool {
    let mut overflow = 0;
    for (i, lhs) in lhs.iter().copied().enumerate() {
        let mut result = result.iter_mut().skip(i);
        let mut rhs = rhs.iter().copied();
        let mut carry = 0_u128;
        loop {
            match (result.next(), rhs.next()) {
                // Partial product.
                (Some(result), Some(rhs)) => {
                    carry += u128::from(*result) + u128::from(lhs) * u128::from(rhs);
                    *result = carry as u64;
                    carry >>= 64;
                }
                // Carry propagation.
                (Some(result), None) => {
                    carry += u128::from(*result);
                    *result = carry as u64;
                    carry >>= 64;
                }
                // Excess product.
                (None, Some(rhs)) => {
                    carry += u128::from(lhs) * u128::from(rhs);
                    overflow |= carry as u64;
                    carry >>= 64;
                }
                // Fin.
                (None, None) => {
                    break;
                }
            }
        }
        overflow |= carry as u64;
    }
    overflow != 0
}

#[cfg(test)]
mod tests {
    use super::*;

    fn test_vals(lhs: &[u64], rhs: &[u64], expected: &[u64], expected_overflow: bool) {
        let mut result = vec![0; expected.len()];
        let overflow = mul(lhs, rhs, &mut result);
        assert_eq!(overflow, expected_overflow);
        assert_eq!(result, expected);
    }

    #[test]
    fn test_empty() {
        test_vals(&[], &[], &[], false);
        test_vals(&[], &[1], &[], false);
        test_vals(&[1], &[], &[], false);
        test_vals(&[1], &[1], &[], true);
        test_vals(&[], &[], &[0], false);
        test_vals(&[], &[1], &[0], false);
        test_vals(&[1], &[], &[0], false);
        test_vals(&[1], &[1], &[1], false);
    }
}