use alloc::vec;
use aligned_cmov::{
subtle::{Choice, ConstantTimeEq, ConstantTimeLess},
typenum::{PartialDiv, Prod, Unsigned, U16, U64, U8},
A64Bytes, A8Bytes, ArrayLength, AsAlignedChunks, AsNeSlice, CMov,
};
use alloc::{boxed::Box, vec::Vec};
use balanced_tree_index::TreeIndex;
use core::{marker::PhantomData, ops::Mul};
use mc_oblivious_traits::{
log2_ceil, ORAMStorage, ORAMStorageCreator, PositionMap, PositionMapCreator, ORAM,
};
use rand_core::{CryptoRng, RngCore};
type MetaSize = U16;
fn meta_leaf_num(src: &A8Bytes<MetaSize>) -> &u64 {
&src.as_ne_u64_slice()[0]
}
fn meta_leaf_num_mut(src: &mut A8Bytes<MetaSize>) -> &mut u64 {
&mut src.as_mut_ne_u64_slice()[0]
}
fn meta_block_num(src: &A8Bytes<MetaSize>) -> &u64 {
&src.as_ne_u64_slice()[1]
}
fn meta_block_num_mut(src: &mut A8Bytes<MetaSize>) -> &mut u64 {
&mut src.as_mut_ne_u64_slice()[1]
}
fn meta_is_vacant(src: &A8Bytes<MetaSize>) -> Choice {
meta_leaf_num(src).ct_eq(&0)
}
fn meta_set_vacant(condition: Choice, src: &mut A8Bytes<MetaSize>) {
meta_leaf_num_mut(src).cmov(condition, &0);
}
pub struct PathORAM<ValueSize, Z, StorageType, RngType>
where
ValueSize: ArrayLength<u8> + PartialDiv<U8> + PartialDiv<U64>,
Z: Unsigned + Mul<ValueSize> + Mul<MetaSize>,
RngType: RngCore + CryptoRng + Send + Sync + 'static,
StorageType: ORAMStorage<Prod<Z, ValueSize>, Prod<Z, MetaSize>> + Send + Sync + 'static,
Prod<Z, ValueSize>: ArrayLength<u8> + PartialDiv<U8>,
Prod<Z, MetaSize>: ArrayLength<u8> + PartialDiv<U8>,
{
height: u32,
storage: StorageType,
pos: Box<dyn PositionMap + Send + Sync + 'static>,
rng: RngType,
stash_data: Vec<A64Bytes<ValueSize>>,
stash_meta: Vec<A8Bytes<MetaSize>>,
branch: BranchCheckout<ValueSize, Z>,
}
impl<ValueSize, Z, StorageType, RngType> PathORAM<ValueSize, Z, StorageType, RngType>
where
ValueSize: ArrayLength<u8> + PartialDiv<U8> + PartialDiv<U64>,
Z: Unsigned + Mul<ValueSize> + Mul<MetaSize>,
RngType: RngCore + CryptoRng + Send + Sync + 'static,
StorageType: ORAMStorage<Prod<Z, ValueSize>, Prod<Z, MetaSize>> + Send + Sync + 'static,
Prod<Z, ValueSize>: ArrayLength<u8> + PartialDiv<U8>,
Prod<Z, MetaSize>: ArrayLength<u8> + PartialDiv<U8>,
{
pub fn new<
PMC: PositionMapCreator<RngType>,
SC: ORAMStorageCreator<Prod<Z, ValueSize>, Prod<Z, MetaSize>, Output = StorageType>,
F: FnMut() -> RngType + 'static,
>(
size: u64,
stash_size: usize,
rng_maker: &mut F,
) -> Self {
assert!(size != 0, "size cannot be zero");
assert!(size & (size - 1) == 0, "size must be a power of two");
let height = log2_ceil(size).saturating_sub(log2_ceil(Z::U64));
let mut rng = rng_maker();
let storage = SC::create(2u64 << height, &mut rng).expect("Storage failed");
let pos = PMC::create(size, height, stash_size, rng_maker);
Self {
height,
storage,
pos,
rng,
stash_data: vec![Default::default(); stash_size],
stash_meta: vec![Default::default(); stash_size],
branch: Default::default(),
}
}
}
impl<ValueSize, Z, StorageType, RngType> ORAM<ValueSize>
for PathORAM<ValueSize, Z, StorageType, RngType>
where
ValueSize: ArrayLength<u8> + PartialDiv<U8> + PartialDiv<U64>,
Z: Unsigned + Mul<ValueSize> + Mul<MetaSize>,
RngType: RngCore + CryptoRng + Send + Sync + 'static,
StorageType: ORAMStorage<Prod<Z, ValueSize>, Prod<Z, MetaSize>> + Send + Sync + 'static,
Prod<Z, ValueSize>: ArrayLength<u8> + PartialDiv<U8>,
Prod<Z, MetaSize>: ArrayLength<u8> + PartialDiv<U8>,
{
fn len(&self) -> u64 {
self.pos.len()
}
fn access<T, F: FnOnce(&mut A64Bytes<ValueSize>) -> T>(&mut self, key: u64, f: F) -> T {
let result: T;
let new_pos = 1u64.random_child_at_height(self.height, &mut self.rng);
let current_pos = self.pos.write(&key, &new_pos);
debug_assert!(current_pos != 0, "position map told us the item is at 0");
debug_assert!(self.branch.leaf == 0);
self.branch.checkout(&mut self.storage, current_pos);
{
debug_assert!(self.branch.leaf == current_pos);
let mut meta = A8Bytes::<MetaSize>::default();
let mut data = A64Bytes::<ValueSize>::default();
self.branch
.ct_find_and_remove(1.into(), &key, &mut data, &mut meta);
details::ct_find_and_remove(
1.into(),
&key,
&mut data,
&mut meta,
&mut self.stash_data,
&mut self.stash_meta,
);
debug_assert!(
meta_block_num(&meta) == &key || meta_is_vacant(&meta).into(),
"Hmm, we didn't find the expected item something else"
);
debug_assert!(self.branch.leaf == current_pos);
result = f(&mut data);
*meta_block_num_mut(&mut meta) = key;
*meta_leaf_num_mut(&mut meta) = new_pos;
details::ct_insert(
1.into(),
&data,
&mut meta,
&mut self.stash_data,
&mut self.stash_meta,
);
assert!(bool::from(meta_is_vacant(&meta)), "Stash overflow!");
}
{
debug_assert!(self.branch.leaf == current_pos);
self.branch.pack();
for idx in 0..self.stash_data.len() {
self.branch
.ct_insert(1.into(), &self.stash_data[idx], &mut self.stash_meta[idx]);
}
}
debug_assert!(self.branch.leaf == current_pos);
self.branch.checkin(&mut self.storage);
debug_assert!(self.branch.leaf == 0);
result
}
}
struct BranchCheckout<ValueSize, Z>
where
ValueSize: ArrayLength<u8> + PartialDiv<U8> + PartialDiv<U64>,
Z: Unsigned + Mul<ValueSize> + Mul<MetaSize>,
Prod<Z, ValueSize>: ArrayLength<u8> + PartialDiv<U8>,
Prod<Z, MetaSize>: ArrayLength<u8> + PartialDiv<U8>,
{
leaf: u64,
data: Vec<A64Bytes<Prod<Z, ValueSize>>>,
meta: Vec<A8Bytes<Prod<Z, MetaSize>>>,
_value_size: PhantomData<fn() -> ValueSize>,
}
impl<ValueSize, Z> Default for BranchCheckout<ValueSize, Z>
where
ValueSize: ArrayLength<u8> + PartialDiv<U8> + PartialDiv<U64>,
Z: Unsigned + Mul<ValueSize> + Mul<MetaSize>,
Prod<Z, ValueSize>: ArrayLength<u8> + PartialDiv<U8>,
Prod<Z, MetaSize>: ArrayLength<u8> + PartialDiv<U8>,
{
fn default() -> Self {
Self {
leaf: 0,
data: Default::default(),
meta: Default::default(),
_value_size: Default::default(),
}
}
}
impl<ValueSize, Z> BranchCheckout<ValueSize, Z>
where
ValueSize: ArrayLength<u8> + PartialDiv<U8> + PartialDiv<U64>,
Z: Unsigned + Mul<ValueSize> + Mul<MetaSize>,
Prod<Z, ValueSize>: ArrayLength<u8> + PartialDiv<U8>,
Prod<Z, MetaSize>: ArrayLength<u8> + PartialDiv<U8>,
{
pub fn ct_find_and_remove(
&mut self,
condition: Choice,
query: &u64,
dest_data: &mut A64Bytes<ValueSize>,
dest_meta: &mut A8Bytes<MetaSize>,
) {
debug_assert!(self.data.len() == self.meta.len());
for idx in 0..self.data.len() {
let bucket_data: &mut [A64Bytes<ValueSize>] = self.data[idx].as_mut_aligned_chunks();
let bucket_meta: &mut [A8Bytes<MetaSize>] = self.meta[idx].as_mut_aligned_chunks();
debug_assert!(bucket_data.len() == Z::USIZE);
debug_assert!(bucket_meta.len() == Z::USIZE);
details::ct_find_and_remove(
condition,
query,
dest_data,
dest_meta,
bucket_data,
bucket_meta,
);
}
}
pub fn ct_insert(
&mut self,
mut condition: Choice,
src_data: &A64Bytes<ValueSize>,
src_meta: &mut A8Bytes<MetaSize>,
) {
condition &= !meta_is_vacant(src_meta);
let lowest_legal_index = self.lowest_legal_index(*meta_leaf_num(src_meta));
Self::insert_into_branch_suffix(
condition,
src_data,
src_meta,
lowest_legal_index,
&mut self.data,
&mut self.meta,
);
}
pub fn pack(&mut self) {
debug_assert!(self.leaf != 0);
debug_assert!(self.data.len() == self.meta.len());
let data_len = self.data.len();
for bucket_num in 1..self.data.len() {
let (lower_data, upper_data) = self.data.split_at_mut(bucket_num);
let (lower_meta, upper_meta) = self.meta.split_at_mut(bucket_num);
let bucket_data: &mut [A64Bytes<ValueSize>] = upper_data[0].as_mut_aligned_chunks();
let bucket_meta: &mut [A8Bytes<MetaSize>] = upper_meta[0].as_mut_aligned_chunks();
debug_assert!(bucket_data.len() == bucket_meta.len());
for idx in 0..bucket_data.len() {
let src_data: &mut A64Bytes<ValueSize> = &mut bucket_data[idx];
let src_meta: &mut A8Bytes<MetaSize> = &mut bucket_meta[idx];
let lowest_legal_index =
Self::lowest_legal_index_impl(*meta_leaf_num(src_meta), self.leaf, data_len);
Self::insert_into_branch_suffix(
1.into(),
src_data,
src_meta,
lowest_legal_index,
lower_data,
lower_meta,
);
}
}
debug_assert!(self.leaf != 0);
}
pub fn checkout(
&mut self,
storage: &mut impl ORAMStorage<Prod<Z, ValueSize>, Prod<Z, MetaSize>>,
leaf: u64,
) {
debug_assert!(self.leaf == 0);
self.data
.resize_with(leaf.height() as usize + 1, Default::default);
self.meta
.resize_with(leaf.height() as usize + 1, Default::default);
storage.checkout(leaf, &mut self.data, &mut self.meta);
self.leaf = leaf;
}
pub fn checkin(
&mut self,
storage: &mut impl ORAMStorage<Prod<Z, ValueSize>, Prod<Z, MetaSize>>,
) {
debug_assert!(self.leaf != 0);
storage.checkin(self.leaf, &mut self.data, &mut self.meta);
self.leaf = 0;
}
fn lowest_legal_index(&self, query: u64) -> usize {
Self::lowest_legal_index_impl(query, self.leaf, self.data.len())
}
fn lowest_legal_index_impl(mut query: u64, leaf: u64, data_len: usize) -> usize {
query.cmov(query.ct_eq(&0), &1);
debug_assert!(
leaf != 0,
"this should not be called when there is not currently a checkout"
);
let common_ancestor_height = leaf.common_ancestor_height(&query) as usize;
debug_assert!(data_len > common_ancestor_height);
data_len - 1 - common_ancestor_height
}
fn insert_into_branch_suffix(
condition: Choice,
src_data: &A64Bytes<ValueSize>,
src_meta: &mut A8Bytes<MetaSize>,
insert_after_index: usize,
dest_data: &mut [A64Bytes<Prod<Z, ValueSize>>],
dest_meta: &mut [A8Bytes<Prod<Z, MetaSize>>],
) {
debug_assert!(dest_data.len() == dest_meta.len());
for idx in 0..dest_data.len() {
details::ct_insert::<ValueSize>(
condition & !(idx as u64).ct_lt(&(insert_after_index as u64)),
src_data,
src_meta,
dest_data[idx].as_mut_aligned_chunks(),
dest_meta[idx].as_mut_aligned_chunks(),
)
}
}
}
mod details {
use super::*;
pub fn ct_find_and_remove<ValueSize: ArrayLength<u8>>(
mut condition: Choice,
query: &u64,
dest_data: &mut A64Bytes<ValueSize>,
dest_meta: &mut A8Bytes<MetaSize>,
src_data: &mut [A64Bytes<ValueSize>],
src_meta: &mut [A8Bytes<MetaSize>],
) {
debug_assert!(src_data.len() == src_meta.len());
for idx in 0..src_meta.len() {
let test = condition
& (query.ct_eq(meta_block_num(&src_meta[idx])))
& !meta_is_vacant(&src_meta[idx]);
dest_meta.cmov(test, &src_meta[idx]);
dest_data.cmov(test, &src_data[idx]);
meta_set_vacant(test, &mut src_meta[idx]);
condition &= !test;
}
}
pub fn ct_insert<ValueSize: ArrayLength<u8>>(
mut condition: Choice,
src_data: &A64Bytes<ValueSize>,
src_meta: &mut A8Bytes<MetaSize>,
dest_data: &mut [A64Bytes<ValueSize>],
dest_meta: &mut [A8Bytes<MetaSize>],
) {
debug_assert!(dest_data.len() == dest_meta.len());
condition &= !meta_is_vacant(src_meta);
for idx in 0..dest_meta.len() {
let test = condition & meta_is_vacant(&dest_meta[idx]);
dest_meta[idx].cmov(test, src_meta);
dest_data[idx].cmov(test, src_data);
meta_set_vacant(test, src_meta);
condition &= !test;
}
}
}