1use std::ptr;
8
9pub trait Scrubbed {
12 fn scrub(&mut self);
13}
14
15#[inline(never)]
27pub unsafe fn memset(dst: *mut u8, val: u8, count: usize) {
28 for i in 0..count {
29 ptr::write_volatile(dst.add(i), val);
30 }
31}
32
33#[inline(never)]
44pub unsafe fn memeq(v1: *const u8, v2: *const u8, len: usize) -> bool {
45 let mut sum = 0;
46
47 assert!(
48 len != 0,
49 "Cannot perform equality comparison if the length is 0"
50 );
51
52 for i in 0..len {
53 let val1 = ptr::read_volatile(v1.add(i));
54 let val2 = ptr::read_volatile(v2.add(i));
55
56 let xor = val1 ^ val2;
57
58 sum |= xor;
59 }
60
61 sum == 0
62}
63
64#[inline(never)]
75pub unsafe fn memcmp(v1: *const u8, v2: *const u8, len: usize) -> std::cmp::Ordering {
76 let mut res = 0;
77
78 assert!(
79 len != 0,
80 "Cannot perform ordering comparison if the length is 0"
81 );
82
83 for i in (0..len).rev() {
84 let val1 = ptr::read_volatile(v1.add(i)) as i32;
85 let val2 = ptr::read_volatile(v2.add(i)) as i32;
86 let diff = val1 - val2;
87 res = (res & (((diff - 1) & !diff) >> 8)) | diff;
88 }
89 let res = ((res - 1) >> 8) + (res >> 8) + 1;
90
91 res.cmp(&0)
92}
93
94macro_rules! impl_scrubbed_primitive {
95 ($t:ty) => {
96 impl Scrubbed for $t {
97 #[inline(never)]
98 fn scrub(&mut self) {
99 *self = 0;
100 }
101 }
102 };
103}
104
105impl_scrubbed_primitive!(u8);
106impl_scrubbed_primitive!(u16);
107impl_scrubbed_primitive!(u32);
108impl_scrubbed_primitive!(u64);
109impl_scrubbed_primitive!(u128);
110impl_scrubbed_primitive!(usize);
111impl_scrubbed_primitive!(i8);
112impl_scrubbed_primitive!(i16);
113impl_scrubbed_primitive!(i32);
114impl_scrubbed_primitive!(i64);
115impl_scrubbed_primitive!(i128);
116impl_scrubbed_primitive!(isize);
117
118macro_rules! impl_scrubbed_array {
119 ($t:ty) => {
120 impl Scrubbed for $t {
121 fn scrub(&mut self) {
122 unsafe { memset(self.as_mut_ptr(), 0, self.len()) }
123 }
124 }
125 };
126}
127
128impl_scrubbed_array!([u8]);
129impl_scrubbed_array!(str);
130
131impl<const N: usize> Scrubbed for [u8; N] {
132 fn scrub(&mut self) {
133 unsafe { memset(self.as_mut_ptr(), 0, self.len()) }
134 }
135}
136
137impl<T: Scrubbed> Scrubbed for Option<T> {
138 fn scrub(&mut self) {
139 self.as_mut().map(Scrubbed::scrub);
140 }
141}
142
143impl<T: Scrubbed> Scrubbed for Vec<T> {
144 fn scrub(&mut self) {
145 self.iter_mut().for_each(Scrubbed::scrub)
146 }
147}
148
149impl<T: Scrubbed> Scrubbed for Box<T> {
150 fn scrub(&mut self) {
151 self.as_mut().scrub()
152 }
153}
154
155impl<T: Scrubbed> Scrubbed for std::cell::Cell<T> {
156 fn scrub(&mut self) {
157 self.get_mut().scrub()
158 }
159}
160
161impl<T: Scrubbed> Scrubbed for std::cell::RefCell<T> {
162 fn scrub(&mut self) {
163 self.get_mut().scrub()
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use std::cmp::Ordering;
170
171 use super::*;
172 use quickcheck::TestResult;
173
174 #[test]
175 #[should_panic]
176 fn eq_empty() {
177 let bytes = Vec::new();
178 unsafe { memeq(bytes.as_ptr(), bytes.as_ptr(), bytes.len()) };
179 }
180
181 #[test]
182 #[should_panic]
183 fn ord_empty() {
184 let bytes = Vec::new();
185 unsafe { memcmp(bytes.as_ptr(), bytes.as_ptr(), bytes.len()) };
186 }
187
188 #[quickcheck]
189 fn eq(bytes: Vec<u8>) -> TestResult {
190 if bytes.is_empty() {
191 TestResult::discard()
192 } else {
193 let b = unsafe { memeq(bytes.as_ptr(), bytes.as_ptr(), bytes.len()) };
194 TestResult::from_bool(b)
195 }
196 }
197
198 #[quickcheck]
199 fn ord_eq(bytes: Vec<u8>) -> TestResult {
200 if bytes.is_empty() {
201 TestResult::discard()
202 } else {
203 let ord = unsafe { memcmp(bytes.as_ptr(), bytes.as_ptr(), bytes.len()) };
204 TestResult::from_bool(ord == Ordering::Equal)
205 }
206 }
207
208 #[quickcheck]
209 fn neq(a: Vec<u8>, b: Vec<u8>) -> TestResult {
210 let len = std::cmp::min(a.len(), b.len());
211
212 if a[..len] == b[..len] || len == 0 {
213 TestResult::discard()
214 } else {
215 let b = unsafe { memeq(a.as_ptr(), b.as_ptr(), len) };
216
217 TestResult::from_bool(!b)
218 }
219 }
220
221 #[quickcheck]
222 fn ord(a: Vec<u8>, b: Vec<u8>) -> TestResult {
223 let len = std::cmp::min(a.len(), b.len());
224
225 if len == 0 {
226 TestResult::discard()
227 } else {
228 let a = &a[..len];
229 let b = &b[..len];
230 let ord = unsafe { memcmp(a.as_ptr(), b.as_ptr(), len) };
231
232 TestResult::from_bool(ord == a.cmp(b))
233 }
234 }
235}