Skip to main content

fixed_bigint/fixeduint/
mul_acc_ops_impl.rs

1//! MulAccOps implementation for FixedUInt.
2//!
3//! Provides row-level fused multiply-accumulate operations used by
4//! Montgomery multiplication variants. All limb access is internal to
5//! this module; the public trait surface never exposes raw arrays.
6
7use super::{FixedUInt, MachineWord};
8use crate::const_numtraits::{ConstCarryingAdd, ConstZero};
9use crate::mul_acc_ops::MulAccOps;
10use crate::patch_num_traits::CarryingMul;
11
12impl<T, const N: usize> MulAccOps for FixedUInt<T, N>
13where
14    T: MachineWord + CarryingMul<Output = T> + ConstCarryingAdd,
15{
16    type Word = T;
17
18    fn word_count() -> usize {
19        N
20    }
21
22    fn get_word(&self, i: usize) -> Option<T> {
23        self.array.get(i).copied()
24    }
25
26    fn mul_acc_row(scalar: T, multiplicand: &Self, acc: &mut Self, carry_in: T) -> T {
27        let mut carry = carry_in;
28        let mut j = 0;
29        while j < N {
30            let (lo, hi) =
31                CarryingMul::carrying_mul_add(scalar, multiplicand.array[j], acc.array[j], carry);
32            acc.array[j] = lo;
33            carry = hi;
34            j += 1;
35        }
36        carry
37    }
38
39    fn mul_acc_shift_row(scalar: T, multiplicand: &Self, acc: &mut Self, acc_hi: T) -> T {
40        debug_assert!(N > 0, "MulAccOps requires at least one word");
41        // First word: discarded (zero by construction)
42        let (_, mut carry) = CarryingMul::carrying_mul_add(
43            scalar,
44            multiplicand.array[0],
45            acc.array[0],
46            <T as ConstZero>::zero(),
47        );
48
49        // Remaining words: shift down by one position
50        let mut j = 1;
51        while j < N {
52            let (lo, hi) =
53                CarryingMul::carrying_mul_add(scalar, multiplicand.array[j], acc.array[j], carry);
54            acc.array[j - 1] = lo;
55            carry = hi;
56            j += 1;
57        }
58
59        // Fold acc_hi + carry into acc[N-1]
60        let (sum, overflow) = <T as ConstCarryingAdd>::carrying_add(acc_hi, carry, false);
61        acc.array[N - 1] = sum;
62
63        // Branchless: convert overflow bool to word via carrying_add(0, 0, overflow)
64        let (overflow_word, _) = <T as ConstCarryingAdd>::carrying_add(
65            <T as ConstZero>::zero(),
66            <T as ConstZero>::zero(),
67            overflow,
68        );
69        overflow_word
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    type U8x1 = FixedUInt<u8, 1>;
78    type U16 = FixedUInt<u8, 2>;
79    type U32 = FixedUInt<u8, 4>;
80    type U64x4 = FixedUInt<u64, 4>;
81
82    #[test]
83    fn test_word_access() {
84        let val = U16::from(0x1234u16);
85        assert_eq!(U16::word_count(), 2);
86        assert_eq!(val.get_word(0), Some(0x34u8));
87        assert_eq!(val.get_word(1), Some(0x12u8));
88        assert_eq!(val.get_word(2), None);
89    }
90
91    #[test]
92    fn test_zero_init() {
93        use crate::const_numtraits::ConstZero;
94        let z = <U32 as ConstZero>::zero();
95        assert_eq!(z, U32::from(0u8));
96    }
97
98    #[test]
99    fn test_mul_acc_row() {
100        // Compute acc += 3 * 5 with acc starting at 10
101        // Expected: 10 + 15 = 25, carry = 0
102        let multiplicand = U16::from(5u8);
103        let mut acc = U16::from(10u8);
104        let carry = U16::mul_acc_row(3u8, &multiplicand, &mut acc, 0u8);
105        assert_eq!(acc, U16::from(25u8));
106        assert_eq!(carry, 0u8);
107    }
108
109    #[test]
110    fn test_mul_acc_row_with_overflow() {
111        // Compute acc += 200 * 200 with acc starting at 0
112        // 200 * 200 = 40000 = 0x9C40
113        use crate::const_numtraits::ConstZero;
114        let multiplicand = U16::from(200u16);
115        let mut acc = <U16 as ConstZero>::zero();
116        let carry = U16::mul_acc_row(200u8, &multiplicand, &mut acc, 0u8);
117        assert_eq!(acc, U16::from(40000u16));
118        assert_eq!(carry, 0u8);
119    }
120
121    #[test]
122    fn test_word_count_u64x4() {
123        assert_eq!(U64x4::word_count(), 4);
124    }
125
126    #[test]
127    fn test_mul_acc_shift_row_no_overflow() {
128        // scalar=2, multiplicand=0x0003, acc=0x0006, acc_hi=0
129        // Word-by-word (u8 limbs, N=2):
130        //   j=0: carrying_mul_add(2, 3, 6, 0) = 2*3+6+0 = 12 = (0x0C, 0x00)
131        //        -> discard lo=0x0C, carry=0x00
132        //   j=1: carrying_mul_add(2, 0, 0, 0) = 0 = (0x00, 0x00)
133        //        -> acc[0] = 0x00, carry=0x00
134        //   fold: acc_hi(0) + carry(0) = 0, no overflow -> acc[1] = 0x00
135        // Result: acc = 0x0000, return 0
136        let multiplicand = U16::from(3u8);
137        let mut acc = U16::from(6u8);
138        let overflow = U16::mul_acc_shift_row(2u8, &multiplicand, &mut acc, 0u8);
139        assert_eq!(acc, U16::from(0u16));
140        assert_eq!(overflow, 0u8);
141    }
142
143    #[test]
144    fn test_mul_acc_shift_row_with_carry() {
145        // scalar=0xFF, multiplicand=0x00FF, acc=0x00FF, acc_hi=0x10
146        // Word-by-word (u8 limbs, N=2):
147        //   j=0: carrying_mul_add(0xFF, 0xFF, 0xFF, 0) = 255*255+255+0 = 65280 = (0x00, 0xFF)
148        //        -> discard lo=0x00, carry=0xFF
149        //   j=1: carrying_mul_add(0xFF, 0x00, 0x00, 0xFF) = 0+0+0xFF = 255 = (0xFF, 0x00)
150        //        -> acc[0] = 0xFF, carry=0x00
151        //   fold: acc_hi(0x10) + carry(0x00) = 0x10, no overflow -> acc[1] = 0x10
152        // Result: acc = 0x10FF, return 0
153        let multiplicand = U16::from(0x00FFu16);
154        let mut acc = U16::from(0x00FFu16);
155        let overflow = U16::mul_acc_shift_row(0xFFu8, &multiplicand, &mut acc, 0x10u8);
156        assert_eq!(acc, U16::from(0x10FFu16));
157        assert_eq!(overflow, 0u8);
158    }
159
160    #[test]
161    fn test_mul_acc_shift_row_fold_overflow() {
162        // Force the fold step (acc_hi + carry) to overflow.
163        // scalar=0xFF, multiplicand=0xFFFF, acc=0xFFFF, acc_hi=0xFF
164        // Word-by-word (u8 limbs, N=2):
165        //   j=0: carrying_mul_add(0xFF, 0xFF, 0xFF, 0) = 255*255+255 = 65280 = (0x00, 0xFF)
166        //        -> discard lo, carry=0xFF
167        //   j=1: carrying_mul_add(0xFF, 0xFF, 0xFF, 0xFF) = 255*255+255+255 = 65535 = (0xFF, 0xFF)
168        //        -> acc[0] = 0xFF, carry=0xFF
169        //   fold: acc_hi(0xFF) + carry(0xFF) = 0x1FE -> sum=0xFE, overflow=true
170        //        -> acc[1] = 0xFE, return 1
171        let multiplicand = U16::from(0xFFFFu16);
172        let mut acc = U16::from(0xFFFFu16);
173        let overflow = U16::mul_acc_shift_row(0xFFu8, &multiplicand, &mut acc, 0xFFu8);
174        assert_eq!(acc, U16::from(0xFEFFu16));
175        assert_eq!(overflow, 1u8);
176    }
177
178    #[test]
179    fn test_mul_acc_shift_row_n1() {
180        // N=1: single word. The shift discards the only word and folds acc_hi.
181        // scalar=3, multiplicand=0x05, acc=0x07, acc_hi=0x02
182        // j=0: carrying_mul_add(3, 5, 7, 0) = 15+7 = 22 = (0x16, 0x00)
183        //      -> discard lo=0x16, carry=0x00
184        // No j=1..N-1 loop iterations.
185        // fold: acc_hi(2) + carry(0) = 2, no overflow -> acc[0] = 0x02
186        // Result: acc = 0x02, return 0
187        let multiplicand = U8x1::from(5u8);
188        let mut acc = U8x1::from(7u8);
189        let overflow = U8x1::mul_acc_shift_row(3u8, &multiplicand, &mut acc, 2u8);
190        assert_eq!(acc, U8x1::from(2u8));
191        assert_eq!(overflow, 0u8);
192    }
193}