use std::ops::Deref;
use crypto_bigint::Encoding;
use crypto_bigint::U256;
use ff::PrimeField;
use wasmer::AsStoreRef;
use wasmer::Memory;
use wasmer::MemoryView;
use super::calculator::from_vec_u32;
use super::calculator::u256_to_vec_u32;
use crate::error::Result;
#[derive(Clone, Debug)]
pub struct SafeMemory {
pub memory: Memory,
pub prime: U256,
short_max: U256,
short_min: U256,
n32: usize,
}
impl Deref for SafeMemory {
type Target = Memory;
fn deref(&self) -> &Self::Target {
&self.memory
}
}
impl SafeMemory {
pub fn new(memory: Memory, n32: usize, prime: U256) -> Self {
let short_max = U256::from(0x8000_0000u64);
let short_min = short_max.neg_mod(&prime);
Self {
memory,
prime,
short_max,
short_min,
n32,
}
}
pub fn view<'a>(&self, store: &'a impl AsStoreRef) -> MemoryView<'a> {
self.memory.view(store)
}
pub fn free_pos(&self, store: &impl AsStoreRef) -> u32 {
self.read_u32(store, 0)
}
pub fn set_free_pos(&mut self, store: &impl AsStoreRef, ptr: u32) {
self.write_u32(store, 0, ptr);
}
pub fn alloc_u32(&mut self, store: &impl AsStoreRef) -> u32 {
let p = self.free_pos(store);
self.set_free_pos(store, p + 8);
p
}
pub fn write_u32(&mut self, store: &impl AsStoreRef, ptr: usize, num: u32) {
let view = self.view(store);
let buf = unsafe { view.data_unchecked_mut() };
buf[ptr..ptr + std::mem::size_of::<u32>()].copy_from_slice(&num.to_le_bytes());
}
pub fn read_u32(&self, store: &impl AsStoreRef, ptr: usize) -> u32 {
let view = self.view(store);
let buf = unsafe { view.data_unchecked() };
let mut bytes = [0; 4];
bytes.copy_from_slice(&buf[ptr..ptr + std::mem::size_of::<u32>()]);
u32::from_le_bytes(bytes)
}
pub fn alloc_fr(&mut self, store: &impl AsStoreRef) -> u32 {
let p = self.free_pos(store);
self.set_free_pos(store, p + self.n32 as u32 * 4 + 8);
p
}
pub fn write_fr(&mut self, store: &impl AsStoreRef, ptr: usize, fr: U256) -> Result<()> {
if fr < self.short_max && fr > self.short_min {
self.write_short(store, ptr, fr)?;
} else {
self.write_long_normal(store, ptr, fr)?;
}
Ok(())
}
pub fn read_fr<F: PrimeField>(&self, store: &impl AsStoreRef, ptr: usize) -> F {
let view = self.view(store);
let view = unsafe { view.data_unchecked_mut() };
if view[ptr + 7] & 0x80 != 0 {
let num = self.read_big(store, ptr + 8);
from_vec_u32(u256_to_vec_u32(num))
} else {
F::from(u64::from(self.read_u32(store, ptr)))
}
}
fn write_short(&mut self, store: &impl AsStoreRef, ptr: usize, fr: U256) -> Result<()> {
let num = fr.to_words()[0] as u32;
self.write_u32(store, ptr, num);
self.write_u32(store, ptr + 4, 0);
Ok(())
}
fn write_long_normal(&mut self, store: &impl AsStoreRef, ptr: usize, fr: U256) -> Result<()> {
self.write_u32(store, ptr, 0);
self.write_u32(store, ptr + 4, i32::MIN as u32); self.write_big(store, ptr + 8, fr)?;
Ok(())
}
fn write_big(&self, store: &impl AsStoreRef, ptr: usize, num: U256) -> Result<()> {
let view = self.view(store);
let buf = unsafe { view.data_unchecked_mut() };
let bytes: [u8; 32] = num.to_le_bytes();
buf[ptr..ptr + 32].copy_from_slice(&bytes);
Ok(())
}
pub fn read_big(&self, store: &impl AsStoreRef, ptr: usize) -> U256 {
let view = self.view(store);
let buf = unsafe { view.data_unchecked() };
U256::from_le_slice(&buf[ptr..])
}
}