rostl_datastructures/
array.rs

1//! Implements a fixed-size array with a fixed-size element type.
2//! The array is oblivious to the access pattern.
3//!
4
5use std::{array::from_fn, mem::ManuallyDrop};
6
7use bytemuck::Pod;
8use rand::{rngs::ThreadRng, Rng};
9use rostl_oram::{
10  circuit_oram::CircuitORAM,
11  linear_oram::{oblivious_read_index, oblivious_write_index},
12  prelude::PositionType,
13  recursive_oram::RecursivePositionMap,
14};
15use rostl_primitives::{indexable::Length, traits::Cmov};
16
17/// A fixed sized array defined at compile time.
18/// The size of the array is public.
19pub type Array<T, const N: usize> = FixedArray<T, N>;
20/// A fixed sized array defined at runtime.
21/// The size of the array is public.
22pub type DArray<T> = DynamicArray<T>;
23
24/// A fixed-size oblivious array, optimal for small sizes.
25/// The size of the array is public.
26#[repr(C)]
27#[derive(Debug)]
28pub struct ShortArray<T, const N: usize>
29// where T: Cmov Default,
30{
31  /// The underlying data storage, which is public
32  pub(crate) data: [T; N],
33}
34
35impl<T, const N: usize> ShortArray<T, N>
36where
37  T: Cmov + Pod + Default,
38{
39  /// Creates a new `ShortArray` with the given size `n`.
40  pub fn new() -> Self {
41    Self { data: [T::default(); N] }
42  }
43
44  /// Reads from the index
45  pub fn read(&self, index: usize, out: &mut T) {
46    oblivious_read_index(&self.data, index, out);
47  }
48
49  /// Writes to the index
50  pub fn write(&mut self, index: usize, value: T) {
51    oblivious_write_index(&mut self.data, index, value);
52  }
53}
54
55impl<T, const N: usize> Length for ShortArray<T, N> {
56  fn len(&self) -> usize {
57    N
58  }
59}
60
61impl<T, const N: usize> Default for ShortArray<T, N>
62where
63  T: Cmov + Pod + Default,
64{
65  fn default() -> Self {
66    Self::new()
67  }
68}
69
70/// A fixed-size oblivious array, optimal for large sizes.
71/// The size of the array is public.
72#[repr(C)]
73#[derive(Debug)]
74pub struct LongArray<T, const N: usize>
75where
76  T: Cmov + Pod,
77{
78  /// The actual data storage oram
79  data: CircuitORAM<T>,
80  /// The position map for the oram
81  pos_map: RecursivePositionMap,
82  /// The local rng for the oram
83  rng: ThreadRng,
84}
85impl<T, const N: usize> LongArray<T, N>
86where
87  T: Cmov + Pod + Default + std::fmt::Debug,
88{
89  /// Creates a new `LongArray` with the given size `n`.
90  pub fn new() -> Self {
91    Self { data: CircuitORAM::new(N), pos_map: RecursivePositionMap::new(N), rng: rand::rng() }
92  }
93
94  /// Reads from the index
95  pub fn read(&mut self, index: usize, out: &mut T) {
96    let new_pos = self.rng.random_range(0..N as PositionType);
97    let old_pos = self.pos_map.access_position(index, new_pos);
98    self.data.read(old_pos, new_pos, index, out);
99  }
100
101  /// Writes to the index
102  pub fn write(&mut self, index: usize, value: T) {
103    let new_pos = self.rng.random_range(0..N as PositionType);
104    let old_pos = self.pos_map.access_position(index, new_pos);
105    self.data.write_or_insert(old_pos, new_pos, index, value);
106  }
107}
108
109impl<T: Cmov + Pod, const N: usize> Length for LongArray<T, N> {
110  fn len(&self) -> usize {
111    N
112  }
113}
114
115impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Default for LongArray<T, N> {
116  fn default() -> Self {
117    Self::new()
118  }
119}
120
121// UNDONE(git-52): Optimize SHORT_ARRAY_THRESHOLD
122const SHORT_ARRAY_THRESHOLD: usize = 128;
123
124/// A fixed-size array that switches between `ShortArray` and `LongArray` based on the size.
125/// The size of the array is public.
126///
127/// # Invariants
128/// if `N <= SHORT_ARRAY_THRESHOLD`, then `ShortArray` is used, otherwise `LongArray` is used.
129///
130#[repr(C)]
131pub union FixedArray<T, const N: usize>
132where
133  T: Cmov + Pod,
134{
135  /// Short variant, linear scan
136  short: ManuallyDrop<ShortArray<T, N>>,
137  /// Long variant, oram
138  long: ManuallyDrop<LongArray<T, N>>,
139}
140
141impl<T, const N: usize> Drop for FixedArray<T, N>
142where
143  T: Cmov + Pod,
144{
145  fn drop(&mut self) {
146    if N <= SHORT_ARRAY_THRESHOLD {
147      unsafe {
148        ManuallyDrop::drop(&mut self.short);
149      }
150    } else {
151      unsafe {
152        ManuallyDrop::drop(&mut self.long);
153      }
154    }
155  }
156}
157
158impl<T, const N: usize> std::fmt::Debug for FixedArray<T, N>
159where
160  T: Cmov + Pod + std::fmt::Debug,
161{
162  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163    if N <= SHORT_ARRAY_THRESHOLD {
164      let short_array: &ManuallyDrop<ShortArray<T, N>>;
165      unsafe {
166        short_array = &self.short;
167      }
168      short_array.fmt(f)
169    } else {
170      let long_array: &ManuallyDrop<LongArray<T, N>>;
171      unsafe {
172        long_array = &self.long;
173      }
174      long_array.fmt(f)
175    }
176  }
177}
178
179impl<T, const N: usize> FixedArray<T, N>
180where
181  T: Cmov + Pod + Default + std::fmt::Debug,
182{
183  /// Creates a new `LongArray` with the given size `n`.
184  pub fn new() -> Self {
185    if N <= SHORT_ARRAY_THRESHOLD {
186      FixedArray { short: ManuallyDrop::new(ShortArray::new()) }
187    } else {
188      FixedArray { long: ManuallyDrop::new(LongArray::new()) }
189    }
190  }
191
192  /// Reads from the index
193  pub fn read(&mut self, index: usize, out: &mut T) {
194    if N <= SHORT_ARRAY_THRESHOLD {
195      // Do an unsafe cast to avoid borrowing issues
196      let short_array: &mut ManuallyDrop<ShortArray<T, N>>;
197      unsafe {
198        short_array = &mut self.short;
199      }
200      short_array.read(index, out);
201    } else {
202      let long_array: &mut ManuallyDrop<LongArray<T, N>>;
203      unsafe {
204        long_array = &mut self.long;
205      }
206      long_array.read(index, out);
207    }
208  }
209
210  /// Writes to the index
211  pub fn write(&mut self, index: usize, value: T) {
212    if N <= SHORT_ARRAY_THRESHOLD {
213      // Do an unsafe cast to avoid borrowing issues
214      let short_array: &mut ManuallyDrop<ShortArray<T, N>>;
215      unsafe {
216        short_array = &mut self.short;
217      }
218      short_array.write(index, value);
219    } else {
220      let long_array: &mut ManuallyDrop<LongArray<T, N>>;
221      unsafe {
222        long_array = &mut self.long;
223      }
224      long_array.write(index, value);
225    }
226  }
227}
228
229impl<T: Cmov + Pod, const N: usize> Length for FixedArray<T, N> {
230  fn len(&self) -> usize {
231    N
232  }
233}
234
235impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Default for FixedArray<T, N> {
236  fn default() -> Self {
237    Self::new()
238  }
239}
240
241// impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Drop for FixedArray<T, N> {
242//   fn drop(&mut self) {
243//     if N <= SHORT_ARRAY_THRESHOLD {
244//       let short_array: &mut ShortArray<T, N>;
245//       unsafe {
246//         short_array = std::mem::transmute::<&mut Self, &mut ShortArray<T, N>>(self);
247//       }
248//       std::mem::drop(short_array);
249//     } else {
250//       let long_array: &mut LongArray<T, N>;
251//       unsafe {
252//         long_array = std::mem::transmute::<&mut Self, &mut LongArray<T, N>>(self);
253//       }
254//       std::mem::drop(long_array);
255//     }
256//   }
257// }
258
259/// An array whose size is determined at runtime.
260/// The size of the array is public.
261/// The array is oblivious to the access pattern.
262///
263#[derive(Debug)]
264pub struct DynamicArray<T>
265where
266  T: Cmov + Pod,
267{
268  /// The actual data storage oram
269  data: CircuitORAM<T>,
270  /// The position map for the oram
271  pos_map: RecursivePositionMap,
272  /// The local rng for the oram
273  rng: ThreadRng,
274}
275
276impl<T> DynamicArray<T>
277where
278  T: Cmov + Pod + Default + std::fmt::Debug,
279{
280  /// Creates a new `LongArray` with the given size `n`.
281  pub fn new(n: usize) -> Self {
282    Self { data: CircuitORAM::new(n), pos_map: RecursivePositionMap::new(n), rng: rand::rng() }
283  }
284
285  /// Resizes the array to have `n` elements.
286  pub fn resize(&mut self, n: usize) {
287    let mut new_array = Self::new(n);
288    for i in 0..self.len() {
289      let mut value = Default::default();
290      self.read(i, &mut value);
291      new_array.write(i, value);
292    }
293    // UNDONE(git-57): Is this 0 cost in rust? DynamicArray is noncopy, so I would expect move semantics here, but double check
294    *self = new_array;
295  }
296
297  /// Reads from the index
298  pub fn read(&mut self, index: usize, out: &mut T) {
299    let new_pos = self.rng.random_range(0..self.len() as PositionType);
300    let old_pos = self.pos_map.access_position(index, new_pos);
301    self.data.read(old_pos, new_pos, index, out);
302  }
303
304  /// Writes to the index
305  pub fn write(&mut self, index: usize, value: T) {
306    let new_pos = self.rng.random_range(0..self.len() as PositionType);
307    let old_pos = self.pos_map.access_position(index, new_pos);
308    self.data.write_or_insert(old_pos, new_pos, index, value);
309  }
310
311  /// Updates the value at the index using the update function.
312  pub fn update<R, F>(&mut self, index: usize, update_func: F) -> (bool, R)
313  where
314    F: FnOnce(&mut T) -> R,
315  {
316    let new_pos = self.rng.random_range(0..self.len() as PositionType);
317    let old_pos = self.pos_map.access_position(index, new_pos);
318    self.data.update(old_pos, new_pos, index, update_func)
319  }
320}
321
322impl<T: Cmov + Pod> Length for DynamicArray<T> {
323  #[inline(always)]
324  fn len(&self) -> usize {
325    self.pos_map.n
326  }
327}
328
329/// A set of `W` subarrays that can be used to store a fixed number of total elements defined at `new` time. It is leaked which subarray is being accessed.
330///
331#[derive(Debug)]
332pub struct MultiWayArray<T, const W: usize>
333where
334  T: Cmov + Pod,
335{
336  /// The actual data storage oram
337  data: CircuitORAM<T>,
338  /// The position maps for each subarray
339  pos_map: [RecursivePositionMap; W],
340  /// The local rng for the oram
341  rng: ThreadRng,
342}
343
344impl<T, const W: usize> MultiWayArray<T, W>
345where
346  T: Cmov + Pod + Default + std::fmt::Debug,
347{
348  /// Creates a new `MultiWayArray` with the given size `n`.
349  pub fn new(n: usize) -> Self {
350    assert!(W.is_power_of_two(), "W must be a power of two due to all the ilog2's here");
351    Self {
352      data: CircuitORAM::new(n),
353      pos_map: from_fn(|_| RecursivePositionMap::new(n)),
354      rng: rand::rng(),
355    }
356  }
357
358  fn get_real_index(&self, subarray: usize, index: usize) -> usize {
359    debug_assert!(subarray < W, "Subarray index out of bounds");
360    debug_assert!(index < self.len(), "Index out of bounds");
361    (index << W.ilog2()) | subarray
362  }
363
364  /// Reads from the subarray and index
365  pub fn read(&mut self, subarray: usize, index: usize, out: &mut T) {
366    let new_pos = self.rng.random_range(0..self.len() as PositionType);
367    let old_pos = self.pos_map[subarray].access_position(index, new_pos);
368    let real_index = self.get_real_index(subarray, index);
369    self.data.read(old_pos, new_pos, real_index, out);
370  }
371
372  /// Writes to the subarray and index
373  pub fn write(&mut self, subarray: usize, index: usize, value: T) {
374    let new_pos = self.rng.random_range(0..self.len() as PositionType);
375    let old_pos = self.pos_map[subarray].access_position(index, new_pos);
376    let real_index = self.get_real_index(subarray, index);
377    self.data.write_or_insert(old_pos, new_pos, real_index, value);
378  }
379
380  /// Updates the value at the subarray and index using the update function.
381  pub fn update<R, F>(&mut self, subarray: usize, index: usize, update_func: F) -> (bool, R)
382  where
383    F: FnOnce(&mut T) -> R,
384  {
385    let new_pos = self.rng.random_range(0..self.len() as PositionType);
386    let old_pos = self.pos_map[subarray].access_position(index, new_pos);
387    let real_index = self.get_real_index(subarray, index);
388    self.data.update(old_pos, new_pos, real_index, update_func)
389  }
390}
391
392impl<T: Cmov + Pod, const W: usize> Length for MultiWayArray<T, W> {
393  #[inline(always)]
394  fn len(&self) -> usize {
395    self.pos_map[0].n
396  }
397}
398
399// UNDONE(git-30): Benchmark short array
400// UNDONE(git-30): Benchmark long array
401// UNDONE(git-30): Benchmark fixed array
402// UNDONE(git-30): Benchmark dynamic array
403// If in rust update monorfization is truly 0-cost, ten we can implement the following two via an update function:
404// UNDONE(git-31): Implement versions of read and write that hide the operation from the caller.
405// UNDONE(git-31): Implement read and write that have an enable flag (maybe_read, maybe_write).
406
407#[cfg(test)]
408#[allow(clippy::reversed_empty_ranges)]
409mod tests {
410  use super::*;
411
412  macro_rules! m_test_fixed_array_exhaustive {
413    ($arraytp:ident, $valtp:ty, $size:expr) => {{
414      println!("Testing {} with size {}", stringify!($arraytp), $size);
415      let mut arr = $arraytp::<$valtp, $size>::new();
416      assert_eq!(arr.len(), $size);
417      for i in 0..$size {
418        let mut value = Default::default();
419        arr.read(i, &mut value);
420        assert_eq!(value, Default::default());
421      }
422      assert_eq!(arr.len(), $size);
423      for i in 0..$size {
424        let value = i as $valtp;
425        arr.write(i, value);
426      }
427      assert_eq!(arr.len(), $size);
428      for i in 0..$size {
429        let mut value = Default::default();
430        arr.read(i, &mut value);
431        let v = i as $valtp;
432        assert_eq!(value, v);
433      }
434      assert_eq!(arr.len(), $size);
435    }};
436  }
437
438  macro_rules! m_test_multiway_array_exhaustive {
439    ($arraytp:ident, $valtp:ty, $size:expr, $ways:expr) => {{
440      println!("Testing {} with size {}", stringify!($arraytp), $size);
441      let mut arr = $arraytp::<$valtp, $ways>::new($size);
442      assert_eq!(arr.len(), $size);
443      for w in 0..$ways {
444        for i in 0..$size {
445          let mut value = Default::default();
446          arr.read(w, i, &mut value);
447          assert_eq!(value, Default::default());
448        }
449      }
450      assert_eq!(arr.len(), $size);
451
452      for w in 0..$ways {
453        for i in 0..($size / $ways) {
454          let value = (i + w) as $valtp;
455          arr.write(w, i, value);
456        }
457      }
458      assert_eq!(arr.len(), $size);
459      for w in 0..$ways {
460        for i in 0..($size / $ways) {
461          let mut value = Default::default();
462          arr.read(w, i, &mut value);
463          let v = (i + w) as $valtp;
464          assert_eq!(value, v);
465        }
466      }
467      assert_eq!(arr.len(), $size);
468    }};
469  }
470
471  macro_rules! m_test_dynamic_array_exhaustive {
472    ($arraytp:ident, $valtp:ty, $size:expr) => {{
473      println!("Testing {} with size {}", stringify!($arraytp), $size);
474      let mut arr = $arraytp::<$valtp>::new($size);
475      assert_eq!(arr.len(), $size);
476      for i in 0..$size {
477        let mut value = Default::default();
478        arr.read(i, &mut value);
479        assert_eq!(value, Default::default());
480      }
481      assert_eq!(arr.len(), $size);
482      for i in 0..$size {
483        let value = i as $valtp;
484        arr.write(i, value);
485      }
486      assert_eq!(arr.len(), $size);
487      for i in 0..$size {
488        let mut value = Default::default();
489        arr.read(i, &mut value);
490        let v = i as $valtp;
491        assert_eq!(value, v);
492      }
493      assert_eq!(arr.len(), $size);
494      arr.resize($size + 1);
495      assert_eq!(arr.len(), $size + 1);
496      for i in 0..$size {
497        let mut value = Default::default();
498        arr.read(i, &mut value);
499        let v = i as $valtp;
500        assert_eq!(value, v);
501      }
502      assert_eq!(arr.len(), $size + 1);
503      for i in $size..($size + 1) {
504        let mut value = Default::default();
505        arr.read(i, &mut value);
506        assert_eq!(value, Default::default());
507      }
508      assert_eq!(arr.len(), $size + 1);
509      arr.resize(2 * $size);
510      assert_eq!(arr.len(), 2 * $size);
511      for i in 0..$size {
512        let mut value = Default::default();
513        arr.read(i, &mut value);
514        let v = i as $valtp;
515        assert_eq!(value, v);
516      }
517      assert_eq!(arr.len(), 2 * $size);
518      for i in $size..(2 * $size) {
519        let mut value = Default::default();
520        arr.read(i, &mut value);
521        assert_eq!(value, Default::default());
522      }
523      assert_eq!(arr.len(), 2 * $size);
524      // UNDONE(git-29): Test update
525    }};
526  }
527
528  #[test]
529  fn test_fixed_arrays() {
530    m_test_fixed_array_exhaustive!(ShortArray, u32, 1);
531    m_test_fixed_array_exhaustive!(ShortArray, u32, 2);
532    m_test_fixed_array_exhaustive!(ShortArray, u32, 3);
533    m_test_fixed_array_exhaustive!(ShortArray, u64, 15);
534    m_test_fixed_array_exhaustive!(ShortArray, u8, 33);
535    m_test_fixed_array_exhaustive!(ShortArray, u64, 200);
536
537    // m_test_fixed_array_exhaustive!(LongArray, u32, 1);
538    m_test_fixed_array_exhaustive!(LongArray, u32, 2);
539    m_test_fixed_array_exhaustive!(LongArray, u32, 3);
540    m_test_fixed_array_exhaustive!(LongArray, u64, 15);
541    m_test_fixed_array_exhaustive!(LongArray, u8, 33);
542
543    m_test_fixed_array_exhaustive!(FixedArray, u32, 1);
544    m_test_fixed_array_exhaustive!(FixedArray, u32, 2);
545    m_test_fixed_array_exhaustive!(FixedArray, u32, 3);
546    m_test_fixed_array_exhaustive!(FixedArray, u64, 15);
547    m_test_fixed_array_exhaustive!(FixedArray, u8, 33);
548    m_test_fixed_array_exhaustive!(FixedArray, u64, 200);
549  }
550
551  #[test]
552  fn test_multiway_array() {
553    // m_test_multiway_array_exhaustive!(MultiWayArray, u32, 1, 1);
554    m_test_multiway_array_exhaustive!(MultiWayArray, u32, 2, 1);
555    m_test_multiway_array_exhaustive!(MultiWayArray, u32, 3, 1);
556    m_test_multiway_array_exhaustive!(MultiWayArray, u64, 15, 1);
557    m_test_multiway_array_exhaustive!(MultiWayArray, u8, 33, 1);
558    m_test_multiway_array_exhaustive!(MultiWayArray, u64, 200, 1);
559
560    // m_test_multiway_array_exhaustive!(MultiWayArray, u32, 1, 2);
561    m_test_multiway_array_exhaustive!(MultiWayArray, u32, 2, 2);
562    m_test_multiway_array_exhaustive!(MultiWayArray, u32, 3, 2);
563    m_test_multiway_array_exhaustive!(MultiWayArray, u64, 15, 2);
564    m_test_multiway_array_exhaustive!(MultiWayArray, u8, 33, 2);
565    m_test_multiway_array_exhaustive!(MultiWayArray, u64, 200, 2);
566
567    // m_test_multiway_array_exhaustive!(MultiWayArray, u32, 1, 4);
568    m_test_multiway_array_exhaustive!(MultiWayArray, u32, 2, 4);
569    m_test_multiway_array_exhaustive!(MultiWayArray, u32, 3, 4);
570    m_test_multiway_array_exhaustive!(MultiWayArray, u64, 15, 4);
571    m_test_multiway_array_exhaustive!(MultiWayArray, u8, 33, 4);
572    m_test_multiway_array_exhaustive!(MultiWayArray, u64, 200, 4);
573  }
574
575  #[test]
576  fn test_dynamic_array() {
577    // m_test_dynamic_array_exhaustive!(DynamicArray, u32, 1);
578    m_test_dynamic_array_exhaustive!(DynamicArray, u32, 2);
579    m_test_dynamic_array_exhaustive!(DynamicArray, u32, 3);
580    m_test_dynamic_array_exhaustive!(DynamicArray, u64, 15);
581    m_test_dynamic_array_exhaustive!(DynamicArray, u8, 33);
582    m_test_dynamic_array_exhaustive!(DynamicArray, u64, 200);
583  }
584}