Skip to main content

rialo_s_program_runtime/
mem_pool.rs

1// Copyright (c) Subzero Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3// This file is either (a) original to Subzero Labs, Inc. or (b) derived from the Anza codebase and modified by Subzero Labs, Inc.
4
5use std::array;
6
7use rialo_s_compute_budget::{
8    compute_budget::{MAX_CALL_DEPTH, MAX_INSTRUCTION_STACK_DEPTH, STACK_FRAME_SIZE},
9    compute_budget_limits::{MAX_HEAP_FRAME_BYTES, MIN_HEAP_FRAME_BYTES},
10};
11use solana_sbpf::{aligned_memory::AlignedMemory, ebpf::HOST_ALIGN};
12
13trait Reset {
14    fn reset(&mut self);
15}
16
17struct Pool<T: Reset, const SIZE: usize> {
18    items: [Option<T>; SIZE],
19    next_empty: usize,
20}
21
22impl<T: Reset, const SIZE: usize> Pool<T, SIZE> {
23    fn new(items: [T; SIZE]) -> Self {
24        Self {
25            items: items.map(|i| Some(i)),
26            next_empty: SIZE,
27        }
28    }
29
30    fn len(&self) -> usize {
31        SIZE
32    }
33
34    fn get(&mut self) -> Option<T> {
35        if self.next_empty == 0 {
36            return None;
37        }
38        self.next_empty = self.next_empty.saturating_sub(1);
39        self.items
40            .get_mut(self.next_empty)
41            .and_then(|item| item.take())
42    }
43
44    fn put(&mut self, mut value: T) -> bool {
45        self.items
46            .get_mut(self.next_empty)
47            .map(|item| {
48                value.reset();
49                item.replace(value);
50                self.next_empty = self.next_empty.saturating_add(1);
51                true
52            })
53            .unwrap_or(false)
54    }
55}
56
57impl Reset for AlignedMemory<{ HOST_ALIGN }> {
58    fn reset(&mut self) {
59        self.as_slice_mut().fill(0)
60    }
61}
62
63pub struct VmMemoryPool {
64    stack: Pool<AlignedMemory<{ HOST_ALIGN }>, MAX_INSTRUCTION_STACK_DEPTH>,
65    heap: Pool<AlignedMemory<{ HOST_ALIGN }>, MAX_INSTRUCTION_STACK_DEPTH>,
66}
67
68impl VmMemoryPool {
69    pub fn new() -> Self {
70        Self {
71            stack: Pool::new(array::from_fn(|_| {
72                AlignedMemory::zero_filled(STACK_FRAME_SIZE * MAX_CALL_DEPTH)
73            })),
74            heap: Pool::new(array::from_fn(|_| {
75                AlignedMemory::zero_filled(MAX_HEAP_FRAME_BYTES as usize)
76            })),
77        }
78    }
79
80    pub fn stack_len(&self) -> usize {
81        self.stack.len()
82    }
83
84    pub fn heap_len(&self) -> usize {
85        self.heap.len()
86    }
87
88    pub fn get_stack(&mut self, size: usize) -> AlignedMemory<{ HOST_ALIGN }> {
89        debug_assert!(size == STACK_FRAME_SIZE * MAX_CALL_DEPTH);
90        self.stack
91            .get()
92            .unwrap_or_else(|| AlignedMemory::zero_filled(size))
93    }
94
95    pub fn put_stack(&mut self, stack: AlignedMemory<{ HOST_ALIGN }>) -> bool {
96        self.stack.put(stack)
97    }
98
99    pub fn get_heap(&mut self, heap_size: u32) -> AlignedMemory<{ HOST_ALIGN }> {
100        debug_assert!((MIN_HEAP_FRAME_BYTES..=MAX_HEAP_FRAME_BYTES).contains(&heap_size));
101        self.heap
102            .get()
103            .unwrap_or_else(|| AlignedMemory::zero_filled(MAX_HEAP_FRAME_BYTES as usize))
104    }
105
106    pub fn put_heap(&mut self, heap: AlignedMemory<{ HOST_ALIGN }>) -> bool {
107        let heap_size = heap.len();
108        debug_assert!(
109            heap_size >= MIN_HEAP_FRAME_BYTES as usize
110                && heap_size <= MAX_HEAP_FRAME_BYTES as usize
111        );
112        self.heap.put(heap)
113    }
114}
115
116impl Default for VmMemoryPool {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122#[cfg(test)]
123mod test {
124    use super::*;
125
126    #[derive(Debug, Eq, PartialEq)]
127    struct Item(u8, u8);
128    impl Reset for Item {
129        fn reset(&mut self) {
130            self.1 = 0;
131        }
132    }
133
134    #[test]
135    fn test_pool() {
136        let mut pool = Pool::<Item, 2>::new([Item(0, 1), Item(1, 1)]);
137        assert_eq!(pool.get(), Some(Item(1, 1)));
138        assert_eq!(pool.get(), Some(Item(0, 1)));
139        assert_eq!(pool.get(), None);
140        pool.put(Item(1, 1));
141        assert_eq!(pool.get(), Some(Item(1, 0)));
142        pool.put(Item(2, 2));
143        pool.put(Item(3, 3));
144        assert!(!pool.put(Item(4, 4)));
145        assert_eq!(pool.get(), Some(Item(3, 0)));
146        assert_eq!(pool.get(), Some(Item(2, 0)));
147        assert_eq!(pool.get(), None);
148    }
149}