rostl_datastructures/
heap.rs

1//! Implements [path oblivious heap](https://eprint.iacr.org/2019/274).
2use bytemuck::{Pod, Zeroable};
3use rand::{rngs::ThreadRng, Rng};
4use rostl_oram::{
5  circuit_oram::{remove_element, write_block_to_empty_slot, Block, CircuitORAM, S, Z},
6  heap_tree::HeapTree,
7  prelude::{PositionType, K},
8};
9use rostl_primitives::traits::{Cmov, _Cmovbase};
10use rostl_primitives::{cmov_body, cxchg_body, impl_cmov_for_generic_pod};
11
12#[derive(Clone, Copy, Debug, Zeroable)]
13#[repr(C)]
14/// A logical heap element.
15pub struct HeapElement<V>
16where
17  V: Cmov + Pod,
18{
19  /// The key associated with the heap element.
20  pub key: K,
21  /// The value associated with the heap element.
22  pub value: V,
23}
24unsafe impl<V: Cmov + Pod> Pod for HeapElement<V> {}
25impl_cmov_for_generic_pod!(HeapElement<V>; where V: Cmov + Pod);
26impl<V: Cmov + Pod> Default for HeapElement<V> {
27  fn default() -> Self {
28    Self { key: K::MAX, value: V::zeroed() }
29  }
30}
31
32#[derive(Debug)]
33/// An oblivious heap.
34/// Elements are stored in an ORAM, along with information about the location of the minimum element in each subtree.
35/// # Invariants
36/// * `metadata` stores the minimum element in each subtree.
37/// * Heap elements are stored in a non-recursive ORAM.
38/// * After insertion, the oram key (timestamp) and path for an element do not change.
39pub struct Heap<V>
40where
41  V: Cmov + Pod,
42{
43  /// The heap elements.
44  pub data: CircuitORAM<HeapElement<V>>,
45  /// The metadata tree used for storing the element with minimum key in the subtree.
46  pub metadata: HeapTree<Block<HeapElement<V>>>,
47  /// Thread local rng.
48  pub rng: ThreadRng,
49  /// maximum size of the heap.
50  pub max_size: usize,
51  /// timestamp: usize,
52  pub timestamp: K,
53}
54
55impl<V> Heap<V>
56where
57  V: Cmov + Pod + std::fmt::Debug,
58{
59  /// Creates a Heap with maximum size of `n` elements.
60  pub fn new(n: usize) -> Self {
61    let data = CircuitORAM::new(n);
62    let default_value = Block::<HeapElement<V>>::default();
63    let metadata = HeapTree::new_with(data.h, default_value);
64    Self { data, metadata, rng: rand::rng(), max_size: n, timestamp: 0 }
65  }
66
67  /// Finds the minimum element in the heap.
68  /// # Returns
69  /// * The minimum element in the heap. => if the heap is non-empty.
70  /// * A `HeapElement<V>` with `pos = DUMMY` => if the heap is empty.
71  pub fn find_min(&self) -> Block<HeapElement<V>> {
72    let mut min_node = *self.metadata.get_path_at_depth(0, 0);
73
74    for elem in &self.data.stash[0..S] {
75      let should_mov = (!elem.is_empty()) & (elem.value.key < min_node.value.key);
76      min_node.cmov(elem, should_mov);
77    }
78
79    min_node
80  }
81
82  fn evict(&mut self, pos: PositionType) {
83    self.data.read_path_and_get_nodes(pos);
84    self.data.evict_once_fast(pos);
85    self.data.write_back_path(pos);
86  }
87
88  /// Prints the heap for debugging purposes.
89  #[cfg(test)]
90  pub fn print_for_debug(&self) {
91    let data = &self.data;
92    println!("Stash: {:?}", data.stash);
93    for i in 0..data.h {
94      print!("Level {i}: ");
95      for j in 0..(1 << i) {
96        print!("{} ", j << (data.h - 1 - i));
97        print!("data.h:{} ", data.h);
98        print!(
99          "{:?} ",
100          data.tree.get_path_at_depth(
101            i,
102            ((j << (data.h - 1 - i)) as u32).reverse_bits() >> (32 - data.h + 1)
103          )
104        );
105      }
106      println!();
107    }
108  }
109
110  // Updates the metadata for the minimum element along a path `pos`.
111  // # Preconditions:
112  // * The path is already loaded into the stash.
113  // * All the metadata except for this path is correct.
114  fn update_min(&mut self, pos: PositionType) {
115    let data = &self.data;
116    let mut h_index = self.metadata.height;
117    let metadata = &mut self.metadata;
118
119    let mut curr_min = Block::<HeapElement<V>>::default();
120    curr_min.value.key = K::MAX;
121
122    for elems in data.stash[S..(S + self.data.h * Z)].chunks(2).rev() {
123      for elem in elems {
124        let should_mov = (!elem.is_empty()) & (elem.value.key < curr_min.value.key);
125        curr_min.cmov(elem, should_mov);
126      }
127
128      if h_index != metadata.height {
129        let sibling = metadata.get_sibling(h_index, pos);
130
131        let should_mov = (!sibling.is_empty()) & (sibling.value.key < curr_min.value.key);
132        curr_min.cmov(sibling, should_mov);
133      }
134
135      *metadata.get_path_at_depth_mut(h_index - 1, pos) = curr_min;
136
137      h_index -= 1;
138    }
139  }
140
141  /// Inserts a new element `value` with priority `key` into the heap.
142  /// # Returns
143  /// * the position and timestamp of the inserted element.
144  pub fn insert(&mut self, key: K, value: V) -> (PositionType, K) {
145    let new_pos = self.rng.random_range(0..self.data.max_n as PositionType);
146    let oram_key: K = self.timestamp;
147    self.timestamp += 1;
148    let heap_value = HeapElement::<V> { key, value };
149
150    write_block_to_empty_slot(
151      &mut self.data.stash[..S],
152      &Block::<HeapElement<V>> { pos: new_pos, key: oram_key, value: heap_value },
153    );
154
155    for _ in 0..2 {
156      let pos_to_evict = self.rng.random_range(0..self.data.max_n as PositionType);
157      self.evict(pos_to_evict);
158      self.update_min(pos_to_evict);
159    }
160
161    (new_pos, oram_key)
162  }
163
164  /// Deletes an element from the heap given it's timestamp and path.
165  /// # Behavior
166  /// * If the element is not in the heap, nothing happens.
167  pub fn delete(&mut self, pos: PositionType, timestamp: K) {
168    self.data.read_path_and_get_nodes(pos);
169    remove_element(&mut self.data.stash, timestamp);
170    self.data.evict_once_fast(pos);
171    self.data.write_back_path(pos);
172    self.update_min(pos);
173
174    let pos_to_evict = self.rng.random_range(0..self.data.max_n as PositionType);
175    self.evict(pos_to_evict);
176    self.update_min(pos_to_evict);
177  }
178
179  /// Find and delete the minimum element from the heap.
180  pub fn extract_min(&mut self) {
181    let to_delete = self.find_min();
182    self.delete(to_delete.pos, to_delete.key);
183  }
184}
185
186#[cfg(test)]
187mod tests {
188  use std::{cmp::Reverse, collections::BinaryHeap};
189
190  use super::*;
191
192  #[test]
193  fn test_insert_and_find_min() {
194    let mut heap = Heap::new(4);
195
196    heap.insert(10, 100);
197    heap.insert(5, 50);
198    heap.insert(20, 200);
199
200    let min_element = heap.find_min();
201
202    assert_eq!(min_element.value.key, 5);
203    assert_eq!(min_element.value.value, 50);
204  }
205
206  #[test]
207  fn test_insert_and_extract_min() {
208    let mut heap = Heap::new(4);
209
210    heap.insert(30, 300);
211    heap.insert(10, 100);
212    heap.insert(20, 200);
213
214    let min_element = heap.find_min();
215    assert_eq!(min_element.value.key, 10);
216    assert_eq!(min_element.value.value, 100);
217
218    heap.extract_min();
219
220    let new_min_element = heap.find_min();
221    assert_eq!(new_min_element.value.key, 20);
222    assert_eq!(new_min_element.value.value, 200);
223  }
224
225  #[test]
226  fn test_delete() {
227    let mut heap = Heap::new(4);
228
229    let _location = heap.insert(15, 150);
230
231    let min_element = heap.find_min();
232
233    heap.delete(min_element.pos, min_element.key);
234
235    let min_element_after_delete = heap.find_min();
236    assert!(min_element_after_delete.is_empty())
237  }
238
239  #[test]
240  fn test_multiple_inserts_and_extracts() {
241    let mut heap = Heap::new(8);
242
243    for i in (1..=8).rev() {
244      heap.insert(i, (i * 10) as u64);
245    }
246
247    let mut last_key = 0;
248    for _i in 0..8 {
249      let min_element = heap.find_min();
250      assert!(!min_element.is_empty());
251      assert!(min_element.value.key >= last_key);
252      last_key = min_element.value.key;
253      heap.extract_min();
254    }
255  }
256
257  #[test]
258  fn test_stress_with_many_operations() {
259    let mut heap = Heap::new(32); // Larger heap for stress test
260    let mut reference_heap = BinaryHeap::new();
261    let operations = 100;
262
263    // Track inserted items for potential deletion
264    let mut inserted = Vec::new();
265
266    for _ in 0..operations {
267      let op = rand::rng().random_range(0..2);
268      //heap.print_for_debug();
269      match op {
270        0 => {
271          // Insert
272          let key = rand::rng().random_range(0..1000);
273          let value = key as u64 * 10;
274          let _inserted_location = heap.insert(key, value);
275          let min_element = heap.find_min();
276          inserted.push(min_element);
277          reference_heap.push(Reverse((key, value)));
278        }
279        1 => {
280          // Extract min
281          if !reference_heap.is_empty() {
282            let min_element = heap.find_min();
283            heap.extract_min();
284
285            if let Some(Reverse((reference_min_key, reference_min_val))) = reference_heap.pop() {
286              assert_eq!(
287                min_element.value.key, reference_min_key,
288                "Extract min returned incorrect key"
289              );
290              assert_eq!(
291                min_element.value.value, reference_min_val,
292                "Extract min returned incorrect value"
293              );
294            }
295          }
296        }
297        _ => unreachable!(),
298      }
299
300      // Verify minimum is consistent
301      if !inserted.is_empty() {
302        let min_element = heap.find_min();
303        if let Some(Reverse((reference_min_key, _))) = reference_heap.peek() {
304          assert_eq!(
305            min_element.value.key, *reference_min_key,
306            "Heap minimum doesn't match reference after operation"
307          );
308        }
309      }
310    }
311  }
312}