const_time_bignum/
lib.rs

1//! # Const-time-bignum
2//! A bignum library that operates in constant time and without any heap allocations.
3//!
4//! ⚠️ This library is currently under development and should not be used.
5
6
7use std::fmt;
8use std::ops::{Add, Div, Mul, Rem, Shl, Shr, Sub};
9
10#[derive(Debug, Copy, Clone)]
11#[allow(non_camel_case_types)]
12pub struct u288([u8; 36]); // 288 bit unsigned integer (8x36)
13
14impl fmt::Display for u288 {
15    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
16        write!(f, "{}", self.to_hex())
17    }
18}
19
20impl Add for u288 {
21    type Output = u288;
22    fn add(self, other: Self) -> Self::Output {
23        let mut output = self;
24        let mut carry = 0;
25        for (i, byte) in output.0.iter_mut().enumerate() {
26            // LSB first
27            let sum: u64 = *byte as u64 + other.0[i] as u64 + carry as u64;
28            *byte = (sum % 256) as u8;
29            carry = sum / 256;
30        }
31        if carry > 0 {
32            panic!("overflow");
33        }
34        output
35    }
36}
37
38impl Sub for u288 {
39    type Output = u288;
40    fn sub(self, other: Self) -> Self::Output {
41        let mut output = self;
42        let mut carry = 0;
43        for (i, byte) in output.0.iter_mut().enumerate() {
44            let difference: i64 = *byte as i64 - other.0[i] as i64 - carry as i64;
45            *byte = ((difference + 256) % 256) as u8;
46            carry = difference.is_negative() as u8;
47        }
48        if carry > 0 {
49            panic!("overflow");
50        }
51        output
52    }
53}
54
55impl Mul for u288 {
56    type Output = u288;
57    fn mul(self, other: Self) -> Self::Output {
58        let mut total_sum = u288::new();
59        for (i, byte_self) in self.0.iter().enumerate() {
60            // Multiply entire second number by each byte in self
61            let mut working_sum = other;
62            let mut carry = 0;
63            for byte_other in working_sum.0.iter_mut() {
64                let product = *byte_other as u64 * *byte_self as u64 + carry as u64;
65                *byte_other = (product % 256) as u8;
66                carry = product / 256;
67            }
68            if carry > 0 {
69                panic!("overflow");
70            }
71            working_sum.0.rotate_right(i);
72            total_sum = total_sum + working_sum;
73        }
74        total_sum
75    }
76}
77
78// NOTE: This shifts in base 256
79impl Shl<u288> for u288 {
80    type Output = u288;
81    fn shl(mut self, other: Self) -> Self::Output {
82        let mut output = self;
83        let mut i = u288::new(); // initializes to 0
84        let one = u288::from_hex("1");
85        while other > i {
86            for j in 0..self.0.len() - 1 {
87                output.0[j + 1] = self.0[j];
88            }
89            output.0[0] = 0;
90            self = output;
91            i = i + one; // Increment
92        }
93        output
94    }
95}
96
97// NOTE: This shifts in base 256
98impl Shl<usize> for u288 {
99    type Output = u288;
100    fn shl(mut self, other: usize) -> Self::Output {
101        let mut output = self;
102        let mut i: usize = 0; // initializes to 0
103        while other > i {
104            for j in 0..self.0.len() - 1 {
105                output.0[j + 1] = self.0[j];
106            }
107            output.0[0] = 0;
108            self = output;
109            i += 1; // Increment
110        }
111        output
112    }
113}
114
115// NOTE: This shifts in base 256
116impl Shr<u288> for u288 {
117    type Output = u288;
118    fn shr(mut self, other: Self) -> Self::Output {
119        let mut output = self;
120        let mut i = u288::new(); // initializes to 0
121        let one = u288::from_hex("1");
122        while other > i {
123            for j in (1..self.0.len() - 2).rev() {
124                output.0[j - 1] = self.0[j];
125            }
126            output.0[output.0.len() - 1] = 0;
127            self = output;
128            i = i + one; // Increment
129        }
130        output
131    }
132}
133
134// NOTE: This shifts in base 256
135impl Shr<usize> for u288 {
136    type Output = u288;
137    fn shr(mut self, other: usize) -> Self::Output {
138        let mut output = self;
139        let mut i: usize = 0; // initializes to 0
140        while other > i {
141            for j in (1..self.0.len() - 2).rev() {
142                output.0[j - 1] = self.0[j];
143            }
144            output.0[output.0.len() - 1] = 0;
145            self = output;
146            i += 1; // Increment
147        }
148        output
149    }
150}
151
152// This is slow. TODO: Look into implementing a more performant algorithm!
153// TODO: Do this in constant time!
154impl Rem for u288 {
155    type Output = u288;
156    fn rem(self, other: Self) -> Self::Output {
157        let mut numerator = self;
158        let mut divisor = other;
159        let mut quotient = u288::new(); // 0
160        let one = u288::from_hex("1");
161
162        // Align divisor to msb of numerator and store the shift amount in n
163        let mut n: usize = 0;
164        let mut flag = 1; // Flag to detect when msb has been hit
165        for i in (0..numerator.0.len()).rev() {
166            // Iterate over the bytes backwards
167            n += flag & (numerator.0[i] != 0 && divisor.0[i] == 0) as usize;
168            flag &= (divisor.0[i] == 0) as usize;
169        }
170        divisor = divisor << n; // TODO: Make this constant time!
171
172        // TODO: This is temporary! Need to find a more permament solution
173        let mut n: i64 = n as i64;
174
175        // Keep shifting divisor to the right (decrease, in-memory left shift due to le)
176        while other <= numerator {
177            // Subtract until not possible anymore, then add to quotient
178            let mut i = u288::new();
179            while divisor <= numerator {
180                numerator = numerator - divisor;
181                i = i + one;
182            }
183            quotient = quotient + i << n as usize;
184            n -= 1;
185            divisor = divisor >> 1;
186        }
187        numerator
188    }
189
190    // fn rem(self, other: Self) -> Self::Output {
191    //     let mut numerator = self;
192    //     while numerator >= other {
193    //         // bigu288::new() is equal to 0
194    //         numerator = numerator - other;
195    //     }
196    //     numerator // Remainder
197    // }
198}
199// TODO: Do this in constant time!
200impl Div for u288 {
201    type Output = u288;
202    fn div(self, other: Self) -> Self::Output {
203        let mut numerator = self;
204        let mut divisor = other;
205        let mut quotient = u288::new(); // 0
206        let one = u288::from_hex("1");
207
208        // Align divisor to msb of numerator and store the shift amount in n
209        let mut n: usize = 0;
210        let mut flag = 1; // Flag to detect when msb has been hit
211        for i in (0..numerator.0.len()).rev() {
212            // Iterate over the bytes backwards
213            n += flag & (numerator.0[i] != 0 && divisor.0[i] == 0) as usize;
214            flag &= !(divisor.0[i] != 0) as usize;
215        }
216        divisor = divisor << n; // TODO: Make this constant time!
217
218        // TODO: This is temporary! Need to find a more permament solution
219        let mut n: i64 = n as i64;
220
221        // Keep shifting divisor to the right (decrease, in-memory left shift due to le)
222        while other <= numerator {
223            // Subtract until not possible anymore, then add to quotient
224            let mut i = u288::new();
225            while divisor <= numerator {
226                numerator = numerator - divisor;
227                i = i + one;
228            }
229            quotient = quotient + i << n as usize;
230            n -= 1;
231            divisor = divisor >> 1;
232        }
233        quotient
234    }
235
236    // fn div(self, other: Self) -> Self::Output {
237    //     let mut quotient = u288::new();
238    //     let mut numerator = self;
239    //     while numerator >= other {
240    //         // bigu288::new() is equal to 0
241    //         numerator = numerator - other;
242    //         quotient = quotient + u288::from_hex("1");
243    //     }
244    //     quotient
245    // }
246}
247
248// I don't actually know if a simple == is constant time, but to be on the safe side I implemented
249// a constant time loop.
250impl PartialEq<u288> for u288 {
251    fn eq(&self, other: &u288) -> bool {
252        let mut equal = 1;
253        for (i, byte_self) in self.0.iter().enumerate() {
254            equal &= (*byte_self == other.0[i]) as u8;
255        }
256        equal == 1
257    }
258}
259
260// impl PartialEq<u8> for u288 {
261//     fn eq(&self, other: &u8) -> bool {
262//         self.0[0] == *other
263//     }
264// }
265
266impl PartialOrd<u288> for u288 {
267    fn lt(&self, other: &Self) -> bool {
268        let mut lt = 0;
269        for (i, byte_self) in self.0.iter().enumerate() {
270            lt = (*byte_self < other.0[i]) as u8 | (lt & (*byte_self == other.0[i]) as u8) as u8;
271        }
272        lt == 1
273    }
274    fn gt(&self, other: &Self) -> bool {
275        let mut gt = 0;
276        for (i, byte_self) in self.0.iter().enumerate() {
277            gt = (*byte_self > other.0[i]) as u8 | (gt & (*byte_self == other.0[i]) as u8) as u8;
278        }
279        gt == 1
280    }
281    fn le(&self, other: &Self) -> bool {
282        let mut le = 1;
283        for (i, byte_self) in self.0.iter().enumerate() {
284            le = (*byte_self < other.0[i]) as u8 | (le & (*byte_self == other.0[i]) as u8) as u8;
285        }
286        le == 1
287    }
288    fn ge(&self, other: &Self) -> bool {
289        let mut ge = 1;
290        for (i, byte_self) in self.0.iter().enumerate() {
291            ge = (*byte_self > other.0[i]) as u8 | (ge & (*byte_self == other.0[i]) as u8) as u8;
292        }
293        ge == 1
294    }
295    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
296        todo!("implement partialcmp");
297    }
298}
299
300impl Eq for u288 {}
301
302impl u288 {
303    pub fn from_slice(bytes: &[u8]) -> u288 {
304        let mut big_u288 = u288::new();
305        big_u288.0 = pad_array_bigu288(bytes).as_slice().try_into().unwrap();
306        big_u288
307    }
308    pub fn from_hex(input: &str) -> u288 {
309        let mut big_u288 = u288::new();
310        // Iterate over the string backwards (we want little endian)
311        let input_padded_le: [u8; 72] = pad_array_hex(&input.bytes().rev().collect::<Vec<_>>()[..]);
312        for (i, char) in input_padded_le.iter().enumerate() {
313            let hex_digit = u8::from_str_radix(
314                &String::from_utf8(vec![*char]).unwrap_or("0".to_string()),
315                16,
316            )
317            .unwrap_or(0);
318            big_u288.0[i / 2] += hex_digit << 4 * (i % 2);
319        }
320        big_u288
321    }
322    pub fn to_hex(&self) -> String {
323        let mut out = String::new();
324        for byte in self.get_bytes().iter().rev() {
325            out += &format!("{:x}{:x}", byte >> 4, byte & 15);
326        }
327        out
328    }
329    pub fn get_bytes(&self) -> [u8; 36] {
330        self.0
331    }
332    pub fn new() -> u288 {
333        u288([0; 36])
334    }
335}
336
337fn pad_array_hex(input: &[u8]) -> [u8; 72] {
338    let mut padded = [0u8; 72]; // TODO: Make this configurable
339    padded[..input.len()].copy_from_slice(input);
340    padded
341}
342
343fn pad_array_bigu288(input: &[u8]) -> [u8; 36] {
344    let mut padded = [0u8; 36]; // TODO: Make this configurable
345    padded[..input.len()].copy_from_slice(input);
346    padded
347}