rostl_datastructures/
stack.rs1use bytemuck::{Pod, Zeroable};
5use rand::{rngs::ThreadRng, Rng};
6use rostl_oram::{
7 circuit_oram::CircuitORAM,
8 prelude::{PositionType, DUMMY_POS},
9};
10use rostl_primitives::{
11 cmov_body, cxchg_body, impl_cmov_for_generic_pod, indexable::Length, traits::Cmov,
12 traits::_Cmovbase,
13};
14
15#[repr(align(8))]
16#[derive(Debug, Default, Clone, Copy, Zeroable)]
17struct StackElement<T>
18where
19 T: Cmov + Pod,
20{
21 value: T,
22 next: PositionType,
23}
24unsafe impl<T: Cmov + Pod> Pod for StackElement<T> {}
25impl_cmov_for_generic_pod!(StackElement<T>; where T: Cmov + Pod);
26
27#[derive(Debug)]
33pub struct Stack<T>
34where
35 T: Cmov + Pod,
36{
37 oram: CircuitORAM<StackElement<T>>,
38 top: PositionType,
39 size: usize,
40 rng: ThreadRng,
41}
42
43impl<T> Stack<T>
44where
45 T: Cmov + Pod + Default + Clone + std::fmt::Debug,
46{
47 pub fn new(max_size: usize) -> Self {
49 Self { oram: CircuitORAM::new(max_size), top: DUMMY_POS, size: 0, rng: rand::rng() }
50 }
51
52 pub fn maybe_push(&mut self, real: bool, value: T) {
55 debug_assert!(!real || self.size < self.oram.max_n);
56
57 let new_id = self.size + 1; let read_pos = self.rng.random_range(0..self.oram.max_n as PositionType);
59
60 let mut new_pos = self.rng.random_range(0..self.oram.max_n as PositionType);
61 new_pos.cmov(&DUMMY_POS, !real); let wv = StackElement { value, next: self.top };
64
65 let _found = self.oram.write_or_insert(read_pos, new_pos, new_id, wv);
66 debug_assert!(!_found);
67
68 self.top.cmov(&new_pos, real); self.size.cmov(&(self.size + 1), real);
70 }
71
72 pub fn maybe_pop(&mut self, real: bool, out: &mut T) {
77 debug_assert!(!real || self.size > 0);
78
79 let target_id = self.size; let mut read_pos = self.rng.random_range(0..self.oram.max_n as PositionType);
81 read_pos.cmov(&self.top, real);
82 let mut new_pos = read_pos;
83 new_pos.cmov(&DUMMY_POS, real); let mut imse = StackElement::default();
86
87 self.oram.read(read_pos, read_pos, target_id, &mut imse);
88
89 out.cmov(&imse.value, real);
90 self.top.cmov(&imse.next, real);
91 self.size.cmov(&(self.size - 1), real);
92 }
93}
94
95impl<T> Length for Stack<T>
96where
97 T: Cmov + Pod,
98{
99 fn len(&self) -> usize {
100 self.size
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn test_stack() {
110 let mut stack = Stack::<u32>::new(10);
111 let mut out = 0;
112 stack.maybe_push(true, 100);
113 assert_eq!(stack.len(), 1);
114 stack.maybe_push(true, 222);
115 assert_eq!(stack.len(), 2);
116 stack.maybe_push(true, 3333);
117 assert_eq!(stack.len(), 3);
118 stack.maybe_push(false, 123214);
119 assert_eq!(stack.len(), 3);
120 stack.maybe_pop(true, &mut out);
121 assert_eq!(stack.len(), 2);
122 assert_eq!(out, 3333);
123 stack.maybe_pop(true, &mut out);
124 assert_eq!(stack.len(), 1);
125 assert_eq!(out, 222);
126 out = 123;
127 stack.maybe_pop(false, &mut out);
128 assert_eq!(stack.len(), 1);
129 assert_eq!(out, 123);
130 stack.maybe_pop(true, &mut out);
131 assert_eq!(stack.len(), 0);
132 assert_eq!(out, 100);
133 }
134
135 }