1use std::{alloc::Layout, mem::size_of, ptr::null_mut};
2pub mod bench;
3
4#[cfg(target_os = "solana")]
5use anchor_lang::{
6 prelude::*,
7 solana_program::entrypoint::{HEAP_LENGTH, HEAP_START_ADDRESS},
8};
9
10#[cfg(target_os = "solana")]
11#[global_allocator]
12pub static GLOBAL_ALLOCATOR: BumpAllocator = BumpAllocator {
13 start: HEAP_START_ADDRESS as usize,
14 len: HEAP_LENGTH,
15};
16
17#[cfg(target_os = "solana")]
18#[error_code]
19pub enum HeapError {
20 #[msg("The provided position to free is invalid.")]
21 InvalidHeapPos,
22}
23pub struct BumpAllocator {
24 pub start: usize,
25 pub len: usize,
26}
27
28impl BumpAllocator {
29 const RESERVED_MEM: usize = size_of::<*mut u8>();
30
31 #[cfg(target_os = "solana")]
32 pub fn new() -> Self {
33 Self {
34 start: HEAP_START_ADDRESS as usize,
35 len: HEAP_LENGTH,
36 }
37 }
38
39 pub unsafe fn pos(&self) -> usize {
45 let pos_ptr = self.start as *mut usize;
46 *pos_ptr
47 }
48
49 pub unsafe fn move_cursor(&self, pos: usize) {
55 let pos_ptr = self.start as *mut usize;
56 *pos_ptr = pos;
57 }
58
59 #[cfg(target_os = "solana")]
60 pub fn log_total_heap(&self, msg: &str) -> u64 {
61 const HEAP_END_ADDRESS: u64 = HEAP_START_ADDRESS as u64 + HEAP_LENGTH as u64;
62
63 let heap_start = unsafe { self.pos() } as u64;
64 let heap_used = HEAP_END_ADDRESS - heap_start;
65 msg!("{}: total heap used: {}", msg, heap_used);
66 heap_used
67 }
68
69 #[cfg(target_os = "solana")]
70 pub fn get_heap_pos(&self) -> usize {
71 let heap_start = unsafe { self.pos() } as usize;
72 heap_start
73 }
74
75 #[cfg(target_os = "solana")]
76 pub fn free_heap(&self, pos: usize) -> Result<()> {
77 if pos < self.start + BumpAllocator::RESERVED_MEM || pos > self.start + self.len {
78 return err!(HeapError::InvalidHeapPos);
79 }
80
81 unsafe { self.move_cursor(pos) };
82 Ok(())
83 }
84}
85
86unsafe impl std::alloc::GlobalAlloc for BumpAllocator {
87 #[inline]
88 unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
89 let pos_ptr = self.start as *mut usize;
90
91 let mut pos = *pos_ptr;
92 if pos == 0 {
93 pos = self.start + self.len;
95 }
96 pos = pos.saturating_sub(layout.size());
97 pos &= !(layout.align().wrapping_sub(1));
98 if pos < self.start + BumpAllocator::RESERVED_MEM {
99 return null_mut();
100 }
101 *pos_ptr = pos;
102 pos as *mut u8
103 }
104 #[inline]
105 unsafe fn dealloc(&self, _: *mut u8, _: Layout) {
106 }
108}
109
110#[cfg(test)]
111mod test {
112 use std::{
113 alloc::{GlobalAlloc, Layout},
114 mem::size_of,
115 ptr::null_mut,
116 };
117
118 use super::*;
119
120 #[test]
121 fn test_pos_move_cursor_heap() {
122 use std::mem::size_of;
123
124 {
125 let heap = [0u8; 128];
126 let allocator = BumpAllocator {
127 start: heap.as_ptr() as *const _ as usize,
128 len: heap.len(),
129 };
130 let pos = unsafe { allocator.pos() };
131 assert_eq!(pos, unsafe { allocator.pos() });
132 assert_eq!(pos, 0);
133 let mut pos_64 = 0;
134 for i in 0..128 - size_of::<*mut u8>() {
135 if i == 64 {
136 pos_64 = unsafe { allocator.pos() };
137 }
138 let ptr = unsafe {
139 allocator.alloc(Layout::from_size_align(1, size_of::<u8>()).unwrap())
140 };
141 assert_eq!(
142 ptr as *const _ as usize,
143 heap.as_ptr() as *const _ as usize + heap.len() - 1 - i
144 );
145 assert_eq!(ptr as *const _ as usize, unsafe { allocator.pos() });
146 }
147 let pos_128 = unsafe { allocator.pos() };
148 unsafe { allocator.move_cursor(pos_64) };
150 assert_eq!(pos_64, unsafe { allocator.pos() });
151 assert_ne!(pos_64 + 1, unsafe { allocator.pos() });
152 for i in 0..64 - size_of::<*mut u8>() {
154 let ptr = unsafe {
155 allocator.alloc(Layout::from_size_align(1, size_of::<u8>()).unwrap())
156 };
157 assert_eq!(
158 ptr as *const _ as usize,
159 heap.as_ptr() as *const _ as usize + heap.len() - 1 - (i + 64)
160 );
161 assert_eq!(ptr as *const _ as usize, unsafe { allocator.pos() });
162 }
163 assert_eq!(pos_128, unsafe { allocator.pos() });
164 unsafe { allocator.move_cursor(pos) };
166 assert_eq!(pos, unsafe { allocator.pos() });
167 assert_ne!(pos + 1, unsafe { allocator.pos() });
168 }
169 }
170
171 #[test]
173 fn test_bump_allocator() {
174 {
176 let heap = [0u8; 128];
177 let allocator = BumpAllocator {
178 start: heap.as_ptr() as *const _ as usize,
179 len: heap.len(),
180 };
181 for i in 0..128 - size_of::<*mut u8>() {
182 let ptr = unsafe {
183 allocator.alloc(Layout::from_size_align(1, size_of::<u8>()).unwrap())
184 };
185 assert_eq!(
186 ptr as *const _ as usize,
187 heap.as_ptr() as *const _ as usize + heap.len() - 1 - i
188 );
189 }
190 assert_eq!(null_mut(), unsafe {
191 allocator.alloc(Layout::from_size_align(1, 1).unwrap())
192 });
193 }
194 {
196 let heap = [0u8; 128];
197 let allocator = BumpAllocator {
198 start: heap.as_ptr() as *const _ as usize,
199 len: heap.len(),
200 };
201 let ptr =
202 unsafe { allocator.alloc(Layout::from_size_align(1, size_of::<u8>()).unwrap()) };
203 assert_eq!(0, ptr.align_offset(size_of::<u8>()));
204 let ptr =
205 unsafe { allocator.alloc(Layout::from_size_align(1, size_of::<u16>()).unwrap()) };
206 assert_eq!(0, ptr.align_offset(size_of::<u16>()));
207 let ptr =
208 unsafe { allocator.alloc(Layout::from_size_align(1, size_of::<u32>()).unwrap()) };
209 assert_eq!(0, ptr.align_offset(size_of::<u32>()));
210 let ptr =
211 unsafe { allocator.alloc(Layout::from_size_align(1, size_of::<u64>()).unwrap()) };
212 assert_eq!(0, ptr.align_offset(size_of::<u64>()));
213 let ptr =
214 unsafe { allocator.alloc(Layout::from_size_align(1, size_of::<u128>()).unwrap()) };
215 assert_eq!(0, ptr.align_offset(size_of::<u128>()));
216 let ptr = unsafe { allocator.alloc(Layout::from_size_align(1, 64).unwrap()) };
217 assert_eq!(0, ptr.align_offset(64));
218 }
219 {
221 let heap = [0u8; 128];
222 let allocator = BumpAllocator {
223 start: heap.as_ptr() as *const _ as usize,
224 len: heap.len(),
225 };
226 let ptr =
227 unsafe { allocator.alloc(Layout::from_size_align(120, size_of::<u8>()).unwrap()) };
228 assert_ne!(ptr, null_mut());
229 assert_eq!(0, ptr.align_offset(size_of::<u64>()));
230 }
231 }
232}