1use super::{FixedUInt, MachineWord};
21use crate::const_numtraits::{
22 ConstBorrowingSub, ConstCarryingAdd, ConstCarryingMul, ConstWideningMul,
23};
24use crate::machineword::ConstMachineWord;
25
26c0nst::c0nst! {
27 c0nst fn add_with_carry<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd, const N: usize>(
30 a: &[T; N],
31 b: &[T; N],
32 carry_in: bool,
33 ) -> ([T; N], bool) {
34 let mut result = [T::zero(); N];
35 let mut carry = carry_in;
36 let mut i = 0usize;
37 while i < N {
38 let (sum, c) = ConstCarryingAdd::carrying_add(a[i], b[i], carry);
39 result[i] = sum;
40 carry = c;
41 i += 1;
42 }
43 (result, carry)
44 }
45
46 c0nst fn sub_with_borrow<T: [c0nst] ConstMachineWord + [c0nst] ConstBorrowingSub, const N: usize>(
49 a: &[T; N],
50 b: &[T; N],
51 borrow_in: bool,
52 ) -> ([T; N], bool) {
53 let mut result = [T::zero(); N];
54 let mut borrow = borrow_in;
55 let mut i = 0usize;
56 while i < N {
57 let (diff, b) = ConstBorrowingSub::borrowing_sub(a[i], b[i], borrow);
58 result[i] = diff;
59 borrow = b;
60 i += 1;
61 }
62 (result, borrow)
63 }
64
65 impl<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + MachineWord, const N: usize> c0nst ConstCarryingAdd for FixedUInt<T, N> {
66 fn carrying_add(self, rhs: Self, carry: bool) -> (Self, bool) {
67 let (array, carry_out) = add_with_carry(&self.array, &rhs.array, carry);
68 (Self { array }, carry_out)
69 }
70 }
71
72 impl<T: [c0nst] ConstMachineWord + [c0nst] ConstBorrowingSub + MachineWord, const N: usize> c0nst ConstBorrowingSub for FixedUInt<T, N> {
73 fn borrowing_sub(self, rhs: Self, borrow: bool) -> (Self, bool) {
74 let (array, borrow_out) = sub_with_borrow(&self.array, &rhs.array, borrow);
75 (Self { array }, borrow_out)
76 }
77 }
78
79 c0nst fn get_at<T: [c0nst] ConstMachineWord, const N: usize>(
81 lo: &[T; N], hi: &[T; N], pos: usize
82 ) -> T {
83 if pos < N { lo[pos] } else if pos < 2 * N { hi[pos - N] } else { T::zero() }
84 }
85
86 c0nst fn set_at<T: [c0nst] ConstMachineWord, const N: usize>(
88 lo: &mut [T; N], hi: &mut [T; N], pos: usize, val: T
89 ) {
90 if pos < N { lo[pos] = val; } else if pos < 2 * N { hi[pos - N] = val; }
91 }
92
93 impl<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize> c0nst ConstWideningMul for FixedUInt<T, N> {
94 fn widening_mul(self, rhs: Self) -> (Self, Self) {
95 let mut result_low = [T::zero(); N];
98 let mut result_high = [T::zero(); N];
99
100 let mut i = 0usize;
101 while i < N {
102 let mut j = 0usize;
103 while j < N {
104 let pos = i + j;
105 let (mul_lo, mul_hi) = ConstWideningMul::widening_mul(self.array[i], rhs.array[j]);
106
107 let cur0 = get_at(&result_low, &result_high, pos);
110 let (sum0, c0) = ConstCarryingAdd::carrying_add(cur0, mul_lo, false);
111 set_at(&mut result_low, &mut result_high, pos, sum0);
112
113 let cur1 = get_at(&result_low, &result_high, pos + 1);
115 let (sum1, c1) = ConstCarryingAdd::carrying_add(cur1, mul_hi, c0);
116 set_at(&mut result_low, &mut result_high, pos + 1, sum1);
117
118 let mut carry = c1;
120 let mut p = pos + 2;
121 while carry && p < 2 * N {
122 let cur = get_at(&result_low, &result_high, p);
123 let (sum, c) = ConstCarryingAdd::carrying_add(cur, T::zero(), true);
124 set_at(&mut result_low, &mut result_high, p, sum);
125 carry = c;
126 p += 1;
127 }
128
129 j += 1;
130 }
131 i += 1;
132 }
133
134 (Self { array: result_low }, Self { array: result_high })
135 }
136 }
137
138 impl<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize> c0nst ConstCarryingMul for FixedUInt<T, N> {
139 fn carrying_mul(self, rhs: Self, carry: Self) -> (Self, Self) {
140 let (lo, hi) = ConstWideningMul::widening_mul(self, rhs);
142
143 let (lo2, c) = add_with_carry(&lo.array, &carry.array, false);
145
146 let zeros = [T::zero(); N];
148 let (hi2, _) = add_with_carry(&hi.array, &zeros, c);
149
150 (Self { array: lo2 }, Self { array: hi2 })
151 }
152
153 fn carrying_mul_add(self, rhs: Self, addend: Self, carry: Self) -> (Self, Self) {
154 let (lo, hi) = ConstWideningMul::widening_mul(self, rhs);
156
157 let (lo2, c1) = add_with_carry(&lo.array, &carry.array, false);
159
160 let (lo3, c2) = add_with_carry(&lo2, &addend.array, false);
162
163 let zeros = [T::zero(); N];
165 let (hi2, _) = add_with_carry(&hi.array, &zeros, c1);
166 let (hi3, _) = add_with_carry(&hi2, &zeros, c2);
167
168 (Self { array: lo3 }, Self { array: hi3 })
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 type U16 = FixedUInt<u8, 2>;
178 type U32 = FixedUInt<u8, 4>;
179
180 c0nst::c0nst! {
181 pub c0nst fn const_carrying_add<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize>(
182 a: FixedUInt<T, N>,
183 b: FixedUInt<T, N>,
184 carry: bool,
185 ) -> (FixedUInt<T, N>, bool) {
186 ConstCarryingAdd::carrying_add(a, b, carry)
187 }
188
189 pub c0nst fn const_borrowing_sub<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize>(
190 a: FixedUInt<T, N>,
191 b: FixedUInt<T, N>,
192 borrow: bool,
193 ) -> (FixedUInt<T, N>, bool) {
194 ConstBorrowingSub::borrowing_sub(a, b, borrow)
195 }
196
197 pub c0nst fn const_widening_mul<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize>(
198 a: FixedUInt<T, N>,
199 b: FixedUInt<T, N>,
200 ) -> (FixedUInt<T, N>, FixedUInt<T, N>) {
201 ConstWideningMul::widening_mul(a, b)
202 }
203
204 pub c0nst fn const_carrying_mul<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize>(
205 a: FixedUInt<T, N>,
206 b: FixedUInt<T, N>,
207 carry: FixedUInt<T, N>,
208 ) -> (FixedUInt<T, N>, FixedUInt<T, N>) {
209 ConstCarryingMul::carrying_mul(a, b, carry)
210 }
211
212 pub c0nst fn const_carrying_mul_add<T: [c0nst] ConstMachineWord + [c0nst] ConstCarryingAdd + [c0nst] ConstBorrowingSub + MachineWord, const N: usize>(
213 a: FixedUInt<T, N>,
214 b: FixedUInt<T, N>,
215 addend: FixedUInt<T, N>,
216 carry: FixedUInt<T, N>,
217 ) -> (FixedUInt<T, N>, FixedUInt<T, N>) {
218 ConstCarryingMul::carrying_mul_add(a, b, addend, carry)
219 }
220 }
221
222 #[test]
223 fn test_carrying_add_no_carry() {
224 let a = U16::from(100u8);
225 let b = U16::from(50u8);
226
227 let (sum, carry_out) = const_carrying_add(a, b, false);
229 assert_eq!(sum, U16::from(150u8));
230 assert!(!carry_out);
231
232 let (sum, carry_out) = const_carrying_add(a, b, true);
234 assert_eq!(sum, U16::from(151u8));
235 assert!(!carry_out);
236 }
237
238 #[test]
239 fn test_carrying_add_with_overflow() {
240 let max = U16::from(0xFFFFu16);
241 let one = U16::from(1u8);
242
243 let (sum, carry_out) = const_carrying_add(max, U16::from(0u8), true);
245 assert_eq!(sum, U16::from(0u8));
246 assert!(carry_out);
247
248 let (sum, carry_out) = const_carrying_add(max, one, false);
250 assert_eq!(sum, U16::from(0u8));
251 assert!(carry_out);
252
253 let (sum, carry_out) = const_carrying_add(max, max, false);
255 assert_eq!(sum, U16::from(0xFFFEu16));
256 assert!(carry_out);
257 }
258
259 #[test]
260 fn test_borrowing_sub_no_borrow() {
261 let a = U16::from(150u8);
262 let b = U16::from(50u8);
263
264 let (diff, borrow_out) = const_borrowing_sub(a, b, false);
266 assert_eq!(diff, U16::from(100u8));
267 assert!(!borrow_out);
268
269 let (diff, borrow_out) = const_borrowing_sub(a, b, true);
271 assert_eq!(diff, U16::from(99u8));
272 assert!(!borrow_out);
273 }
274
275 #[test]
276 fn test_borrowing_sub_with_underflow() {
277 let zero = U16::from(0u8);
278 let one = U16::from(1u8);
279
280 let (diff, borrow_out) = const_borrowing_sub(zero, one, false);
282 assert_eq!(diff, U16::from(0xFFFFu16));
283 assert!(borrow_out);
284
285 let (diff, borrow_out) = const_borrowing_sub(zero, zero, true);
287 assert_eq!(diff, U16::from(0xFFFFu16));
288 assert!(borrow_out);
289
290 let (diff, borrow_out) = const_borrowing_sub(one, one, true);
292 assert_eq!(diff, U16::from(0xFFFFu16));
293 assert!(borrow_out);
294 }
295
296 #[test]
297 fn test_widening_mul() {
298 let a = U16::from(100u8);
300 let (lo, hi) = const_widening_mul(a, a);
301 assert_eq!(lo, U16::from(10000u16));
302 assert_eq!(hi, U16::from(0u8));
303
304 let b = U16::from(256u16);
306 let (lo, hi) = const_widening_mul(b, b);
307 assert_eq!(lo, U16::from(0u8));
308 assert_eq!(hi, U16::from(1u8));
309
310 let max = U16::from(0xFFFFu16);
312 let (lo, hi) = const_widening_mul(max, max);
313 assert_eq!(lo, U16::from(0x0001u16)); assert_eq!(hi, U16::from(0xFFFEu16)); }
316
317 #[test]
318 fn test_widening_mul_larger() {
319 let a = U32::from(0x10000u32); let b = U32::from(0x10000u32); let (lo, hi) = const_widening_mul(a, b);
323 assert_eq!(lo, U32::from(0u8));
326 assert_eq!(hi, U32::from(1u8));
327 }
328
329 #[test]
330 fn test_carrying_mul() {
331 let a = U16::from(100u8);
332 let b = U16::from(100u8);
333 let carry = U16::from(5u8);
334
335 let (lo, hi) = const_carrying_mul(a, b, carry);
337 assert_eq!(lo, U16::from(10005u16));
338 assert_eq!(hi, U16::from(0u8));
339
340 let max = U16::from(0xFFFFu16);
342 let one = U16::from(1u8);
343 let (lo, hi) = const_carrying_mul(one, one, max);
345 assert_eq!(lo, U16::from(0u8));
346 assert_eq!(hi, U16::from(1u8));
347 }
348
349 #[test]
350 fn test_carrying_mul_add() {
351 let a = U16::from(100u8);
352 let b = U16::from(100u8);
353 let addend = U16::from(10u8);
354 let carry = U16::from(5u8);
355
356 let (lo, hi) = const_carrying_mul_add(a, b, addend, carry);
358 assert_eq!(lo, U16::from(10015u16));
359 assert_eq!(hi, U16::from(0u8));
360 }
361
362 #[test]
363 fn test_carrying_mul_add_double_overflow() {
364 let max = U16::from(0xFFFFu16);
366 let one = U16::from(1u8);
367
368 let (lo, hi) = const_carrying_mul_add(one, one, max, max);
370 assert_eq!(lo, U16::from(0xFFFFu16));
371 assert_eq!(hi, U16::from(1u8));
372 }
373
374 #[test]
375 fn test_const_context() {
376 #[cfg(feature = "nightly")]
377 {
378 const A: U16 = FixedUInt { array: [100, 0] };
379 const B: U16 = FixedUInt { array: [50, 0] };
380
381 const ADD_RESULT: (U16, bool) = const_carrying_add(A, B, false);
383 assert_eq!(ADD_RESULT.0, U16::from(150u8));
384 assert!(!ADD_RESULT.1);
385
386 const ADD_WITH_CARRY: (U16, bool) = const_carrying_add(A, B, true);
387 assert_eq!(ADD_WITH_CARRY.0, U16::from(151u8));
388
389 const SUB_RESULT: (U16, bool) = const_borrowing_sub(A, B, false);
391 assert_eq!(SUB_RESULT.0, U16::from(50u8));
392 assert!(!SUB_RESULT.1);
393
394 const C: U16 = FixedUInt { array: [0, 1] }; const MUL_RESULT: (U16, U16) = const_widening_mul(C, C);
397 assert_eq!(MUL_RESULT.0, U16::from(0u8)); assert_eq!(MUL_RESULT.1, U16::from(1u8)); }
400 }
401
402 #[test]
405 fn test_widening_mul_polymorphic() {
406 fn test_widening<T>(a: T, b: T, expected_lo: T, expected_hi: T)
408 where
409 T: ConstWideningMul
410 + ConstCarryingAdd
411 + ConstBorrowingSub
412 + Eq
413 + core::fmt::Debug
414 + Copy,
415 {
416 let (lo, hi) = ConstWideningMul::widening_mul(a, b);
417 assert_eq!(lo, expected_lo, "lo mismatch");
418 assert_eq!(hi, expected_hi, "hi mismatch");
419 }
420
421 test_widening(
424 U16::from(256u16),
425 U16::from(256u16),
426 U16::from(0u16),
427 U16::from(1u16),
428 );
429
430 test_widening(
432 U32::from(256u32),
433 U32::from(256u32),
434 U32::from(65536u32),
435 U32::from(0u32),
436 );
437
438 test_widening(
440 U16::from(0xFFFFu16),
441 U16::from(0xFFFFu16),
442 U16::from(0x0001u16),
443 U16::from(0xFFFEu16),
444 );
445
446 test_widening(
448 U32::from(0xFFFFFFFFu32),
449 U32::from(2u32),
450 U32::from(0xFFFFFFFEu32),
451 U32::from(1u32),
452 );
453 }
454
455 #[test]
457 fn test_carrying_mul_add_polymorphic() {
458 fn test_cma<T>(a: T, b: T, addend: T, carry: T, expected_lo: T, expected_hi: T)
459 where
460 T: ConstCarryingMul + Eq + core::fmt::Debug + Copy,
461 {
462 let (lo, hi) = ConstCarryingMul::carrying_mul_add(a, b, addend, carry);
463 assert_eq!(lo, expected_lo, "lo mismatch");
464 assert_eq!(hi, expected_hi, "hi mismatch");
465 }
466
467 let max16 = U16::from(0xFFFFu16);
471 test_cma(
472 max16,
473 max16,
474 max16,
475 max16,
476 U16::from(0xFFFFu16),
477 U16::from(0xFFFFu16),
478 );
479
480 let max32 = U32::from(0xFFFFFFFFu32);
482 let zero32 = U32::from(0u32);
483 test_cma(
485 max32,
486 U32::from(1u32),
487 zero32,
488 max32,
489 U32::from(0xFFFFFFFEu32),
490 U32::from(1u32),
491 );
492 }
493}