mc_oblivious_traits/
testing.rs

1// Copyright (c) 2018-2023 The MobileCoin Foundation
2
3//! Some generic tests that exercise objects implementing these traits
4
5use crate::{ObliviousHashMap, OMAP_FOUND, OMAP_INVALID_KEY, OMAP_NOT_FOUND, OMAP_OVERFLOW, ORAM};
6use aligned_cmov::{subtle::Choice, typenum::U8, A64Bytes, A8Bytes, Aligned, ArrayLength};
7use alloc::{
8    collections::{btree_map::Entry, BTreeMap},
9    vec::Vec,
10};
11use rand_core::{CryptoRng, RngCore};
12
13/// Exercise an ORAM by writing, reading, and rewriting, a progressively larger
14/// set of random locations
15pub fn exercise_oram<BlockSize, O, R>(mut num_rounds: usize, oram: &mut O, rng: &mut R)
16where
17    BlockSize: ArrayLength<u8>,
18    O: ORAM<BlockSize>,
19    R: RngCore + CryptoRng,
20{
21    let len = oram.len();
22    assert!(len != 0, "len is zero");
23    assert_eq!(len & (len - 1), 0, "len is not a power of two");
24    let mut expected = BTreeMap::<u64, A64Bytes<BlockSize>>::default();
25    let mut probe_positions = Vec::<u64>::new();
26    let mut probe_idx = 0usize;
27
28    while num_rounds > 0 {
29        if probe_idx >= probe_positions.len() {
30            probe_positions.push(rng.next_u64() & (len - 1));
31            probe_idx = 0;
32        }
33        let query = probe_positions[probe_idx];
34
35        query_oram_and_randomize(&mut expected, query, oram, rng);
36
37        probe_idx += 1;
38        num_rounds -= 1;
39    }
40}
41
42/// Exercise an ORAM by writing, reading, and rewriting, all locations
43/// consecutively
44pub fn exercise_oram_consecutive<BlockSize, O, R>(mut num_rounds: usize, oram: &mut O, rng: &mut R)
45where
46    BlockSize: ArrayLength<u8>,
47    O: ORAM<BlockSize>,
48    R: RngCore + CryptoRng,
49{
50    let len = oram.len();
51    assert!(len != 0, "len is zero");
52    assert_eq!(len & (len - 1), 0, "len is not a power of two");
53    let mut expected = BTreeMap::<u64, A64Bytes<BlockSize>>::default();
54
55    while num_rounds > 0 {
56        let query = num_rounds as u64 & (len - 1);
57        query_oram_and_randomize(&mut expected, query, oram, rng);
58
59        num_rounds -= 1;
60    }
61}
62
63fn query_oram_and_randomize<BlockSize, O, R>(
64    expected: &mut BTreeMap<
65        u64,
66        Aligned<aligned_cmov::A64, aligned_cmov::GenericArray<u8, BlockSize>>,
67    >,
68    query: u64,
69    oram: &mut O,
70    rng: &mut R,
71) where
72    BlockSize: ArrayLength<u8>,
73    O: ORAM<BlockSize>,
74    R: RngCore + CryptoRng,
75{
76    let expected_ent = expected.entry(query).or_default();
77    oram.access(query, |val| {
78        assert_eq!(val, expected_ent);
79        rng.fill_bytes(val);
80        expected_ent.clone_from_slice(val.as_slice());
81    });
82}
83
84/// Exercise an ORAM by writing, reading, and rewriting, first cycling through
85/// all N locations num_pre_rounds times to warm up the oram, then repeatedly
86/// cycling through all N locations a total of num_rounds times as a worst case
87/// access sequence and measuring the stash size.
88pub fn measure_oram_stash_size_distribution<BlockSize, O, R>(
89    mut num_pre_rounds: usize,
90    mut num_rounds: usize,
91    oram: &mut O,
92    rng: &mut R,
93) -> BTreeMap<usize, usize>
94where
95    BlockSize: ArrayLength<u8>,
96    O: ORAM<BlockSize>,
97    R: RngCore + CryptoRng,
98{
99    let len = oram.len();
100    assert!(len != 0, "len is zero");
101    assert_eq!(len & (len - 1), 0, "len is not a power of two");
102
103    let mut expected = BTreeMap::<u64, A64Bytes<BlockSize>>::default();
104    let mut probe_idx = 0u64;
105    let mut stash_size_by_count = BTreeMap::<usize, usize>::default();
106
107    while num_pre_rounds > 0 {
108        query_oram_and_randomize(&mut expected, probe_idx, oram, rng);
109        probe_idx = (probe_idx + 1) & (len - 1);
110        num_pre_rounds -= 1;
111    }
112
113    while num_rounds > 0 {
114        query_oram_and_randomize(&mut expected, probe_idx, oram, rng);
115        *stash_size_by_count.entry(oram.stash_size()).or_default() += 1;
116        probe_idx = (probe_idx + 1) & (len - 1);
117        num_rounds -= 1;
118    }
119    stash_size_by_count
120}
121
122/// Exercise an OMAP by writing, reading, accessing, and removing a
123/// progressively larger set of random locations
124pub fn exercise_omap<KeySize, ValSize, O, R>(mut num_rounds: usize, omap: &mut O, rng: &mut R)
125where
126    KeySize: ArrayLength<u8>,
127    ValSize: ArrayLength<u8>,
128    O: ObliviousHashMap<KeySize, ValSize>,
129    R: RngCore + CryptoRng,
130{
131    let mut expected = BTreeMap::<A8Bytes<KeySize>, A8Bytes<ValSize>>::default();
132    let mut probe_positions = Vec::<A8Bytes<KeySize>>::new();
133    let mut probe_idx = 0usize;
134
135    while num_rounds > 0 {
136        if probe_idx >= probe_positions.len() {
137            let mut bytes = A8Bytes::<KeySize>::default();
138            rng.fill_bytes(&mut bytes);
139            probe_positions.push(bytes);
140            probe_idx = 0;
141        }
142
143        // In one round, do a query from the sequence and a random query
144        let query1 = probe_positions[probe_idx].clone();
145        let query2 = {
146            let mut bytes = A8Bytes::<KeySize>::default();
147            rng.fill_bytes(&mut bytes);
148            bytes
149        };
150
151        for query in &[query1, query2] {
152            // First, read at query and sanity check it
153            {
154                let mut output = A8Bytes::<ValSize>::default();
155                let result_code = omap.read(query, &mut output);
156
157                let expected_ent = expected.entry(query.clone());
158                match expected_ent {
159                    Entry::Vacant(_) => {
160                        assert_eq!(result_code, OMAP_NOT_FOUND);
161                    }
162                    Entry::Occupied(occ) => {
163                        assert_eq!(result_code, OMAP_FOUND);
164                        assert_eq!(&output, occ.get());
165                    }
166                };
167            }
168
169            // decide what random action to take that modifies the map
170            let action = rng.next_u32() % 7;
171            match action {
172                // In this case we only READ and continue through the loop
173                0 => {
174                    continue;
175                }
176                1 | 2 => {
177                    // In this case we WRITE to the omap, allowing overwrite
178                    let mut new_val = A8Bytes::<ValSize>::default();
179                    rng.fill_bytes(new_val.as_mut_slice());
180                    let result_code = omap.vartime_write(query, &new_val, Choice::from(1));
181
182                    if expected.contains_key(query) {
183                        assert_eq!(result_code, OMAP_FOUND);
184                    } else {
185                        assert_eq!(result_code, OMAP_NOT_FOUND);
186                    }
187
188                    expected
189                        .entry(query.clone())
190                        .or_default()
191                        .copy_from_slice(new_val.as_slice());
192                }
193                3 | 4 => {
194                    // In this case we WRITE to the omap, not allowing overwrite
195                    let mut new_val = A8Bytes::<ValSize>::default();
196                    rng.fill_bytes(new_val.as_mut_slice());
197                    let result_code = omap.vartime_write(query, &new_val, Choice::from(0));
198
199                    if expected.contains_key(query) {
200                        assert_eq!(result_code, OMAP_FOUND);
201                    } else {
202                        assert_eq!(result_code, OMAP_NOT_FOUND);
203                    }
204
205                    expected.entry(query.clone()).or_insert(new_val);
206                }
207                5 => {
208                    // In this case we ACCESS the omap
209                    omap.access(query, |result_code, val| {
210                        match expected.entry(query.clone()) {
211                            Entry::Vacant(_) => {
212                                assert_eq!(result_code, OMAP_NOT_FOUND);
213                            }
214                            Entry::Occupied(mut occ) => {
215                                assert_eq!(result_code, OMAP_FOUND);
216                                rng.fill_bytes(val.as_mut_slice());
217                                *occ.get_mut() = val.clone();
218                            }
219                        }
220                    })
221                }
222                _ => {
223                    // In this case we REMOVE from the omap
224                    let result_code = omap.remove(query);
225
226                    if expected.remove(query).is_some() {
227                        assert_eq!(result_code, OMAP_FOUND);
228                    } else {
229                        assert_eq!(result_code, OMAP_NOT_FOUND);
230                    }
231                }
232            };
233
234            // Finally read from the position again as an extra check
235            {
236                // In this case we READ from omap
237                let mut output = A8Bytes::<ValSize>::default();
238                let result_code = omap.read(query, &mut output);
239
240                let expected_ent = expected.entry(query.clone());
241                match expected_ent {
242                    Entry::Vacant(_) => {
243                        assert_eq!(result_code, OMAP_NOT_FOUND);
244                    }
245                    Entry::Occupied(occ) => {
246                        assert_eq!(result_code, OMAP_FOUND);
247                        assert_eq!(&output, occ.get(),);
248                    }
249                };
250            }
251        }
252
253        probe_idx += 1;
254        num_rounds -= 1;
255    }
256}
257
258/// Take an empty omap and add items consecutively to it until it overflows.
259/// Then test that on overflow we have rollback semantics, and we can still find
260/// all of the items that we added.
261pub fn test_omap_overflow<KeySize, ValSize, O>(omap: &mut O) -> u64
262where
263    KeySize: ArrayLength<u8>,
264    ValSize: ArrayLength<u8>,
265    O: ObliviousHashMap<KeySize, ValSize>,
266{
267    // count from 1 because 0 is an invalid key
268    let mut idx = 1u64;
269    let mut key = A8Bytes::<KeySize>::default();
270    let mut val = A8Bytes::<ValSize>::default();
271
272    loop {
273        assert_eq!(omap.len(), idx - 1, "unexpected omap.len()");
274        key[0..8].copy_from_slice(&idx.to_le_bytes());
275        val[0..8].copy_from_slice(&idx.to_le_bytes());
276        let result_code = omap.vartime_write(&key, &val, Choice::from(0));
277
278        if result_code == OMAP_FOUND {
279            panic!("unexpectedly found item idx = {}", idx);
280        } else if result_code == OMAP_INVALID_KEY {
281            panic!("unexpectedly recieved OMAP_INVALID_KEY, idx = {}", idx);
282        } else if result_code == OMAP_OVERFLOW {
283            // Now that we got an overflow, lets test if rollback semantics worked.
284            assert_eq!(
285                omap.len(),
286                idx - 1,
287                "omap.len() unexpected value after overflow"
288            );
289            let mut temp = A8Bytes::<ValSize>::default();
290            for idx2 in 1u64..idx {
291                key[0..8].copy_from_slice(&idx2.to_le_bytes());
292                val[0..8].copy_from_slice(&idx2.to_le_bytes());
293                let result_code = omap.read(&key, &mut temp);
294                assert_eq!(
295                    result_code, OMAP_FOUND,
296                    "Failed to find an item that should be in the map: idx2 = {}",
297                    idx2
298                );
299                assert_eq!(
300                    temp, val,
301                    "Value that was stored in the map was wrong after overflow: idx2 = {}",
302                    idx2
303                );
304            }
305            return omap.len();
306        } else if result_code != OMAP_NOT_FOUND {
307            panic!("unexpected result code: {}", result_code);
308        }
309
310        idx += 1;
311    }
312}
313
314/// Exercise an OMAP used as an oblivious counter table via the
315/// access_and_insert operation
316pub fn exercise_omap_counter_table<KeySize, O, R>(mut num_rounds: usize, omap: &mut O, rng: &mut R)
317where
318    KeySize: ArrayLength<u8>,
319    O: ObliviousHashMap<KeySize, U8>,
320    R: RngCore + CryptoRng,
321{
322    type ValSize = U8;
323    let zero: A8Bytes<ValSize> = Default::default();
324
325    let mut expected = BTreeMap::<A8Bytes<KeySize>, A8Bytes<ValSize>>::default();
326    let mut probe_positions = Vec::<A8Bytes<KeySize>>::new();
327    let mut probe_idx = 0usize;
328
329    while num_rounds > 0 {
330        if probe_idx >= probe_positions.len() {
331            let mut bytes = A8Bytes::<KeySize>::default();
332            rng.fill_bytes(&mut bytes);
333            probe_positions.push(bytes);
334            probe_idx = 0;
335        }
336
337        // In one round, do a query from the sequence and a random query
338        let query1 = probe_positions[probe_idx].clone();
339        let query2 = {
340            let mut bytes = A8Bytes::<KeySize>::default();
341            rng.fill_bytes(&mut bytes);
342            bytes
343        };
344
345        for query in &[query1, query2] {
346            // First, read at query and sanity check it
347            {
348                let mut output = A8Bytes::<ValSize>::default();
349                let result_code = omap.read(query, &mut output);
350
351                let expected_ent = expected.entry(query.clone());
352                match expected_ent {
353                    Entry::Vacant(_) => {
354                        // Value should be absent or 0 (0's are created by the random inserts)
355                        assert!(
356                            result_code == OMAP_NOT_FOUND
357                                || (result_code == OMAP_FOUND && output == zero),
358                            "Expected no value but omap found nonzero value: result_code {}",
359                            result_code
360                        );
361                    }
362                    Entry::Occupied(occ) => {
363                        assert!(
364                            result_code == OMAP_FOUND
365                                || (result_code == OMAP_NOT_FOUND && occ.get() == &zero),
366                            "Expected a value but OMAP found none: result_code: {}",
367                            result_code
368                        );
369                        assert_eq!(&output, occ.get());
370                    }
371                };
372            }
373
374            // Next, use access_and_insert to increment it
375            let result_code = omap.access_and_insert(query, &zero, rng, |_status_code, buffer| {
376                let num = u64::from_ne_bytes(*buffer.as_ref()) + 1;
377                *buffer = Aligned(num.to_ne_bytes().into());
378
379                expected
380                    .entry(query.clone())
381                    .or_default()
382                    .copy_from_slice(buffer);
383            });
384            assert!(result_code != OMAP_INVALID_KEY, "Invalid key");
385            if result_code == OMAP_OVERFLOW {
386                // When overflow occurs, we don't know if the change was rolled back or not,
387                // so we have to read the map to figure it out, if we want to continue the test.
388                let mut buffer = A8Bytes::<ValSize>::default();
389                let _result_code = omap.read(query, &mut buffer);
390
391                let map_num = u64::from_ne_bytes(*buffer.as_ref());
392                let expected_buf = expected.get(query).unwrap().clone();
393                let expected_num = u64::from_ne_bytes(*expected_buf.as_ref());
394                assert!(
395                    map_num == expected_num || map_num + 1 == expected_num,
396                    "Unexpected value in omap: map_num {}, expected_num = {}",
397                    map_num,
398                    expected_num
399                );
400                expected
401                    .entry(query.clone())
402                    .or_default()
403                    .copy_from_slice(&buffer);
404            }
405
406            // Finally read from the position again as an extra check
407            {
408                // In this case we READ from omap
409                let mut output = A8Bytes::<ValSize>::default();
410                let result_code = omap.read(query, &mut output);
411
412                let expected_ent = expected.entry(query.clone());
413                match expected_ent {
414                    Entry::Vacant(_) => {
415                        // Value should be absent or 0 (0's are created by the random inserts)
416                        assert!(
417                            result_code == OMAP_NOT_FOUND
418                                || (result_code == OMAP_FOUND && output == zero),
419                            "Expected no value but omap found nonzero value: result_code {}",
420                            result_code
421                        );
422                    }
423                    Entry::Occupied(occ) => {
424                        assert!(
425                            result_code == OMAP_FOUND
426                                || (result_code == OMAP_NOT_FOUND && occ.get() == &zero),
427                            "Expected a value but OMAP found none: result_code: {}",
428                            result_code
429                        );
430                        assert_eq!(&output, occ.get());
431                    }
432                };
433            }
434        }
435
436        probe_idx += 1;
437        num_rounds -= 1;
438    }
439}