light_heap/
lib.rs

1use std::{alloc::Layout, mem::size_of, ptr::null_mut};
2pub mod bench;
3
4#[cfg(target_os = "solana")]
5use anchor_lang::{
6    prelude::*,
7    solana_program::entrypoint::{HEAP_LENGTH, HEAP_START_ADDRESS},
8};
9
10#[cfg(target_os = "solana")]
11#[global_allocator]
12pub static GLOBAL_ALLOCATOR: BumpAllocator = BumpAllocator {
13    start: HEAP_START_ADDRESS as usize,
14    len: HEAP_LENGTH,
15};
16
17#[cfg(target_os = "solana")]
18#[error_code]
19pub enum HeapError {
20    #[msg("The provided position to free is invalid.")]
21    InvalidHeapPos,
22}
23pub struct BumpAllocator {
24    pub start: usize,
25    pub len: usize,
26}
27
28impl BumpAllocator {
29    const RESERVED_MEM: usize = size_of::<*mut u8>();
30
31    #[cfg(target_os = "solana")]
32    pub fn new() -> Self {
33        Self {
34            start: HEAP_START_ADDRESS as usize,
35            len: HEAP_LENGTH,
36        }
37    }
38
39    /// Returns the current position of the heap.
40    ///
41    /// # Safety
42    ///
43    /// This function is unsafe because it returns a raw pointer.
44    pub unsafe fn pos(&self) -> usize {
45        let pos_ptr = self.start as *mut usize;
46        *pos_ptr
47    }
48
49    /// Reset heap start cursor to position.
50    ///
51    /// # Safety
52    ///
53    /// Do not use this function if you initialized heap memory after pos which you still need.
54    pub unsafe fn move_cursor(&self, pos: usize) {
55        let pos_ptr = self.start as *mut usize;
56        *pos_ptr = pos;
57    }
58
59    #[cfg(target_os = "solana")]
60    pub fn log_total_heap(&self, msg: &str) -> u64 {
61        const HEAP_END_ADDRESS: u64 = HEAP_START_ADDRESS as u64 + HEAP_LENGTH as u64;
62
63        let heap_start = unsafe { self.pos() } as u64;
64        let heap_used = HEAP_END_ADDRESS - heap_start;
65        msg!("{}: total heap used: {}", msg, heap_used);
66        heap_used
67    }
68
69    #[cfg(target_os = "solana")]
70    pub fn get_heap_pos(&self) -> usize {
71        let heap_start = unsafe { self.pos() } as usize;
72        heap_start
73    }
74
75    #[cfg(target_os = "solana")]
76    pub fn free_heap(&self, pos: usize) -> Result<()> {
77        if pos < self.start + BumpAllocator::RESERVED_MEM || pos > self.start + self.len {
78            return err!(HeapError::InvalidHeapPos);
79        }
80
81        unsafe { self.move_cursor(pos) };
82        Ok(())
83    }
84}
85
86unsafe impl std::alloc::GlobalAlloc for BumpAllocator {
87    #[inline]
88    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
89        let pos_ptr = self.start as *mut usize;
90
91        let mut pos = *pos_ptr;
92        if pos == 0 {
93            // First time, set starting position
94            pos = self.start + self.len;
95        }
96        pos = pos.saturating_sub(layout.size());
97        pos &= !(layout.align().wrapping_sub(1));
98        if pos < self.start + BumpAllocator::RESERVED_MEM {
99            return null_mut();
100        }
101        *pos_ptr = pos;
102        pos as *mut u8
103    }
104    #[inline]
105    unsafe fn dealloc(&self, _: *mut u8, _: Layout) {
106        // no dellaoc in Solana runtime :*(
107    }
108}
109
110#[cfg(test)]
111mod test {
112    use std::{
113        alloc::{GlobalAlloc, Layout},
114        mem::size_of,
115        ptr::null_mut,
116    };
117
118    use super::*;
119
120    #[test]
121    fn test_pos_move_cursor_heap() {
122        use std::mem::size_of;
123
124        {
125            let heap = [0u8; 128];
126            let allocator = BumpAllocator {
127                start: heap.as_ptr() as *const _ as usize,
128                len: heap.len(),
129            };
130            let pos = unsafe { allocator.pos() };
131            assert_eq!(pos, unsafe { allocator.pos() });
132            assert_eq!(pos, 0);
133            let mut pos_64 = 0;
134            for i in 0..128 - size_of::<*mut u8>() {
135                if i == 64 {
136                    pos_64 = unsafe { allocator.pos() };
137                }
138                let ptr = unsafe {
139                    allocator.alloc(Layout::from_size_align(1, size_of::<u8>()).unwrap())
140                };
141                assert_eq!(
142                    ptr as *const _ as usize,
143                    heap.as_ptr() as *const _ as usize + heap.len() - 1 - i
144                );
145                assert_eq!(ptr as *const _ as usize, unsafe { allocator.pos() });
146            }
147            let pos_128 = unsafe { allocator.pos() };
148            // free half of the heap
149            unsafe { allocator.move_cursor(pos_64) };
150            assert_eq!(pos_64, unsafe { allocator.pos() });
151            assert_ne!(pos_64 + 1, unsafe { allocator.pos() });
152            // allocate second half of the heap again
153            for i in 0..64 - size_of::<*mut u8>() {
154                let ptr = unsafe {
155                    allocator.alloc(Layout::from_size_align(1, size_of::<u8>()).unwrap())
156                };
157                assert_eq!(
158                    ptr as *const _ as usize,
159                    heap.as_ptr() as *const _ as usize + heap.len() - 1 - (i + 64)
160                );
161                assert_eq!(ptr as *const _ as usize, unsafe { allocator.pos() });
162            }
163            assert_eq!(pos_128, unsafe { allocator.pos() });
164            // free all of the heap
165            unsafe { allocator.move_cursor(pos) };
166            assert_eq!(pos, unsafe { allocator.pos() });
167            assert_ne!(pos + 1, unsafe { allocator.pos() });
168        }
169    }
170
171    /// taken from solana-program https://github.com/solana-labs/solana/blob/9a520fd5b42bafefa4815afe3e5390b4ea7482ca/sdk/program/src/entrypoint.rs#L374
172    #[test]
173    fn test_bump_allocator() {
174        // alloc the entire
175        {
176            let heap = [0u8; 128];
177            let allocator = BumpAllocator {
178                start: heap.as_ptr() as *const _ as usize,
179                len: heap.len(),
180            };
181            for i in 0..128 - size_of::<*mut u8>() {
182                let ptr = unsafe {
183                    allocator.alloc(Layout::from_size_align(1, size_of::<u8>()).unwrap())
184                };
185                assert_eq!(
186                    ptr as *const _ as usize,
187                    heap.as_ptr() as *const _ as usize + heap.len() - 1 - i
188                );
189            }
190            assert_eq!(null_mut(), unsafe {
191                allocator.alloc(Layout::from_size_align(1, 1).unwrap())
192            });
193        }
194        // check alignment
195        {
196            let heap = [0u8; 128];
197            let allocator = BumpAllocator {
198                start: heap.as_ptr() as *const _ as usize,
199                len: heap.len(),
200            };
201            let ptr =
202                unsafe { allocator.alloc(Layout::from_size_align(1, size_of::<u8>()).unwrap()) };
203            assert_eq!(0, ptr.align_offset(size_of::<u8>()));
204            let ptr =
205                unsafe { allocator.alloc(Layout::from_size_align(1, size_of::<u16>()).unwrap()) };
206            assert_eq!(0, ptr.align_offset(size_of::<u16>()));
207            let ptr =
208                unsafe { allocator.alloc(Layout::from_size_align(1, size_of::<u32>()).unwrap()) };
209            assert_eq!(0, ptr.align_offset(size_of::<u32>()));
210            let ptr =
211                unsafe { allocator.alloc(Layout::from_size_align(1, size_of::<u64>()).unwrap()) };
212            assert_eq!(0, ptr.align_offset(size_of::<u64>()));
213            let ptr =
214                unsafe { allocator.alloc(Layout::from_size_align(1, size_of::<u128>()).unwrap()) };
215            assert_eq!(0, ptr.align_offset(size_of::<u128>()));
216            let ptr = unsafe { allocator.alloc(Layout::from_size_align(1, 64).unwrap()) };
217            assert_eq!(0, ptr.align_offset(64));
218        }
219        // alloc entire block (minus the pos ptr)
220        {
221            let heap = [0u8; 128];
222            let allocator = BumpAllocator {
223                start: heap.as_ptr() as *const _ as usize,
224                len: heap.len(),
225            };
226            let ptr =
227                unsafe { allocator.alloc(Layout::from_size_align(120, size_of::<u8>()).unwrap()) };
228            assert_ne!(ptr, null_mut());
229            assert_eq!(0, ptr.align_offset(size_of::<u64>()));
230        }
231    }
232}