Skip to main content

fixed_bigint/fixeduint/
extended_precision_impl.rs

1// Copyright 2021 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Extended precision arithmetic operations for FixedUInt.
16//!
17//! These operations expose carry/borrow inputs and outputs, and widening
18//! multiplication, useful for implementing arbitrary-precision arithmetic.
19
20use super::{FixedUInt, MachineWord};
21use crate::const_numtraits::{
22    ConstBorrowingSub, ConstCarryingAdd, ConstCarryingMul, ConstWideningMul,
23};
24use crate::machineword::ConstMachineWord;
25
26c0nst::c0nst! {
27    /// Add with carry input, returns sum and carry output.
28    /// Uses ConstCarryingAdd on limb types for consistency.
29    c0nst fn add_with_carry<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd, const N: usize>(
30        a: &[T; N],
31        b: &[T; N],
32        carry_in: bool,
33    ) -> ([T; N], bool) {
34        let mut result = [T::zero(); N];
35        let mut carry = carry_in;
36        let mut i = 0usize;
37        while i < N {
38            let (sum, c) = ConstCarryingAdd::carrying_add(a[i], b[i], carry);
39            result[i] = sum;
40            carry = c;
41            i += 1;
42        }
43        (result, carry)
44    }
45
46    /// Subtract with borrow input, returns difference and borrow output.
47    /// Uses ConstBorrowingSub on limb types for consistency.
48    c0nst fn sub_with_borrow<T: [c0nst] ConstMachineWord + [c0nst] ConstBorrowingSub, const N: usize>(
49        a: &[T; N],
50        b: &[T; N],
51        borrow_in: bool,
52    ) -> ([T; N], bool) {
53        let mut result = [T::zero(); N];
54        let mut borrow = borrow_in;
55        let mut i = 0usize;
56        while i < N {
57            let (diff, b) = ConstBorrowingSub::borrowing_sub(a[i], b[i], borrow);
58            result[i] = diff;
59            borrow = b;
60            i += 1;
61        }
62        (result, borrow)
63    }
64
65    impl<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + MachineWord, const N: usize> c0nst ConstCarryingAdd for FixedUInt<T, N> {
66        fn carrying_add(self, rhs: Self, carry: bool) -> (Self, bool) {
67            let (array, carry_out) = add_with_carry(&self.array, &rhs.array, carry);
68            (Self { array }, carry_out)
69        }
70    }
71
72    impl<T: [c0nst] ConstMachineWord + [c0nst] ConstBorrowingSub + MachineWord, const N: usize> c0nst ConstBorrowingSub for FixedUInt<T, N> {
73        fn borrowing_sub(self, rhs: Self, borrow: bool) -> (Self, bool) {
74            let (array, borrow_out) = sub_with_borrow(&self.array, &rhs.array, borrow);
75            (Self { array }, borrow_out)
76        }
77    }
78
79    /// Helper: get value at position in 2N-word result (split into low/high).
80    c0nst fn get_at<T: [c0nst] ConstMachineWord, const N: usize>(
81        lo: &[T; N], hi: &[T; N], pos: usize
82    ) -> T {
83        if pos < N { lo[pos] } else if pos < 2 * N { hi[pos - N] } else { T::zero() }
84    }
85
86    /// Helper: set value at position in 2N-word result (split into low/high).
87    c0nst fn set_at<T: [c0nst] ConstMachineWord, const N: usize>(
88        lo: &mut [T; N], hi: &mut [T; N], pos: usize, val: T
89    ) {
90        if pos < N { lo[pos] = val; } else if pos < 2 * N { hi[pos - N] = val; }
91    }
92
93    impl<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize> c0nst ConstWideningMul for FixedUInt<T, N> {
94        fn widening_mul(self, rhs: Self) -> (Self, Self) {
95            // Schoolbook multiplication: for each (i,j), add the 2-word product a[i]*b[j]
96            // to result[i+j : i+j+2], propagating any carry upward.
97            let mut result_low = [T::zero(); N];
98            let mut result_high = [T::zero(); N];
99
100            let mut i = 0usize;
101            while i < N {
102                let mut j = 0usize;
103                while j < N {
104                    let pos = i + j;
105                    let (mul_lo, mul_hi) = ConstWideningMul::widening_mul(self.array[i], rhs.array[j]);
106
107                    // Add the 2-word product (mul_hi, mul_lo) at position pos.
108                    // Step 1: add mul_lo at pos
109                    let cur0 = get_at(&result_low, &result_high, pos);
110                    let (sum0, c0) = ConstCarryingAdd::carrying_add(cur0, mul_lo, false);
111                    set_at(&mut result_low, &mut result_high, pos, sum0);
112
113                    // Step 2: add mul_hi + carry at pos+1
114                    let cur1 = get_at(&result_low, &result_high, pos + 1);
115                    let (sum1, c1) = ConstCarryingAdd::carrying_add(cur1, mul_hi, c0);
116                    set_at(&mut result_low, &mut result_high, pos + 1, sum1);
117
118                    // Step 3: propagate any remaining carry
119                    let mut carry = c1;
120                    let mut p = pos + 2;
121                    while carry && p < 2 * N {
122                        let cur = get_at(&result_low, &result_high, p);
123                        let (sum, c) = ConstCarryingAdd::carrying_add(cur, T::zero(), true);
124                        set_at(&mut result_low, &mut result_high, p, sum);
125                        carry = c;
126                        p += 1;
127                    }
128
129                    j += 1;
130                }
131                i += 1;
132            }
133
134            (Self { array: result_low }, Self { array: result_high })
135        }
136    }
137
138    impl<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize> c0nst ConstCarryingMul for FixedUInt<T, N> {
139        fn carrying_mul(self, rhs: Self, carry: Self) -> (Self, Self) {
140            // widening_mul + add carry to result
141            let (lo, hi) = ConstWideningMul::widening_mul(self, rhs);
142
143            // Add carry to lo, propagate overflow to hi using carry bit directly
144            let (lo2, c) = add_with_carry(&lo.array, &carry.array, false);
145
146            // Pass the carry bit directly instead of constructing a temporary array
147            let zeros = [T::zero(); N];
148            let (hi2, _) = add_with_carry(&hi.array, &zeros, c);
149
150            (Self { array: lo2 }, Self { array: hi2 })
151        }
152
153        fn carrying_mul_add(self, rhs: Self, addend: Self, carry: Self) -> (Self, Self) {
154            // widening_mul + add addend + add carry
155            let (lo, hi) = ConstWideningMul::widening_mul(self, rhs);
156
157            // Add carry to lo
158            let (lo2, c1) = add_with_carry(&lo.array, &carry.array, false);
159
160            // Add addend to lo2
161            let (lo3, c2) = add_with_carry(&lo2, &addend.array, false);
162
163            // Add carry bits to hi separately (both c1 and c2 can be true = need to add 2)
164            let zeros = [T::zero(); N];
165            let (hi2, _) = add_with_carry(&hi.array, &zeros, c1);
166            let (hi3, _) = add_with_carry(&hi2, &zeros, c2);
167
168            (Self { array: lo3 }, Self { array: hi3 })
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    type U16 = FixedUInt<u8, 2>;
178    type U32 = FixedUInt<u8, 4>;
179
180    c0nst::c0nst! {
181        pub c0nst fn const_carrying_add<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize>(
182            a: FixedUInt<T, N>,
183            b: FixedUInt<T, N>,
184            carry: bool,
185        ) -> (FixedUInt<T, N>, bool) {
186            ConstCarryingAdd::carrying_add(a, b, carry)
187        }
188
189        pub c0nst fn const_borrowing_sub<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize>(
190            a: FixedUInt<T, N>,
191            b: FixedUInt<T, N>,
192            borrow: bool,
193        ) -> (FixedUInt<T, N>, bool) {
194            ConstBorrowingSub::borrowing_sub(a, b, borrow)
195        }
196
197        pub c0nst fn const_widening_mul<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize>(
198            a: FixedUInt<T, N>,
199            b: FixedUInt<T, N>,
200        ) -> (FixedUInt<T, N>, FixedUInt<T, N>) {
201            ConstWideningMul::widening_mul(a, b)
202        }
203
204        pub c0nst fn const_carrying_mul<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize>(
205            a: FixedUInt<T, N>,
206            b: FixedUInt<T, N>,
207            carry: FixedUInt<T, N>,
208        ) -> (FixedUInt<T, N>, FixedUInt<T, N>) {
209            ConstCarryingMul::carrying_mul(a, b, carry)
210        }
211
212        pub c0nst fn const_carrying_mul_add<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize>(
213            a: FixedUInt<T, N>,
214            b: FixedUInt<T, N>,
215            addend: FixedUInt<T, N>,
216            carry: FixedUInt<T, N>,
217        ) -> (FixedUInt<T, N>, FixedUInt<T, N>) {
218            ConstCarryingMul::carrying_mul_add(a, b, addend, carry)
219        }
220    }
221
222    #[test]
223    fn test_carrying_add_no_carry() {
224        let a = U16::from(100u8);
225        let b = U16::from(50u8);
226
227        // Without carry in
228        let (sum, carry_out) = const_carrying_add(a, b, false);
229        assert_eq!(sum, U16::from(150u8));
230        assert!(!carry_out);
231
232        // With carry in
233        let (sum, carry_out) = const_carrying_add(a, b, true);
234        assert_eq!(sum, U16::from(151u8));
235        assert!(!carry_out);
236    }
237
238    #[test]
239    fn test_carrying_add_with_overflow() {
240        let max = U16::from(0xFFFFu16);
241        let one = U16::from(1u8);
242
243        // max + 0 with carry = max + 1 = 0 with carry out
244        let (sum, carry_out) = const_carrying_add(max, U16::from(0u8), true);
245        assert_eq!(sum, U16::from(0u8));
246        assert!(carry_out);
247
248        // max + 1 = 0 with carry out
249        let (sum, carry_out) = const_carrying_add(max, one, false);
250        assert_eq!(sum, U16::from(0u8));
251        assert!(carry_out);
252
253        // max + max = 0xFFFE with carry out
254        let (sum, carry_out) = const_carrying_add(max, max, false);
255        assert_eq!(sum, U16::from(0xFFFEu16));
256        assert!(carry_out);
257    }
258
259    #[test]
260    fn test_borrowing_sub_no_borrow() {
261        let a = U16::from(150u8);
262        let b = U16::from(50u8);
263
264        // Without borrow in
265        let (diff, borrow_out) = const_borrowing_sub(a, b, false);
266        assert_eq!(diff, U16::from(100u8));
267        assert!(!borrow_out);
268
269        // With borrow in
270        let (diff, borrow_out) = const_borrowing_sub(a, b, true);
271        assert_eq!(diff, U16::from(99u8));
272        assert!(!borrow_out);
273    }
274
275    #[test]
276    fn test_borrowing_sub_with_underflow() {
277        let zero = U16::from(0u8);
278        let one = U16::from(1u8);
279
280        // 0 - 1 = 0xFFFF with borrow out
281        let (diff, borrow_out) = const_borrowing_sub(zero, one, false);
282        assert_eq!(diff, U16::from(0xFFFFu16));
283        assert!(borrow_out);
284
285        // 0 - 0 with borrow = 0xFFFF with borrow out
286        let (diff, borrow_out) = const_borrowing_sub(zero, zero, true);
287        assert_eq!(diff, U16::from(0xFFFFu16));
288        assert!(borrow_out);
289
290        // 1 - 1 with borrow = 0xFFFF with borrow out
291        let (diff, borrow_out) = const_borrowing_sub(one, one, true);
292        assert_eq!(diff, U16::from(0xFFFFu16));
293        assert!(borrow_out);
294    }
295
296    #[test]
297    fn test_widening_mul() {
298        // 100 * 100 = 10000 (fits in 16 bits, high = 0)
299        let a = U16::from(100u8);
300        let (lo, hi) = const_widening_mul(a, a);
301        assert_eq!(lo, U16::from(10000u16));
302        assert_eq!(hi, U16::from(0u8));
303
304        // 256 * 256 = 65536 = 0x10000 (low = 0, high = 1)
305        let b = U16::from(256u16);
306        let (lo, hi) = const_widening_mul(b, b);
307        assert_eq!(lo, U16::from(0u8));
308        assert_eq!(hi, U16::from(1u8));
309
310        // 0xFFFF * 0xFFFF = 0xFFFE0001
311        let max = U16::from(0xFFFFu16);
312        let (lo, hi) = const_widening_mul(max, max);
313        assert_eq!(lo, U16::from(0x0001u16)); // low 16 bits of 0xFFFE0001
314        assert_eq!(hi, U16::from(0xFFFEu16)); // high 16 bits of 0xFFFE0001
315    }
316
317    #[test]
318    fn test_widening_mul_larger() {
319        // Test with 32-bit type (U32 = FixedUInt<u8, 4>)
320        let a = U32::from(0x10000u32); // 2^16
321        let b = U32::from(0x10000u32); // 2^16
322        let (lo, hi) = const_widening_mul(a, b);
323        // 2^16 * 2^16 = 2^32 = 0x100000000
324        // low 32 bits = 0, high 32 bits = 1
325        assert_eq!(lo, U32::from(0u8));
326        assert_eq!(hi, U32::from(1u8));
327    }
328
329    #[test]
330    fn test_carrying_mul() {
331        let a = U16::from(100u8);
332        let b = U16::from(100u8);
333        let carry = U16::from(5u8);
334
335        // 100 * 100 + 5 = 10005
336        let (lo, hi) = const_carrying_mul(a, b, carry);
337        assert_eq!(lo, U16::from(10005u16));
338        assert_eq!(hi, U16::from(0u8));
339
340        // With larger carry that causes overflow in low part
341        let max = U16::from(0xFFFFu16);
342        let one = U16::from(1u8);
343        // 1 * 1 + 0xFFFF = 0x10000 = (0, 1)
344        let (lo, hi) = const_carrying_mul(one, one, max);
345        assert_eq!(lo, U16::from(0u8));
346        assert_eq!(hi, U16::from(1u8));
347    }
348
349    #[test]
350    fn test_carrying_mul_add() {
351        let a = U16::from(100u8);
352        let b = U16::from(100u8);
353        let addend = U16::from(10u8);
354        let carry = U16::from(5u8);
355
356        // 100 * 100 + 10 + 5 = 10015
357        let (lo, hi) = const_carrying_mul_add(a, b, addend, carry);
358        assert_eq!(lo, U16::from(10015u16));
359        assert_eq!(hi, U16::from(0u8));
360    }
361
362    #[test]
363    fn test_carrying_mul_add_double_overflow() {
364        // Test case where both addend and carry cause overflow
365        let max = U16::from(0xFFFFu16);
366        let one = U16::from(1u8);
367
368        // 1 * 1 + 0xFFFF + 0xFFFF = 1 + 0x1FFFE = 0x1FFFF = (0xFFFF, 1)
369        let (lo, hi) = const_carrying_mul_add(one, one, max, max);
370        assert_eq!(lo, U16::from(0xFFFFu16));
371        assert_eq!(hi, U16::from(1u8));
372    }
373
374    #[test]
375    fn test_const_context() {
376        #[cfg(feature = "nightly")]
377        {
378            const A: U16 = FixedUInt { array: [100, 0] };
379            const B: U16 = FixedUInt { array: [50, 0] };
380
381            // Test carrying_add in const context
382            const ADD_RESULT: (U16, bool) = const_carrying_add(A, B, false);
383            assert_eq!(ADD_RESULT.0, U16::from(150u8));
384            assert!(!ADD_RESULT.1);
385
386            const ADD_WITH_CARRY: (U16, bool) = const_carrying_add(A, B, true);
387            assert_eq!(ADD_WITH_CARRY.0, U16::from(151u8));
388
389            // Test borrowing_sub in const context
390            const SUB_RESULT: (U16, bool) = const_borrowing_sub(A, B, false);
391            assert_eq!(SUB_RESULT.0, U16::from(50u8));
392            assert!(!SUB_RESULT.1);
393
394            // Test widening_mul in const context
395            const C: U16 = FixedUInt { array: [0, 1] }; // 256
396            const MUL_RESULT: (U16, U16) = const_widening_mul(C, C);
397            assert_eq!(MUL_RESULT.0, U16::from(0u8)); // 256*256 = 65536, low = 0
398            assert_eq!(MUL_RESULT.1, U16::from(1u8)); // high = 1
399        }
400    }
401
402    /// Polymorphic test: verify widening_mul produces identical results across
403    /// different word layouts for the same values.
404    #[test]
405    fn test_widening_mul_polymorphic() {
406        // Generic test function following crate pattern
407        fn test_widening<T>(a: T, b: T, expected_lo: T, expected_hi: T)
408        where
409            T: ConstWideningMul
410                + ConstCarryingAdd
411                + ConstBorrowingSub
412                + Eq
413                + core::fmt::Debug
414                + Copy,
415        {
416            let (lo, hi) = ConstWideningMul::widening_mul(a, b);
417            assert_eq!(lo, expected_lo, "lo mismatch");
418            assert_eq!(hi, expected_hi, "hi mismatch");
419        }
420
421        // Test 1: 256 * 256 = 65536
422        // As u16 (FixedUInt<u8, 2>): 16-bit overflow, 256*256 = 0x10000 = (lo=0, hi=1)
423        test_widening(
424            U16::from(256u16),
425            U16::from(256u16),
426            U16::from(0u16),
427            U16::from(1u16),
428        );
429
430        // As u32 (FixedUInt<u8, 4>): fits in 32 bits, so lo=65536, hi=0
431        test_widening(
432            U32::from(256u32),
433            U32::from(256u32),
434            U32::from(65536u32),
435            U32::from(0u32),
436        );
437
438        // Test 2: 0xFFFF * 0xFFFF = 0xFFFE0001
439        test_widening(
440            U16::from(0xFFFFu16),
441            U16::from(0xFFFFu16),
442            U16::from(0x0001u16),
443            U16::from(0xFFFEu16),
444        );
445
446        // Test 3: 0xFFFFFFFF * 2 = 0x1_FFFFFFFE (tests carry across word boundary)
447        test_widening(
448            U32::from(0xFFFFFFFFu32),
449            U32::from(2u32),
450            U32::from(0xFFFFFFFEu32),
451            U32::from(1u32),
452        );
453    }
454
455    /// Polymorphic test for carrying_mul_add with edge cases.
456    #[test]
457    fn test_carrying_mul_add_polymorphic() {
458        fn test_cma<T>(a: T, b: T, addend: T, carry: T, expected_lo: T, expected_hi: T)
459        where
460            T: ConstCarryingMul + Eq + core::fmt::Debug + Copy,
461        {
462            let (lo, hi) = ConstCarryingMul::carrying_mul_add(a, b, addend, carry);
463            assert_eq!(lo, expected_lo, "lo mismatch");
464            assert_eq!(hi, expected_hi, "hi mismatch");
465        }
466
467        // Test: max * max + max + max = 0xFFFF * 0xFFFF + 0xFFFF + 0xFFFF
468        //     = 0xFFFE0001 + 0x1FFFE = 0xFFFFFFFF
469        // lo = 0xFFFF, hi = 0xFFFF
470        let max16 = U16::from(0xFFFFu16);
471        test_cma(
472            max16,
473            max16,
474            max16,
475            max16,
476            U16::from(0xFFFFu16),
477            U16::from(0xFFFFu16),
478        );
479
480        // Same test with different layout (U32 = FixedUInt<u8, 4>)
481        let max32 = U32::from(0xFFFFFFFFu32);
482        let zero32 = U32::from(0u32);
483        // max * 1 + 0 + max = max + max = 2*max = 0x1_FFFFFFFE
484        test_cma(
485            max32,
486            U32::from(1u32),
487            zero32,
488            max32,
489            U32::from(0xFFFFFFFEu32),
490            U32::from(1u32),
491        );
492    }
493}