use std::clone::Clone;
use std::cmp::PartialEq;
use std::convert::TryFrom;
use std::fmt::Debug;
use std::mem::size_of;
use hyperlight_common::mem::PAGE_SIZE_USIZE;
use crate::{log_then_return, new_error, Result};
type ReaderFn<S, T> = dyn Fn(&S, usize) -> Result<T>;
type WriterFn<S, T> = dyn Fn(&mut S, usize, T) -> Result<()>;
pub(super) fn read_write_test_suite<S, T, U, ShmNew: Fn(usize) -> Result<S>>(
initial_val: U,
shared_memory_new: ShmNew,
reader: Box<ReaderFn<S, T>>,
writer: Box<WriterFn<S, U>>,
) -> Result<()>
where
T: PartialEq + Debug + Clone + TryFrom<U>,
U: Debug + Clone,
{
let mem_size = PAGE_SIZE_USIZE;
let test_read = |mem_size, offset| {
let sm = shared_memory_new(mem_size)?;
(reader)(&sm, offset)
};
let test_write = |mem_size, offset, val| {
let mut sm = shared_memory_new(mem_size)?;
(writer)(&mut sm, offset, val)
};
let test_write_read = |mem_size, offset: usize, initial_val: U| {
let mut sm = shared_memory_new(mem_size)?;
writer(&mut sm, offset, initial_val.clone())?;
let ret_val = reader(&sm, offset)?;
let initial_val_as_t =
T::try_from(initial_val.clone()).map_err(|_| new_error!("cannot convert types"))?;
if initial_val_as_t == ret_val {
Ok(())
} else {
log_then_return!(
"(mem_size: {}, offset: {}, val: {:?}), actual returned val = {:?}",
mem_size,
offset,
initial_val,
ret_val,
);
}
};
test_write_read(mem_size, 0, initial_val.clone())?;
test_write_read(mem_size, mem_size - size_of::<T>(), initial_val.clone())?;
test_write_read(mem_size, mem_size / 2, initial_val.clone())?;
swap_res(test_write_read(mem_size, mem_size * 2, initial_val.clone()))?;
swap_res(test_write(mem_size, mem_size * 2, initial_val.clone()))?;
swap_res(test_read(mem_size, mem_size))?;
swap_res(test_write(mem_size, mem_size, initial_val))?;
Ok(())
}
fn swap_res<T>(r: Result<T>) -> Result<()> {
match r {
Ok(_) => {
log_then_return!("result was expected to be an error, but wasn't");
}
Err(_) => Ok(()),
}
}