Skip to main content

sbpf_vm/
memory.rs

1use {
2    crate::errors::{SbpfVmError, SbpfVmResult},
3    serde::{Deserialize, Serialize},
4};
5
6/// Memory region
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum MemoryRegion {
9    Input,
10    Rodata,
11    Stack,
12    Heap,
13}
14
15/// Memory layout
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Memory {
18    pub rodata: Vec<u8>,
19    pub stack: Vec<u8>,
20    pub heap: Vec<u8>,
21    pub input: Vec<u8>,
22    pub heap_ptr: usize,
23}
24
25impl Memory {
26    // Virtual address memory map
27    pub const RODATA_START: u64 = 0x100000000; // Read-only data (rodata)
28    pub const STACK_START: u64 = 0x200000000; // Stack data
29    pub const HEAP_START: u64 = 0x300000000; // Heap data
30    pub const INPUT_START: u64 = 0x400000000; // Program input parameters
31
32    pub const DEFAULT_HEAP_SIZE: usize = 32768; // 32KB
33    pub const STACK_FRAME_SIZE: u64 = 4096; // 4KB
34
35    pub fn new(input: Vec<u8>, rodata: Vec<u8>, stack_size: usize, heap_size: usize) -> Self {
36        Self {
37            input,
38            rodata,
39            stack: vec![0u8; stack_size],
40            heap: vec![0u8; heap_size],
41            heap_ptr: 0,
42        }
43    }
44
45    pub fn initial_frame_pointer(&self) -> u64 {
46        Self::STACK_START + Self::STACK_FRAME_SIZE
47    }
48
49    pub fn stack_size(max_call_depth: usize) -> usize {
50        Self::STACK_FRAME_SIZE as usize * max_call_depth
51    }
52
53    // Translate virtual address to region and offset
54    fn translate(&self, addr: u64) -> SbpfVmResult<(MemoryRegion, usize)> {
55        if addr >= Self::INPUT_START {
56            let offset = (addr - Self::INPUT_START) as usize;
57            if offset < self.input.len() {
58                Ok((MemoryRegion::Input, offset))
59            } else {
60                Err(SbpfVmError::MemoryOutOfBounds(addr, 0))
61            }
62        } else if addr >= Self::HEAP_START {
63            let offset = (addr - Self::HEAP_START) as usize;
64            if offset < self.heap.len() {
65                Ok((MemoryRegion::Heap, offset))
66            } else {
67                Err(SbpfVmError::MemoryOutOfBounds(addr, 0))
68            }
69        } else if addr >= Self::STACK_START {
70            let offset = (addr - Self::STACK_START) as usize;
71            if offset < self.stack.len() {
72                Ok((MemoryRegion::Stack, offset))
73            } else {
74                Err(SbpfVmError::MemoryOutOfBounds(addr, 0))
75            }
76        } else if addr >= Self::RODATA_START {
77            let offset = (addr - Self::RODATA_START) as usize;
78            if offset < self.rodata.len() {
79                Ok((MemoryRegion::Rodata, offset))
80            } else {
81                Err(SbpfVmError::MemoryOutOfBounds(addr, 0))
82            }
83        } else {
84            Err(SbpfVmError::InvalidMemoryAccess(addr))
85        }
86    }
87
88    fn get_slice(&self, region: MemoryRegion, offset: usize, len: usize) -> SbpfVmResult<&[u8]> {
89        let data = match region {
90            MemoryRegion::Input => &self.input,
91            MemoryRegion::Rodata => &self.rodata,
92            MemoryRegion::Stack => &self.stack,
93            MemoryRegion::Heap => &self.heap,
94        };
95
96        if offset + len > data.len() {
97            return Err(SbpfVmError::MemoryOutOfBounds(offset as u64, len));
98        }
99
100        Ok(&data[offset..offset + len])
101    }
102
103    fn get_slice_mut(
104        &mut self,
105        region: MemoryRegion,
106        offset: usize,
107        len: usize,
108    ) -> SbpfVmResult<&mut [u8]> {
109        // Rodata region is read-only
110        if region == MemoryRegion::Rodata {
111            return Err(SbpfVmError::InvalidMemoryAccess(
112                Self::RODATA_START + offset as u64,
113            ));
114        }
115
116        let data = match region {
117            MemoryRegion::Input => &mut self.input,
118            MemoryRegion::Stack => &mut self.stack,
119            MemoryRegion::Heap => &mut self.heap,
120            MemoryRegion::Rodata => unreachable!(),
121        };
122
123        if offset + len > data.len() {
124            return Err(SbpfVmError::MemoryOutOfBounds(offset as u64, len));
125        }
126
127        Ok(&mut data[offset..offset + len])
128    }
129
130    pub fn read_u8(&self, addr: u64) -> SbpfVmResult<u8> {
131        let (region, offset) = self.translate(addr)?;
132        let slice = self.get_slice(region, offset, 1)?;
133        Ok(slice[0])
134    }
135
136    pub fn read_u16(&self, addr: u64) -> SbpfVmResult<u16> {
137        let (region, offset) = self.translate(addr)?;
138        let slice = self.get_slice(region, offset, 2)?;
139        Ok(u16::from_le_bytes([slice[0], slice[1]]))
140    }
141
142    pub fn read_u32(&self, addr: u64) -> SbpfVmResult<u32> {
143        let (region, offset) = self.translate(addr)?;
144        let slice = self.get_slice(region, offset, 4)?;
145        Ok(u32::from_le_bytes([slice[0], slice[1], slice[2], slice[3]]))
146    }
147
148    pub fn read_u64(&self, addr: u64) -> SbpfVmResult<u64> {
149        let (region, offset) = self.translate(addr)?;
150        let slice = self.get_slice(region, offset, 8)?;
151        Ok(u64::from_le_bytes([
152            slice[0], slice[1], slice[2], slice[3], slice[4], slice[5], slice[6], slice[7],
153        ]))
154    }
155
156    pub fn read_bytes(&self, addr: u64, len: usize) -> SbpfVmResult<&[u8]> {
157        let (region, offset) = self.translate(addr)?;
158        self.get_slice(region, offset, len)
159    }
160
161    pub fn write_u8(&mut self, addr: u64, value: u8) -> SbpfVmResult<()> {
162        let (region, offset) = self.translate(addr)?;
163        let slice = self.get_slice_mut(region, offset, 1)?;
164        slice[0] = value;
165        Ok(())
166    }
167
168    pub fn write_u16(&mut self, addr: u64, value: u16) -> SbpfVmResult<()> {
169        let (region, offset) = self.translate(addr)?;
170        let slice = self.get_slice_mut(region, offset, 2)?;
171        slice.copy_from_slice(&value.to_le_bytes());
172        Ok(())
173    }
174
175    pub fn write_u32(&mut self, addr: u64, value: u32) -> SbpfVmResult<()> {
176        let (region, offset) = self.translate(addr)?;
177        let slice = self.get_slice_mut(region, offset, 4)?;
178        slice.copy_from_slice(&value.to_le_bytes());
179        Ok(())
180    }
181
182    pub fn write_u64(&mut self, addr: u64, value: u64) -> SbpfVmResult<()> {
183        let (region, offset) = self.translate(addr)?;
184        let slice = self.get_slice_mut(region, offset, 8)?;
185        slice.copy_from_slice(&value.to_le_bytes());
186        Ok(())
187    }
188
189    pub fn write_i64(&mut self, addr: u64, value: i64) -> SbpfVmResult<()> {
190        let (region, offset) = self.translate(addr)?;
191        let slice = self.get_slice_mut(region, offset, 8)?;
192        slice.copy_from_slice(&value.to_le_bytes());
193        Ok(())
194    }
195
196    pub fn write_bytes(&mut self, addr: u64, bytes: &[u8]) -> SbpfVmResult<()> {
197        let (region, offset) = self.translate(addr)?;
198        let slice = self.get_slice_mut(region, offset, bytes.len())?;
199        slice.copy_from_slice(bytes);
200        Ok(())
201    }
202
203    pub fn alloc(&mut self, size: usize) -> SbpfVmResult<u64> {
204        if self.heap_ptr + size > self.heap.len() {
205            return Err(SbpfVmError::MemoryOutOfBounds(
206                Self::HEAP_START + self.heap_ptr as u64,
207                size,
208            ));
209        }
210        let addr = Self::HEAP_START + self.heap_ptr as u64;
211        self.heap_ptr += size;
212        Ok(addr)
213    }
214
215    pub fn reset_heap(&mut self) {
216        self.heap_ptr = 0;
217        self.heap.fill(0);
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_memory_regions() {
227        let input = vec![1, 2, 3, 4];
228        let rodata = vec![5, 6, 7, 8];
229        let memory = Memory::new(input, rodata, 1024, 1024);
230
231        // Test input and rodata region
232        assert_eq!(memory.read_u8(Memory::INPUT_START).unwrap(), 1);
233        assert_eq!(memory.read_u8(Memory::INPUT_START + 3).unwrap(), 4);
234
235        assert_eq!(memory.read_u8(Memory::RODATA_START).unwrap(), 5);
236        assert_eq!(memory.read_u8(Memory::RODATA_START + 3).unwrap(), 8);
237    }
238
239    #[test]
240    fn test_read_write() {
241        let mut memory = Memory::new(
242            vec![0; 16],
243            vec![0; 16],
244            Memory::STACK_FRAME_SIZE as usize,
245            1024,
246        );
247
248        let fp = memory.initial_frame_pointer();
249
250        // Write and read u8
251        let addr = fp - 1;
252        memory.write_u8(addr, 0x5).unwrap();
253        assert_eq!(memory.read_u8(addr).unwrap(), 0x5);
254
255        // Write and read u16
256        let addr = fp - 2;
257        memory.write_u16(addr, 0xabcd).unwrap();
258        assert_eq!(memory.read_u16(addr).unwrap(), 0xabcd);
259
260        // Write and read u32
261        let addr = fp - 4;
262        memory.write_u32(addr, 0xabcd1234).unwrap();
263        assert_eq!(memory.read_u32(addr).unwrap(), 0xabcd1234);
264
265        // Write and read u64
266        let addr = fp - 8;
267        memory.write_u64(addr, 0x123456789abcdef0).unwrap();
268        assert_eq!(memory.read_u64(addr).unwrap(), 0x123456789abcdef0);
269    }
270
271    #[test]
272    fn test_heap_allocation() {
273        let mut memory = Memory::new(vec![], vec![], 1024, 1024);
274
275        let addr1 = memory.alloc(64).unwrap();
276        assert_eq!(addr1, Memory::HEAP_START);
277
278        let addr2 = memory.alloc(128).unwrap();
279        assert_eq!(addr2, Memory::HEAP_START + 64);
280
281        memory.write_u64(addr1, 0x12345678).unwrap();
282        assert_eq!(memory.read_u64(addr1).unwrap(), 0x12345678);
283    }
284
285    #[test]
286    fn test_rodata_readonly() {
287        let mut memory = Memory::new(vec![], vec![1, 2, 3, 4], 1024, 1024);
288
289        // should fail to write to read-only region
290        let result = memory.write_u8(Memory::RODATA_START, 12);
291        assert!(result.is_err());
292    }
293}