use serde::{de::DeserializeOwned, Deserialize, Serialize};
use vec_map::VecMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound(serialize = "T: Serialize"))]
#[serde(bound(deserialize = "T: DeserializeOwned"))]
pub struct Memory<T: Copy> {
pub registers: Registers<T>,
pub page_table: PagedMemory<T>,
}
impl<V: Copy + 'static> IntoIterator for Memory<V> {
type Item = (u64, V);
type IntoIter = Box<dyn Iterator<Item = Self::Item>>;
fn into_iter(self) -> Self::IntoIter {
Box::new(self.registers.into_iter().chain(self.page_table))
}
}
impl<T: Copy + Default> Default for Memory<T> {
fn default() -> Self {
Self { registers: Registers::default(), page_table: PagedMemory::default() }
}
}
impl<T: Copy> Memory<T> {
pub fn new_preallocated() -> Self {
Self { registers: Registers::default(), page_table: PagedMemory::new_preallocated() }
}
#[inline]
pub fn entry(&mut self, addr: u64) -> Entry<'_, T> {
if addr < 32 {
self.registers.entry(addr)
} else {
self.page_table.entry(addr)
}
}
#[inline]
pub fn insert(&mut self, addr: u64, value: T) -> Option<T> {
if addr < 32 {
self.registers.insert(addr, value)
} else {
self.page_table.insert(addr, value)
}
}
#[inline]
pub fn get(&self, addr: u64) -> Option<&T> {
if addr < 32 {
self.registers.get(addr)
} else {
self.page_table.get(addr)
}
}
#[inline]
pub fn remove(&mut self, addr: u64) -> Option<T> {
if addr < 32 {
self.registers.remove(addr)
} else {
self.page_table.remove(addr)
}
}
#[inline]
pub fn clear(&mut self) {
self.registers.clear();
self.page_table.clear();
}
}
impl<V: Copy + Default> FromIterator<(u64, V)> for Memory<V> {
fn from_iter<T: IntoIterator<Item = (u64, V)>>(iter: T) -> Self {
let mut memory = Self::new_preallocated();
for (addr, value) in iter {
memory.insert(addr, value);
}
memory
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound(serialize = "T: Serialize"))]
#[serde(bound(deserialize = "T: DeserializeOwned"))]
pub struct Registers<T: Copy> {
pub registers: [Option<T>; 32],
}
impl<T: Copy> Default for Registers<T> {
fn default() -> Self {
Self { registers: [None; 32] }
}
}
impl<T: Copy> Registers<T> {
#[inline]
pub fn entry(&mut self, addr: u64) -> Entry<'_, T> {
let entry = &mut self.registers[addr as usize];
match entry {
Some(v) => Entry::Occupied(OccupiedEntry { entry: v }),
None => Entry::Vacant(VacantEntry { entry }),
}
}
#[inline]
pub fn insert(&mut self, addr: u64, value: T) -> Option<T> {
self.registers[addr as usize].replace(value)
}
#[inline]
pub fn remove(&mut self, addr: u64) -> Option<T> {
self.registers[addr as usize].take()
}
#[inline]
pub fn get(&self, addr: u64) -> Option<&T> {
self.registers[addr as usize].as_ref()
}
#[inline]
pub fn clear(&mut self) {
self.registers.fill(None);
}
}
impl<V: Copy> FromIterator<(u64, V)> for Registers<V> {
fn from_iter<T: IntoIterator<Item = (u64, V)>>(iter: T) -> Self {
let mut mmu = Self::default();
for (k, v) in iter {
mmu.insert(k, v);
}
mmu
}
}
impl<V: Copy + 'static> IntoIterator for Registers<V> {
type Item = (u64, V);
type IntoIter = Box<dyn Iterator<Item = Self::Item>>;
fn into_iter(self) -> Self::IntoIter {
Box::new(
self.registers
.into_iter()
.enumerate()
.filter_map(move |(i, v)| v.map(|v| (i as u64, v))),
)
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Page<V>(VecMap<V>);
impl<V> Default for Page<V> {
fn default() -> Self {
Self(VecMap::default())
}
}
pub(crate) const MAX_LOG_ADDR: usize = sp1_primitives::consts::MAX_JIT_LOG_ADDR;
const LOG_PAGE_LEN: usize = 18;
const PAGE_LEN: usize = 1 << LOG_PAGE_LEN;
const MAX_PAGE_COUNT: usize = (1 << MAX_LOG_ADDR) / 8 / PAGE_LEN;
const NO_PAGE: u32 = u32::MAX;
const PAGE_MASK: usize = PAGE_LEN - 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound(serialize = "V: Serialize"))]
#[serde(bound(deserialize = "V: DeserializeOwned"))]
pub struct NewPage<V>(Vec<Option<V>>);
impl<V: Copy> NewPage<V> {
pub fn new() -> Self {
Self(vec![None; PAGE_LEN])
}
}
impl<V: Copy> Default for NewPage<V> {
fn default() -> Self {
Self(Vec::new())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound(serialize = "V: Serialize"))]
#[serde(bound(deserialize = "V: DeserializeOwned"))]
pub struct PagedMemory<V: Copy> {
pub page_table: Vec<NewPage<V>>,
pub index: Vec<u32>,
}
impl<V: Copy> PagedMemory<V> {
const NUM_IGNORED_LOWER_BITS: usize = 3;
pub fn new_preallocated() -> Self {
Self { page_table: Vec::new(), index: vec![NO_PAGE; MAX_PAGE_COUNT] }
}
pub fn get(&self, addr: u64) -> Option<&V> {
let (upper, lower) = Self::indices(addr);
let index = self.index[upper];
if index == NO_PAGE {
None
} else {
self.page_table[index as usize].0[lower].as_ref()
}
}
pub fn get_mut(&mut self, addr: u64) -> Option<&mut V> {
let (upper, lower) = Self::indices(addr);
let index = self.index[upper];
if index == NO_PAGE {
None
} else {
self.page_table[index as usize].0[lower].as_mut()
}
}
pub fn insert(&mut self, addr: u64, value: V) -> Option<V> {
let (upper, lower) = Self::indices(addr);
let mut index = self.index[upper];
if index == NO_PAGE {
index = self.page_table.len() as u32;
self.index[upper] = index;
self.page_table.push(NewPage::new());
}
self.page_table[index as usize].0[lower].replace(value)
}
pub fn remove(&mut self, addr: u64) -> Option<V> {
let (upper, lower) = Self::indices(addr);
let index = self.index[upper];
if index == NO_PAGE {
None
} else {
self.page_table[index as usize].0[lower].take()
}
}
pub fn entry(&mut self, addr: u64) -> Entry<'_, V> {
let (upper, lower) = Self::indices(addr);
let index = self.index[upper];
if index == NO_PAGE {
let index = self.page_table.len();
self.index[upper] = index as u32;
self.page_table.push(NewPage::new());
Entry::Vacant(VacantEntry { entry: &mut self.page_table[index].0[lower] })
} else {
let option = &mut self.page_table[index as usize].0[lower];
match option {
Some(v) => Entry::Occupied(OccupiedEntry { entry: v }),
None => Entry::Vacant(VacantEntry { entry: option }),
}
}
}
pub fn keys(&self) -> impl Iterator<Item = u64> + '_ {
self.index.iter().enumerate().filter(|(_, &i)| i != NO_PAGE).flat_map(|(i, index)| {
let upper = i << LOG_PAGE_LEN;
self.page_table[*index as usize]
.0
.iter()
.enumerate()
.filter_map(move |(lower, v)| v.map(|_| Self::decompress_addr(upper + lower)))
})
}
pub fn exact_len(&self) -> usize {
self.index
.iter()
.filter(|&&i| i != NO_PAGE)
.map(|index| self.page_table[*index as usize].0.iter().filter(|v| v.is_some()).count())
.sum()
}
pub fn estimate_len(&self) -> usize {
self.index.iter().filter(|&i| *i != NO_PAGE).count() * PAGE_LEN
}
pub fn clear(&mut self) {
self.page_table.clear();
self.index.fill(NO_PAGE);
}
#[inline]
const fn indices(addr: u64) -> (usize, usize) {
let index = Self::compress_addr(addr);
(index >> LOG_PAGE_LEN, index & PAGE_MASK)
}
#[inline]
const fn compress_addr(addr: u64) -> usize {
addr as usize >> Self::NUM_IGNORED_LOWER_BITS
}
#[inline]
const fn decompress_addr(addr: usize) -> u64 {
(addr << Self::NUM_IGNORED_LOWER_BITS) as u64
}
}
impl<V: Copy> Default for PagedMemory<V> {
fn default() -> Self {
Self { page_table: Vec::new(), index: vec![NO_PAGE; MAX_PAGE_COUNT] }
}
}
pub enum Entry<'a, V: Copy> {
Vacant(VacantEntry<'a, V>),
Occupied(OccupiedEntry<'a, V>),
}
impl<'a, V: Copy> Entry<'a, V> {
pub fn or_insert(self, default: V) -> &'a mut V {
match self {
Entry::Vacant(entry) => entry.insert(default),
Entry::Occupied(entry) => entry.into_mut(),
}
}
pub fn or_insert_with<F: FnOnce() -> V>(self, default: F) -> &'a mut V {
match self {
Entry::Vacant(entry) => entry.insert(default()),
Entry::Occupied(entry) => entry.into_mut(),
}
}
pub fn and_modify<F: FnOnce(&mut V)>(mut self, f: F) -> Self {
match &mut self {
Entry::Vacant(_) => {}
Entry::Occupied(entry) => f(entry.get_mut()),
}
self
}
}
impl<'a, V: Copy + Default> Entry<'a, V> {
pub fn or_default(self) -> &'a mut V {
self.or_insert_with(Default::default)
}
}
pub struct VacantEntry<'a, V: Copy> {
entry: &'a mut Option<V>,
}
impl<'a, V: Copy> VacantEntry<'a, V> {
pub fn insert(self, value: V) -> &'a mut V {
*self.entry = Some(value);
self.entry.as_mut().unwrap()
}
}
pub struct OccupiedEntry<'a, V> {
entry: &'a mut V,
}
impl<'a, V: Copy> OccupiedEntry<'a, V> {
pub fn get(&self) -> &V {
self.entry
}
pub fn get_mut(&mut self) -> &mut V {
self.entry
}
pub fn insert(&mut self, value: V) -> V {
std::mem::replace(self.entry, value)
}
pub fn into_mut(self) -> &'a mut V {
self.entry
}
pub fn remove(self) -> V {
*self.entry
}
}
impl<V: Copy> FromIterator<(u64, V)> for PagedMemory<V> {
fn from_iter<T: IntoIterator<Item = (u64, V)>>(iter: T) -> Self {
let mut mmu = Self::new_preallocated();
for (k, v) in iter {
mmu.insert(k, v);
}
mmu
}
}
impl<V: Copy + 'static> IntoIterator for PagedMemory<V> {
type Item = (u64, V);
type IntoIter = Box<dyn Iterator<Item = Self::Item>>;
fn into_iter(mut self) -> Self::IntoIter {
Box::new(self.index.into_iter().enumerate().filter(|(_, i)| *i != NO_PAGE).flat_map(
move |(i, index)| {
let upper = i << LOG_PAGE_LEN;
std::mem::take(&mut self.page_table[index as usize])
.0
.into_iter()
.enumerate()
.filter_map(move |(lower, v)| {
v.map(|v| (Self::decompress_addr(upper + lower), v))
})
},
))
}
}