rostl_datastructures/
queue.rs

1//! This module implements oblivious queues
2#![allow(clippy::needless_bitwise_bool)] // UNDONE(git-8): This is needed to enforce the bitwise operations to not short circuit. Investigate if we should be using helper functions instead.
3
4use bytemuck::{Pod, Zeroable};
5use rostl_primitives::{
6  cmov_body, cxchg_body, impl_cmov_for_generic_pod, indexable::Length, traits::Cmov,
7  traits::_Cmovbase,
8};
9
10use crate::array::ShortArray;
11
12/// An element in a short queue.
13/// See `ShortQueue` for more details.
14/// # Invariant
15/// * `timestamp` == 0 ==> `value` is not valid
16/// * `timestamp` != 0 ==> `value` is valid and in the queue
17/// * `timestamp` is unique for each enqueued element and in the range `[lowest_timestamp, highest_timestamp]`
18#[repr(C)]
19#[derive(Debug, Default, Clone, Copy, Zeroable)]
20pub struct ShortQueueElement<T>
21where
22  T: Cmov + Pod,
23{
24  timestamp: usize,
25  pub(crate) value: T,
26}
27unsafe impl<T: Cmov + Pod> Pod for ShortQueueElement<T> {}
28impl_cmov_for_generic_pod!(ShortQueueElement<T>; where T: Cmov + Pod);
29
30impl<T> ShortQueueElement<T>
31where
32  T: Cmov + Pod,
33{
34  /// Returns true if the element is empty.
35  pub const fn is_empty(&self) -> bool {
36    self.timestamp == 0
37  }
38}
39
40/// Implements a queue with a fixed maximum size.
41/// The queue access pattern and size are oblivious.
42///
43/// There are two trivial efficient ways to implement this for short queues:
44/// 1. Use oblivious compaction:
45///  - Push: n (2 + log n)
46///  - Pop: n
47///  - Iter: n
48/// 2. Use timestamps:
49///  - Push: n
50///  - Pop: n
51///  - Iter: n (1 + log^2 n)
52///
53/// This implementation uses timestamps (2.), as we only need push and pop for the unsorted map.
54/// # Invariants
55/// * `highest_timestamp` is the timestamp of the most recently added element
56/// * `lowest_timestamp` is the timestamp of the oldest element added, or just non-zero if the queue is empty
57/// * `size` is the number of `elements` with non-zero timestamps
58/// * an element in `elements` is valid if its timestamp is non-zero, in which case the timestamp is unique and in the range `[lowest_timestamp, highest_timestamp]`
59#[derive(Debug)]
60pub struct ShortQueue<T, const N: usize>
61where
62  T: Cmov + Pod,
63{
64  // The timestamp of the most recently added element
65  highest_timestamp: usize,
66  // The timestamp of the oldest element added
67  lowest_timestamp: usize,
68  // Number of elements in the queue
69  pub(crate) size: usize,
70  // The array that stores the elements and their timestamps
71  pub(crate) elements: ShortArray<ShortQueueElement<T>, N>,
72}
73
74impl<T, const N: usize> ShortQueue<T, N>
75where
76  T: Cmov + Pod + Default,
77{
78  /// Creates a new empty `ShortQueue` with maximum size `N`.
79  pub fn new() -> Self {
80    Self { highest_timestamp: 0, lowest_timestamp: 1, size: 0, elements: ShortArray::new() }
81  }
82
83  /// Pushes `element` into the queue if `real` is true.
84  pub fn maybe_push(&mut self, real: bool, element: T) {
85    debug_assert!(!real | (self.size < N));
86
87    self.size.cmov(&(self.size + 1), real);
88    self.highest_timestamp.cmov(&(self.highest_timestamp + 1), real);
89    let mut inserted = !real;
90    let mut lowest_timestamp = self.highest_timestamp;
91    for i in 0..self.elements.len() {
92      let curr = &mut self.elements.data[i];
93      let is_empty = curr.is_empty();
94      let should_insert = !inserted & is_empty;
95      let is_lowest_timemstamp = !is_empty & (curr.timestamp < lowest_timestamp);
96      curr.timestamp.cmov(&self.highest_timestamp, should_insert);
97      curr.value.cmov(&element, should_insert);
98      lowest_timestamp.cmov(&curr.timestamp, is_lowest_timemstamp);
99      inserted |= should_insert;
100    }
101
102    debug_assert!(inserted);
103
104    self.lowest_timestamp.cmov(&lowest_timestamp, real);
105  }
106
107  /// Pops an element from the queue into `out` if `real` is true.
108  pub fn maybe_pop(&mut self, real: bool, out: &mut T) {
109    debug_assert!(!real | (self.size > 0));
110
111    self.size.cmov(&(self.size.wrapping_sub(1)), real);
112    let mut second_lowest_timestamp = self.highest_timestamp;
113    for i in 0..self.elements.len() {
114      let curr = &mut self.elements.data[i];
115      let is_lowest = curr.timestamp == self.lowest_timestamp;
116      let could_be_second_lowest =
117        !curr.is_empty() & !is_lowest & (curr.timestamp < second_lowest_timestamp);
118      let should_pop = real & is_lowest;
119      second_lowest_timestamp.cmov(&curr.timestamp, could_be_second_lowest);
120      out.cmov(&curr.value, should_pop);
121      curr.timestamp.cmov(&0, should_pop);
122    }
123    self.lowest_timestamp.cmov(&second_lowest_timestamp, real);
124  }
125}
126
127impl<T, const N: usize> Length for ShortQueue<T, N>
128where
129  T: Cmov + Pod + Default,
130{
131  fn len(&self) -> usize {
132    self.size
133  }
134}
135
136impl<T, const N: usize> Default for ShortQueue<T, N>
137where
138  T: Cmov + Pod + Default,
139{
140  fn default() -> Self {
141    Self::new()
142  }
143}
144
145// UNDONE(git-36): Implement ShortStack and LongStack using CircuitORAM
146// UNDONE(git-37): Implement LongQueue using CircuitORAM
147
148// UNDONE(git-39): Benchmark ShortQueue
149
150#[cfg(test)]
151mod tests {
152  use super::*;
153
154  #[test]
155  fn test_short_queue() {
156    let mut queue: ShortQueue<u32, 3> = ShortQueue::new();
157    assert_eq!(queue.len(), 0);
158    queue.maybe_push(true, 1); // ==> [1]
159    assert_eq!(queue.len(), 1);
160
161    queue.maybe_push(true, 2); // ==> [1, 2]
162    assert_eq!(queue.len(), 2);
163
164    queue.maybe_push(false, 42);
165    assert_eq!(queue.len(), 2);
166
167    queue.maybe_push(true, 3); // ==> [1, 2, 3]
168    assert_eq!(queue.len(), 3);
169
170    queue.maybe_push(false, 4);
171    assert_eq!(queue.len(), 3);
172
173    let mut out = 0;
174    queue.maybe_pop(true, &mut out); // ==> [2, 3]
175    assert_eq!(out, 1);
176    assert_eq!(queue.len(), 2);
177
178    queue.maybe_pop(true, &mut out); // ==> [3]
179    assert_eq!(out, 2);
180    assert_eq!(queue.len(), 1);
181
182    queue.maybe_pop(false, &mut out);
183    assert_eq!(queue.len(), 1);
184
185    queue.maybe_pop(true, &mut out); // ==> []
186    assert_eq!(out, 3);
187    assert_eq!(queue.len(), 0);
188
189    queue.maybe_pop(false, &mut out);
190    assert_eq!(queue.len(), 0);
191  }
192}