rostl_datastructures/
array.rs

1//! Implements a fixed-size array with a fixed-size element type.
2//! The array is oblivious to the access pattern.
3//!
4
5use std::mem::ManuallyDrop;
6
7use bytemuck::Pod;
8use rand::{rngs::ThreadRng, Rng};
9use rostl_oram::{
10  circuit_oram::CircuitORAM,
11  linear_oram::{oblivious_read_index, oblivious_write_index},
12  prelude::PositionType,
13  recursive_oram::RecursivePositionMap,
14};
15use rostl_primitives::{indexable::Length, traits::Cmov};
16
17/// A fixed sized array defined at compile time.
18/// The size of the array is public.
19pub type Array<T, const N: usize> = FixedArray<T, N>;
20/// A fixed sized array defined at runtime.
21/// The size of the array is public.
22pub type DArray<T> = DynamicArray<T>;
23
24/// A fixed-size oblivious array, optimal for small sizes.
25/// The size of the array is public.
26#[repr(C)]
27#[derive(Debug)]
28pub struct ShortArray<T, const N: usize>
29// where T: Cmov Default,
30{
31  /// The underlying data storage, which is public
32  pub(crate) data: [T; N],
33}
34
35impl<T, const N: usize> ShortArray<T, N>
36where
37  T: Cmov + Pod + Default,
38{
39  /// Creates a new `ShortArray` with the given size `n`.
40  pub fn new() -> Self {
41    Self { data: [T::default(); N] }
42  }
43
44  /// Reads from the index
45  pub fn read(&self, index: usize, out: &mut T) {
46    oblivious_read_index(&self.data, index, out);
47  }
48
49  /// Writes to the index
50  pub fn write(&mut self, index: usize, value: T) {
51    oblivious_write_index(&mut self.data, index, value);
52  }
53}
54
55impl<T, const N: usize> Length for ShortArray<T, N> {
56  fn len(&self) -> usize {
57    N
58  }
59}
60
61impl<T, const N: usize> Default for ShortArray<T, N>
62where
63  T: Cmov + Pod + Default,
64{
65  fn default() -> Self {
66    Self::new()
67  }
68}
69
70/// A fixed-size oblivious array, optimal for large sizes.
71/// The size of the array is public.
72#[repr(C)]
73#[derive(Debug)]
74pub struct LongArray<T, const N: usize>
75where
76  T: Cmov + Pod,
77{
78  /// The actual data storage oram
79  data: CircuitORAM<T>,
80  /// The position map for the oram
81  pos_map: RecursivePositionMap,
82  /// The local rng for the oram
83  rng: ThreadRng,
84}
85impl<T, const N: usize> LongArray<T, N>
86where
87  T: Cmov + Pod + Default + std::fmt::Debug,
88{
89  /// Creates a new `LongArray` with the given size `n`.
90  pub fn new() -> Self {
91    Self { data: CircuitORAM::new(N), pos_map: RecursivePositionMap::new(N), rng: rand::rng() }
92  }
93
94  /// Reads from the index
95  pub fn read(&mut self, index: usize, out: &mut T) {
96    let new_pos = self.rng.random_range(0..N as PositionType);
97    let old_pos = self.pos_map.access_position(index, new_pos);
98    self.data.read(old_pos, new_pos, index, out);
99  }
100
101  /// Writes to the index
102  pub fn write(&mut self, index: usize, value: T) {
103    let new_pos = self.rng.random_range(0..N as PositionType);
104    let old_pos = self.pos_map.access_position(index, new_pos);
105    self.data.write_or_insert(old_pos, new_pos, index, value);
106  }
107}
108
109impl<T: Cmov + Pod, const N: usize> Length for LongArray<T, N> {
110  fn len(&self) -> usize {
111    N
112  }
113}
114
115impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Default for LongArray<T, N> {
116  fn default() -> Self {
117    Self::new()
118  }
119}
120
121// UNDONE(git-52): Optimize SHORT_ARRAY_THRESHOLD
122const SHORT_ARRAY_THRESHOLD: usize = 128;
123
124/// A fixed-size array that switches between `ShortArray` and `LongArray` based on the size.
125/// The size of the array is public.
126///
127/// # Invariants
128/// if `N <= SHORT_ARRAY_THRESHOLD`, then `ShortArray` is used, otherwise `LongArray` is used.
129///
130#[repr(C)]
131pub union FixedArray<T, const N: usize>
132where
133  T: Cmov + Pod,
134{
135  /// Short variant, linear scan
136  short: ManuallyDrop<ShortArray<T, N>>,
137  /// Long variant, oram
138  long: ManuallyDrop<LongArray<T, N>>,
139}
140
141impl<T, const N: usize> Drop for FixedArray<T, N>
142where
143  T: Cmov + Pod,
144{
145  fn drop(&mut self) {
146    if N <= SHORT_ARRAY_THRESHOLD {
147      unsafe {
148        ManuallyDrop::drop(&mut self.short);
149      }
150    } else {
151      unsafe {
152        ManuallyDrop::drop(&mut self.long);
153      }
154    }
155  }
156}
157
158impl<T, const N: usize> std::fmt::Debug for FixedArray<T, N>
159where
160  T: Cmov + Pod + std::fmt::Debug,
161{
162  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163    if N <= SHORT_ARRAY_THRESHOLD {
164      let short_array: &ManuallyDrop<ShortArray<T, N>>;
165      unsafe {
166        short_array = &self.short;
167      }
168      short_array.fmt(f)
169    } else {
170      let long_array: &ManuallyDrop<LongArray<T, N>>;
171      unsafe {
172        long_array = &self.long;
173      }
174      long_array.fmt(f)
175    }
176  }
177}
178
179impl<T, const N: usize> FixedArray<T, N>
180where
181  T: Cmov + Pod + Default + std::fmt::Debug,
182{
183  /// Creates a new `LongArray` with the given size `n`.
184  pub fn new() -> Self {
185    if N <= SHORT_ARRAY_THRESHOLD {
186      FixedArray { short: ManuallyDrop::new(ShortArray::new()) }
187    } else {
188      FixedArray { long: ManuallyDrop::new(LongArray::new()) }
189    }
190  }
191
192  /// Reads from the index
193  pub fn read(&mut self, index: usize, out: &mut T) {
194    if N <= SHORT_ARRAY_THRESHOLD {
195      // Do an unsafe cast to avoid borrowing issues
196      let short_array: &mut ManuallyDrop<ShortArray<T, N>>;
197      unsafe {
198        short_array = &mut self.short;
199      }
200      short_array.read(index, out);
201    } else {
202      let long_array: &mut ManuallyDrop<LongArray<T, N>>;
203      unsafe {
204        long_array = &mut self.long;
205      }
206      long_array.read(index, out);
207    }
208  }
209
210  /// Writes to the index
211  pub fn write(&mut self, index: usize, value: T) {
212    if N <= SHORT_ARRAY_THRESHOLD {
213      // Do an unsafe cast to avoid borrowing issues
214      let short_array: &mut ManuallyDrop<ShortArray<T, N>>;
215      unsafe {
216        short_array = &mut self.short;
217      }
218      short_array.write(index, value);
219    } else {
220      let long_array: &mut ManuallyDrop<LongArray<T, N>>;
221      unsafe {
222        long_array = &mut self.long;
223      }
224      long_array.write(index, value);
225    }
226  }
227}
228
229impl<T: Cmov + Pod, const N: usize> Length for FixedArray<T, N> {
230  fn len(&self) -> usize {
231    N
232  }
233}
234
235impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Default for FixedArray<T, N> {
236  fn default() -> Self {
237    Self::new()
238  }
239}
240
241// impl<T: Cmov + Pod + Default + std::fmt::Debug, const N: usize> Drop for FixedArray<T, N> {
242//   fn drop(&mut self) {
243//     if N <= SHORT_ARRAY_THRESHOLD {
244//       let short_array: &mut ShortArray<T, N>;
245//       unsafe {
246//         short_array = std::mem::transmute::<&mut Self, &mut ShortArray<T, N>>(self);
247//       }
248//       std::mem::drop(short_array);
249//     } else {
250//       let long_array: &mut LongArray<T, N>;
251//       unsafe {
252//         long_array = std::mem::transmute::<&mut Self, &mut LongArray<T, N>>(self);
253//       }
254//       std::mem::drop(long_array);
255//     }
256//   }
257// }
258
259/// An array whose size is determined at runtime.
260/// The size of the array is public.
261/// The array is oblivious to the access pattern.
262///
263#[derive(Debug)]
264pub struct DynamicArray<T>
265where
266  T: Cmov + Pod,
267{
268  /// The actual data storage oram
269  data: CircuitORAM<T>,
270  /// The position map for the oram
271  pos_map: RecursivePositionMap,
272  /// The local rng for the oram
273  rng: ThreadRng,
274}
275
276impl<T> DynamicArray<T>
277where
278  T: Cmov + Pod + Default + std::fmt::Debug,
279{
280  /// Creates a new `LongArray` with the given size `n`.
281  pub fn new(n: usize) -> Self {
282    Self { data: CircuitORAM::new(n), pos_map: RecursivePositionMap::new(n), rng: rand::rng() }
283  }
284
285  /// Resizes the array to have `n` elements.
286  pub fn resize(&mut self, n: usize) {
287    let mut new_array = Self::new(n);
288    for i in 0..self.len() {
289      let mut value = Default::default();
290      self.read(i, &mut value);
291      new_array.write(i, value);
292    }
293    // UNDONE(git-57): Is this 0 cost in rust? DynamicArray is noncopy, so I would expect move semantics here, but double check
294    *self = new_array;
295  }
296
297  /// Reads from the index
298  pub fn read(&mut self, index: usize, out: &mut T) {
299    let new_pos = self.rng.random_range(0..self.len() as PositionType);
300    let old_pos = self.pos_map.access_position(index, new_pos);
301    self.data.read(old_pos, new_pos, index, out);
302  }
303
304  /// Writes to the index
305  pub fn write(&mut self, index: usize, value: T) {
306    let new_pos = self.rng.random_range(0..self.len() as PositionType);
307    let old_pos = self.pos_map.access_position(index, new_pos);
308    self.data.write_or_insert(old_pos, new_pos, index, value);
309  }
310
311  /// Updates the value at the index using the update function.
312  pub fn update<R, F>(&mut self, index: usize, update_func: F) -> (bool, R)
313  where
314    F: FnOnce(&mut T) -> R,
315  {
316    let new_pos = self.rng.random_range(0..self.len() as PositionType);
317    let old_pos = self.pos_map.access_position(index, new_pos);
318    self.data.update(old_pos, new_pos, index, update_func)
319  }
320}
321
322impl<T: Cmov + Pod> Length for DynamicArray<T> {
323  #[inline(always)]
324  fn len(&self) -> usize {
325    self.pos_map.n
326  }
327}
328
329// UNDONE(git-30): Benchmark short array
330// UNDONE(git-30): Benchmark long array
331// UNDONE(git-30): Benchmark fixed array
332// UNDONE(git-30): Benchmark dynamic array
333// If in rust update monorfization is truly 0-cost, ten we can implement the following two via an update function:
334// UNDONE(git-31): Implement versions of read and write that hide the operation from the caller.
335// UNDONE(git-31): Implement read and write that have an enable flag (maybe_read, maybe_write).
336
337#[cfg(test)]
338mod tests {
339  use super::*;
340
341  macro_rules! m_test_fixed_array_exaustive {
342    ($arraytp:ident, $valtp:ty, $size:expr) => {{
343      println!("Testing {} with size {}", stringify!($arraytp), $size);
344      let mut arr = $arraytp::<$valtp, $size>::new();
345      assert_eq!(arr.len(), $size);
346      for i in 0..$size {
347        let mut value = Default::default();
348        arr.read(i, &mut value);
349        assert_eq!(value, Default::default());
350      }
351      assert_eq!(arr.len(), $size);
352      for i in 0..$size {
353        let value = i as $valtp;
354        arr.write(i, value);
355      }
356      assert_eq!(arr.len(), $size);
357      for i in 0..$size {
358        let mut value = Default::default();
359        arr.read(i, &mut value);
360        let v = i as $valtp;
361        assert_eq!(value, v);
362      }
363      assert_eq!(arr.len(), $size);
364    }};
365  }
366
367  macro_rules! m_test_dynamic_array_exaustive {
368    ($arraytp:ident, $valtp:ty, $size:expr) => {{
369      println!("Testing {} with size {}", stringify!($arraytp), $size);
370      let mut arr = $arraytp::<$valtp>::new($size);
371      assert_eq!(arr.len(), $size);
372      for i in 0..$size {
373        let mut value = Default::default();
374        arr.read(i, &mut value);
375        assert_eq!(value, Default::default());
376      }
377      assert_eq!(arr.len(), $size);
378      for i in 0..$size {
379        let value = i as $valtp;
380        arr.write(i, value);
381      }
382      assert_eq!(arr.len(), $size);
383      for i in 0..$size {
384        let mut value = Default::default();
385        arr.read(i, &mut value);
386        let v = i as $valtp;
387        assert_eq!(value, v);
388      }
389      assert_eq!(arr.len(), $size);
390      arr.resize($size + 1);
391      assert_eq!(arr.len(), $size + 1);
392      for i in 0..$size {
393        let mut value = Default::default();
394        arr.read(i, &mut value);
395        let v = i as $valtp;
396        assert_eq!(value, v);
397      }
398      assert_eq!(arr.len(), $size + 1);
399      for i in $size..($size + 1) {
400        let mut value = Default::default();
401        arr.read(i, &mut value);
402        assert_eq!(value, Default::default());
403      }
404      assert_eq!(arr.len(), $size + 1);
405      arr.resize(2 * $size);
406      assert_eq!(arr.len(), 2 * $size);
407      for i in 0..$size {
408        let mut value = Default::default();
409        arr.read(i, &mut value);
410        let v = i as $valtp;
411        assert_eq!(value, v);
412      }
413      assert_eq!(arr.len(), 2 * $size);
414      for i in $size..(2 * $size) {
415        let mut value = Default::default();
416        arr.read(i, &mut value);
417        assert_eq!(value, Default::default());
418      }
419      assert_eq!(arr.len(), 2 * $size);
420      // UNDONE(git-29): Test update
421    }};
422  }
423
424  #[test]
425  fn test_fixed_arrays() {
426    m_test_fixed_array_exaustive!(ShortArray, u32, 1);
427    m_test_fixed_array_exaustive!(ShortArray, u32, 2);
428    m_test_fixed_array_exaustive!(ShortArray, u32, 3);
429    m_test_fixed_array_exaustive!(ShortArray, u64, 15);
430    m_test_fixed_array_exaustive!(ShortArray, u8, 33);
431    m_test_fixed_array_exaustive!(ShortArray, u64, 200);
432
433    // m_test_fixed_array_exaustive!(LongArray, u32, 1);
434    m_test_fixed_array_exaustive!(LongArray, u32, 2);
435    m_test_fixed_array_exaustive!(LongArray, u32, 3);
436    m_test_fixed_array_exaustive!(LongArray, u64, 15);
437    m_test_fixed_array_exaustive!(LongArray, u8, 33);
438
439    m_test_fixed_array_exaustive!(FixedArray, u32, 1);
440    m_test_fixed_array_exaustive!(FixedArray, u32, 2);
441    m_test_fixed_array_exaustive!(FixedArray, u32, 3);
442    m_test_fixed_array_exaustive!(FixedArray, u64, 15);
443    m_test_fixed_array_exaustive!(FixedArray, u8, 33);
444    m_test_fixed_array_exaustive!(FixedArray, u64, 200);
445  }
446
447  #[test]
448  fn test_dynamic_array() {
449    // m_test_dynamic_array_exaustive!(DynamicArray, u32, 1);
450    m_test_dynamic_array_exaustive!(DynamicArray, u32, 2);
451    m_test_dynamic_array_exaustive!(DynamicArray, u32, 3);
452    m_test_dynamic_array_exaustive!(DynamicArray, u64, 15);
453    m_test_dynamic_array_exaustive!(DynamicArray, u8, 33);
454    m_test_dynamic_array_exaustive!(DynamicArray, u64, 200);
455  }
456}