rialo_s_program_runtime/
mem_pool.rs

1use std::array;
2
3use rialo_s_compute_budget::{
4    compute_budget::{MAX_CALL_DEPTH, MAX_INSTRUCTION_STACK_DEPTH, STACK_FRAME_SIZE},
5    compute_budget_limits::{MAX_HEAP_FRAME_BYTES, MIN_HEAP_FRAME_BYTES},
6};
7use solana_sbpf::{aligned_memory::AlignedMemory, ebpf::HOST_ALIGN};
8
9trait Reset {
10    fn reset(&mut self);
11}
12
13struct Pool<T: Reset, const SIZE: usize> {
14    items: [Option<T>; SIZE],
15    next_empty: usize,
16}
17
18impl<T: Reset, const SIZE: usize> Pool<T, SIZE> {
19    fn new(items: [T; SIZE]) -> Self {
20        Self {
21            items: items.map(|i| Some(i)),
22            next_empty: SIZE,
23        }
24    }
25
26    fn len(&self) -> usize {
27        SIZE
28    }
29
30    fn get(&mut self) -> Option<T> {
31        if self.next_empty == 0 {
32            return None;
33        }
34        self.next_empty = self.next_empty.saturating_sub(1);
35        self.items
36            .get_mut(self.next_empty)
37            .and_then(|item| item.take())
38    }
39
40    fn put(&mut self, mut value: T) -> bool {
41        self.items
42            .get_mut(self.next_empty)
43            .map(|item| {
44                value.reset();
45                item.replace(value);
46                self.next_empty = self.next_empty.saturating_add(1);
47                true
48            })
49            .unwrap_or(false)
50    }
51}
52
53impl Reset for AlignedMemory<{ HOST_ALIGN }> {
54    fn reset(&mut self) {
55        self.as_slice_mut().fill(0)
56    }
57}
58
59pub struct VmMemoryPool {
60    stack: Pool<AlignedMemory<{ HOST_ALIGN }>, MAX_INSTRUCTION_STACK_DEPTH>,
61    heap: Pool<AlignedMemory<{ HOST_ALIGN }>, MAX_INSTRUCTION_STACK_DEPTH>,
62}
63
64impl VmMemoryPool {
65    pub fn new() -> Self {
66        Self {
67            stack: Pool::new(array::from_fn(|_| {
68                AlignedMemory::zero_filled(STACK_FRAME_SIZE * MAX_CALL_DEPTH)
69            })),
70            heap: Pool::new(array::from_fn(|_| {
71                AlignedMemory::zero_filled(MAX_HEAP_FRAME_BYTES as usize)
72            })),
73        }
74    }
75
76    pub fn stack_len(&self) -> usize {
77        self.stack.len()
78    }
79
80    pub fn heap_len(&self) -> usize {
81        self.heap.len()
82    }
83
84    pub fn get_stack(&mut self, size: usize) -> AlignedMemory<{ HOST_ALIGN }> {
85        debug_assert!(size == STACK_FRAME_SIZE * MAX_CALL_DEPTH);
86        self.stack
87            .get()
88            .unwrap_or_else(|| AlignedMemory::zero_filled(size))
89    }
90
91    pub fn put_stack(&mut self, stack: AlignedMemory<{ HOST_ALIGN }>) -> bool {
92        self.stack.put(stack)
93    }
94
95    pub fn get_heap(&mut self, heap_size: u32) -> AlignedMemory<{ HOST_ALIGN }> {
96        debug_assert!((MIN_HEAP_FRAME_BYTES..=MAX_HEAP_FRAME_BYTES).contains(&heap_size));
97        self.heap
98            .get()
99            .unwrap_or_else(|| AlignedMemory::zero_filled(MAX_HEAP_FRAME_BYTES as usize))
100    }
101
102    pub fn put_heap(&mut self, heap: AlignedMemory<{ HOST_ALIGN }>) -> bool {
103        let heap_size = heap.len();
104        debug_assert!(
105            heap_size >= MIN_HEAP_FRAME_BYTES as usize
106                && heap_size <= MAX_HEAP_FRAME_BYTES as usize
107        );
108        self.heap.put(heap)
109    }
110}
111
112impl Default for VmMemoryPool {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118#[cfg(test)]
119mod test {
120    use super::*;
121
122    #[derive(Debug, Eq, PartialEq)]
123    struct Item(u8, u8);
124    impl Reset for Item {
125        fn reset(&mut self) {
126            self.1 = 0;
127        }
128    }
129
130    #[test]
131    fn test_pool() {
132        let mut pool = Pool::<Item, 2>::new([Item(0, 1), Item(1, 1)]);
133        assert_eq!(pool.get(), Some(Item(1, 1)));
134        assert_eq!(pool.get(), Some(Item(0, 1)));
135        assert_eq!(pool.get(), None);
136        pool.put(Item(1, 1));
137        assert_eq!(pool.get(), Some(Item(1, 0)));
138        pool.put(Item(2, 2));
139        pool.put(Item(3, 3));
140        assert!(!pool.put(Item(4, 4)));
141        assert_eq!(pool.get(), Some(Item(3, 0)));
142        assert_eq!(pool.get(), Some(Item(2, 0)));
143        assert_eq!(pool.get(), None);
144    }
145}