1use crate::{ObliviousHashMap, OMAP_FOUND, OMAP_INVALID_KEY, OMAP_NOT_FOUND, OMAP_OVERFLOW, ORAM};
6use aligned_cmov::{subtle::Choice, typenum::U8, A64Bytes, A8Bytes, Aligned, ArrayLength};
7use alloc::{
8 collections::{btree_map::Entry, BTreeMap},
9 vec::Vec,
10};
11use rand_core::{CryptoRng, RngCore};
12
13pub fn exercise_oram<BlockSize, O, R>(mut num_rounds: usize, oram: &mut O, rng: &mut R)
16where
17 BlockSize: ArrayLength<u8>,
18 O: ORAM<BlockSize>,
19 R: RngCore + CryptoRng,
20{
21 let len = oram.len();
22 assert!(len != 0, "len is zero");
23 assert_eq!(len & (len - 1), 0, "len is not a power of two");
24 let mut expected = BTreeMap::<u64, A64Bytes<BlockSize>>::default();
25 let mut probe_positions = Vec::<u64>::new();
26 let mut probe_idx = 0usize;
27
28 while num_rounds > 0 {
29 if probe_idx >= probe_positions.len() {
30 probe_positions.push(rng.next_u64() & (len - 1));
31 probe_idx = 0;
32 }
33 let query = probe_positions[probe_idx];
34
35 query_oram_and_randomize(&mut expected, query, oram, rng);
36
37 probe_idx += 1;
38 num_rounds -= 1;
39 }
40}
41
42pub fn exercise_oram_consecutive<BlockSize, O, R>(mut num_rounds: usize, oram: &mut O, rng: &mut R)
45where
46 BlockSize: ArrayLength<u8>,
47 O: ORAM<BlockSize>,
48 R: RngCore + CryptoRng,
49{
50 let len = oram.len();
51 assert!(len != 0, "len is zero");
52 assert_eq!(len & (len - 1), 0, "len is not a power of two");
53 let mut expected = BTreeMap::<u64, A64Bytes<BlockSize>>::default();
54
55 while num_rounds > 0 {
56 let query = num_rounds as u64 & (len - 1);
57 query_oram_and_randomize(&mut expected, query, oram, rng);
58
59 num_rounds -= 1;
60 }
61}
62
63fn query_oram_and_randomize<BlockSize, O, R>(
64 expected: &mut BTreeMap<
65 u64,
66 Aligned<aligned_cmov::A64, aligned_cmov::GenericArray<u8, BlockSize>>,
67 >,
68 query: u64,
69 oram: &mut O,
70 rng: &mut R,
71) where
72 BlockSize: ArrayLength<u8>,
73 O: ORAM<BlockSize>,
74 R: RngCore + CryptoRng,
75{
76 let expected_ent = expected.entry(query).or_default();
77 oram.access(query, |val| {
78 assert_eq!(val, expected_ent);
79 rng.fill_bytes(val);
80 expected_ent.clone_from_slice(val.as_slice());
81 });
82}
83
84pub fn measure_oram_stash_size_distribution<BlockSize, O, R>(
89 mut num_pre_rounds: usize,
90 mut num_rounds: usize,
91 oram: &mut O,
92 rng: &mut R,
93) -> BTreeMap<usize, usize>
94where
95 BlockSize: ArrayLength<u8>,
96 O: ORAM<BlockSize>,
97 R: RngCore + CryptoRng,
98{
99 let len = oram.len();
100 assert!(len != 0, "len is zero");
101 assert_eq!(len & (len - 1), 0, "len is not a power of two");
102
103 let mut expected = BTreeMap::<u64, A64Bytes<BlockSize>>::default();
104 let mut probe_idx = 0u64;
105 let mut stash_size_by_count = BTreeMap::<usize, usize>::default();
106
107 while num_pre_rounds > 0 {
108 query_oram_and_randomize(&mut expected, probe_idx, oram, rng);
109 probe_idx = (probe_idx + 1) & (len - 1);
110 num_pre_rounds -= 1;
111 }
112
113 while num_rounds > 0 {
114 query_oram_and_randomize(&mut expected, probe_idx, oram, rng);
115 *stash_size_by_count.entry(oram.stash_size()).or_default() += 1;
116 probe_idx = (probe_idx + 1) & (len - 1);
117 num_rounds -= 1;
118 }
119 stash_size_by_count
120}
121
122pub fn exercise_omap<KeySize, ValSize, O, R>(mut num_rounds: usize, omap: &mut O, rng: &mut R)
125where
126 KeySize: ArrayLength<u8>,
127 ValSize: ArrayLength<u8>,
128 O: ObliviousHashMap<KeySize, ValSize>,
129 R: RngCore + CryptoRng,
130{
131 let mut expected = BTreeMap::<A8Bytes<KeySize>, A8Bytes<ValSize>>::default();
132 let mut probe_positions = Vec::<A8Bytes<KeySize>>::new();
133 let mut probe_idx = 0usize;
134
135 while num_rounds > 0 {
136 if probe_idx >= probe_positions.len() {
137 let mut bytes = A8Bytes::<KeySize>::default();
138 rng.fill_bytes(&mut bytes);
139 probe_positions.push(bytes);
140 probe_idx = 0;
141 }
142
143 let query1 = probe_positions[probe_idx].clone();
145 let query2 = {
146 let mut bytes = A8Bytes::<KeySize>::default();
147 rng.fill_bytes(&mut bytes);
148 bytes
149 };
150
151 for query in &[query1, query2] {
152 {
154 let mut output = A8Bytes::<ValSize>::default();
155 let result_code = omap.read(query, &mut output);
156
157 let expected_ent = expected.entry(query.clone());
158 match expected_ent {
159 Entry::Vacant(_) => {
160 assert_eq!(result_code, OMAP_NOT_FOUND);
161 }
162 Entry::Occupied(occ) => {
163 assert_eq!(result_code, OMAP_FOUND);
164 assert_eq!(&output, occ.get());
165 }
166 };
167 }
168
169 let action = rng.next_u32() % 7;
171 match action {
172 0 => {
174 continue;
175 }
176 1 | 2 => {
177 let mut new_val = A8Bytes::<ValSize>::default();
179 rng.fill_bytes(new_val.as_mut_slice());
180 let result_code = omap.vartime_write(query, &new_val, Choice::from(1));
181
182 if expected.contains_key(query) {
183 assert_eq!(result_code, OMAP_FOUND);
184 } else {
185 assert_eq!(result_code, OMAP_NOT_FOUND);
186 }
187
188 expected
189 .entry(query.clone())
190 .or_default()
191 .copy_from_slice(new_val.as_slice());
192 }
193 3 | 4 => {
194 let mut new_val = A8Bytes::<ValSize>::default();
196 rng.fill_bytes(new_val.as_mut_slice());
197 let result_code = omap.vartime_write(query, &new_val, Choice::from(0));
198
199 if expected.contains_key(query) {
200 assert_eq!(result_code, OMAP_FOUND);
201 } else {
202 assert_eq!(result_code, OMAP_NOT_FOUND);
203 }
204
205 expected.entry(query.clone()).or_insert(new_val);
206 }
207 5 => {
208 omap.access(query, |result_code, val| {
210 match expected.entry(query.clone()) {
211 Entry::Vacant(_) => {
212 assert_eq!(result_code, OMAP_NOT_FOUND);
213 }
214 Entry::Occupied(mut occ) => {
215 assert_eq!(result_code, OMAP_FOUND);
216 rng.fill_bytes(val.as_mut_slice());
217 *occ.get_mut() = val.clone();
218 }
219 }
220 })
221 }
222 _ => {
223 let result_code = omap.remove(query);
225
226 if expected.remove(query).is_some() {
227 assert_eq!(result_code, OMAP_FOUND);
228 } else {
229 assert_eq!(result_code, OMAP_NOT_FOUND);
230 }
231 }
232 };
233
234 {
236 let mut output = A8Bytes::<ValSize>::default();
238 let result_code = omap.read(query, &mut output);
239
240 let expected_ent = expected.entry(query.clone());
241 match expected_ent {
242 Entry::Vacant(_) => {
243 assert_eq!(result_code, OMAP_NOT_FOUND);
244 }
245 Entry::Occupied(occ) => {
246 assert_eq!(result_code, OMAP_FOUND);
247 assert_eq!(&output, occ.get(),);
248 }
249 };
250 }
251 }
252
253 probe_idx += 1;
254 num_rounds -= 1;
255 }
256}
257
258pub fn test_omap_overflow<KeySize, ValSize, O>(omap: &mut O) -> u64
262where
263 KeySize: ArrayLength<u8>,
264 ValSize: ArrayLength<u8>,
265 O: ObliviousHashMap<KeySize, ValSize>,
266{
267 let mut idx = 1u64;
269 let mut key = A8Bytes::<KeySize>::default();
270 let mut val = A8Bytes::<ValSize>::default();
271
272 loop {
273 assert_eq!(omap.len(), idx - 1, "unexpected omap.len()");
274 key[0..8].copy_from_slice(&idx.to_le_bytes());
275 val[0..8].copy_from_slice(&idx.to_le_bytes());
276 let result_code = omap.vartime_write(&key, &val, Choice::from(0));
277
278 if result_code == OMAP_FOUND {
279 panic!("unexpectedly found item idx = {}", idx);
280 } else if result_code == OMAP_INVALID_KEY {
281 panic!("unexpectedly recieved OMAP_INVALID_KEY, idx = {}", idx);
282 } else if result_code == OMAP_OVERFLOW {
283 assert_eq!(
285 omap.len(),
286 idx - 1,
287 "omap.len() unexpected value after overflow"
288 );
289 let mut temp = A8Bytes::<ValSize>::default();
290 for idx2 in 1u64..idx {
291 key[0..8].copy_from_slice(&idx2.to_le_bytes());
292 val[0..8].copy_from_slice(&idx2.to_le_bytes());
293 let result_code = omap.read(&key, &mut temp);
294 assert_eq!(
295 result_code, OMAP_FOUND,
296 "Failed to find an item that should be in the map: idx2 = {}",
297 idx2
298 );
299 assert_eq!(
300 temp, val,
301 "Value that was stored in the map was wrong after overflow: idx2 = {}",
302 idx2
303 );
304 }
305 return omap.len();
306 } else if result_code != OMAP_NOT_FOUND {
307 panic!("unexpected result code: {}", result_code);
308 }
309
310 idx += 1;
311 }
312}
313
314pub fn exercise_omap_counter_table<KeySize, O, R>(mut num_rounds: usize, omap: &mut O, rng: &mut R)
317where
318 KeySize: ArrayLength<u8>,
319 O: ObliviousHashMap<KeySize, U8>,
320 R: RngCore + CryptoRng,
321{
322 type ValSize = U8;
323 let zero: A8Bytes<ValSize> = Default::default();
324
325 let mut expected = BTreeMap::<A8Bytes<KeySize>, A8Bytes<ValSize>>::default();
326 let mut probe_positions = Vec::<A8Bytes<KeySize>>::new();
327 let mut probe_idx = 0usize;
328
329 while num_rounds > 0 {
330 if probe_idx >= probe_positions.len() {
331 let mut bytes = A8Bytes::<KeySize>::default();
332 rng.fill_bytes(&mut bytes);
333 probe_positions.push(bytes);
334 probe_idx = 0;
335 }
336
337 let query1 = probe_positions[probe_idx].clone();
339 let query2 = {
340 let mut bytes = A8Bytes::<KeySize>::default();
341 rng.fill_bytes(&mut bytes);
342 bytes
343 };
344
345 for query in &[query1, query2] {
346 {
348 let mut output = A8Bytes::<ValSize>::default();
349 let result_code = omap.read(query, &mut output);
350
351 let expected_ent = expected.entry(query.clone());
352 match expected_ent {
353 Entry::Vacant(_) => {
354 assert!(
356 result_code == OMAP_NOT_FOUND
357 || (result_code == OMAP_FOUND && output == zero),
358 "Expected no value but omap found nonzero value: result_code {}",
359 result_code
360 );
361 }
362 Entry::Occupied(occ) => {
363 assert!(
364 result_code == OMAP_FOUND
365 || (result_code == OMAP_NOT_FOUND && occ.get() == &zero),
366 "Expected a value but OMAP found none: result_code: {}",
367 result_code
368 );
369 assert_eq!(&output, occ.get());
370 }
371 };
372 }
373
374 let result_code = omap.access_and_insert(query, &zero, rng, |_status_code, buffer| {
376 let num = u64::from_ne_bytes(*buffer.as_ref()) + 1;
377 *buffer = Aligned(num.to_ne_bytes().into());
378
379 expected
380 .entry(query.clone())
381 .or_default()
382 .copy_from_slice(buffer);
383 });
384 assert!(result_code != OMAP_INVALID_KEY, "Invalid key");
385 if result_code == OMAP_OVERFLOW {
386 let mut buffer = A8Bytes::<ValSize>::default();
389 let _result_code = omap.read(query, &mut buffer);
390
391 let map_num = u64::from_ne_bytes(*buffer.as_ref());
392 let expected_buf = expected.get(query).unwrap().clone();
393 let expected_num = u64::from_ne_bytes(*expected_buf.as_ref());
394 assert!(
395 map_num == expected_num || map_num + 1 == expected_num,
396 "Unexpected value in omap: map_num {}, expected_num = {}",
397 map_num,
398 expected_num
399 );
400 expected
401 .entry(query.clone())
402 .or_default()
403 .copy_from_slice(&buffer);
404 }
405
406 {
408 let mut output = A8Bytes::<ValSize>::default();
410 let result_code = omap.read(query, &mut output);
411
412 let expected_ent = expected.entry(query.clone());
413 match expected_ent {
414 Entry::Vacant(_) => {
415 assert!(
417 result_code == OMAP_NOT_FOUND
418 || (result_code == OMAP_FOUND && output == zero),
419 "Expected no value but omap found nonzero value: result_code {}",
420 result_code
421 );
422 }
423 Entry::Occupied(occ) => {
424 assert!(
425 result_code == OMAP_FOUND
426 || (result_code == OMAP_NOT_FOUND && occ.get() == &zero),
427 "Expected a value but OMAP found none: result_code: {}",
428 result_code
429 );
430 assert_eq!(&output, occ.get());
431 }
432 };
433 }
434 }
435
436 probe_idx += 1;
437 num_rounds -= 1;
438 }
439}