use std::cmp::min;
use std::collections::{HashMap, HashSet};
use std::hash::BuildHasherDefault;
use std::path::Path;
use aes::cipher::generic_array::GenericArray;
use aes::Aes128Ctr;
use aes_gcm::aead::{Aead, NewAead, Payload};
use aes_gcm::{Aes128Gcm, Key as AesGcmKey, Nonce as AesGcmNonce};
use bytes::{Buf, Bytes, BytesMut};
use chacha20::cipher::{NewStreamCipher, SyncStreamCipher};
use chacha20::{ChaCha8, Key, Nonce};
use ctr::cipher::{NewCipher, StreamCipher};
use log::{debug, info};
use nohash_hasher::NoHashHasher;
use rand::seq::SliceRandom;
use rand::{thread_rng, AsByteSliceMut, Rng};
use serde::{Deserialize, Serialize};
use crate::io::BaseIOService;
use crate::oram::pathoram::tree::TreeNode;
use crate::oram::BaseORAM;
use crate::{ORAMConfig, ORAMManager};
pub mod tree;
#[derive(Clone, Serialize, Deserialize)]
pub struct Block {
id: i64,
payload: Bytes,
}
#[derive(Serialize, Deserialize)]
pub struct Bucket {
version: String,
format: String,
blocks: Vec<Block>,
}
#[derive(Serialize, Deserialize)]
pub struct EncryptedBytes {
iv: Bytes,
ciphertext: BytesMut,
}
pub struct PathORAM<'a> {
pub args: &'a ORAMConfig,
pub io: Box<dyn BaseIOService + 'a>,
pub position_map: HashMap<i64, i64, BuildHasherDefault<NoHashHasher<i64>>>,
pub stash: Vec<Block>,
pub tree: TreeNode,
pub encryption_key: Vec<u8>,
}
impl BaseORAM for PathORAM<'_> {
fn test_state(&mut self) -> bool {
self.verify_main_invariant()
}
fn init(&mut self) {
self.setup();
}
fn cleanup(&mut self) {
debug!("Path ORAM cleanup...");
self.save();
debug!("...done!");
}
fn post_op(&mut self) {
self.save();
}
fn read(&mut self, block_id: i64) -> Vec<u8> {
let read_bytes = self.access("read", block_id, None);
match read_bytes {
Some(bytes) => bytes.to_vec(),
None => panic!("Could not read block"),
}
}
fn write(&mut self, block_id: i64, data: Bytes) -> usize {
let _ = self.access("write", block_id, Some(data.clone()));
data.len()
}
fn size(&self) -> i64 {
self.args.b * self.args.z * self.args.n
}
fn name(&self) -> String {
String::from("pathoram")
}
fn args(&self) -> &ORAMConfig {
self.args
}
}
impl<'a> PathORAM<'a> {
pub fn new(args: &'a ORAMConfig, io: Box<dyn BaseIOService + 'a>) -> Self {
let mut pathoram = Self {
args,
io,
position_map: HashMap::<i64, i64, BuildHasherDefault<NoHashHasher<i64>>>::default(),
tree: TreeNode::create_tree(args.n),
stash: Vec::new(),
encryption_key: Vec::new(),
};
if !args.init {
pathoram.load();
}
pathoram
}
pub fn setup(&mut self) {
info!("Initializing Path ORAM...");
self.load_encryption_key();
let rbmap = self.init_position_map();
self.init_public_storage(rbmap);
info!("...initialization complete!")
}
pub fn verify_main_invariant(&mut self) -> bool {
let mut incorrectly_mapped_blocks = vec![];
for bucket_id in 0..self.args.n {
let bucket = self.read_bucket(bucket_id);
for block in bucket {
if block.id != -1 {
let leaf = self.position_map.get(&block.id).unwrap();
let path = self.tree.path(*leaf);
if !path.contains(&bucket_id) {
incorrectly_mapped_blocks.push(block.id);
}
}
}
}
incorrectly_mapped_blocks.is_empty()
}
pub fn access(&mut self, op: &str, a: i64, data_star: Option<Bytes>) -> Option<Bytes> {
let x = *self.position_map.get(&a).unwrap();
let tree_height = self.tree.height;
let max_leaf = 2i64.pow(tree_height as u32) - 1;
let mut rng = thread_rng();
let new_random_leaf = rng.gen_range(0, max_leaf + 1);
self.position_map.insert(a, new_random_leaf);
for l in 0..tree_height + 1 {
let bucket_id = self.tree.pathl(x, l);
let mut blocks: Vec<Block> = self.read_bucket(bucket_id);
blocks.retain(|b| b.id != -1);
self.stash.extend(blocks);
}
let mut data = None;
for b in &self.stash {
if b.id == a {
data = Some(b.payload.clone());
}
}
if data.is_none() {
panic!("Failed to find block {} in stash", { a });
}
self.stash.retain(|b| b.id != a);
self.stash.insert(
0,
Block {
id: a,
payload: data.clone().unwrap(),
},
);
if op == "write" {
self.stash.retain(|b| b.id != a);
self.stash.insert(
0,
Block {
id: a,
payload: data_star.unwrap(),
},
);
}
for l in (0..tree_height + 1).rev() {
let pxl = self.tree.pathl(x, l);
let mut s_prime: Vec<Block> = self.stash.clone();
s_prime.retain(|b| {
let leaf = *self.position_map.get(&b.id).unwrap();
pxl == self.tree.pathl(leaf, l)
});
let select_count = min(s_prime.len(), self.args.z as usize);
s_prime = s_prime[0..select_count].to_vec();
let s_prime_block_ids: HashSet<i64, BuildHasherDefault<NoHashHasher<i64>>> =
s_prime.iter().map(|b| b.id).collect();
self.stash.retain(|b| !s_prime_block_ids.contains(&b.id));
self.write_bucket(pxl, s_prime);
}
data
}
pub fn stash_path(&self) -> String {
String::from(
Path::new(&self.args.client_data_dir)
.join("stash.bin")
.to_str()
.unwrap(),
)
}
pub fn position_map_path(&self) -> String {
String::from(
Path::new(&self.args.client_data_dir)
.join("position_map.bin")
.to_str()
.unwrap(),
)
}
pub fn load_encryption_key(&mut self) {
if !self.args.disable_encryption {
let (derived_key, _) =
ORAMManager::derive_key(&self.args.encryption_passphrase, &self.args.salt);
let (ciphertext, nonce) =
ORAMManager::deserialize_key(self.args.clone().encrypted_encryption_key);
let encryption_key = ORAMManager::decrypt_key(derived_key, ciphertext, nonce)
.expect("Failed to load encryption key. Invalid passphrase?");
self.encryption_key = encryption_key;
}
}
pub fn load(&mut self) {
debug!("Loading client data from disk...");
self.load_encryption_key();
let stash_bytes = self.io.read_file(self.stash_path());
self.stash = bincode::deserialize(&stash_bytes).unwrap();
let position_map_bytes = self.io.read_file(self.position_map_path());
self.position_map = bincode::deserialize(&position_map_bytes).unwrap();
debug!("...done!");
}
pub fn save(&mut self) {
debug!("Saving client data to disk...");
match std::fs::create_dir_all(Path::new(&self.args.client_data_dir)) {
Ok(_) => (),
Err(e) => panic!("Failed to create client directory: {}", e),
}
let stash_bytes = bincode::serialize(&self.stash).unwrap();
let position_map_bytes = bincode::serialize(&self.position_map).unwrap();
self.io.write_file(self.stash_path(), &stash_bytes);
self.io
.write_file(self.position_map_path(), &position_map_bytes);
debug!("...done!");
}
pub fn init_position_map(&mut self) -> HashMap<i64, HashMap<i64, i64>> {
let block_count = self.args.n * self.args.z;
let mut block_ids: Vec<i64> = (0..block_count).collect();
block_ids.shuffle(&mut thread_rng());
let mut bmap = HashMap::new(); let mut rbmap: HashMap<i64, HashMap<i64, i64>> = HashMap::new();
let mut i = 0;
for bucket_id in 0..self.args.n {
for block_index in 0..self.args.z {
let block_id = block_ids.get(i).unwrap();
bmap.insert(*block_id, (bucket_id, block_index));
Self::mapmap_insert(&mut rbmap, &bucket_id, block_index, block_id);
i += 1;
}
}
let leaves_count = self.tree.leaves_count;
let mut leaves: Vec<i64> = (0..leaves_count).collect();
leaves.shuffle(&mut thread_rng());
for block_id in 0..block_count {
let (bucket_id, _) = bmap.get(&block_id).unwrap();
let mut found_leaf = None;
for x in &leaves {
if found_leaf.is_some() {
break;
}
let p = self.tree.path(*x);
if p.contains(bucket_id) {
found_leaf = Some(*x);
}
}
self.position_map.insert(block_id, found_leaf.unwrap());
}
rbmap
}
pub fn mapmap_insert(
rbmap: &mut HashMap<i64, HashMap<i64, i64>>,
bucket_id: &i64,
block_index: i64,
block_id: &i64,
) {
match rbmap.get_mut(&bucket_id) {
Some(entry) => {
entry.insert(block_index, *block_id);
}
_ => {
let mut new_map = HashMap::new();
new_map.insert(block_index, *block_id);
rbmap.insert(*bucket_id, new_map);
}
}
}
pub fn init_public_storage(&mut self, rbmap: HashMap<i64, HashMap<i64, i64>>) {
debug!("Initializing public storage...");
for (bucket_id, y) in rbmap {
let node_path = self.node_path(bucket_id);
let mut block_ids = Vec::new();
for block_index in 0..self.args.z {
block_ids.push(*y.get(&block_index).unwrap())
}
self.write_empty_bucket(block_ids, node_path);
}
debug!("...done!");
}
fn read_bucket(&self, bucket_id: i64) -> Vec<Block> {
let bucket_path = self.node_path(bucket_id);
self.raw_read_bucket(bucket_path)
}
#[allow(clippy::comparison_chain)]
fn write_bucket(&mut self, bucket_id: i64, mut blocks: Vec<Block>) {
let missing_blocks = self.args.z - blocks.len() as i64;
if missing_blocks > 0 {
let empty_block_contents = Bytes::from(vec![0u8; self.args.b as usize]);
for _ in 0..missing_blocks {
blocks.push(Block {
id: -1,
payload: empty_block_contents.clone(),
});
}
} else if missing_blocks < 0 {
panic!("Error: trying to write more blocks than the bucket can hold");
} else {
}
let node_path = self.node_path(bucket_id);
self.raw_write_bucket(node_path, blocks);
}
pub fn write_empty_bucket(&mut self, block_ids: Vec<i64>, node_path: String) {
let empty_block_contents = Bytes::from(vec![0u8; self.args.b as usize]);
let mut blocks = Vec::new();
for block_id in block_ids {
blocks.push(Block {
id: block_id,
payload: empty_block_contents.clone(),
});
}
self.raw_write_bucket(node_path, blocks);
}
fn raw_read_bucket(&self, path: String) -> Vec<Block> {
let file_contents = self.io.read_file(path);
let bucket: Bucket = match self.args.disable_encryption {
true => bincode::deserialize(file_contents.as_slice()).unwrap(),
false => {
let encrypted_bytes: EncryptedBytes =
bincode::deserialize(file_contents.as_slice()).unwrap();
self.decrypt_bucket(encrypted_bytes)
}
};
bucket.blocks
}
fn raw_write_bucket(&mut self, path: String, blocks: Vec<Block>) {
let bucket = Bucket {
blocks,
format: self.name(),
version: String::from("1.0"),
};
let bytes: Vec<u8> = match self.encryption_key.is_empty() {
true => bincode::serialize(&bucket).unwrap(),
false => {
let encrypted_bytes = self.encrypt_bucket(bucket);
bincode::serialize(&encrypted_bytes).unwrap()
}
};
self.io.write_file(path, bytes.as_slice());
}
fn encrypt_bucket(&self, bucket: Bucket) -> EncryptedBytes {
let mut data = bincode::serialize(&bucket).unwrap();
let (iv, ct) = self.encrypt(&mut data);
let mut bm = BytesMut::new();
match ct {
Some(ciphertext) => {
bm.extend_from_slice(&ciphertext);
}
None => {
bm.extend_from_slice(&data);
}
}
EncryptedBytes {
iv: Bytes::from(iv),
ciphertext: bm,
}
}
fn decrypt_bucket(&self, encrypted_bytes: EncryptedBytes) -> Bucket {
let mut data = encrypted_bytes.ciphertext;
let ciphertext = data.as_byte_slice_mut();
match self.decrypt(encrypted_bytes.iv.bytes(), ciphertext) {
Some(plaintext) => {
let bucket: Bucket = bincode::deserialize(&plaintext).unwrap();
bucket
}
None => {
let bucket: Bucket = bincode::deserialize(&ciphertext).unwrap();
bucket
}
}
}
fn encrypt(&self, mut data: &mut [u8]) -> (Vec<u8>, Option<Vec<u8>>) {
match &self.args.cipher[..] {
"aes-ctr" => {
let iv = thread_rng().gen::<[u8; 16]>();
let key = GenericArray::from_slice(&self.encryption_key);
let nonce = GenericArray::from_slice(&iv);
let mut cipher = Aes128Ctr::new(key, nonce);
cipher.apply_keystream(&mut data);
(iv.to_vec(), None)
}
"chacha8" => {
let iv = thread_rng().gen::<[u8; 12]>();
let key = Key::from_slice(&self.encryption_key);
let nonce = Nonce::from_slice(&iv);
let mut cipher = ChaCha8::new(&key, &nonce);
cipher.apply_keystream(&mut data);
(iv.to_vec(), None)
}
"aes-gcm" => {
let key = AesGcmKey::from_slice(&self.encryption_key);
let cipher = Aes128Gcm::new(key);
let iv = thread_rng().gen::<[u8; 12]>();
let nonce = AesGcmNonce::from_slice(&iv);
let ad = b"oramfs";
let payload = Payload { aad: ad, msg: data };
let ciphertext = cipher
.encrypt(nonce, payload)
.expect("AES-GCM encryption failure");
(iv.to_vec(), Some(ciphertext))
}
_ => panic!("Unsupported cipher."),
}
}
fn decrypt(&self, iv: &[u8], mut data: &mut [u8]) -> Option<Vec<u8>> {
match &self.args.cipher[..] {
"aes-ctr" => {
let key = GenericArray::from_slice(&self.encryption_key);
let nonce = GenericArray::from_slice(&iv);
let mut cipher = Aes128Ctr::new(key, nonce);
cipher.apply_keystream(&mut data);
None
}
"chacha8" => {
let key = Key::from_slice(self.encryption_key.as_slice());
let nonce = Nonce::from_slice(iv);
let mut cipher = ChaCha8::new(&key, &nonce);
cipher.apply_keystream(&mut data);
None
}
"aes-gcm" => {
let key = AesGcmKey::from_slice(&self.encryption_key);
let cipher = Aes128Gcm::new(key);
let nonce = AesGcmNonce::from_slice(&iv);
let ad = b"oramfs";
let payload = Payload { aad: ad, msg: data };
let plaintext = cipher
.decrypt(nonce, payload)
.expect("[SECURITY WARNING] It looks like the ciphertext or tag has been tampered with. Aborting.");
Some(plaintext)
}
_ => panic!("Unsupported cipher."),
}
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use crate::io::MemoryIOService;
use crate::{ORAMConfig, PathORAM};
fn cli_for_oram(disable_encryption: bool) -> ORAMConfig {
let mut args = ORAMConfig {
name: "".to_string(),
private_directory: "private".to_string(),
public_directory: "public".to_string(),
mountpoint: "".to_string(),
algorithm: "".to_string(),
cipher: "".to_string(),
client_data_dir: "".to_string(),
encrypted_encryption_key: "".to_string(),
encryption_passphrase: "".to_string(),
salt: "".to_string(),
io: "".to_string(),
n: 0,
z: 0,
b: 0,
init: false,
disable_encryption,
manual: false,
foreground: false,
interactive: false,
phc: "".to_string(),
};
args.algorithm = "pathoram".to_string();
args.disable_encryption = disable_encryption;
args.init = true;
args.foreground = false;
args.manual = true;
args.io = "memory".to_string();
args.n = 255;
args.z = 4;
args.b = 16384;
args
}
#[test]
fn test_access() {
let disable_encryption = true;
let args = cli_for_oram(disable_encryption);
let io = Box::new(MemoryIOService::new());
let mut pathoram = PathORAM::new(&args, io);
pathoram.setup();
assert_eq!(pathoram.verify_main_invariant(), true);
let data = Bytes::from(vec![43; args.b as usize]);
let _ = pathoram.access("write", 1, Some(data.clone()));
let read_bytes = pathoram.access("read", 1, None).unwrap();
println!("{:?}", data);
println!("{:?}", read_bytes);
assert_eq!(data, read_bytes);
}
#[test]
fn test_encryption() {
let disable_encryption = false;
let mut args = cli_for_oram(disable_encryption);
args.cipher = "aes-gcm".to_string();
args.encrypted_encryption_key =
"A38eJ8oREvmjEMVTAA68+m8KceZWaJ3vYlOrXo0Qe+Q=:ZDLbGGMuPSLeNycz".to_string();
args.encryption_passphrase = "a".to_string();
args.salt = "OS6qK/8mJA22SMWANwsiaw".to_string();
let io = Box::new(MemoryIOService::new());
let mut pathoram = PathORAM::new(&args, io);
pathoram.io.write_file(
args.encrypted_encryption_key.clone(),
vec![66; 32].as_slice(),
);
pathoram.setup();
assert_eq!(pathoram.verify_main_invariant(), true);
let data = Bytes::from(vec![43; args.b as usize]);
let _ = pathoram.access("write", 1, Some(data.clone()));
let read_bytes = pathoram.access("read", 1, None).unwrap();
println!("{:?}", data);
println!("{:?}", read_bytes);
assert_eq!(data, read_bytes);
}
}