use super::MemoryTr;
use crate::InstructionResult;
use context_interface::cfg::GasParams;
use core::{
cell::{Ref, RefCell, RefMut},
cmp::min,
fmt,
ops::Range,
};
use primitives::{hex, B256, U256};
use std::{rc::Rc, vec::Vec};
trait RefcellExt<T> {
fn dbg_borrow(&self) -> Ref<'_, T>;
fn dbg_borrow_mut(&self) -> RefMut<'_, T>;
}
impl<T> RefcellExt<T> for RefCell<T> {
#[inline]
fn dbg_borrow(&self) -> Ref<'_, T> {
match self.try_borrow() {
Ok(b) => b,
Err(e) => debug_unreachable!("{e}"),
}
}
#[inline]
fn dbg_borrow_mut(&self) -> RefMut<'_, T> {
match self.try_borrow_mut() {
Ok(b) => b,
Err(e) => debug_unreachable!("{e}"),
}
}
}
#[derive(Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SharedMemory {
buffer: Option<Rc<RefCell<Vec<u8>>>>,
my_checkpoint: usize,
child_checkpoint: Option<usize>,
#[cfg(feature = "memory_limit")]
memory_limit: u64,
}
impl fmt::Debug for SharedMemory {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SharedMemory")
.field("current_len", &self.len())
.field("context_memory", &hex::encode(&*self.context_memory()))
.finish_non_exhaustive()
}
}
impl Default for SharedMemory {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl MemoryTr for SharedMemory {
fn set_data(&mut self, memory_offset: usize, data_offset: usize, len: usize, data: &[u8]) {
self.set_data(memory_offset, data_offset, len, data);
}
fn set(&mut self, memory_offset: usize, data: &[u8]) {
self.set(memory_offset, data);
}
fn size(&self) -> usize {
self.len()
}
fn copy(&mut self, destination: usize, source: usize, len: usize) {
self.copy(destination, source, len);
}
fn slice(&self, range: Range<usize>) -> Ref<'_, [u8]> {
self.slice_range(range)
}
fn local_memory_offset(&self) -> usize {
self.my_checkpoint
}
fn set_data_from_global(
&mut self,
memory_offset: usize,
data_offset: usize,
len: usize,
data_range: Range<usize>,
) {
self.global_to_local_set_data(memory_offset, data_offset, len, data_range);
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
fn global_slice(&self, range: Range<usize>) -> Ref<'_, [u8]> {
self.global_slice_range(range)
}
fn resize(&mut self, new_size: usize) -> bool {
self.resize(new_size);
true
}
#[cfg(feature = "memory_limit")]
#[inline]
fn limit_reached(&self, offset: usize, len: usize) -> bool {
self.my_checkpoint
.saturating_add(offset)
.saturating_add(len) as u64
> self.memory_limit
}
}
impl SharedMemory {
#[inline]
pub fn new() -> Self {
Self::with_capacity(4 * 1024) }
#[inline]
pub fn invalid() -> Self {
Self {
buffer: None,
my_checkpoint: 0,
child_checkpoint: None,
#[cfg(feature = "memory_limit")]
memory_limit: 0,
}
}
pub fn new_with_buffer(buffer: Rc<RefCell<Vec<u8>>>) -> Self {
Self {
buffer: Some(buffer),
my_checkpoint: 0,
child_checkpoint: None,
#[cfg(feature = "memory_limit")]
memory_limit: u64::MAX,
}
}
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
Self {
buffer: Some(Rc::new(RefCell::new(Vec::with_capacity(capacity)))),
my_checkpoint: 0,
child_checkpoint: None,
#[cfg(feature = "memory_limit")]
memory_limit: u64::MAX,
}
}
#[cfg(feature = "memory_limit")]
#[inline]
pub fn new_with_memory_limit(memory_limit: u64) -> Self {
Self {
memory_limit,
..Self::new()
}
}
#[inline]
pub fn set_memory_limit(&mut self, limit: u64) {
#[cfg(feature = "memory_limit")]
{
self.memory_limit = limit;
}
let _ = limit;
}
#[inline]
fn buffer(&self) -> &Rc<RefCell<Vec<u8>>> {
debug_assert!(self.buffer.is_some(), "cannot use SharedMemory::empty");
unsafe { self.buffer.as_ref().unwrap_unchecked() }
}
#[inline]
fn buffer_ref(&self) -> Ref<'_, Vec<u8>> {
self.buffer().dbg_borrow()
}
#[inline]
fn buffer_ref_mut(&self) -> RefMut<'_, Vec<u8>> {
self.buffer().dbg_borrow_mut()
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
fn slice_range_with_base(&self, range: Range<usize>, base: usize) -> Ref<'_, [u8]> {
let buffer = self.buffer_ref();
Ref::map(buffer, |b| {
let range = range.start + base..range.end + base;
match b.get(range.clone()) {
Some(slice) => slice,
None => debug_unreachable!("slice OOB: {range:?}; len: {}", self.len()),
}
})
}
#[inline]
pub fn new_child_context(&mut self) -> SharedMemory {
if self.child_checkpoint.is_some() {
panic!("new_child_context was already called without freeing child context");
}
let new_checkpoint = self.full_len();
self.child_checkpoint = Some(new_checkpoint);
SharedMemory {
buffer: Some(self.buffer().clone()),
my_checkpoint: new_checkpoint,
child_checkpoint: None,
#[cfg(feature = "memory_limit")]
memory_limit: self.memory_limit,
}
}
#[inline]
pub fn free_child_context(&mut self) {
let Some(child_checkpoint) = self.child_checkpoint.take() else {
return;
};
unsafe {
self.buffer_ref_mut().set_len(child_checkpoint);
}
}
#[inline]
pub fn len(&self) -> usize {
self.full_len() - self.my_checkpoint
}
fn full_len(&self) -> usize {
self.buffer_ref().len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn resize(&mut self, new_size: usize) {
self.buffer()
.dbg_borrow_mut()
.resize(self.my_checkpoint + new_size, 0);
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn slice_len(&self, offset: usize, size: usize) -> Ref<'_, [u8]> {
self.slice_range(offset..offset + size)
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn slice_range(&self, range: Range<usize>) -> Ref<'_, [u8]> {
self.slice_range_with_base(range, self.my_checkpoint)
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn global_slice_range(&self, range: Range<usize>) -> Ref<'_, [u8]> {
self.slice_range_with_base(range, 0)
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn slice_mut(&mut self, offset: usize, size: usize) -> RefMut<'_, [u8]> {
let buffer = self.buffer_ref_mut();
RefMut::map(buffer, |b| {
match b.get_mut(self.my_checkpoint + offset..self.my_checkpoint + offset + size) {
Some(slice) => slice,
None => debug_unreachable!("slice OOB: {offset}..{}", offset + size),
}
})
}
#[inline]
pub fn get_byte(&self, offset: usize) -> u8 {
self.slice_len(offset, 1)[0]
}
#[inline]
pub fn get_word(&self, offset: usize) -> B256 {
(*self.slice_len(offset, 32)).try_into().unwrap()
}
#[inline]
pub fn get_u256(&self, offset: usize) -> U256 {
self.get_word(offset).into()
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn set_byte(&mut self, offset: usize, byte: u8) {
self.set(offset, &[byte]);
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn set_word(&mut self, offset: usize, value: &B256) {
self.set(offset, &value[..]);
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn set_u256(&mut self, offset: usize, value: U256) {
self.set(offset, &value.to_be_bytes::<32>());
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn set(&mut self, offset: usize, value: &[u8]) {
if !value.is_empty() {
self.slice_mut(offset, value.len()).copy_from_slice(value);
}
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn set_data(&mut self, memory_offset: usize, data_offset: usize, len: usize, data: &[u8]) {
let mut dst = self.context_memory_mut();
unsafe { set_data(dst.as_mut(), data, memory_offset, data_offset, len) };
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn global_to_local_set_data(
&mut self,
memory_offset: usize,
data_offset: usize,
len: usize,
data_range: Range<usize>,
) {
let mut buffer = self.buffer_ref_mut();
let (src, dst) = buffer.split_at_mut(self.my_checkpoint);
let src = if data_range.is_empty() {
&mut []
} else {
src.get_mut(data_range).unwrap()
};
unsafe { set_data(dst, src, memory_offset, data_offset, len) };
}
#[inline]
#[cfg_attr(debug_assertions, track_caller)]
pub fn copy(&mut self, dst: usize, src: usize, len: usize) {
self.context_memory_mut().copy_within(src..src + len, dst);
}
#[inline]
pub fn context_memory(&self) -> Ref<'_, [u8]> {
let buffer = self.buffer_ref();
Ref::map(buffer, |b| match b.get(self.my_checkpoint..) {
Some(slice) => slice,
None => debug_unreachable!("Context memory should be always valid"),
})
}
#[inline]
pub fn context_memory_mut(&mut self) -> RefMut<'_, [u8]> {
let buffer = self.buffer_ref_mut();
RefMut::map(buffer, |b| match b.get_mut(self.my_checkpoint..) {
Some(slice) => slice,
None => debug_unreachable!("Context memory should be always valid"),
})
}
}
unsafe fn set_data(dst: &mut [u8], src: &[u8], dst_offset: usize, src_offset: usize, len: usize) {
if len == 0 {
return;
}
if src_offset >= src.len() {
dst.get_mut(dst_offset..dst_offset + len).unwrap().fill(0);
return;
}
let src_end = min(src_offset + len, src.len());
let src_len = src_end - src_offset;
debug_assert!(src_offset < src.len() && src_end <= src.len());
let data = unsafe { src.get_unchecked(src_offset..src_end) };
unsafe {
dst.get_unchecked_mut(dst_offset..dst_offset + src_len)
.copy_from_slice(data)
};
unsafe {
dst.get_unchecked_mut(dst_offset + src_len..dst_offset + len)
.fill(0)
};
}
#[inline]
pub const fn num_words(len: usize) -> usize {
len.div_ceil(32)
}
#[inline]
pub fn resize_memory<Memory: MemoryTr>(
gas: &mut crate::Gas,
memory: &mut Memory,
gas_table: &GasParams,
offset: usize,
len: usize,
) -> Result<(), InstructionResult> {
#[cfg(feature = "memory_limit")]
if memory.limit_reached(offset, len) {
return Err(InstructionResult::MemoryLimitOOG);
}
let new_num_words = num_words(offset.saturating_add(len));
if new_num_words > gas.memory().words_num {
return resize_memory_cold(gas, memory, gas_table, new_num_words);
}
Ok(())
}
#[cold]
#[inline(never)]
fn resize_memory_cold<Memory: MemoryTr>(
gas: &mut crate::Gas,
memory: &mut Memory,
gas_table: &GasParams,
new_num_words: usize,
) -> Result<(), InstructionResult> {
let cost = gas_table.memory_cost(new_num_words);
let cost = unsafe {
gas.memory_mut()
.set_words_num(new_num_words, cost)
.unwrap_unchecked()
};
if !gas.record_regular_cost(cost) {
return Err(InstructionResult::MemoryOOG);
}
memory.resize(new_num_words * 32);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_num_words() {
assert_eq!(num_words(0), 0);
assert_eq!(num_words(1), 1);
assert_eq!(num_words(31), 1);
assert_eq!(num_words(32), 1);
assert_eq!(num_words(33), 2);
assert_eq!(num_words(63), 2);
assert_eq!(num_words(64), 2);
assert_eq!(num_words(65), 3);
assert_eq!(num_words(usize::MAX - 31), usize::MAX / 32);
assert_eq!(num_words(usize::MAX - 30), (usize::MAX / 32) + 1);
assert_eq!(num_words(usize::MAX), (usize::MAX / 32) + 1);
}
#[test]
fn new_free_child_context() {
let mut sm1 = SharedMemory::new();
assert_eq!(sm1.buffer_ref().len(), 0);
assert_eq!(sm1.my_checkpoint, 0);
unsafe { sm1.buffer_ref_mut().set_len(32) };
assert_eq!(sm1.len(), 32);
let mut sm2 = sm1.new_child_context();
assert_eq!(sm2.buffer_ref().len(), 32);
assert_eq!(sm2.my_checkpoint, 32);
assert_eq!(sm2.len(), 0);
unsafe { sm2.buffer_ref_mut().set_len(96) };
assert_eq!(sm2.len(), 64);
let mut sm3 = sm2.new_child_context();
assert_eq!(sm3.buffer_ref().len(), 96);
assert_eq!(sm3.my_checkpoint, 96);
assert_eq!(sm3.len(), 0);
unsafe { sm3.buffer_ref_mut().set_len(128) };
let sm4 = sm3.new_child_context();
assert_eq!(sm4.buffer_ref().len(), 128);
assert_eq!(sm4.my_checkpoint, 128);
assert_eq!(sm4.len(), 0);
drop(sm4);
sm3.free_child_context();
assert_eq!(sm3.buffer_ref().len(), 128);
assert_eq!(sm3.my_checkpoint, 96);
assert_eq!(sm3.len(), 32);
sm2.free_child_context();
assert_eq!(sm2.buffer_ref().len(), 96);
assert_eq!(sm2.my_checkpoint, 32);
assert_eq!(sm2.len(), 64);
sm1.free_child_context();
assert_eq!(sm1.buffer_ref().len(), 32);
assert_eq!(sm1.my_checkpoint, 0);
assert_eq!(sm1.len(), 32);
}
#[test]
fn resize() {
let mut sm1 = SharedMemory::new();
sm1.resize(32);
assert_eq!(sm1.buffer_ref().len(), 32);
assert_eq!(sm1.len(), 32);
assert_eq!(sm1.buffer_ref().get(0..32), Some(&[0_u8; 32] as &[u8]));
let mut sm2 = sm1.new_child_context();
sm2.resize(96);
assert_eq!(sm2.buffer_ref().len(), 128);
assert_eq!(sm2.len(), 96);
assert_eq!(sm2.buffer_ref().get(32..128), Some(&[0_u8; 96] as &[u8]));
sm1.free_child_context();
assert_eq!(sm1.buffer_ref().len(), 32);
assert_eq!(sm1.len(), 32);
assert_eq!(sm1.buffer_ref().get(0..32), Some(&[0_u8; 32] as &[u8]));
}
}