use super::{position_map::PositionMap, stash::ObliviousStash};
use crate::{
bucket::{Bucket, PathOramBlock, PositionBlock},
linear_time_oram::LinearTimeOram,
utils::{
invert_permutation_oblivious, random_permutation_of_0_through_n_exclusive, to_usize_vec,
CompleteBinaryTreeIndex, TreeHeight,
},
Address, BlockSize, BucketSize, Oram, OramBlock, OramError, RecursionCutoff, StashSize,
};
use rand::{CryptoRng, Rng};
pub const DEFAULT_RECURSION_CUTOFF: RecursionCutoff = 1 << 14;
pub const DEFAULT_BLOCKS_PER_BUCKET: BucketSize = 4;
pub const DEFAULT_POSITIONS_PER_BLOCK: BlockSize = 8;
pub const DEFAULT_STASH_OVERFLOW_SIZE: StashSize = 40;
const LINEAR_TIME_ORAM_CUTOFF: RecursionCutoff = 1 << 10;
#[derive(Debug)]
pub struct PathOram<V: OramBlock, const Z: BucketSize, const AB: BlockSize> {
physical_memory: Vec<Bucket<V, Z>>,
stash: ObliviousStash<V>,
position_map: PositionMap<AB, Z>,
height: TreeHeight,
}
#[derive(Debug)]
pub struct DefaultOram<V: OramBlock>(DefaultOramBackend<V>);
#[derive(Debug)]
enum DefaultOramBackend<V: OramBlock> {
Path(PathOram<V, DEFAULT_BLOCKS_PER_BUCKET, DEFAULT_POSITIONS_PER_BLOCK>),
Linear(LinearTimeOram<V>),
}
impl<V: OramBlock> Oram for DefaultOram<V> {
type V = V;
fn block_capacity(&self) -> Result<Address, OramError> {
match &self.0 {
DefaultOramBackend::Path(p) => p.block_capacity(),
DefaultOramBackend::Linear(l) => l.block_capacity(),
}
}
fn access<R: rand::RngCore + CryptoRng, F: Fn(&Self::V) -> Self::V>(
&mut self,
index: Address,
callback: F,
rng: &mut R,
) -> Result<Self::V, OramError> {
match &mut self.0 {
DefaultOramBackend::Path(p) => p.access(index, callback, rng),
DefaultOramBackend::Linear(l) => l.access(index, callback, rng),
}
}
}
impl<V: OramBlock> DefaultOram<V> {
pub fn new<R: Rng + CryptoRng>(
block_capacity: Address,
rng: &mut R,
) -> Result<Self, OramError> {
if block_capacity < LINEAR_TIME_ORAM_CUTOFF {
Ok(Self(DefaultOramBackend::Linear(LinearTimeOram::new(
block_capacity,
)?)))
} else {
Ok(Self(DefaultOramBackend::Path(PathOram::<
V,
DEFAULT_BLOCKS_PER_BUCKET,
DEFAULT_POSITIONS_PER_BLOCK,
>::new_with_parameters(
block_capacity,
rng,
DEFAULT_STASH_OVERFLOW_SIZE,
DEFAULT_RECURSION_CUTOFF,
)?)))
}
}
}
impl<V: OramBlock, const Z: BucketSize, const AB: BlockSize> PathOram<V, Z, AB> {
pub fn new_with_parameters<R: Rng + CryptoRng>(
block_capacity: Address,
rng: &mut R,
overflow_size: StashSize,
recursion_cutoff: RecursionCutoff,
) -> Result<Self, OramError> {
log::info!("PathOram::new(capacity = {})", block_capacity,);
if !block_capacity.is_power_of_two() | (block_capacity <= 1) {
return Err(OramError::InvalidConfigurationError {
parameter_name: "ORAM capacity".to_string(),
parameter_value: block_capacity.to_string(),
});
}
if Z <= 1 {
return Err(OramError::InvalidConfigurationError {
parameter_name: "Bucket size Z".to_string(),
parameter_value: Z.to_string(),
});
}
if recursion_cutoff == 0 {
return Err(OramError::InvalidConfigurationError {
parameter_name: "Recursion cutoff".to_string(),
parameter_value: recursion_cutoff.to_string(),
});
}
let number_of_nodes = block_capacity;
let height: u64 = (block_capacity.ilog2() - 1).into();
let path_size = u64::try_from(Z)? * (height + 1);
let stash = ObliviousStash::new(path_size, overflow_size)?;
let mut physical_memory = Vec::new();
physical_memory.resize(usize::try_from(number_of_nodes)?, Bucket::<V, Z>::default());
let mut position_map =
PositionMap::new(block_capacity, rng, overflow_size, recursion_cutoff)?;
let slot_indices_to_addresses =
random_permutation_of_0_through_n_exclusive(block_capacity, rng);
let addresses_to_slot_indices = invert_permutation_oblivious(&slot_indices_to_addresses)?;
let slot_indices_to_addresses = to_usize_vec(slot_indices_to_addresses)?;
let mut addresses_to_slot_indices = to_usize_vec(addresses_to_slot_indices)?;
let first_leaf_index: usize = 2u64.pow(height.try_into()?).try_into()?;
let last_leaf_index = (2 * first_leaf_index) - 1;
let addresses_per_leaf = 2;
for (leaf_index, tree_bucket) in physical_memory
.iter_mut()
.enumerate()
.take(last_leaf_index + 1)
.skip(first_leaf_index)
{
for slot_index in 0..addresses_per_leaf {
let address_index = (leaf_index - first_leaf_index) * 2 + slot_index;
tree_bucket.blocks[slot_index] = PathOramBlock::<V> {
value: V::default(),
address: slot_indices_to_addresses[address_index].try_into()?,
position: leaf_index.try_into()?,
};
}
}
let ab_address: Address = AB.try_into()?;
let mut num_blocks = block_capacity / ab_address;
if block_capacity % ab_address > 0 {
num_blocks += 1;
addresses_to_slot_indices.resize((block_capacity + ab_address).try_into()?, 0);
}
for block_index in 0..num_blocks {
let mut data = [0; AB];
for i in 0..AB {
let offset: usize = (block_index * ab_address).try_into()?;
data[i] =
(first_leaf_index + addresses_to_slot_indices[offset + i] / 2).try_into()?;
}
let block = PositionBlock::<AB> { data };
position_map.write_position_block(block_index * ab_address, block, rng)?;
}
Ok(Self {
physical_memory,
stash,
position_map,
height,
})
}
#[cfg(test)]
pub(crate) fn stash_occupancy(&self) -> StashSize {
self.stash.occupancy()
}
}
impl<V: OramBlock, const Z: BucketSize, const AB: BlockSize> Oram for PathOram<V, Z, AB> {
type V = V;
fn access<R: Rng + CryptoRng, F: Fn(&V) -> V>(
&mut self,
address: Address,
callback: F,
rng: &mut R,
) -> Result<V, OramError> {
if address > self.block_capacity()? {
return Err(OramError::AddressOutOfBoundsError {
attempted: address,
capacity: self.block_capacity()?,
});
}
let new_position = CompleteBinaryTreeIndex::random_leaf(self.height, rng)?;
let position = self.position_map.write(address, new_position, rng)?;
assert!(position.is_leaf(self.height));
self.stash
.read_from_path(&mut self.physical_memory, position)?;
let result = self.stash.access(address, new_position, callback);
self.stash
.write_to_path(&mut self.physical_memory, position)?;
result
}
fn block_capacity(&self) -> Result<Address, OramError> {
Ok(u64::try_from(self.physical_memory.len())?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{bucket::*, test_utils::*};
use rand::{rngs::StdRng, SeedableRng};
create_path_oram_correctness_tests!(4, 8, 16384, 40);
create_path_oram_correctness_tests!(4, 8, 1, 40);
create_path_oram_correctness_tests!(4, 8, 1, 10);
create_path_oram_correctness_tests!(4, 8, 1, 0);
create_path_oram_correctness_tests!(3, 8, 1, 40);
create_path_oram_correctness_tests!(5, 8, 1, 40);
create_path_oram_correctness_tests!(4, 2, 1, 40);
create_path_oram_correctness_tests!(4, 64, 1, 40);
create_path_oram_stash_size_tests!(4, 8, 16384, 40);
#[test]
fn default_oram_linear_correctness() {
let mut rng = StdRng::seed_from_u64(0);
let mut oram = DefaultOram::<BlockValue<1>>::new(64, &mut rng).unwrap();
match oram.0 {
DefaultOramBackend::Linear(_) => {}
DefaultOramBackend::Path(_) => assert!(false),
}
random_workload(&mut oram, 1000);
}
#[test]
#[ignore]
fn default_oram_path_correctness() {
let mut rng = StdRng::seed_from_u64(0);
let mut oram = DefaultOram::<BlockValue<1>>::new(2048, &mut rng).unwrap();
match oram.0 {
DefaultOramBackend::Linear(_) => {
assert!(false)
}
DefaultOramBackend::Path(_) => {}
}
random_workload(&mut oram, 1000);
}
}