1use {
2 crate::errors::{SbpfVmError, SbpfVmResult},
3 serde::{Deserialize, Serialize},
4};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum MemoryRegion {
9 Input,
10 Rodata,
11 Stack,
12 Heap,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Memory {
18 pub rodata: Vec<u8>,
19 pub stack: Vec<u8>,
20 pub heap: Vec<u8>,
21 pub input: Vec<u8>,
22 pub heap_ptr: usize,
23}
24
25impl Memory {
26 pub const RODATA_START: u64 = 0x100000000; pub const STACK_START: u64 = 0x200000000; pub const HEAP_START: u64 = 0x300000000; pub const INPUT_START: u64 = 0x400000000; pub const DEFAULT_HEAP_SIZE: usize = 32768; pub const STACK_FRAME_SIZE: u64 = 4096; pub fn new(input: Vec<u8>, rodata: Vec<u8>, stack_size: usize, heap_size: usize) -> Self {
36 Self {
37 input,
38 rodata,
39 stack: vec![0u8; stack_size],
40 heap: vec![0u8; heap_size],
41 heap_ptr: 0,
42 }
43 }
44
45 pub fn initial_frame_pointer(&self) -> u64 {
46 Self::STACK_START + Self::STACK_FRAME_SIZE
47 }
48
49 pub fn stack_size(max_call_depth: usize) -> usize {
50 Self::STACK_FRAME_SIZE as usize * max_call_depth
51 }
52
53 fn translate(&self, addr: u64) -> SbpfVmResult<(MemoryRegion, usize)> {
55 if addr >= Self::INPUT_START {
56 let offset = (addr - Self::INPUT_START) as usize;
57 if offset < self.input.len() {
58 Ok((MemoryRegion::Input, offset))
59 } else {
60 Err(SbpfVmError::MemoryOutOfBounds(addr, 0))
61 }
62 } else if addr >= Self::HEAP_START {
63 let offset = (addr - Self::HEAP_START) as usize;
64 if offset < self.heap.len() {
65 Ok((MemoryRegion::Heap, offset))
66 } else {
67 Err(SbpfVmError::MemoryOutOfBounds(addr, 0))
68 }
69 } else if addr >= Self::STACK_START {
70 let offset = (addr - Self::STACK_START) as usize;
71 if offset < self.stack.len() {
72 Ok((MemoryRegion::Stack, offset))
73 } else {
74 Err(SbpfVmError::MemoryOutOfBounds(addr, 0))
75 }
76 } else if addr >= Self::RODATA_START {
77 let offset = (addr - Self::RODATA_START) as usize;
78 if offset < self.rodata.len() {
79 Ok((MemoryRegion::Rodata, offset))
80 } else {
81 Err(SbpfVmError::MemoryOutOfBounds(addr, 0))
82 }
83 } else {
84 Err(SbpfVmError::InvalidMemoryAccess(addr))
85 }
86 }
87
88 fn get_slice(&self, region: MemoryRegion, offset: usize, len: usize) -> SbpfVmResult<&[u8]> {
89 let data = match region {
90 MemoryRegion::Input => &self.input,
91 MemoryRegion::Rodata => &self.rodata,
92 MemoryRegion::Stack => &self.stack,
93 MemoryRegion::Heap => &self.heap,
94 };
95
96 if offset + len > data.len() {
97 return Err(SbpfVmError::MemoryOutOfBounds(offset as u64, len));
98 }
99
100 Ok(&data[offset..offset + len])
101 }
102
103 fn get_slice_mut(
104 &mut self,
105 region: MemoryRegion,
106 offset: usize,
107 len: usize,
108 ) -> SbpfVmResult<&mut [u8]> {
109 if region == MemoryRegion::Rodata {
111 return Err(SbpfVmError::InvalidMemoryAccess(
112 Self::RODATA_START + offset as u64,
113 ));
114 }
115
116 let data = match region {
117 MemoryRegion::Input => &mut self.input,
118 MemoryRegion::Stack => &mut self.stack,
119 MemoryRegion::Heap => &mut self.heap,
120 MemoryRegion::Rodata => unreachable!(),
121 };
122
123 if offset + len > data.len() {
124 return Err(SbpfVmError::MemoryOutOfBounds(offset as u64, len));
125 }
126
127 Ok(&mut data[offset..offset + len])
128 }
129
130 pub fn read_u8(&self, addr: u64) -> SbpfVmResult<u8> {
131 let (region, offset) = self.translate(addr)?;
132 let slice = self.get_slice(region, offset, 1)?;
133 Ok(slice[0])
134 }
135
136 pub fn read_u16(&self, addr: u64) -> SbpfVmResult<u16> {
137 let (region, offset) = self.translate(addr)?;
138 let slice = self.get_slice(region, offset, 2)?;
139 Ok(u16::from_le_bytes([slice[0], slice[1]]))
140 }
141
142 pub fn read_u32(&self, addr: u64) -> SbpfVmResult<u32> {
143 let (region, offset) = self.translate(addr)?;
144 let slice = self.get_slice(region, offset, 4)?;
145 Ok(u32::from_le_bytes([slice[0], slice[1], slice[2], slice[3]]))
146 }
147
148 pub fn read_u64(&self, addr: u64) -> SbpfVmResult<u64> {
149 let (region, offset) = self.translate(addr)?;
150 let slice = self.get_slice(region, offset, 8)?;
151 Ok(u64::from_le_bytes([
152 slice[0], slice[1], slice[2], slice[3], slice[4], slice[5], slice[6], slice[7],
153 ]))
154 }
155
156 pub fn read_bytes(&self, addr: u64, len: usize) -> SbpfVmResult<&[u8]> {
157 let (region, offset) = self.translate(addr)?;
158 self.get_slice(region, offset, len)
159 }
160
161 pub fn write_u8(&mut self, addr: u64, value: u8) -> SbpfVmResult<()> {
162 let (region, offset) = self.translate(addr)?;
163 let slice = self.get_slice_mut(region, offset, 1)?;
164 slice[0] = value;
165 Ok(())
166 }
167
168 pub fn write_u16(&mut self, addr: u64, value: u16) -> SbpfVmResult<()> {
169 let (region, offset) = self.translate(addr)?;
170 let slice = self.get_slice_mut(region, offset, 2)?;
171 slice.copy_from_slice(&value.to_le_bytes());
172 Ok(())
173 }
174
175 pub fn write_u32(&mut self, addr: u64, value: u32) -> SbpfVmResult<()> {
176 let (region, offset) = self.translate(addr)?;
177 let slice = self.get_slice_mut(region, offset, 4)?;
178 slice.copy_from_slice(&value.to_le_bytes());
179 Ok(())
180 }
181
182 pub fn write_u64(&mut self, addr: u64, value: u64) -> SbpfVmResult<()> {
183 let (region, offset) = self.translate(addr)?;
184 let slice = self.get_slice_mut(region, offset, 8)?;
185 slice.copy_from_slice(&value.to_le_bytes());
186 Ok(())
187 }
188
189 pub fn write_i64(&mut self, addr: u64, value: i64) -> SbpfVmResult<()> {
190 let (region, offset) = self.translate(addr)?;
191 let slice = self.get_slice_mut(region, offset, 8)?;
192 slice.copy_from_slice(&value.to_le_bytes());
193 Ok(())
194 }
195
196 pub fn write_bytes(&mut self, addr: u64, bytes: &[u8]) -> SbpfVmResult<()> {
197 let (region, offset) = self.translate(addr)?;
198 let slice = self.get_slice_mut(region, offset, bytes.len())?;
199 slice.copy_from_slice(bytes);
200 Ok(())
201 }
202
203 pub fn alloc(&mut self, size: usize) -> SbpfVmResult<u64> {
204 if self.heap_ptr + size > self.heap.len() {
205 return Err(SbpfVmError::MemoryOutOfBounds(
206 Self::HEAP_START + self.heap_ptr as u64,
207 size,
208 ));
209 }
210 let addr = Self::HEAP_START + self.heap_ptr as u64;
211 self.heap_ptr += size;
212 Ok(addr)
213 }
214
215 pub fn reset_heap(&mut self) {
216 self.heap_ptr = 0;
217 self.heap.fill(0);
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_memory_regions() {
227 let input = vec![1, 2, 3, 4];
228 let rodata = vec![5, 6, 7, 8];
229 let memory = Memory::new(input, rodata, 1024, 1024);
230
231 assert_eq!(memory.read_u8(Memory::INPUT_START).unwrap(), 1);
233 assert_eq!(memory.read_u8(Memory::INPUT_START + 3).unwrap(), 4);
234
235 assert_eq!(memory.read_u8(Memory::RODATA_START).unwrap(), 5);
236 assert_eq!(memory.read_u8(Memory::RODATA_START + 3).unwrap(), 8);
237 }
238
239 #[test]
240 fn test_read_write() {
241 let mut memory = Memory::new(
242 vec![0; 16],
243 vec![0; 16],
244 Memory::STACK_FRAME_SIZE as usize,
245 1024,
246 );
247
248 let fp = memory.initial_frame_pointer();
249
250 let addr = fp - 1;
252 memory.write_u8(addr, 0x5).unwrap();
253 assert_eq!(memory.read_u8(addr).unwrap(), 0x5);
254
255 let addr = fp - 2;
257 memory.write_u16(addr, 0xabcd).unwrap();
258 assert_eq!(memory.read_u16(addr).unwrap(), 0xabcd);
259
260 let addr = fp - 4;
262 memory.write_u32(addr, 0xabcd1234).unwrap();
263 assert_eq!(memory.read_u32(addr).unwrap(), 0xabcd1234);
264
265 let addr = fp - 8;
267 memory.write_u64(addr, 0x123456789abcdef0).unwrap();
268 assert_eq!(memory.read_u64(addr).unwrap(), 0x123456789abcdef0);
269 }
270
271 #[test]
272 fn test_heap_allocation() {
273 let mut memory = Memory::new(vec![], vec![], 1024, 1024);
274
275 let addr1 = memory.alloc(64).unwrap();
276 assert_eq!(addr1, Memory::HEAP_START);
277
278 let addr2 = memory.alloc(128).unwrap();
279 assert_eq!(addr2, Memory::HEAP_START + 64);
280
281 memory.write_u64(addr1, 0x12345678).unwrap();
282 assert_eq!(memory.read_u64(addr1).unwrap(), 0x12345678);
283 }
284
285 #[test]
286 fn test_rodata_readonly() {
287 let mut memory = Memory::new(vec![], vec![1, 2, 3, 4], 1024, 1024);
288
289 let result = memory.write_u8(Memory::RODATA_START, 12);
291 assert!(result.is_err());
292 }
293}