use std::cmp::min;
use std::fmt;
use std::io::Result as IoResult;
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::mem::size_of;
use std::ptr::copy;
use std::ptr::{null_mut, read_volatile, write_volatile};
use std::result;
use std::slice::{from_raw_parts, from_raw_parts_mut};
use std::{isize, usize};
use DataInit;
#[derive(Eq, PartialEq, Debug)]
pub enum VolatileMemoryError {
OutOfBounds { addr: u64 },
Overflow { base: u64, offset: u64 },
}
impl fmt::Display for VolatileMemoryError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
VolatileMemoryError::OutOfBounds { addr } => {
write!(f, "address 0x{:x} is out of bounds", addr)
}
VolatileMemoryError::Overflow { base, offset } => write!(
f,
"address 0x{:x} offset by 0x{:x} would overflow",
base, offset
),
}
}
}
pub type VolatileMemoryResult<T> = result::Result<T, VolatileMemoryError>;
use VolatileMemoryError as Error;
type Result<T> = VolatileMemoryResult<T>;
pub fn calc_offset(base: u64, offset: u64) -> Result<u64> {
match base.checked_add(offset) {
None => Err(Error::Overflow { base, offset }),
Some(m) => Ok(m),
}
}
pub trait VolatileMemory {
fn get_slice(&self, offset: u64, count: u64) -> Result<VolatileSlice>;
fn get_ref<T: DataInit>(&self, offset: u64) -> Result<VolatileRef<T>> {
let slice = self.get_slice(offset, size_of::<T>() as u64)?;
Ok(VolatileRef {
addr: slice.addr as *mut T,
phantom: PhantomData,
})
}
}
impl<'a> VolatileMemory for &'a mut [u8] {
fn get_slice(&self, offset: u64, count: u64) -> Result<VolatileSlice> {
let mem_end = calc_offset(offset, count)?;
if mem_end > self.len() as u64 {
return Err(Error::OutOfBounds { addr: mem_end });
}
Ok(unsafe { VolatileSlice::new((self.as_ptr() as u64 + offset) as *mut _, count) })
}
}
#[derive(Copy, Clone, Debug)]
pub struct VolatileSlice<'a> {
addr: *mut u8,
size: u64,
phantom: PhantomData<&'a u8>,
}
impl<'a> Default for VolatileSlice<'a> {
fn default() -> VolatileSlice<'a> {
VolatileSlice {
addr: null_mut(),
size: 0,
phantom: PhantomData,
}
}
}
impl<'a> VolatileSlice<'a> {
pub unsafe fn new(addr: *mut u8, size: u64) -> VolatileSlice<'a> {
VolatileSlice {
addr,
size,
phantom: PhantomData,
}
}
pub fn as_ptr(&self) -> *mut u8 {
self.addr
}
pub fn size(&self) -> u64 {
self.size
}
pub fn offset(self, count: u64) -> Result<VolatileSlice<'a>> {
let new_addr =
(self.addr as u64)
.checked_add(count)
.ok_or(VolatileMemoryError::Overflow {
base: self.addr as u64,
offset: count,
})?;
if new_addr > usize::MAX as u64 {
return Err(VolatileMemoryError::Overflow {
base: self.addr as u64,
offset: count,
})?;
}
let new_size = self
.size
.checked_sub(count)
.ok_or(VolatileMemoryError::OutOfBounds { addr: new_addr })?;
unsafe { Ok(VolatileSlice::new(new_addr as *mut u8, new_size)) }
}
pub fn copy_to<T>(&self, buf: &mut [T])
where
T: DataInit,
{
let mut addr = self.addr;
for v in buf.iter_mut().take(self.size as usize / size_of::<T>()) {
unsafe {
*v = read_volatile(addr as *const T);
addr = addr.add(size_of::<T>());
}
}
}
pub fn copy_to_volatile_slice(&self, slice: VolatileSlice) {
unsafe {
copy(self.addr, slice.addr, min(self.size, slice.size) as usize);
}
}
pub fn copy_from<T>(&self, buf: &[T])
where
T: DataInit,
{
let mut addr = self.addr;
for &v in buf.iter().take(self.size as usize / size_of::<T>()) {
unsafe {
write_volatile(addr as *mut T, v);
addr = addr.add(size_of::<T>());
}
}
}
pub fn write_to<T: Write>(&self, w: &mut T) -> IoResult<usize> {
w.write(unsafe { self.as_slice() })
}
pub fn write_all_to<T: Write>(&self, w: &mut T) -> IoResult<()> {
w.write_all(unsafe { self.as_slice() })
}
pub fn read_from<T: Read>(&self, r: &mut T) -> IoResult<usize> {
r.read(unsafe { self.as_mut_slice() })
}
pub fn read_exact_from<T: Read>(&self, r: &mut T) -> IoResult<()> {
r.read_exact(unsafe { self.as_mut_slice() })
}
unsafe fn as_slice(&self) -> &[u8] {
from_raw_parts(self.addr, self.size as usize)
}
unsafe fn as_mut_slice(&self) -> &mut [u8] {
from_raw_parts_mut(self.addr, self.size as usize)
}
}
impl<'a> VolatileMemory for VolatileSlice<'a> {
fn get_slice(&self, offset: u64, count: u64) -> Result<VolatileSlice> {
let mem_end = calc_offset(offset, count)?;
if mem_end > self.size {
return Err(Error::OutOfBounds { addr: mem_end });
}
Ok(VolatileSlice {
addr: (self.addr as u64 + offset) as *mut _,
size: count,
phantom: PhantomData,
})
}
}
#[derive(Debug)]
pub struct VolatileRef<'a, T: DataInit>
where
T: 'a,
{
addr: *mut T,
phantom: PhantomData<&'a T>,
}
impl<'a, T: DataInit> VolatileRef<'a, T> {
pub unsafe fn new(addr: *mut T) -> VolatileRef<'a, T> {
VolatileRef {
addr,
phantom: PhantomData,
}
}
pub fn as_ptr(&self) -> *mut T {
self.addr
}
pub fn size(&self) -> u64 {
size_of::<T>() as u64
}
#[inline(always)]
pub fn store(&self, v: T) {
unsafe { write_volatile(self.addr, v) };
}
#[inline(always)]
pub fn load(&self) -> T {
unsafe { read_volatile(self.addr) }
}
pub fn to_slice(&self) -> VolatileSlice<'a> {
unsafe { VolatileSlice::new(self.addr as *mut u8, size_of::<T>() as u64) }
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread::{sleep, spawn};
use std::time::Duration;
#[derive(Clone)]
struct VecMem {
mem: Arc<Vec<u8>>,
}
impl VecMem {
fn new(size: usize) -> VecMem {
let mut mem = Vec::new();
mem.resize(size, 0);
VecMem { mem: Arc::new(mem) }
}
}
impl VolatileMemory for VecMem {
fn get_slice(&self, offset: u64, count: u64) -> Result<VolatileSlice> {
let mem_end = calc_offset(offset, count)?;
if mem_end > self.mem.len() as u64 {
return Err(Error::OutOfBounds { addr: mem_end });
}
Ok(unsafe { VolatileSlice::new((self.mem.as_ptr() as u64 + offset) as *mut _, count) })
}
}
#[test]
fn ref_store() {
let mut a = [0u8; 1];
{
let a_ref = &mut a[..];
let v_ref = a_ref.get_ref(0).unwrap();
v_ref.store(2u8);
}
assert_eq!(a[0], 2);
}
#[test]
fn ref_load() {
let mut a = [5u8; 1];
{
let a_ref = &mut a[..];
let c = {
let v_ref = a_ref.get_ref::<u8>(0).unwrap();
assert_eq!(v_ref.load(), 5u8);
v_ref
};
c.load();
} ;
}
#[test]
fn ref_to_slice() {
let mut a = [1u8; 5];
let a_ref = &mut a[..];
let v_ref = a_ref.get_ref(1).unwrap();
v_ref.store(0x12345678u32);
let ref_slice = v_ref.to_slice();
assert_eq!(v_ref.as_ptr() as u64, ref_slice.as_ptr() as u64);
assert_eq!(v_ref.size(), ref_slice.size());
}
#[test]
fn observe_mutate() {
let a = VecMem::new(1);
let a_clone = a.clone();
let v_ref = a.get_ref::<u8>(0).unwrap();
v_ref.store(99);
spawn(move || {
sleep(Duration::from_millis(10));
let clone_v_ref = a_clone.get_ref::<u8>(0).unwrap();
clone_v_ref.store(0);
});
assert_eq!(v_ref.load(), 99);
#[cfg(debug_assertions)]
const RETRY_MAX: u64 = 500_000_000;
#[cfg(not(debug_assertions))]
const RETRY_MAX: u64 = 10_000_000_000;
let mut retry = 0;
while v_ref.load() == 99 && retry < RETRY_MAX {
retry += 1;
}
assert_ne!(retry, RETRY_MAX, "maximum retry exceeded");
assert_eq!(v_ref.load(), 0);
}
#[test]
fn slice_size() {
let a = VecMem::new(100);
let s = a.get_slice(0, 27).unwrap();
assert_eq!(s.size(), 27);
let s = a.get_slice(34, 27).unwrap();
assert_eq!(s.size(), 27);
let s = s.get_slice(20, 5).unwrap();
assert_eq!(s.size(), 5);
}
#[test]
fn slice_overflow_error() {
use std::u64::MAX;
let a = VecMem::new(1);
let res = a.get_slice(MAX, 1).unwrap_err();
assert_eq!(
res,
Error::Overflow {
base: MAX,
offset: 1,
}
);
}
#[test]
fn slice_oob_error() {
let a = VecMem::new(100);
a.get_slice(50, 50).unwrap();
let res = a.get_slice(55, 50).unwrap_err();
assert_eq!(res, Error::OutOfBounds { addr: 105 });
}
#[test]
fn ref_overflow_error() {
use std::u64::MAX;
let a = VecMem::new(1);
let res = a.get_ref::<u8>(MAX).unwrap_err();
assert_eq!(
res,
Error::Overflow {
base: MAX,
offset: 1,
}
);
}
#[test]
fn ref_oob_error() {
let a = VecMem::new(100);
a.get_ref::<u8>(99).unwrap();
let res = a.get_ref::<u16>(99).unwrap_err();
assert_eq!(res, Error::OutOfBounds { addr: 101 });
}
#[test]
fn ref_oob_too_large() {
let a = VecMem::new(3);
let res = a.get_ref::<u32>(0).unwrap_err();
assert_eq!(res, Error::OutOfBounds { addr: 4 });
}
}