1use bytemuck::{Pod, Zeroable};
3use rand::{rngs::ThreadRng, Rng};
4use rostl_oram::{
5 circuit_oram::{remove_element, write_block_to_empty_slot, Block, CircuitORAM, S, Z},
6 heap_tree::HeapTree,
7 prelude::{PositionType, K},
8};
9use rostl_primitives::traits::{Cmov, _Cmovbase};
10use rostl_primitives::{cmov_body, cxchg_body, impl_cmov_for_generic_pod};
11
12#[derive(Clone, Copy, Debug, Zeroable)]
13#[repr(C)]
14pub struct HeapElement<V>
16where
17 V: Cmov + Pod,
18{
19 pub key: K,
21 pub value: V,
23}
24unsafe impl<V: Cmov + Pod> Pod for HeapElement<V> {}
25impl_cmov_for_generic_pod!(HeapElement<V>; where V: Cmov + Pod);
26impl<V: Cmov + Pod> Default for HeapElement<V> {
27 fn default() -> Self {
28 Self { key: K::MAX, value: V::zeroed() }
29 }
30}
31
32#[derive(Debug)]
33pub struct Heap<V>
40where
41 V: Cmov + Pod,
42{
43 pub data: CircuitORAM<HeapElement<V>>,
45 pub metadata: HeapTree<Block<HeapElement<V>>>,
47 pub rng: ThreadRng,
49 pub max_size: usize,
51 pub timestamp: K,
53}
54
55impl<V> Heap<V>
56where
57 V: Cmov + Pod + std::fmt::Debug,
58{
59 pub fn new(n: usize) -> Self {
61 let data = CircuitORAM::new(n);
62 let default_value = Block::<HeapElement<V>>::default();
63 let metadata = HeapTree::new_with(data.h, default_value);
64 Self { data, metadata, rng: rand::rng(), max_size: n, timestamp: 0 }
65 }
66
67 pub fn find_min(&self) -> Block<HeapElement<V>> {
72 let mut min_node = *self.metadata.get_path_at_depth(0, 0);
73
74 for elem in &self.data.stash[0..S] {
75 let should_mov = (!elem.is_empty()) & (elem.value.key < min_node.value.key);
76 min_node.cmov(elem, should_mov);
77 }
78
79 min_node
80 }
81
82 fn evict(&mut self, pos: PositionType) {
83 self.data.read_path_and_get_nodes(pos);
84 self.data.evict_once_fast(pos);
85 self.data.write_back_path(pos);
86 }
87
88 #[cfg(test)]
90 pub fn print_for_debug(&self) {
91 let data = &self.data;
92 println!("Stash: {:?}", data.stash);
93 for i in 0..data.h {
94 print!("Level {i}: ");
95 for j in 0..(1 << i) {
96 print!("{} ", j << (data.h - 1 - i));
97 print!("data.h:{} ", data.h);
98 print!(
99 "{:?} ",
100 data.tree.get_path_at_depth(
101 i,
102 ((j << (data.h - 1 - i)) as u32).reverse_bits() >> (32 - data.h + 1)
103 )
104 );
105 }
106 println!();
107 }
108 }
109
110 fn update_min(&mut self, pos: PositionType) {
115 let data = &self.data;
116 let mut h_index = self.metadata.height;
117 let metadata = &mut self.metadata;
118
119 let mut curr_min = Block::<HeapElement<V>>::default();
120 curr_min.value.key = K::MAX;
121
122 for elems in data.stash[S..(S + self.data.h * Z)].chunks(2).rev() {
123 for elem in elems {
124 let should_mov = (!elem.is_empty()) & (elem.value.key < curr_min.value.key);
125 curr_min.cmov(elem, should_mov);
126 }
127
128 if h_index != metadata.height {
129 let sibling = metadata.get_sibling(h_index, pos);
130
131 let should_mov = (!sibling.is_empty()) & (sibling.value.key < curr_min.value.key);
132 curr_min.cmov(sibling, should_mov);
133 }
134
135 *metadata.get_path_at_depth_mut(h_index - 1, pos) = curr_min;
136
137 h_index -= 1;
138 }
139 }
140
141 pub fn insert(&mut self, key: K, value: V) -> (PositionType, K) {
145 let new_pos = self.rng.random_range(0..self.data.max_n as PositionType);
146 let oram_key: K = self.timestamp;
147 self.timestamp += 1;
148 let heap_value = HeapElement::<V> { key, value };
149
150 write_block_to_empty_slot(
151 &mut self.data.stash[..S],
152 &Block::<HeapElement<V>> { pos: new_pos, key: oram_key, value: heap_value },
153 );
154
155 for _ in 0..2 {
156 let pos_to_evict = self.rng.random_range(0..self.data.max_n as PositionType);
157 self.evict(pos_to_evict);
158 self.update_min(pos_to_evict);
159 }
160
161 (new_pos, oram_key)
162 }
163
164 pub fn delete(&mut self, pos: PositionType, timestamp: K) {
168 self.data.read_path_and_get_nodes(pos);
169 remove_element(&mut self.data.stash, timestamp);
170 self.data.evict_once_fast(pos);
171 self.data.write_back_path(pos);
172 self.update_min(pos);
173
174 let pos_to_evict = self.rng.random_range(0..self.data.max_n as PositionType);
175 self.evict(pos_to_evict);
176 self.update_min(pos_to_evict);
177 }
178
179 pub fn extract_min(&mut self) {
181 let to_delete = self.find_min();
182 self.delete(to_delete.pos, to_delete.key);
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use std::{cmp::Reverse, collections::BinaryHeap};
189
190 use super::*;
191
192 #[test]
193 fn test_insert_and_find_min() {
194 let mut heap = Heap::new(4);
195
196 heap.insert(10, 100);
197 heap.insert(5, 50);
198 heap.insert(20, 200);
199
200 let min_element = heap.find_min();
201
202 assert_eq!(min_element.value.key, 5);
203 assert_eq!(min_element.value.value, 50);
204 }
205
206 #[test]
207 fn test_insert_and_extract_min() {
208 let mut heap = Heap::new(4);
209
210 heap.insert(30, 300);
211 heap.insert(10, 100);
212 heap.insert(20, 200);
213
214 let min_element = heap.find_min();
215 assert_eq!(min_element.value.key, 10);
216 assert_eq!(min_element.value.value, 100);
217
218 heap.extract_min();
219
220 let new_min_element = heap.find_min();
221 assert_eq!(new_min_element.value.key, 20);
222 assert_eq!(new_min_element.value.value, 200);
223 }
224
225 #[test]
226 fn test_delete() {
227 let mut heap = Heap::new(4);
228
229 let _location = heap.insert(15, 150);
230
231 let min_element = heap.find_min();
232
233 heap.delete(min_element.pos, min_element.key);
234
235 let min_element_after_delete = heap.find_min();
236 assert!(min_element_after_delete.is_empty())
237 }
238
239 #[test]
240 fn test_multiple_inserts_and_extracts() {
241 let mut heap = Heap::new(8);
242
243 for i in (1..=8).rev() {
244 heap.insert(i, (i * 10) as u64);
245 }
246
247 let mut last_key = 0;
248 for _i in 0..8 {
249 let min_element = heap.find_min();
250 assert!(!min_element.is_empty());
251 assert!(min_element.value.key >= last_key);
252 last_key = min_element.value.key;
253 heap.extract_min();
254 }
255 }
256
257 #[test]
258 fn test_stress_with_many_operations() {
259 let mut heap = Heap::new(32); let mut reference_heap = BinaryHeap::new();
261 let operations = 100;
262
263 let mut inserted = Vec::new();
265
266 for _ in 0..operations {
267 let op = rand::rng().random_range(0..2);
268 match op {
270 0 => {
271 let key = rand::rng().random_range(0..1000);
273 let value = key as u64 * 10;
274 let _inserted_location = heap.insert(key, value);
275 let min_element = heap.find_min();
276 inserted.push(min_element);
277 reference_heap.push(Reverse((key, value)));
278 }
279 1 => {
280 if !reference_heap.is_empty() {
282 let min_element = heap.find_min();
283 heap.extract_min();
284
285 if let Some(Reverse((reference_min_key, reference_min_val))) = reference_heap.pop() {
286 assert_eq!(
287 min_element.value.key, reference_min_key,
288 "Extract min returned incorrect key"
289 );
290 assert_eq!(
291 min_element.value.value, reference_min_val,
292 "Extract min returned incorrect value"
293 );
294 }
295 }
296 }
297 _ => unreachable!(),
298 }
299
300 if !inserted.is_empty() {
302 let min_element = heap.find_min();
303 if let Some(Reverse((reference_min_key, _))) = reference_heap.peek() {
304 assert_eq!(
305 min_element.value.key, *reference_min_key,
306 "Heap minimum doesn't match reference after operation"
307 );
308 }
309 }
310 }
311 }
312}