use crate::{
bucket::{Bucket, PathOramBlock},
utils::{bitonic_sort_by_keys, CompleteBinaryTreeIndex, TreeIndex},
Address, BucketSize, OramBlock, OramError, StashSize,
};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
const STASH_GROWTH_INCREMENT: usize = 10;
#[derive(Debug)]
pub struct ObliviousStash<V: OramBlock> {
blocks: Vec<PathOramBlock<V>>,
path_size: StashSize,
}
impl<V: OramBlock> ObliviousStash<V> {
fn len(&self) -> usize {
self.blocks.len()
}
}
impl<V: OramBlock> ObliviousStash<V> {
pub fn new(path_size: StashSize, overflow_size: StashSize) -> Result<Self, OramError> {
let num_stash_blocks: usize = (path_size + overflow_size).try_into()?;
Ok(Self {
blocks: vec![PathOramBlock::<V>::dummy(); num_stash_blocks],
path_size,
})
}
pub fn write_to_path<const Z: BucketSize>(
&mut self,
physical_memory: &mut [Bucket<V, Z>],
position: TreeIndex,
) -> Result<(), OramError> {
let height = position.ct_depth();
let mut level_assignments = vec![TreeIndex::MAX; self.len()];
let mut level_counts = vec![0; usize::try_from(height)? + 1];
for (i, block) in self.blocks.iter().enumerate() {
let block_is_dummy = block.ct_is_dummy();
let an_arbitrary_leaf: TreeIndex = 1 << height;
let block_position =
TreeIndex::conditional_select(&block.position, &an_arbitrary_leaf, block_is_dummy);
let mut assigned = Choice::from(0);
for (level, count) in level_counts.iter_mut().enumerate().rev() {
let level_bucket_full: Choice = count.ct_eq(&(u64::try_from(Z)?));
let level_u64 = u64::try_from(level)?;
let level_satisfies_invariant = block_position
.ct_node_on_path(level_u64, height)
.ct_eq(&position.ct_node_on_path(level_u64, height));
let should_assign = level_satisfies_invariant
& (!level_bucket_full)
& (!block_is_dummy)
& (!assigned);
assigned |= should_assign;
let level_count_incremented = *count + 1;
count.conditional_assign(&level_count_incremented, should_assign);
level_assignments[i].conditional_assign(&level_u64, should_assign);
}
level_assignments[i].conditional_assign(&(TreeIndex::MAX - 1), !assigned);
}
let mut exists_unfilled_levels: Choice = 1.into();
let mut first_unassigned_block_index: usize = 0;
while exists_unfilled_levels.into() {
for (i, block) in self
.blocks
.iter()
.enumerate()
.skip(first_unassigned_block_index)
{
let block_free = block.ct_is_dummy();
let mut assigned: Choice = 0.into();
for (level, count) in level_counts.iter_mut().enumerate() {
let full = count.ct_eq(&(u64::try_from(Z)?));
let no_op = assigned | full | !block_free;
level_assignments[i].conditional_assign(&(u64::try_from(level))?, !no_op);
count.conditional_assign(&(*count + 1), !no_op);
assigned |= !no_op;
}
}
exists_unfilled_levels = 0.into();
for count in level_counts.iter() {
let full = count.ct_eq(&(u64::try_from(Z)?));
exists_unfilled_levels |= !full;
}
if exists_unfilled_levels.into() {
first_unassigned_block_index = self.blocks.len();
self.blocks.resize(
self.blocks.len() + STASH_GROWTH_INCREMENT,
PathOramBlock::<V>::dummy(),
);
level_assignments.resize(
level_assignments.len() + STASH_GROWTH_INCREMENT,
TreeIndex::MAX,
);
log::warn!(
"Stash overflow occurred. Stash resized to {} blocks.",
self.blocks.len()
);
}
}
bitonic_sort_by_keys(&mut self.blocks, &mut level_assignments);
for depth in 0..=height {
let bucket_to_write =
&mut physical_memory[usize::try_from(position.ct_node_on_path(depth, height))?];
for slot_number in 0..Z {
let stash_index = (usize::try_from(depth)?) * Z + slot_number;
bucket_to_write.blocks[slot_number] = self.blocks[stash_index];
}
}
Ok(())
}
pub fn access<F: Fn(&V) -> V>(
&mut self,
address: Address,
new_position: TreeIndex,
value_callback: F,
) -> Result<V, OramError> {
let mut result: V = V::default();
for block in &mut self.blocks {
let is_requested_index = block.address.ct_eq(&address);
result.conditional_assign(&block.value, is_requested_index);
block
.position
.conditional_assign(&new_position, is_requested_index);
let value_to_write = value_callback(&result);
block
.value
.conditional_assign(&value_to_write, is_requested_index);
}
Ok(result)
}
#[cfg(test)]
pub fn occupancy(&self) -> StashSize {
let mut result = 0;
for i in self.path_size.try_into().unwrap()..(self.blocks.len()) {
if !self.blocks[i].is_dummy() {
result += 1;
}
}
result
}
pub fn read_from_path<const Z: crate::BucketSize>(
&mut self,
physical_memory: &mut [Bucket<V, Z>],
position: TreeIndex,
) -> Result<(), OramError> {
let height = position.ct_depth();
for i in (0..(self.path_size / u64::try_from(Z)?)).rev() {
let bucket_index = position.ct_node_on_path(i, height);
let bucket = physical_memory[usize::try_from(bucket_index)?];
for slot_index in 0..Z {
self.blocks[Z * (usize::try_from(i)?) + slot_index] = bucket.blocks[slot_index];
}
}
Ok(())
}
}