rostl_datastructures/
array.rs

1//! Implements a fixed-size array with a fixed-size element type.
2//! The array is oblivious to the access pattern.
3//!
4
5use std::{array::from_fn, mem::ManuallyDrop};
6
7use bytemuck::Pod;
8use rand::{rng, Rng};
9use rostl_oram::{
10  circuit_oram::CircuitORAM,
11  linear_oram::{oblivious_read_index, oblivious_write_index},
12  prelude::PositionType,
13  recursive_oram::RecursivePositionMap,
14};
15use rostl_primitives::{indexable::Length, traits::Cmov};
16
17/// A fixed sized array defined at compile time.
18/// The size of the array is public.
19pub type Array<T, const N: usize> = FixedArray<T, N>;
20/// A fixed sized array defined at runtime.
21/// The size of the array is public.
22pub type DArray<T> = DynamicArray<T>;
23
24/// A fixed-size oblivious array, optimal for small sizes.
25/// The size of the array is public.
26#[repr(C)]
27#[derive(Debug)]
28pub struct ShortArray<T, const N: usize>
29// where T: Cmov Default,
30{
31  /// The underlying data storage, which is public
32  pub(crate) data: [T; N],
33}
34
35impl<T, const N: usize> ShortArray<T, N>
36where
37  T: Cmov + Pod + Default,
38{
39  /// Creates a new `ShortArray` with the given size `n`.
40  pub fn new() -> Self {
41    Self { data: [T::default(); N] }
42  }
43
44  /// Reads from the index
45  pub fn read(&self, index: usize, out: &mut T) {
46    oblivious_read_index(&self.data, index, out);
47  }
48
49  /// Writes to the index
50  pub fn write(&mut self, index: usize, value: T) {
51    oblivious_write_index(&mut self.data, index, value);
52  }
53}
54
55impl<T, const N: usize> Length for ShortArray<T, N> {
56  fn len(&self) -> usize {
57    N
58  }
59}
60
61impl<T, const N: usize> Default for ShortArray<T, N>
62where
63  T: Cmov + Pod + Default,
64{
65  fn default() -> Self {
66    Self::new()
67  }
68}
69
70/// A fixed-size oblivious array, optimal for large sizes.
71/// The size of the array is public.
72#[repr(C)]
73#[derive(Debug)]
74pub struct LongArray<T, const N: usize>
75where
76  T: Cmov + Pod,
77{
78  /// The actual data storage oram
79  data: CircuitORAM<T>,
80  /// The position map for the oram
81  pos_map: RecursivePositionMap,
82}
83impl<T, const N: usize> LongArray<T, N>
84where
85  T: Cmov + Pod + Default + std::fmt::Debug,
86{
87  /// Creates a new `LongArray` with the given size `n`.
88  pub fn new() -> Self {
89    Self { data: CircuitORAM::new(N), pos_map: RecursivePositionMap::new(N) }
90  }
91
92  /// Reads from the index
93  pub fn read(&mut self, index: usize, out: &mut T) {
94    let new_pos = rng().random_range(0..N as PositionType);
95    let old_pos = self.pos_map.access_position(index, new_pos);
96    self.data.read(old_pos, new_pos, index, out);
97  }
98
99  /// Writes to the index
100  pub fn write(&mut self, index: usize, value: T) {
101    let new_pos = rng().random_range(0..N as PositionType);
102    let old_pos = self.pos_map.access_position(index, new_pos);
103    self.data.write_or_insert(old_pos, new_pos, index, value);
104  }
105}
106
107impl<T: Cmov + Pod, const N: usize> Length for LongArray<T, N> {
108  fn len(&self) -> usize {
109    N
110  }
111}
112
113impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Default for LongArray<T, N> {
114  fn default() -> Self {
115    Self::new()
116  }
117}
118
119// UNDONE(git-52): Optimize SHORT_ARRAY_THRESHOLD
120const SHORT_ARRAY_THRESHOLD: usize = 128;
121
122/// A fixed-size array that switches between `ShortArray` and `LongArray` based on the size.
123/// The size of the array is public.
124///
125/// # Invariants
126/// if `N <= SHORT_ARRAY_THRESHOLD`, then `ShortArray` is used, otherwise `LongArray` is used.
127///
128#[repr(C)]
129pub union FixedArray<T, const N: usize>
130where
131  T: Cmov + Pod,
132{
133  /// Short variant, linear scan
134  short: ManuallyDrop<ShortArray<T, N>>,
135  /// Long variant, oram
136  long: ManuallyDrop<LongArray<T, N>>,
137}
138
139impl<T, const N: usize> Drop for FixedArray<T, N>
140where
141  T: Cmov + Pod,
142{
143  fn drop(&mut self) {
144    if N <= SHORT_ARRAY_THRESHOLD {
145      unsafe {
146        ManuallyDrop::drop(&mut self.short);
147      }
148    } else {
149      unsafe {
150        ManuallyDrop::drop(&mut self.long);
151      }
152    }
153  }
154}
155
156impl<T, const N: usize> std::fmt::Debug for FixedArray<T, N>
157where
158  T: Cmov + Pod + std::fmt::Debug,
159{
160  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161    if N <= SHORT_ARRAY_THRESHOLD {
162      let short_array: &ManuallyDrop<ShortArray<T, N>>;
163      unsafe {
164        short_array = &self.short;
165      }
166      short_array.fmt(f)
167    } else {
168      let long_array: &ManuallyDrop<LongArray<T, N>>;
169      unsafe {
170        long_array = &self.long;
171      }
172      long_array.fmt(f)
173    }
174  }
175}
176
177impl<T, const N: usize> FixedArray<T, N>
178where
179  T: Cmov + Pod + Default + std::fmt::Debug,
180{
181  /// Creates a new `LongArray` with the given size `n`.
182  pub fn new() -> Self {
183    if N <= SHORT_ARRAY_THRESHOLD {
184      FixedArray { short: ManuallyDrop::new(ShortArray::new()) }
185    } else {
186      FixedArray { long: ManuallyDrop::new(LongArray::new()) }
187    }
188  }
189
190  /// Reads from the index
191  pub fn read(&mut self, index: usize, out: &mut T) {
192    if N <= SHORT_ARRAY_THRESHOLD {
193      // Do an unsafe cast to avoid borrowing issues
194      let short_array: &mut ManuallyDrop<ShortArray<T, N>>;
195      unsafe {
196        short_array = &mut self.short;
197      }
198      short_array.read(index, out);
199    } else {
200      let long_array: &mut ManuallyDrop<LongArray<T, N>>;
201      unsafe {
202        long_array = &mut self.long;
203      }
204      long_array.read(index, out);
205    }
206  }
207
208  /// Writes to the index
209  pub fn write(&mut self, index: usize, value: T) {
210    if N <= SHORT_ARRAY_THRESHOLD {
211      // Do an unsafe cast to avoid borrowing issues
212      let short_array: &mut ManuallyDrop<ShortArray<T, N>>;
213      unsafe {
214        short_array = &mut self.short;
215      }
216      short_array.write(index, value);
217    } else {
218      let long_array: &mut ManuallyDrop<LongArray<T, N>>;
219      unsafe {
220        long_array = &mut self.long;
221      }
222      long_array.write(index, value);
223    }
224  }
225}
226
227impl<T: Cmov + Pod, const N: usize> Length for FixedArray<T, N> {
228  fn len(&self) -> usize {
229    N
230  }
231}
232
233impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Default for FixedArray<T, N> {
234  fn default() -> Self {
235    Self::new()
236  }
237}
238
239// impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Drop for FixedArray<T, N> {
240//   fn drop(&mut self) {
241//     if N <= SHORT_ARRAY_THRESHOLD {
242//       let short_array: &mut ShortArray<T, N>;
243//       unsafe {
244//         short_array = std::mem::transmute::<&mut Self, &mut ShortArray<T, N>>(self);
245//       }
246//       std::mem::drop(short_array);
247//     } else {
248//       let long_array: &mut LongArray<T, N>;
249//       unsafe {
250//         long_array = std::mem::transmute::<&mut Self, &mut LongArray<T, N>>(self);
251//       }
252//       std::mem::drop(long_array);
253//     }
254//   }
255// }
256
257/// An array whose size is determined at runtime.
258/// The size of the array is public.
259/// The array is oblivious to the access pattern.
260///
261#[derive(Debug)]
262pub struct DynamicArray<T>
263where
264  T: Cmov + Pod,
265{
266  /// The actual data storage oram
267  data: CircuitORAM<T>,
268  /// The position map for the oram
269  pos_map: RecursivePositionMap,
270}
271
272impl<T> DynamicArray<T>
273where
274  T: Cmov + Pod + Default + std::fmt::Debug,
275{
276  /// Creates a new `LongArray` with the given size `n`.
277  pub fn new(n: usize) -> Self {
278    Self { data: CircuitORAM::new(n), pos_map: RecursivePositionMap::new(n) }
279  }
280
281  /// Resizes the array to have `n` elements.
282  pub fn resize(&mut self, n: usize) {
283    let mut new_array = Self::new(n);
284    for i in 0..self.len() {
285      let mut value = Default::default();
286      self.read(i, &mut value);
287      new_array.write(i, value);
288    }
289    // UNDONE(git-57): Is this 0 cost in rust? DynamicArray is noncopy, so I would expect move semantics here, but double check
290    *self = new_array;
291  }
292
293  /// Reads from the index
294  pub fn read(&mut self, index: usize, out: &mut T) {
295    let new_pos = rng().random_range(0..self.len() as PositionType);
296    let old_pos = self.pos_map.access_position(index, new_pos);
297    self.data.read(old_pos, new_pos, index, out);
298  }
299
300  /// Writes to the index
301  pub fn write(&mut self, index: usize, value: T) {
302    let new_pos = rng().random_range(0..self.len() as PositionType);
303    let old_pos = self.pos_map.access_position(index, new_pos);
304    self.data.write_or_insert(old_pos, new_pos, index, value);
305  }
306
307  /// Updates the value at the index using the update function.
308  pub fn update<R, F>(&mut self, index: usize, update_func: F) -> (bool, R)
309  where
310    F: FnOnce(&mut T) -> R,
311  {
312    let new_pos = rng().random_range(0..self.len() as PositionType);
313    let old_pos = self.pos_map.access_position(index, new_pos);
314    self.data.update(old_pos, new_pos, index, update_func)
315  }
316}
317
318impl<T: Cmov + Pod> Length for DynamicArray<T> {
319  #[inline(always)]
320  fn len(&self) -> usize {
321    self.pos_map.n
322  }
323}
324
325/// A set of `W` subarrays that can be used to store a fixed number of total elements defined at `new` time. It is leaked which subarray is being accessed.
326///
327#[derive(Debug)]
328pub struct MultiWayArray<T, const W: usize>
329where
330  T: Cmov + Pod,
331{
332  /// The actual data storage oram
333  data: CircuitORAM<T>,
334  /// The position maps for each subarray
335  pos_map: [RecursivePositionMap; W],
336}
337
338impl<T, const W: usize> MultiWayArray<T, W>
339where
340  T: Cmov + Pod + Default + std::fmt::Debug,
341{
342  /// Creates a new `MultiWayArray` with the given size `n`.
343  pub fn new(n: usize) -> Self {
344    assert!(W.is_power_of_two(), "W must be a power of two due to all the ilog2's here");
345    Self { data: CircuitORAM::new(n), pos_map: from_fn(|_| RecursivePositionMap::new(n)) }
346  }
347
348  fn get_real_index(&self, subarray: usize, index: usize) -> usize {
349    debug_assert!(subarray < W, "Subarray index out of bounds");
350    debug_assert!(index < self.len(), "Index out of bounds");
351    (index << W.ilog2()) | subarray
352  }
353
354  /// Reads from the subarray and index
355  pub fn read(&mut self, subarray: usize, index: usize, out: &mut T) {
356    let new_pos = rng().random_range(0..self.len() as PositionType);
357    let old_pos = self.pos_map[subarray].access_position(index, new_pos);
358    let real_index = self.get_real_index(subarray, index);
359    self.data.read(old_pos, new_pos, real_index, out);
360  }
361
362  /// Writes to the subarray and index
363  pub fn write(&mut self, subarray: usize, index: usize, value: T) {
364    let new_pos = rng().random_range(0..self.len() as PositionType);
365    let old_pos = self.pos_map[subarray].access_position(index, new_pos);
366    let real_index = self.get_real_index(subarray, index);
367    self.data.write_or_insert(old_pos, new_pos, real_index, value);
368  }
369
370  /// Updates the value at the subarray and index using the update function.
371  pub fn update<R, F>(&mut self, subarray: usize, index: usize, update_func: F) -> (bool, R)
372  where
373    F: FnOnce(&mut T) -> R,
374  {
375    let new_pos = rng().random_range(0..self.len() as PositionType);
376    let old_pos = self.pos_map[subarray].access_position(index, new_pos);
377    let real_index = self.get_real_index(subarray, index);
378    self.data.update(old_pos, new_pos, real_index, update_func)
379  }
380}
381
382impl<T: Cmov + Pod, const W: usize> Length for MultiWayArray<T, W> {
383  #[inline(always)]
384  fn len(&self) -> usize {
385    self.pos_map[0].n
386  }
387}
388
389// UNDONE(git-30): Benchmark short array
390// UNDONE(git-30): Benchmark long array
391// UNDONE(git-30): Benchmark fixed array
392// UNDONE(git-30): Benchmark dynamic array
393// If in rust update monorfization is truly 0-cost, ten we can implement the following two via an update function:
394// UNDONE(git-31): Implement versions of read and write that hide the operation from the caller.
395// UNDONE(git-31): Implement read and write that have an enable flag (maybe_read, maybe_write).
396
397#[cfg(test)]
398#[allow(clippy::reversed_empty_ranges)]
399mod tests {
400  use super::*;
401
402  macro_rules! m_test_fixed_array_exhaustive {
403    ($arraytp:ident, $valtp:ty, $size:expr) => {{
404      println!("Testing {} with size {}", stringify!($arraytp), $size);
405      let mut arr = $arraytp::<$valtp, $size>::new();
406      assert_eq!(arr.len(), $size);
407      for i in 0..$size {
408        let mut value = Default::default();
409        arr.read(i, &mut value);
410        assert_eq!(value, Default::default());
411      }
412      assert_eq!(arr.len(), $size);
413      for i in 0..$size {
414        let value = i as $valtp;
415        arr.write(i, value);
416      }
417      assert_eq!(arr.len(), $size);
418      for i in 0..$size {
419        let mut value = Default::default();
420        arr.read(i, &mut value);
421        let v = i as $valtp;
422        assert_eq!(value, v);
423      }
424      assert_eq!(arr.len(), $size);
425    }};
426  }
427
428  macro_rules! m_test_multiway_array_exhaustive {
429    ($arraytp:ident, $valtp:ty, $size:expr, $ways:expr) => {{
430      println!("Testing {} with size {}", stringify!($arraytp), $size);
431      let mut arr = $arraytp::<$valtp, $ways>::new($size);
432      assert_eq!(arr.len(), $size);
433      for w in 0..$ways {
434        for i in 0..$size {
435          let mut value = Default::default();
436          arr.read(w, i, &mut value);
437          assert_eq!(value, Default::default());
438        }
439      }
440      assert_eq!(arr.len(), $size);
441
442      for w in 0..$ways {
443        for i in 0..($size / $ways) {
444          let value = (i + w) as $valtp;
445          arr.write(w, i, value);
446        }
447      }
448      assert_eq!(arr.len(), $size);
449      for w in 0..$ways {
450        for i in 0..($size / $ways) {
451          let mut value = Default::default();
452          arr.read(w, i, &mut value);
453          let v = (i + w) as $valtp;
454          assert_eq!(value, v);
455        }
456      }
457      assert_eq!(arr.len(), $size);
458    }};
459  }
460
461  macro_rules! m_test_dynamic_array_exhaustive {
462    ($arraytp:ident, $valtp:ty, $size:expr) => {{
463      println!("Testing {} with size {}", stringify!($arraytp), $size);
464      let mut arr = $arraytp::<$valtp>::new($size);
465      assert_eq!(arr.len(), $size);
466      for i in 0..$size {
467        let mut value = Default::default();
468        arr.read(i, &mut value);
469        assert_eq!(value, Default::default());
470      }
471      assert_eq!(arr.len(), $size);
472      for i in 0..$size {
473        let value = i as $valtp;
474        arr.write(i, value);
475      }
476      assert_eq!(arr.len(), $size);
477      for i in 0..$size {
478        let mut value = Default::default();
479        arr.read(i, &mut value);
480        let v = i as $valtp;
481        assert_eq!(value, v);
482      }
483      assert_eq!(arr.len(), $size);
484      arr.resize($size + 1);
485      assert_eq!(arr.len(), $size + 1);
486      for i in 0..$size {
487        let mut value = Default::default();
488        arr.read(i, &mut value);
489        let v = i as $valtp;
490        assert_eq!(value, v);
491      }
492      assert_eq!(arr.len(), $size + 1);
493      for i in $size..($size + 1) {
494        let mut value = Default::default();
495        arr.read(i, &mut value);
496        assert_eq!(value, Default::default());
497      }
498      assert_eq!(arr.len(), $size + 1);
499      arr.resize(2 * $size);
500      assert_eq!(arr.len(), 2 * $size);
501      for i in 0..$size {
502        let mut value = Default::default();
503        arr.read(i, &mut value);
504        let v = i as $valtp;
505        assert_eq!(value, v);
506      }
507      assert_eq!(arr.len(), 2 * $size);
508      for i in $size..(2 * $size) {
509        let mut value = Default::default();
510        arr.read(i, &mut value);
511        assert_eq!(value, Default::default());
512      }
513      assert_eq!(arr.len(), 2 * $size);
514      // UNDONE(git-29): Test update
515    }};
516  }
517
518  #[test]
519  fn test_fixed_arrays() {
520    m_test_fixed_array_exhaustive!(ShortArray, u32, 1);
521    m_test_fixed_array_exhaustive!(ShortArray, u32, 2);
522    m_test_fixed_array_exhaustive!(ShortArray, u32, 3);
523    m_test_fixed_array_exhaustive!(ShortArray, u64, 15);
524    m_test_fixed_array_exhaustive!(ShortArray, u8, 33);
525    m_test_fixed_array_exhaustive!(ShortArray, u64, 200);
526
527    // m_test_fixed_array_exhaustive!(LongArray, u32, 1);
528    m_test_fixed_array_exhaustive!(LongArray, u32, 2);
529    m_test_fixed_array_exhaustive!(LongArray, u32, 3);
530    m_test_fixed_array_exhaustive!(LongArray, u64, 15);
531    m_test_fixed_array_exhaustive!(LongArray, u8, 33);
532
533    m_test_fixed_array_exhaustive!(FixedArray, u32, 1);
534    m_test_fixed_array_exhaustive!(FixedArray, u32, 2);
535    m_test_fixed_array_exhaustive!(FixedArray, u32, 3);
536    m_test_fixed_array_exhaustive!(FixedArray, u64, 15);
537    m_test_fixed_array_exhaustive!(FixedArray, u8, 33);
538    m_test_fixed_array_exhaustive!(FixedArray, u64, 200);
539  }
540
541  #[test]
542  fn test_multiway_array() {
543    // m_test_multiway_array_exhaustive!(MultiWayArray, u32, 1, 1);
544    m_test_multiway_array_exhaustive!(MultiWayArray, u32, 2, 1);
545    m_test_multiway_array_exhaustive!(MultiWayArray, u32, 3, 1);
546    m_test_multiway_array_exhaustive!(MultiWayArray, u64, 15, 1);
547    m_test_multiway_array_exhaustive!(MultiWayArray, u8, 33, 1);
548    m_test_multiway_array_exhaustive!(MultiWayArray, u64, 200, 1);
549
550    // m_test_multiway_array_exhaustive!(MultiWayArray, u32, 1, 2);
551    m_test_multiway_array_exhaustive!(MultiWayArray, u32, 2, 2);
552    m_test_multiway_array_exhaustive!(MultiWayArray, u32, 3, 2);
553    m_test_multiway_array_exhaustive!(MultiWayArray, u64, 15, 2);
554    m_test_multiway_array_exhaustive!(MultiWayArray, u8, 33, 2);
555    m_test_multiway_array_exhaustive!(MultiWayArray, u64, 200, 2);
556
557    // m_test_multiway_array_exhaustive!(MultiWayArray, u32, 1, 4);
558    m_test_multiway_array_exhaustive!(MultiWayArray, u32, 2, 4);
559    m_test_multiway_array_exhaustive!(MultiWayArray, u32, 3, 4);
560    m_test_multiway_array_exhaustive!(MultiWayArray, u64, 15, 4);
561    m_test_multiway_array_exhaustive!(MultiWayArray, u8, 33, 4);
562    m_test_multiway_array_exhaustive!(MultiWayArray, u64, 200, 4);
563  }
564
565  #[test]
566  fn test_dynamic_array() {
567    // m_test_dynamic_array_exhaustive!(DynamicArray, u32, 1);
568    m_test_dynamic_array_exhaustive!(DynamicArray, u32, 2);
569    m_test_dynamic_array_exhaustive!(DynamicArray, u32, 3);
570    m_test_dynamic_array_exhaustive!(DynamicArray, u64, 15);
571    m_test_dynamic_array_exhaustive!(DynamicArray, u8, 33);
572    m_test_dynamic_array_exhaustive!(DynamicArray, u64, 200);
573  }
574}