use super::dependencies::{MemSlice, MemoryLike};
use super::errors::{HostError, VMLogicError};
use super::gas_counter::GasCounter;
use core::mem::size_of;
use std::borrow::Cow;
use std::collections::hash_map::Entry;
use unc_parameters::vm::LimitConfig;
use unc_parameters::ExtCosts::*;
type Result<T> = ::std::result::Result<T, VMLogicError>;
pub(super) struct Memory<'a>(&'a mut dyn MemoryLike);
macro_rules! memory_get {
($_type:ty, $name:ident) => {
pub(super) fn $name(
&mut self,
gas_counter: &mut GasCounter,
offset: u64,
) -> Result<$_type> {
let mut array = [0u8; size_of::<$_type>()];
self.get_into(gas_counter, offset, &mut array)?;
Ok(<$_type>::from_le_bytes(array))
}
};
}
macro_rules! memory_set {
($_type:ty, $name:ident) => {
pub(super) fn $name(
&mut self,
gas_counter: &mut GasCounter,
offset: u64,
value: $_type,
) -> Result<()> {
self.set(gas_counter, offset, &value.to_le_bytes())
}
};
}
impl<'a> Memory<'a> {
pub(super) fn new(mem: &'a mut dyn MemoryLike) -> Self {
Self(mem)
}
pub(super) fn view<'s>(
&'s self,
gas_counter: &mut GasCounter,
slice: MemSlice,
) -> Result<Cow<'s, [u8]>> {
gas_counter.pay_base(read_memory_base)?;
gas_counter.pay_per(read_memory_byte, slice.len)?;
self.0.view_memory(slice).map_err(|_| HostError::MemoryAccessViolation.into())
}
pub(super) fn view_for_free(&self, slice: MemSlice) -> Result<Cow<[u8]>> {
self.0.view_memory(slice).map_err(|_| HostError::MemoryAccessViolation.into())
}
fn get_into(&self, gas_counter: &mut GasCounter, offset: u64, buf: &mut [u8]) -> Result<()> {
gas_counter.pay_base(read_memory_base)?;
let len = u64::try_from(buf.len()).map_err(|_| HostError::MemoryAccessViolation)?;
gas_counter.pay_per(read_memory_byte, len)?;
self.0.read_memory(offset, buf).map_err(|_| HostError::MemoryAccessViolation.into())
}
pub(super) fn set(
&mut self,
gas_counter: &mut GasCounter,
offset: u64,
buf: &[u8],
) -> Result<()> {
gas_counter.pay_base(write_memory_base)?;
gas_counter.pay_per(write_memory_byte, buf.len() as _)?;
self.0.write_memory(offset, buf).map_err(|_| HostError::MemoryAccessViolation.into())
}
#[cfg(test)]
pub(super) fn set_for_free(&mut self, offset: u64, buf: &[u8]) -> Result<()> {
self.0.write_memory(offset, buf).map_err(|_| HostError::MemoryAccessViolation.into())
}
memory_get!(u128, get_u128);
memory_get!(u32, get_u32);
memory_get!(u16, get_u16);
memory_get!(u8, get_u8);
memory_set!(u128, set_u128);
}
#[derive(Default, Clone)]
pub(super) struct Registers {
registers: std::collections::HashMap<u64, Box<[u8]>>,
total_memory_usage: u64,
}
impl Registers {
pub(super) fn get<'s>(
&'s self,
gas_counter: &mut GasCounter,
register_id: u64,
) -> Result<&'s [u8]> {
if let Some(data) = self.registers.get(®ister_id) {
gas_counter.pay_base(read_register_base)?;
let len = u64::try_from(data.len()).map_err(|_| HostError::MemoryAccessViolation)?;
gas_counter.pay_per(read_register_byte, len)?;
Ok(&data[..])
} else {
Err(HostError::InvalidRegisterId { register_id }.into())
}
}
#[cfg(test)]
pub(super) fn get_for_free<'s>(&'s self, register_id: u64) -> Option<&'s [u8]> {
self.registers.get(®ister_id).map(|data| &data[..])
}
pub(super) fn get_len(&self, register_id: u64) -> Option<u64> {
self.registers.get(®ister_id).map(|data| data.len() as u64)
}
pub(super) fn set<T>(
&mut self,
gas_counter: &mut GasCounter,
config: &LimitConfig,
register_id: u64,
data: T,
) -> Result<()>
where
T: Into<Box<[u8]>> + AsRef<[u8]>,
{
let data_len =
u64::try_from(data.as_ref().len()).map_err(|_| HostError::MemoryAccessViolation)?;
gas_counter.pay_base(write_register_base)?;
gas_counter.pay_per(write_register_byte, data_len)?;
let entry = self.check_set_register(config, register_id, data_len)?;
let data = data.into();
match entry {
Entry::Occupied(mut entry) => {
entry.insert(data);
}
Entry::Vacant(entry) => {
entry.insert(data);
}
};
Ok(())
}
fn check_set_register<'a>(
&'a mut self,
config: &LimitConfig,
register_id: u64,
data_len: u64,
) -> Result<Entry<'a, u64, Box<[u8]>>> {
if data_len > config.max_register_size {
return Err(HostError::MemoryAccessViolation.into());
}
if self.registers.len() as u64 >= config.max_number_registers {
return Err(HostError::MemoryAccessViolation.into());
}
let entry = self.registers.entry(register_id);
let calc_usage = |len: u64| len + size_of::<u64>() as u64;
let old_mem_usage = match &entry {
Entry::Occupied(entry) => calc_usage(entry.get().len() as u64),
Entry::Vacant(_) => 0,
};
let usage = self
.total_memory_usage
.checked_sub(old_mem_usage)
.unwrap()
.checked_add(calc_usage(data_len))
.ok_or(HostError::MemoryAccessViolation)?;
if usage > config.registers_memory_limit {
return Err(HostError::MemoryAccessViolation.into());
}
self.total_memory_usage = usage;
Ok(entry)
}
}
pub(super) fn get_memory_or_register<'a, 'b>(
gas_counter: &mut GasCounter,
memory: &'b Memory<'a>,
registers: &'b Registers,
ptr: u64,
len: u64,
) -> Result<Cow<'b, [u8]>> {
if len == u64::MAX {
registers.get(gas_counter, ptr).map(Cow::Borrowed)
} else {
memory.view(gas_counter, MemSlice { ptr, len })
}
}
#[cfg(test)]
mod tests {
use super::HostError;
use super::Registers;
use crate::logic::gas_counter::GasCounter;
use crate::logic::LimitConfig;
use crate::tests::test_vm_config;
use unc_parameters::ExtCostsConfig;
struct RegistersTestContext {
gas: GasCounter,
cfg: LimitConfig,
regs: Registers,
}
impl RegistersTestContext {
fn new() -> Self {
let costs = ExtCostsConfig::test();
Self {
gas: GasCounter::new(costs, u64::MAX, 0, u64::MAX, false),
cfg: test_vm_config().limit_config,
regs: Default::default(),
}
}
#[track_caller]
fn assert_set_success(&mut self, register_id: u64, value: &str) {
self.regs.set(&mut self.gas, &self.cfg, register_id, value.as_bytes()).unwrap();
self.assert_read(register_id, Some(value));
}
#[track_caller]
fn assert_set_failure(&mut self, register_id: u64, value: &str) {
let want = Err(HostError::MemoryAccessViolation.into());
let got = self.regs.set(&mut self.gas, &self.cfg, register_id, value.as_bytes());
assert_eq!(want, got);
}
#[track_caller]
fn assert_read(&mut self, register_id: u64, value: Option<&str>) {
if let Some(value) = value {
assert_eq!(Ok(value.as_bytes()), self.regs.get(&mut self.gas, register_id));
assert_eq!(Some(value.len() as u64), self.regs.get_len(register_id));
} else {
let err = HostError::InvalidRegisterId { register_id }.into();
assert_eq!(Err(err), self.regs.get(&mut self.gas, register_id));
assert_eq!(None, self.regs.get_len(register_id));
}
}
#[track_caller]
fn assert_used_gas(&self, gas: u64) {
assert_eq!((gas, gas), (self.gas.burnt_gas(), self.gas.used_gas()));
}
}
#[test]
fn registers_set() {
let mut ctx = RegistersTestContext::new();
ctx.assert_read(42, None);
ctx.assert_read(24, None);
ctx.assert_set_success(42, "foo");
ctx.assert_read(24, None);
ctx.assert_used_gas(5394388050);
}
#[test]
fn registers_max_number_limit() {
let mut ctx = RegistersTestContext::new();
ctx.cfg.max_number_registers = 2;
ctx.assert_set_success(42, "foo");
ctx.assert_set_success(24, "bar");
ctx.assert_set_failure(12, "baz");
ctx.assert_set_failure(42, "O_o");
ctx.assert_set_failure(24, "O_o");
ctx.assert_used_gas(19419557634);
}
#[test]
fn registers_register_size_limit() {
let mut ctx = RegistersTestContext::new();
ctx.cfg.max_register_size = 3;
ctx.assert_set_success(42, "foo");
ctx.assert_set_failure(24, "quux");
ctx.assert_used_gas(8275116792);
}
#[test]
fn registers_usage_limit() {
let mut ctx = RegistersTestContext::new();
ctx.cfg.registers_memory_limit = 11;
ctx.assert_set_success(42, "foo");
ctx.assert_set_success(42, "bar");
ctx.assert_set_success(42, "");
ctx.assert_set_success(42, "baz");
ctx.assert_set_failure(42, "quux");
ctx.assert_used_gas(24446580564);
}
}