use crate::{
memory::{Entry, PagedMemory, MAX_LOG_ADDR},
ExecutionError, Opcode,
};
use std::cell::RefCell;
pub enum MaybeCowMemory<T: Copy> {
Cow { copy: PagedMemory<T>, original: PagedMemory<T> },
Owned { memory: PagedMemory<T> },
}
impl<T: Copy> MaybeCowMemory<T> {
pub fn new_owned() -> Self {
Self::Owned { memory: PagedMemory::default() }
}
pub fn new_cow(original: PagedMemory<T>) -> Self {
Self::Cow { copy: PagedMemory::default(), original }
}
pub fn copy_on_write(&mut self) {
match self {
Self::Cow { .. } => {}
Self::Owned { memory } => {
*self = Self::new_cow(std::mem::take(memory));
}
}
}
pub fn owned(&mut self) {
match self {
Self::Cow { copy: _, original } => {
*self = Self::Owned { memory: std::mem::take(original) };
}
Self::Owned { .. } => {}
}
}
pub fn get(&self, addr: u64) -> Option<&T> {
assert!(addr.is_multiple_of(8), "Address must be a multiple of 8");
match self {
Self::Cow { copy, original } => copy.get(addr).or_else(|| original.get(addr)),
Self::Owned { memory } => memory.get(addr),
}
}
pub fn entry(&mut self, addr: u64) -> (Entry<'_, T>, bool) {
assert!(addr.is_multiple_of(8), "Address must be a multiple of 8");
let mut duplicated = false;
match self {
Self::Cow { copy, original } => match copy.entry(addr) {
Entry::Vacant(entry) => {
if let Some(value) = original.get(addr) {
entry.insert(*value);
duplicated = true;
}
}
Entry::Occupied(_) => {}
},
Self::Owned { .. } => {}
}
(
match self {
Self::Cow { copy, original: _ } => copy.entry(addr),
Self::Owned { memory } => memory.entry(addr),
},
duplicated,
)
}
pub fn insert(&mut self, addr: u64, value: T) -> Option<T> {
assert!(addr.is_multiple_of(8), "Address must be a multiple of 8");
match self {
Self::Cow { copy, original: _ } => copy.insert(addr, value),
Self::Owned { memory } => memory.insert(addr, value),
}
}
}
#[derive(Clone)]
enum Limiter {
NoLimit,
Limit { current: usize, limit: usize },
}
impl Limiter {
fn new(memory_limit: Option<u64>) -> Self {
match memory_limit {
Some(memory_limit) => Self::Limit { current: 0, limit: (memory_limit / 8) as usize },
None => Self::NoLimit,
}
}
fn increase(&mut self) -> Result<(), ExecutionError> {
if let Self::Limit { current, limit } = self {
*current += 1;
if current > limit {
return Err(ExecutionError::TooMuchMemory());
}
}
Ok(())
}
fn check_ptr(&self, addr: u64, write: bool) -> Result<(), ExecutionError> {
if let Self::Limit { .. } = self {
let max_memory = 1u64 << MAX_LOG_ADDR;
if addr > max_memory - 8 {
return Err(ExecutionError::InvalidMemoryAccess(
if write { Opcode::SD } else { Opcode::LD },
addr,
));
}
}
Ok(())
}
}
pub struct LimitedMemory<T: Copy> {
memory: MaybeCowMemory<T>,
limiter: Limiter,
before_cow: Option<Limiter>,
last_error: RefCell<Option<ExecutionError>>,
dummy_value: T,
}
fn c(last_error: &RefCell<Option<ExecutionError>>, result: Result<(), ExecutionError>) {
let mut e = last_error.try_borrow_mut().expect("borrow twice");
if e.is_none() {
*e = result.err();
}
}
impl<T: Copy + Default> LimitedMemory<T> {
pub fn new_owned(memory_limit: Option<u64>) -> Self {
Self {
memory: MaybeCowMemory::new_owned(),
limiter: Limiter::new(memory_limit),
before_cow: None,
last_error: None.into(),
dummy_value: T::default(),
}
}
#[inline]
pub fn copy_on_write(&mut self) {
self.memory.copy_on_write();
self.before_cow = Some(self.limiter.clone());
}
pub fn owned(&mut self) {
self.memory.owned();
self.limiter = self.before_cow.take().unwrap();
}
#[inline]
pub fn get(&self, addr: u64) -> Option<&T> {
self.check_ptr(addr, false);
if self.has_last_error() {
return Some(&self.dummy_value);
}
self.memory.get(addr)
}
#[inline]
pub fn insert(&mut self, addr: u64, value: T) {
self.check_ptr(addr, true);
if self.has_last_error() {
return;
}
let previous_value = self.memory.insert(addr, value);
if previous_value.is_none() {
c(&self.last_error, self.limiter.increase());
}
}
#[inline]
fn check_ptr(&self, addr: u64, write: bool) {
c(&self.last_error, self.limiter.check_ptr(addr, write));
}
#[inline]
pub fn has_last_error(&self) -> bool {
self.last_error.borrow().is_some()
}
#[inline]
pub fn last_error(&self) -> ExecutionError {
self.last_error.borrow().clone().unwrap()
}
}
impl<T: Copy + Default> LimitedMemory<T> {
#[inline]
pub fn get_mut(&mut self, addr: u64) -> &'_ mut T {
self.check_ptr(addr, true);
if self.has_last_error() {
return &mut self.dummy_value;
}
let (entry, duplicated) = self.memory.entry(addr);
if duplicated || matches!(entry, Entry::Vacant(_)) {
c(&self.last_error, self.limiter.increase());
}
entry.or_default()
}
}