use std::borrow::Borrow;
use crate::prelude::K;
use bytemuck::Pod;
use rostl_primitives::traits::Cmov;
use rostl_sort::rotate::rotate_left;
#[derive(Debug)]
pub struct LinearORAM<T>
where
T: Cmov + Pod,
{
pub data: Vec<T>,
}
#[inline]
pub fn oblivious_read_update_index<T: Cmov>(data: &mut [T], index: usize, ret: &mut T, value: T) {
for (i, item) in data.iter_mut().enumerate() {
let choice = i == index;
ret.cmov(item, choice);
item.cmov(&value, choice);
}
}
#[inline]
pub fn oblivious_read_index<T: Cmov>(data: &[T], index: usize, out: &mut T) {
for (i, item) in data.iter().enumerate() {
let choice = i == index;
out.cmov(item, choice);
}
}
#[inline]
pub fn oblivious_write_index<T: Cmov, U: Borrow<T>>(data: &mut [T], index: usize, value: U) {
for (i, item) in data.iter_mut().enumerate() {
let choice = i == index;
item.cmov(value.borrow(), choice);
}
}
#[inline]
pub fn oblivious_memcpy<T: Cmov + Copy>(dst: &mut [T], src: &[T], src_offset: usize) {
let len = dst.len();
for (i, item) in src.iter().enumerate() {
let choice = (i >= src_offset) && (i < src_offset + len);
dst[i % len].cmov(item, choice);
}
let mut shift_amount = src_offset % len;
shift_amount.cmov(&0, src_offset >= src.len());
rotate_left(dst, shift_amount);
}
impl<T> LinearORAM<T>
where
T: Cmov + Pod + Default + std::fmt::Debug,
{
pub fn new(max_n: usize) -> Self {
Self { data: vec![T::default(); max_n] }
}
pub fn read(&self, index: K, ret: &mut T) {
debug_assert!(index < self.data.len());
for i in 0..self.data.len() {
let choice = i == index;
ret.cmov(&self.data[i], choice);
}
}
pub fn write(&mut self, index: K, value: T) {
for i in 0..self.data.len() {
let choice = i == index;
self.data[i].cmov(&value, choice);
}
}
pub fn read_update(&mut self, index: K, value: T, ret: &mut T) {
oblivious_read_update_index(&mut self.data, index, ret, value);
}
#[cfg(test)]
pub(crate) fn print_for_debug(&self) {
for i in 0..self.data.len() {
print!("{:?}, ", self.data[i]);
}
println!();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read() {
let default = 0;
let index = 3;
let oram = LinearORAM::<u32>::new(10);
let mut ret = 0;
oram.read(index, &mut ret);
assert_eq!(ret, default);
}
#[test]
fn test_write() {
let default = 0;
let new_value = 25;
let index = 3;
let mut oram = LinearORAM::<u32>::new(10);
oram.write(index, new_value);
let mut ret = default;
oram.read(index, &mut ret);
assert_eq!(ret, new_value);
}
#[test]
fn test_oblivious_memcpy() {
let src = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let mut dst = vec![0u8; 5];
oblivious_memcpy(&mut dst, &src, 0);
assert_eq!(dst, vec![1, 2, 3, 4, 5]);
oblivious_memcpy(&mut dst, &src, 1);
assert_eq!(dst, vec![2, 3, 4, 5, 6]);
oblivious_memcpy(&mut dst, &src, 2);
assert_eq!(dst, vec![3, 4, 5, 6, 7]);
oblivious_memcpy(&mut dst, &src, 3);
assert_eq!(dst, vec![4, 5, 6, 7, 8]);
oblivious_memcpy(&mut dst, &src, 4);
assert_eq!(dst, vec![5, 6, 7, 8, 9]);
oblivious_memcpy(&mut dst, &src, 5);
assert_eq!(dst, vec![6, 7, 8, 9, 10]);
oblivious_memcpy(&mut dst, &src, 6);
assert_eq!(dst[..4], vec![7, 8, 9, 10]);
oblivious_memcpy(&mut dst, &src, 7);
assert_eq!(dst[..3], vec![8, 9, 10]);
oblivious_memcpy(&mut dst, &src, 8);
assert_eq!(dst[..2], vec![9, 10]);
oblivious_memcpy(&mut dst, &src, 9);
assert_eq!(dst[..1], vec![10]);
oblivious_memcpy(&mut dst, &src, 10);
}
}