rostl_datastructures/
stack.rs

1//! This module implements an oblivious stack
2// The stack is implemented as a linked list on top of NRORAM.
3
4use bytemuck::{Pod, Zeroable};
5use rand::{rngs::ThreadRng, Rng};
6use rostl_oram::{
7  circuit_oram::CircuitORAM,
8  prelude::{PositionType, DUMMY_POS},
9};
10use rostl_primitives::{
11  cmov_body, cxchg_body, impl_cmov_for_generic_pod, indexable::Length, traits::Cmov,
12  traits::_Cmovbase,
13};
14
15#[repr(align(8))]
16#[derive(Debug, Default, Clone, Copy, Zeroable)]
17struct StackElement<T>
18where
19  T: Cmov + Pod,
20{
21  value: T,
22  next: PositionType,
23}
24unsafe impl<T: Cmov + Pod> Pod for StackElement<T> {}
25impl_cmov_for_generic_pod!(StackElement<T>; where T: Cmov + Pod);
26
27/// Implements a stack with a fixed maximum size.
28/// The stack access pattern and size are oblivious.
29/// The stack is implemented as a linked list on top of NRORAM.
30/// # Invariants
31/// * 1) The linked list of elements is in monotonic decreasing order of Ids.
32#[derive(Debug)]
33pub struct Stack<T>
34where
35  T: Cmov + Pod,
36{
37  oram: CircuitORAM<StackElement<T>>,
38  top: PositionType,
39  size: usize,
40  rng: ThreadRng,
41}
42
43impl<T> Stack<T>
44where
45  T: Cmov + Pod + Default + Clone + std::fmt::Debug,
46{
47  /// Creates a new stack.
48  pub fn new(max_size: usize) -> Self {
49    Self { oram: CircuitORAM::new(max_size), top: DUMMY_POS, size: 0, rng: rand::rng() }
50  }
51
52  /// Pushes a new element on the stack if `real` is true.
53  /// If `real` is false, the element is not pushed and the stack size is not incremented.
54  pub fn maybe_push(&mut self, real: bool, value: T) {
55    debug_assert!(!real || self.size < self.oram.max_n);
56
57    let new_id = self.size + 1; // inv1
58    let read_pos = self.rng.random_range(0..self.oram.max_n as PositionType);
59
60    let mut new_pos = self.rng.random_range(0..self.oram.max_n as PositionType);
61    new_pos.cmov(&DUMMY_POS, !real); // if not real, new_pos is DUMMY_POS, oram will ignore the write
62
63    let wv = StackElement { value, next: self.top };
64
65    let _found = self.oram.write_or_insert(read_pos, new_pos, new_id, wv);
66    debug_assert!(!_found);
67
68    self.top.cmov(&new_pos, real); // if real, top is new_pos
69    self.size.cmov(&(self.size + 1), real);
70  }
71
72  /// Pops the top element from the stack if `real` is true.
73  /// The popped element is returned in `out`.
74  /// The stack size is decremented by 1 if `real` is true.
75  /// If `real` is false, the element is not popped, the stack size is not decremented, and `out` is not modified.
76  pub fn maybe_pop(&mut self, real: bool, out: &mut T) {
77    debug_assert!(!real || self.size > 0);
78
79    let target_id = self.size; // inv1 - the position of the top of the stack is the size of the stack.
80    let mut read_pos = self.rng.random_range(0..self.oram.max_n as PositionType);
81    read_pos.cmov(&self.top, real);
82    let mut new_pos = read_pos;
83    new_pos.cmov(&DUMMY_POS, real); // if real, we should delete the top element. if not real, we should not change the read element.
84
85    let mut imse = StackElement::default();
86
87    self.oram.read(read_pos, read_pos, target_id, &mut imse);
88
89    out.cmov(&imse.value, real);
90    self.top.cmov(&imse.next, real);
91    self.size.cmov(&(self.size - 1), real);
92  }
93}
94
95impl<T> Length for Stack<T>
96where
97  T: Cmov + Pod,
98{
99  fn len(&self) -> usize {
100    self.size
101  }
102}
103
104#[cfg(test)]
105mod tests {
106  use super::*;
107
108  #[test]
109  fn test_stack() {
110    let mut stack = Stack::<u32>::new(10);
111    let mut out = 0;
112    stack.maybe_push(true, 100);
113    assert_eq!(stack.len(), 1);
114    stack.maybe_push(true, 222);
115    assert_eq!(stack.len(), 2);
116    stack.maybe_push(true, 3333);
117    assert_eq!(stack.len(), 3);
118    stack.maybe_push(false, 123214);
119    assert_eq!(stack.len(), 3);
120    stack.maybe_pop(true, &mut out);
121    assert_eq!(stack.len(), 2);
122    assert_eq!(out, 3333);
123    stack.maybe_pop(true, &mut out);
124    assert_eq!(stack.len(), 1);
125    assert_eq!(out, 222);
126    out = 123;
127    stack.maybe_pop(false, &mut out);
128    assert_eq!(stack.len(), 1);
129    assert_eq!(out, 123);
130    stack.maybe_pop(true, &mut out);
131    assert_eq!(stack.len(), 0);
132    assert_eq!(out, 100);
133  }
134
135  // UNDONE(git-61): Benchmark Stack.
136}