use revm_primitives::{B256, U256};
use crate::alloc::vec::Vec;
use core::{
cmp::min,
fmt,
ops::{BitAnd, Not},
};
#[derive(Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SharedMemory {
buffer: Vec<u8>,
checkpoints: Vec<usize>,
last_checkpoint: usize,
#[cfg(feature = "memory_limit")]
memory_limit: u64,
}
pub const EMPTY_SHARED_MEMORY: SharedMemory = SharedMemory {
buffer: Vec::new(),
checkpoints: Vec::new(),
last_checkpoint: 0,
#[cfg(feature = "memory_limit")]
memory_limit: u64::MAX,
};
impl fmt::Debug for SharedMemory {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SharedMemory")
.field("current_len", &self.len())
.field(
"context_memory",
&crate::primitives::hex::encode(self.context_memory()),
)
.finish_non_exhaustive()
}
}
impl Default for SharedMemory {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl SharedMemory {
#[inline]
pub fn new() -> Self {
Self::with_capacity(4 * 1024) }
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
Self {
buffer: Vec::with_capacity(capacity),
checkpoints: Vec::with_capacity(32),
last_checkpoint: 0,
#[cfg(feature = "memory_limit")]
memory_limit: u64::MAX,
}
}
#[cfg(feature = "memory_limit")]
#[inline]
pub fn new_with_memory_limit(memory_limit: u64) -> Self {
Self {
memory_limit,
..Self::new()
}
}
#[cfg(feature = "memory_limit")]
#[inline]
pub fn limit_reached(&self, new_size: usize) -> bool {
(self.last_checkpoint + new_size) as u64 > self.memory_limit
}
#[inline]
pub fn new_context(&mut self) {
let new_checkpoint = self.buffer.len();
self.checkpoints.push(new_checkpoint);
self.last_checkpoint = new_checkpoint;
}
#[inline]
pub fn free_context(&mut self) {
if let Some(old_checkpoint) = self.checkpoints.pop() {
self.last_checkpoint = self.checkpoints.last().cloned().unwrap_or_default();
unsafe { self.buffer.set_len(old_checkpoint) };
}
}
#[inline]
pub fn len(&self) -> usize {
self.buffer.len() - self.last_checkpoint
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn resize(&mut self, new_size: usize) {
self.buffer.resize(self.last_checkpoint + new_size, 0);
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn slice(&self, offset: usize, size: usize) -> &[u8] {
let end = offset + size;
let last_checkpoint = self.last_checkpoint;
self.buffer
.get(last_checkpoint + offset..last_checkpoint + offset + size)
.unwrap_or_else(|| {
debug_unreachable!("slice OOB: {offset}..{end}; len: {}", self.len())
})
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn slice_mut(&mut self, offset: usize, size: usize) -> &mut [u8] {
let len = self.len();
let end = offset + size;
let last_checkpoint = self.last_checkpoint;
self.buffer
.get_mut(last_checkpoint + offset..last_checkpoint + offset + size)
.unwrap_or_else(|| debug_unreachable!("slice OOB: {offset}..{end}; len: {}", len))
}
#[inline]
pub fn get_byte(&self, offset: usize) -> u8 {
self.slice(offset, 1)[0]
}
#[inline]
pub fn get_word(&self, offset: usize) -> B256 {
self.slice(offset, 32).try_into().unwrap()
}
#[inline]
pub fn get_u256(&self, offset: usize) -> U256 {
self.get_word(offset).into()
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn set_byte(&mut self, offset: usize, byte: u8) {
self.set(offset, &[byte]);
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn set_word(&mut self, offset: usize, value: &B256) {
self.set(offset, &value[..]);
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn set_u256(&mut self, offset: usize, value: U256) {
self.set(offset, &value.to_be_bytes::<32>());
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn set(&mut self, offset: usize, value: &[u8]) {
if !value.is_empty() {
self.slice_mut(offset, value.len()).copy_from_slice(value);
}
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn set_data(&mut self, memory_offset: usize, data_offset: usize, len: usize, data: &[u8]) {
if data_offset >= data.len() {
self.slice_mut(memory_offset, len).fill(0);
return;
}
let data_end = min(data_offset + len, data.len());
let data_len = data_end - data_offset;
debug_assert!(data_offset < data.len() && data_end <= data.len());
let data = unsafe { data.get_unchecked(data_offset..data_end) };
self.slice_mut(memory_offset, data_len)
.copy_from_slice(data);
self.slice_mut(memory_offset + data_len, len - data_len)
.fill(0);
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn copy(&mut self, dst: usize, src: usize, len: usize) {
self.context_memory_mut().copy_within(src..src + len, dst);
}
#[inline]
pub fn context_memory(&self) -> &[u8] {
unsafe {
self.buffer
.get_unchecked(self.last_checkpoint..self.buffer.len())
}
}
#[inline]
fn context_memory_mut(&mut self) -> &mut [u8] {
let buf_len = self.buffer.len();
unsafe { self.buffer.get_unchecked_mut(self.last_checkpoint..buf_len) }
}
}
#[inline]
pub fn next_multiple_of_32(x: usize) -> usize {
let r = x.bitand(31).not().wrapping_add(1).bitand(31);
x.saturating_add(r)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_next_multiple_of_32() {
for i in 0..32 {
let x = i * 32;
assert_eq!(x, next_multiple_of_32(x));
}
for x in 0..1024 {
if x % 32 == 0 {
continue;
}
let next_multiple = x + 32 - (x % 32);
assert_eq!(next_multiple, next_multiple_of_32(x));
}
}
#[test]
fn new_free_context() {
let mut shared_memory = SharedMemory::new();
shared_memory.new_context();
assert_eq!(shared_memory.buffer.len(), 0);
assert_eq!(shared_memory.checkpoints.len(), 1);
assert_eq!(shared_memory.last_checkpoint, 0);
unsafe { shared_memory.buffer.set_len(32) };
assert_eq!(shared_memory.len(), 32);
shared_memory.new_context();
assert_eq!(shared_memory.buffer.len(), 32);
assert_eq!(shared_memory.checkpoints.len(), 2);
assert_eq!(shared_memory.last_checkpoint, 32);
assert_eq!(shared_memory.len(), 0);
unsafe { shared_memory.buffer.set_len(96) };
assert_eq!(shared_memory.len(), 64);
shared_memory.new_context();
assert_eq!(shared_memory.buffer.len(), 96);
assert_eq!(shared_memory.checkpoints.len(), 3);
assert_eq!(shared_memory.last_checkpoint, 96);
assert_eq!(shared_memory.len(), 0);
shared_memory.free_context();
assert_eq!(shared_memory.buffer.len(), 96);
assert_eq!(shared_memory.checkpoints.len(), 2);
assert_eq!(shared_memory.last_checkpoint, 32);
assert_eq!(shared_memory.len(), 64);
shared_memory.free_context();
assert_eq!(shared_memory.buffer.len(), 32);
assert_eq!(shared_memory.checkpoints.len(), 1);
assert_eq!(shared_memory.last_checkpoint, 0);
assert_eq!(shared_memory.len(), 32);
shared_memory.free_context();
assert_eq!(shared_memory.buffer.len(), 0);
assert_eq!(shared_memory.checkpoints.len(), 0);
assert_eq!(shared_memory.last_checkpoint, 0);
assert_eq!(shared_memory.len(), 0);
}
#[test]
fn resize() {
let mut shared_memory = SharedMemory::new();
shared_memory.new_context();
shared_memory.resize(32);
assert_eq!(shared_memory.buffer.len(), 32);
assert_eq!(shared_memory.len(), 32);
assert_eq!(shared_memory.buffer.get(0..32), Some(&[0_u8; 32] as &[u8]));
shared_memory.new_context();
shared_memory.resize(96);
assert_eq!(shared_memory.buffer.len(), 128);
assert_eq!(shared_memory.len(), 96);
assert_eq!(
shared_memory.buffer.get(32..128),
Some(&[0_u8; 96] as &[u8])
);
shared_memory.free_context();
shared_memory.resize(64);
assert_eq!(shared_memory.buffer.len(), 64);
assert_eq!(shared_memory.len(), 64);
assert_eq!(shared_memory.buffer.get(0..64), Some(&[0_u8; 64] as &[u8]));
}
}