keynesis_core/
memsec.rs

1/*!
2# Memsec utility functions
3
4Most of the types defined here implements `Scrubbed` trait.
5*/
6
7use std::ptr;
8
9/// Types implementing this can be scrubbed, the memory is cleared and
10/// erased with a dummy value.
11pub trait Scrubbed {
12    fn scrub(&mut self);
13}
14
15/// Perform a secure memset. This function is guaranteed not to be elided
16/// or reordered.
17///
18/// # Performance consideration
19///
20/// On `nightly`, the function use a more efficient.
21///
22/// # Safety
23///
24/// The destination memory (`dst` to `dst+count`) must be properly allocated
25/// and ready to use.
26#[inline(never)]
27pub unsafe fn memset(dst: *mut u8, val: u8, count: usize) {
28    for i in 0..count {
29        ptr::write_volatile(dst.add(i), val);
30    }
31}
32
33/// compare the equality of the 2 given arrays, constant in time
34///
35/// # Panics
36///
37/// The function will panic if it is called with a `len` of 0.
38///
39/// # Safety
40///
41/// Expecting to have both valid pointer and the count to fit in
42/// both the allocated memories
43#[inline(never)]
44pub unsafe fn memeq(v1: *const u8, v2: *const u8, len: usize) -> bool {
45    let mut sum = 0;
46
47    assert!(
48        len != 0,
49        "Cannot perform equality comparison if the length is 0"
50    );
51
52    for i in 0..len {
53        let val1 = ptr::read_volatile(v1.add(i));
54        let val2 = ptr::read_volatile(v2.add(i));
55
56        let xor = val1 ^ val2;
57
58        sum |= xor;
59    }
60
61    sum == 0
62}
63
64/// Constant time comparison
65///
66/// # Panics
67///
68/// The function will panic if it is called with a `len` of 0.
69///
70/// # Safety
71///
72/// Expecting to have both valid pointer and the count to fit in
73/// both the allocated memories
74#[inline(never)]
75pub unsafe fn memcmp(v1: *const u8, v2: *const u8, len: usize) -> std::cmp::Ordering {
76    let mut res = 0;
77
78    assert!(
79        len != 0,
80        "Cannot perform ordering comparison if the length is 0"
81    );
82
83    for i in (0..len).rev() {
84        let val1 = ptr::read_volatile(v1.add(i)) as i32;
85        let val2 = ptr::read_volatile(v2.add(i)) as i32;
86        let diff = val1 - val2;
87        res = (res & (((diff - 1) & !diff) >> 8)) | diff;
88    }
89    let res = ((res - 1) >> 8) + (res >> 8) + 1;
90
91    res.cmp(&0)
92}
93
94macro_rules! impl_scrubbed_primitive {
95    ($t:ty) => {
96        impl Scrubbed for $t {
97            #[inline(never)]
98            fn scrub(&mut self) {
99                *self = 0;
100            }
101        }
102    };
103}
104
105impl_scrubbed_primitive!(u8);
106impl_scrubbed_primitive!(u16);
107impl_scrubbed_primitive!(u32);
108impl_scrubbed_primitive!(u64);
109impl_scrubbed_primitive!(u128);
110impl_scrubbed_primitive!(usize);
111impl_scrubbed_primitive!(i8);
112impl_scrubbed_primitive!(i16);
113impl_scrubbed_primitive!(i32);
114impl_scrubbed_primitive!(i64);
115impl_scrubbed_primitive!(i128);
116impl_scrubbed_primitive!(isize);
117
118macro_rules! impl_scrubbed_array {
119    ($t:ty) => {
120        impl Scrubbed for $t {
121            fn scrub(&mut self) {
122                unsafe { memset(self.as_mut_ptr(), 0, self.len()) }
123            }
124        }
125    };
126}
127
128impl_scrubbed_array!([u8]);
129impl_scrubbed_array!(str);
130
131impl<const N: usize> Scrubbed for [u8; N] {
132    fn scrub(&mut self) {
133        unsafe { memset(self.as_mut_ptr(), 0, self.len()) }
134    }
135}
136
137impl<T: Scrubbed> Scrubbed for Option<T> {
138    fn scrub(&mut self) {
139        self.as_mut().map(Scrubbed::scrub);
140    }
141}
142
143impl<T: Scrubbed> Scrubbed for Vec<T> {
144    fn scrub(&mut self) {
145        self.iter_mut().for_each(Scrubbed::scrub)
146    }
147}
148
149impl<T: Scrubbed> Scrubbed for Box<T> {
150    fn scrub(&mut self) {
151        self.as_mut().scrub()
152    }
153}
154
155impl<T: Scrubbed> Scrubbed for std::cell::Cell<T> {
156    fn scrub(&mut self) {
157        self.get_mut().scrub()
158    }
159}
160
161impl<T: Scrubbed> Scrubbed for std::cell::RefCell<T> {
162    fn scrub(&mut self) {
163        self.get_mut().scrub()
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use std::cmp::Ordering;
170
171    use super::*;
172    use quickcheck::TestResult;
173
174    #[test]
175    #[should_panic]
176    fn eq_empty() {
177        let bytes = Vec::new();
178        unsafe { memeq(bytes.as_ptr(), bytes.as_ptr(), bytes.len()) };
179    }
180
181    #[test]
182    #[should_panic]
183    fn ord_empty() {
184        let bytes = Vec::new();
185        unsafe { memcmp(bytes.as_ptr(), bytes.as_ptr(), bytes.len()) };
186    }
187
188    #[quickcheck]
189    fn eq(bytes: Vec<u8>) -> TestResult {
190        if bytes.is_empty() {
191            TestResult::discard()
192        } else {
193            let b = unsafe { memeq(bytes.as_ptr(), bytes.as_ptr(), bytes.len()) };
194            TestResult::from_bool(b)
195        }
196    }
197
198    #[quickcheck]
199    fn ord_eq(bytes: Vec<u8>) -> TestResult {
200        if bytes.is_empty() {
201            TestResult::discard()
202        } else {
203            let ord = unsafe { memcmp(bytes.as_ptr(), bytes.as_ptr(), bytes.len()) };
204            TestResult::from_bool(ord == Ordering::Equal)
205        }
206    }
207
208    #[quickcheck]
209    fn neq(a: Vec<u8>, b: Vec<u8>) -> TestResult {
210        let len = std::cmp::min(a.len(), b.len());
211
212        if a[..len] == b[..len] || len == 0 {
213            TestResult::discard()
214        } else {
215            let b = unsafe { memeq(a.as_ptr(), b.as_ptr(), len) };
216
217            TestResult::from_bool(!b)
218        }
219    }
220
221    #[quickcheck]
222    fn ord(a: Vec<u8>, b: Vec<u8>) -> TestResult {
223        let len = std::cmp::min(a.len(), b.len());
224
225        if len == 0 {
226            TestResult::discard()
227        } else {
228            let a = &a[..len];
229            let b = &b[..len];
230            let ord = unsafe { memcmp(a.as_ptr(), b.as_ptr(), len) };
231
232            TestResult::from_bool(ord == a.cmp(b))
233        }
234    }
235}