Skip to main content

sp1_core_executor/
memory.rs

1use serde::{de::DeserializeOwned, Deserialize, Serialize};
2use vec_map::VecMap;
3/// A memory.
4///
5/// Consists of registers, as well as a page table for main memory.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(bound(serialize = "T: Serialize"))]
8#[serde(bound(deserialize = "T: DeserializeOwned"))]
9pub struct Memory<T: Copy> {
10    /// The registers.
11    pub registers: Registers<T>,
12    /// The page table.
13    pub page_table: PagedMemory<T>,
14}
15
16impl<V: Copy + 'static> IntoIterator for Memory<V> {
17    type Item = (u64, V);
18
19    type IntoIter = Box<dyn Iterator<Item = Self::Item>>;
20
21    fn into_iter(self) -> Self::IntoIter {
22        Box::new(self.registers.into_iter().chain(self.page_table))
23    }
24}
25
26impl<T: Copy + Default> Default for Memory<T> {
27    fn default() -> Self {
28        Self { registers: Registers::default(), page_table: PagedMemory::default() }
29    }
30}
31
32impl<T: Copy> Memory<T> {
33    /// Initialize a new memory with preallocated page table.
34    pub fn new_preallocated() -> Self {
35        Self { registers: Registers::default(), page_table: PagedMemory::new_preallocated() }
36    }
37
38    /// Get an entry for the given address.
39    ///
40    /// When possible, prefer directly accessing the `page_table` or `registers` fields.
41    /// This method often incurs unnecessary branching.
42    #[inline]
43    pub fn entry(&mut self, addr: u64) -> Entry<'_, T> {
44        if addr < 32 {
45            self.registers.entry(addr)
46        } else {
47            self.page_table.entry(addr)
48        }
49    }
50
51    /// Insert a value into the memory.
52    ///
53    /// When possible, prefer directly accessing the `page_table` or `registers` fields.
54    /// This method often incurs unnecessary branching.   
55    #[inline]
56    pub fn insert(&mut self, addr: u64, value: T) -> Option<T> {
57        if addr < 32 {
58            self.registers.insert(addr, value)
59        } else {
60            self.page_table.insert(addr, value)
61        }
62    }
63
64    /// Get a value from the memory.
65    ///
66    /// When possible, prefer directly accessing the `page_table` or `registers` fields.
67    /// This method often incurs unnecessary branching.
68    #[inline]
69    pub fn get(&self, addr: u64) -> Option<&T> {
70        if addr < 32 {
71            self.registers.get(addr)
72        } else {
73            self.page_table.get(addr)
74        }
75    }
76
77    /// Remove a value from the memory.
78    ///
79    /// When possible, prefer directly accessing the `page_table` or `registers` fields.
80    /// This method often incurs unnecessary branching.
81    #[inline]
82    pub fn remove(&mut self, addr: u64) -> Option<T> {
83        if addr < 32 {
84            self.registers.remove(addr)
85        } else {
86            self.page_table.remove(addr)
87        }
88    }
89
90    /// Clear the memory.
91    #[inline]
92    pub fn clear(&mut self) {
93        self.registers.clear();
94        self.page_table.clear();
95    }
96}
97
98impl<V: Copy + Default> FromIterator<(u64, V)> for Memory<V> {
99    fn from_iter<T: IntoIterator<Item = (u64, V)>>(iter: T) -> Self {
100        let mut memory = Self::new_preallocated();
101        for (addr, value) in iter {
102            memory.insert(addr, value);
103        }
104        memory
105    }
106}
107
108/// An array of 32 registers.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110#[serde(bound(serialize = "T: Serialize"))]
111#[serde(bound(deserialize = "T: DeserializeOwned"))]
112pub struct Registers<T: Copy> {
113    pub registers: [Option<T>; 32],
114}
115
116impl<T: Copy> Default for Registers<T> {
117    fn default() -> Self {
118        Self { registers: [None; 32] }
119    }
120}
121
122impl<T: Copy> Registers<T> {
123    /// Get an entry for the given register.
124    #[inline]
125    pub fn entry(&mut self, addr: u64) -> Entry<'_, T> {
126        let entry = &mut self.registers[addr as usize];
127        match entry {
128            Some(v) => Entry::Occupied(OccupiedEntry { entry: v }),
129            None => Entry::Vacant(VacantEntry { entry }),
130        }
131    }
132
133    /// Insert a value into the registers.
134    ///
135    /// Assumes addr < 32.
136    #[inline]
137    pub fn insert(&mut self, addr: u64, value: T) -> Option<T> {
138        self.registers[addr as usize].replace(value)
139    }
140
141    /// Remove a value from the registers, and return it if it exists.
142    ///
143    /// Assumes addr < 32.
144    #[inline]
145    pub fn remove(&mut self, addr: u64) -> Option<T> {
146        self.registers[addr as usize].take()
147    }
148
149    /// Get a reference to the value at the given address, if it exists.
150    ///
151    /// Assumes addr < 32.
152    #[inline]
153    pub fn get(&self, addr: u64) -> Option<&T> {
154        self.registers[addr as usize].as_ref()
155    }
156
157    /// Clear the registers.
158    #[inline]
159    pub fn clear(&mut self) {
160        self.registers.fill(None);
161    }
162}
163
164impl<V: Copy> FromIterator<(u64, V)> for Registers<V> {
165    fn from_iter<T: IntoIterator<Item = (u64, V)>>(iter: T) -> Self {
166        let mut mmu = Self::default();
167        for (k, v) in iter {
168            mmu.insert(k, v);
169        }
170        mmu
171    }
172}
173
174impl<V: Copy + 'static> IntoIterator for Registers<V> {
175    type Item = (u64, V);
176
177    type IntoIter = Box<dyn Iterator<Item = Self::Item>>;
178
179    fn into_iter(self) -> Self::IntoIter {
180        Box::new(
181            self.registers
182                .into_iter()
183                .enumerate()
184                .filter_map(move |(i, v)| v.map(|v| (i as u64, v))),
185        )
186    }
187}
188
189/// A page of memory.
190#[allow(dead_code)]
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct Page<V>(VecMap<V>);
193
194impl<V> Default for Page<V> {
195    fn default() -> Self {
196        Self(VecMap::default())
197    }
198}
199
200pub(crate) const MAX_LOG_ADDR: usize = 40;
201const LOG_PAGE_LEN: usize = 18;
202const PAGE_LEN: usize = 1 << LOG_PAGE_LEN;
203const MAX_PAGE_COUNT: usize = (1 << MAX_LOG_ADDR) / 8 / PAGE_LEN;
204const NO_PAGE: u32 = u32::MAX;
205const PAGE_MASK: usize = PAGE_LEN - 1;
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
208#[serde(bound(serialize = "V: Serialize"))]
209#[serde(bound(deserialize = "V: DeserializeOwned"))]
210pub struct NewPage<V>(Vec<Option<V>>);
211
212impl<V: Copy> NewPage<V> {
213    pub fn new() -> Self {
214        Self(vec![None; PAGE_LEN])
215    }
216}
217
218impl<V: Copy> Default for NewPage<V> {
219    fn default() -> Self {
220        Self(Vec::new())
221    }
222}
223
224/// Paged memory. Balances both memory locality and total memory usage.
225#[derive(Debug, Clone, Serialize, Deserialize)]
226#[serde(bound(serialize = "V: Serialize"))]
227#[serde(bound(deserialize = "V: DeserializeOwned"))]
228pub struct PagedMemory<V: Copy> {
229    /// The internal page table.
230    pub page_table: Vec<NewPage<V>>,
231    pub index: Vec<u32>,
232}
233
234impl<V: Copy> PagedMemory<V> {
235    /// The number of lower bits to ignore, since addresses (except registers) are a multiple of 8.
236    const NUM_IGNORED_LOWER_BITS: usize = 3;
237
238    /// Create a `PagedMemory` with capacity `MAX_PAGE_COUNT`.
239    pub fn new_preallocated() -> Self {
240        Self { page_table: Vec::new(), index: vec![NO_PAGE; MAX_PAGE_COUNT] }
241    }
242
243    /// Get a reference to the memory value at the given address, if it exists.
244    pub fn get(&self, addr: u64) -> Option<&V> {
245        let (upper, lower) = Self::indices(addr);
246        let index = self.index[upper];
247        if index == NO_PAGE {
248            None
249        } else {
250            self.page_table[index as usize].0[lower].as_ref()
251        }
252    }
253
254    /// Get a mutable reference to the memory value at the given address, if it exists.
255    pub fn get_mut(&mut self, addr: u64) -> Option<&mut V> {
256        let (upper, lower) = Self::indices(addr);
257        let index = self.index[upper];
258        if index == NO_PAGE {
259            None
260        } else {
261            self.page_table[index as usize].0[lower].as_mut()
262        }
263    }
264
265    /// Insert a value at the given address. Returns the previous value, if any.
266    pub fn insert(&mut self, addr: u64, value: V) -> Option<V> {
267        let (upper, lower) = Self::indices(addr);
268        let mut index = self.index[upper];
269        if index == NO_PAGE {
270            index = self.page_table.len() as u32;
271            self.index[upper] = index;
272            self.page_table.push(NewPage::new());
273        }
274        self.page_table[index as usize].0[lower].replace(value)
275    }
276
277    /// Remove the value at the given address if it exists, returning it.
278    pub fn remove(&mut self, addr: u64) -> Option<V> {
279        let (upper, lower) = Self::indices(addr);
280        let index = self.index[upper];
281        if index == NO_PAGE {
282            None
283        } else {
284            self.page_table[index as usize].0[lower].take()
285        }
286    }
287
288    /// Gets the memory entry for the given address.
289    pub fn entry(&mut self, addr: u64) -> Entry<'_, V> {
290        let (upper, lower) = Self::indices(addr);
291        let index = self.index[upper];
292        if index == NO_PAGE {
293            let index = self.page_table.len();
294            self.index[upper] = index as u32;
295            self.page_table.push(NewPage::new());
296            Entry::Vacant(VacantEntry { entry: &mut self.page_table[index].0[lower] })
297        } else {
298            let option = &mut self.page_table[index as usize].0[lower];
299            match option {
300                Some(v) => Entry::Occupied(OccupiedEntry { entry: v }),
301                None => Entry::Vacant(VacantEntry { entry: option }),
302            }
303        }
304    }
305
306    /// Returns an iterator over the occupied addresses.
307    pub fn keys(&self) -> impl Iterator<Item = u64> + '_ {
308        self.index.iter().enumerate().filter(|(_, &i)| i != NO_PAGE).flat_map(|(i, index)| {
309            let upper = i << LOG_PAGE_LEN;
310            self.page_table[*index as usize]
311                .0
312                .iter()
313                .enumerate()
314                .filter_map(move |(lower, v)| v.map(|_| Self::decompress_addr(upper + lower)))
315        })
316    }
317
318    /// Get the exact number of addresses in use. This function iterates through each page
319    /// and is therefore somewhat expensive.
320    pub fn exact_len(&self) -> usize {
321        self.index
322            .iter()
323            .filter(|&&i| i != NO_PAGE)
324            .map(|index| self.page_table[*index as usize].0.iter().filter(|v| v.is_some()).count())
325            .sum()
326    }
327
328    /// Estimate the number of addresses in use.
329    pub fn estimate_len(&self) -> usize {
330        self.index.iter().filter(|&i| *i != NO_PAGE).count() * PAGE_LEN
331    }
332
333    /// Clears the page table. Drops all `Page`s, but retains the memory used by the table itself.
334    pub fn clear(&mut self) {
335        self.page_table.clear();
336        self.index.fill(NO_PAGE);
337    }
338
339    /// Break apart an address into an upper and lower index.
340    #[inline]
341    const fn indices(addr: u64) -> (usize, usize) {
342        let index = Self::compress_addr(addr);
343        (index >> LOG_PAGE_LEN, index & PAGE_MASK)
344    }
345
346    /// Compress an address from the sparse address space to a contiguous space.
347    #[inline]
348    const fn compress_addr(addr: u64) -> usize {
349        addr as usize >> Self::NUM_IGNORED_LOWER_BITS
350    }
351
352    /// Decompress an address from a contiguous space to the sparse address space.
353    #[inline]
354    const fn decompress_addr(addr: usize) -> u64 {
355        (addr << Self::NUM_IGNORED_LOWER_BITS) as u64
356    }
357}
358
359impl<V: Copy> Default for PagedMemory<V> {
360    fn default() -> Self {
361        Self { page_table: Vec::new(), index: vec![NO_PAGE; MAX_PAGE_COUNT] }
362    }
363}
364
365/// An entry of `PagedMemory` or `Registers`, for in-place manipulation.
366pub enum Entry<'a, V: Copy> {
367    Vacant(VacantEntry<'a, V>),
368    Occupied(OccupiedEntry<'a, V>),
369}
370
371impl<'a, V: Copy> Entry<'a, V> {
372    /// Ensures a value is in the entry, inserting the provided value if necessary.
373    /// Returns a mutable reference to the value.
374    pub fn or_insert(self, default: V) -> &'a mut V {
375        match self {
376            Entry::Vacant(entry) => entry.insert(default),
377            Entry::Occupied(entry) => entry.into_mut(),
378        }
379    }
380
381    /// Ensures a value is in the entry, computing a value if necessary.
382    /// Returns a mutable reference to the value.
383    pub fn or_insert_with<F: FnOnce() -> V>(self, default: F) -> &'a mut V {
384        match self {
385            Entry::Vacant(entry) => entry.insert(default()),
386            Entry::Occupied(entry) => entry.into_mut(),
387        }
388    }
389
390    /// Provides in-place mutable access to an occupied entry before any potential inserts into the
391    /// map.
392    pub fn and_modify<F: FnOnce(&mut V)>(mut self, f: F) -> Self {
393        match &mut self {
394            Entry::Vacant(_) => {}
395            Entry::Occupied(entry) => f(entry.get_mut()),
396        }
397        self
398    }
399}
400
401impl<'a, V: Copy + Default> Entry<'a, V> {
402    /// Ensures a value is in the entry, inserting the default value if necessary.
403    /// Returns a mutable reference to the value.
404    pub fn or_default(self) -> &'a mut V {
405        self.or_insert_with(Default::default)
406    }
407}
408
409/// A vacant entry, for in-place manipulation.
410pub struct VacantEntry<'a, V: Copy> {
411    entry: &'a mut Option<V>,
412}
413
414impl<'a, V: Copy> VacantEntry<'a, V> {
415    /// Insert a value into the `VacantEntry`, returning a mutable reference to it.
416    pub fn insert(self, value: V) -> &'a mut V {
417        // By construction, the slot in the page is `None`.
418        *self.entry = Some(value);
419        self.entry.as_mut().unwrap()
420    }
421}
422
423/// An occupied entry, for in-place manipulation.
424pub struct OccupiedEntry<'a, V> {
425    entry: &'a mut V,
426}
427
428impl<'a, V: Copy> OccupiedEntry<'a, V> {
429    /// Get a reference to the value in the `OccupiedEntry`.
430    pub fn get(&self) -> &V {
431        self.entry
432    }
433
434    /// Get a mutable reference to the value in the `OccupiedEntry`.
435    pub fn get_mut(&mut self) -> &mut V {
436        self.entry
437    }
438
439    /// Insert a value in the `OccupiedEntry`, returning the previous value.
440    pub fn insert(&mut self, value: V) -> V {
441        std::mem::replace(self.entry, value)
442    }
443
444    /// Converts the `OccupiedEntry` the into a mutable reference to the associated value.
445    pub fn into_mut(self) -> &'a mut V {
446        self.entry
447    }
448
449    /// Removes the value from the `OccupiedEntry` and returns it.
450    pub fn remove(self) -> V {
451        *self.entry
452    }
453}
454
455impl<V: Copy> FromIterator<(u64, V)> for PagedMemory<V> {
456    fn from_iter<T: IntoIterator<Item = (u64, V)>>(iter: T) -> Self {
457        let mut mmu = Self::new_preallocated();
458        for (k, v) in iter {
459            mmu.insert(k, v);
460        }
461        mmu
462    }
463}
464
465impl<V: Copy + 'static> IntoIterator for PagedMemory<V> {
466    type Item = (u64, V);
467
468    type IntoIter = Box<dyn Iterator<Item = Self::Item>>;
469
470    fn into_iter(mut self) -> Self::IntoIter {
471        Box::new(self.index.into_iter().enumerate().filter(|(_, i)| *i != NO_PAGE).flat_map(
472            move |(i, index)| {
473                let upper = i << LOG_PAGE_LEN;
474                std::mem::take(&mut self.page_table[index as usize])
475                    .0
476                    .into_iter()
477                    .enumerate()
478                    .filter_map(move |(lower, v)| {
479                        v.map(|v| (Self::decompress_addr(upper + lower), v))
480                    })
481            },
482        ))
483    }
484}