Skip to main content

fixed_bigint/fixeduint/
mul_acc_ops_impl.rs

1//! MulAccOps implementation for FixedUInt.
2//!
3//! Provides row-level fused multiply-accumulate operations used by
4//! Montgomery multiplication variants. All limb access is internal to
5//! this module; the public trait surface never exposes raw arrays.
6
7use super::{FixedUInt, MachineWord};
8use crate::const_numtraits::{ConstCarryingAdd, ConstZero};
9use crate::mul_acc_ops::MulAccOps;
10use crate::patch_num_traits::CarryingMul;
11use crate::personality::{Ct, Nct};
12
13macro_rules! mul_acc_ops_common {
14    () => {
15        type Word = T;
16
17        fn word_count() -> usize {
18            N
19        }
20
21        fn mul_acc_row(scalar: T, multiplicand: &Self, acc: &mut Self, carry_in: T) -> T {
22            let mut carry = carry_in;
23            let mut j = 0;
24            while j < N {
25                let (lo, hi) = CarryingMul::carrying_mul_add(
26                    scalar,
27                    multiplicand.array[j],
28                    acc.array[j],
29                    carry,
30                );
31                acc.array[j] = lo;
32                carry = hi;
33                j += 1;
34            }
35            carry
36        }
37
38        fn mul_acc_shift_row(scalar: T, multiplicand: &Self, acc: &mut Self, acc_hi: T) -> T {
39            debug_assert!(N > 0, "MulAccOps requires at least one word");
40            // First word: discarded (zero by construction)
41            let (_, mut carry) = CarryingMul::carrying_mul_add(
42                scalar,
43                multiplicand.array[0],
44                acc.array[0],
45                <T as ConstZero>::zero(),
46            );
47
48            // Remaining words: shift down by one position
49            let mut j = 1;
50            while j < N {
51                let (lo, hi) = CarryingMul::carrying_mul_add(
52                    scalar,
53                    multiplicand.array[j],
54                    acc.array[j],
55                    carry,
56                );
57                acc.array[j - 1] = lo;
58                carry = hi;
59                j += 1;
60            }
61
62            // Fold acc_hi + carry into acc[N-1]
63            let (sum, overflow) = <T as ConstCarryingAdd>::carrying_add(acc_hi, carry, false);
64            acc.array[N - 1] = sum;
65
66            // Branchless: convert overflow bool to word via carrying_add(0, 0, overflow)
67            let (overflow_word, _) = <T as ConstCarryingAdd>::carrying_add(
68                <T as ConstZero>::zero(),
69                <T as ConstZero>::zero(),
70                overflow,
71            );
72            overflow_word
73        }
74    };
75}
76
77impl<T, const N: usize> MulAccOps for FixedUInt<T, N, Nct>
78where
79    T: MachineWord + CarryingMul<Output = T> + ConstCarryingAdd,
80{
81    type GetWordOutput = Option<T>;
82
83    fn get_word(&self, i: usize) -> Option<T> {
84        self.array.get(i).copied()
85    }
86
87    mul_acc_ops_common!();
88}
89
90impl<T, const N: usize> MulAccOps for FixedUInt<T, N, Ct>
91where
92    T: MachineWord + CarryingMul<Output = T> + ConstCarryingAdd + subtle::ConditionallySelectable,
93{
94    type GetWordOutput = subtle::CtOption<T>;
95
96    fn get_word(&self, i: usize) -> subtle::CtOption<T> {
97        use subtle::{Choice, CtOption};
98        let mut selected = <T as ConstZero>::zero();
99        let mut j = 0;
100        while j < N {
101            let is_target = Choice::from((j == i) as u8);
102            selected = <T as subtle::ConditionallySelectable>::conditional_select(
103                &selected,
104                &self.array[j],
105                is_target,
106            );
107            j += 1;
108        }
109        CtOption::new(selected, Choice::from((i < N) as u8))
110    }
111
112    mul_acc_ops_common!();
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    type U8x1 = FixedUInt<u8, 1>;
120    type U16 = FixedUInt<u8, 2>;
121    type U32 = FixedUInt<u8, 4>;
122    type U64x4 = FixedUInt<u64, 4>;
123
124    #[test]
125    fn test_word_access() {
126        let val = U16::from(0x1234u16);
127        assert_eq!(U16::word_count(), 2);
128        assert_eq!(val.get_word(0), Some(0x34u8));
129        assert_eq!(val.get_word(1), Some(0x12u8));
130        assert_eq!(val.get_word(2), None);
131    }
132
133    #[test]
134    fn test_zero_init() {
135        use crate::const_numtraits::ConstZero;
136        let z = <U32 as ConstZero>::zero();
137        assert_eq!(z, U32::from(0u8));
138    }
139
140    #[test]
141    fn test_mul_acc_row() {
142        // Compute acc += 3 * 5 with acc starting at 10
143        // Expected: 10 + 15 = 25, carry = 0
144        let multiplicand = U16::from(5u8);
145        let mut acc = U16::from(10u8);
146        let carry = U16::mul_acc_row(3u8, &multiplicand, &mut acc, 0u8);
147        assert_eq!(acc, U16::from(25u8));
148        assert_eq!(carry, 0u8);
149    }
150
151    #[test]
152    fn test_mul_acc_row_with_overflow() {
153        // Compute acc += 200 * 200 with acc starting at 0
154        // 200 * 200 = 40000 = 0x9C40
155        use crate::const_numtraits::ConstZero;
156        let multiplicand = U16::from(200u16);
157        let mut acc = <U16 as ConstZero>::zero();
158        let carry = U16::mul_acc_row(200u8, &multiplicand, &mut acc, 0u8);
159        assert_eq!(acc, U16::from(40000u16));
160        assert_eq!(carry, 0u8);
161    }
162
163    #[test]
164    fn test_word_count_u64x4() {
165        assert_eq!(U64x4::word_count(), 4);
166    }
167
168    #[test]
169    fn test_mul_acc_shift_row_no_overflow() {
170        // scalar=2, multiplicand=0x0003, acc=0x0006, acc_hi=0
171        // Word-by-word (u8 limbs, N=2):
172        //   j=0: carrying_mul_add(2, 3, 6, 0) = 2*3+6+0 = 12 = (0x0C, 0x00)
173        //        -> discard lo=0x0C, carry=0x00
174        //   j=1: carrying_mul_add(2, 0, 0, 0) = 0 = (0x00, 0x00)
175        //        -> acc[0] = 0x00, carry=0x00
176        //   fold: acc_hi(0) + carry(0) = 0, no overflow -> acc[1] = 0x00
177        // Result: acc = 0x0000, return 0
178        let multiplicand = U16::from(3u8);
179        let mut acc = U16::from(6u8);
180        let overflow = U16::mul_acc_shift_row(2u8, &multiplicand, &mut acc, 0u8);
181        assert_eq!(acc, U16::from(0u16));
182        assert_eq!(overflow, 0u8);
183    }
184
185    #[test]
186    fn test_mul_acc_shift_row_with_carry() {
187        // scalar=0xFF, multiplicand=0x00FF, acc=0x00FF, acc_hi=0x10
188        // Word-by-word (u8 limbs, N=2):
189        //   j=0: carrying_mul_add(0xFF, 0xFF, 0xFF, 0) = 255*255+255+0 = 65280 = (0x00, 0xFF)
190        //        -> discard lo=0x00, carry=0xFF
191        //   j=1: carrying_mul_add(0xFF, 0x00, 0x00, 0xFF) = 0+0+0xFF = 255 = (0xFF, 0x00)
192        //        -> acc[0] = 0xFF, carry=0x00
193        //   fold: acc_hi(0x10) + carry(0x00) = 0x10, no overflow -> acc[1] = 0x10
194        // Result: acc = 0x10FF, return 0
195        let multiplicand = U16::from(0x00FFu16);
196        let mut acc = U16::from(0x00FFu16);
197        let overflow = U16::mul_acc_shift_row(0xFFu8, &multiplicand, &mut acc, 0x10u8);
198        assert_eq!(acc, U16::from(0x10FFu16));
199        assert_eq!(overflow, 0u8);
200    }
201
202    #[test]
203    fn test_mul_acc_shift_row_fold_overflow() {
204        // Force the fold step (acc_hi + carry) to overflow.
205        // scalar=0xFF, multiplicand=0xFFFF, acc=0xFFFF, acc_hi=0xFF
206        // Word-by-word (u8 limbs, N=2):
207        //   j=0: carrying_mul_add(0xFF, 0xFF, 0xFF, 0) = 255*255+255 = 65280 = (0x00, 0xFF)
208        //        -> discard lo, carry=0xFF
209        //   j=1: carrying_mul_add(0xFF, 0xFF, 0xFF, 0xFF) = 255*255+255+255 = 65535 = (0xFF, 0xFF)
210        //        -> acc[0] = 0xFF, carry=0xFF
211        //   fold: acc_hi(0xFF) + carry(0xFF) = 0x1FE -> sum=0xFE, overflow=true
212        //        -> acc[1] = 0xFE, return 1
213        let multiplicand = U16::from(0xFFFFu16);
214        let mut acc = U16::from(0xFFFFu16);
215        let overflow = U16::mul_acc_shift_row(0xFFu8, &multiplicand, &mut acc, 0xFFu8);
216        assert_eq!(acc, U16::from(0xFEFFu16));
217        assert_eq!(overflow, 1u8);
218    }
219
220    #[test]
221    fn test_mul_acc_shift_row_n1() {
222        // N=1: single word. The shift discards the only word and folds acc_hi.
223        // scalar=3, multiplicand=0x05, acc=0x07, acc_hi=0x02
224        // j=0: carrying_mul_add(3, 5, 7, 0) = 15+7 = 22 = (0x16, 0x00)
225        //      -> discard lo=0x16, carry=0x00
226        // No j=1..N-1 loop iterations.
227        // fold: acc_hi(2) + carry(0) = 2, no overflow -> acc[0] = 0x02
228        // Result: acc = 0x02, return 0
229        let multiplicand = U8x1::from(5u8);
230        let mut acc = U8x1::from(7u8);
231        let overflow = U8x1::mul_acc_shift_row(3u8, &multiplicand, &mut acc, 2u8);
232        assert_eq!(acc, U8x1::from(2u8));
233        assert_eq!(overflow, 0u8);
234    }
235}