kaiju_vm_core/
state.rs

1use core::error::*;
2use itertools::Itertools;
3use std::fmt;
4use std::mem::size_of;
5use std::ptr::copy_nonoverlapping;
6
7#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
8pub struct Value {
9    pub address: usize,
10    pub size: usize,
11}
12
13impl Value {
14    pub fn new(address: usize, size: usize) -> Self {
15        Self { address, size }
16    }
17}
18
19#[derive(Clone)]
20pub struct State {
21    bytes: Vec<u8>,
22    memory_size: usize,
23    stack_size: usize,
24    memory_free: Vec<(usize, usize)>,
25    stack_pos: usize,
26}
27
28impl fmt::Debug for State {
29    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30        f.debug_struct("State")
31            .field("bytes", &format!("[...; {}]", self.bytes.len()))
32            .field("memory_size", &self.memory_size)
33            .field("stack_size", &self.stack_size)
34            .field("memory_free", &self.memory_free)
35            .field("stack_pos", &self.stack_pos)
36            .finish()
37    }
38}
39
40impl State {
41    pub fn new(stack_size: usize, memory_size: usize) -> Self {
42        Self {
43            bytes: vec![0; stack_size + memory_size],
44            stack_size,
45            memory_size,
46            stack_pos: 0,
47            memory_free: vec![(0, memory_size)],
48        }
49    }
50
51    #[inline]
52    pub fn stack_size(&self) -> usize {
53        self.stack_size
54    }
55
56    #[inline]
57    pub fn memory_size(&self) -> usize {
58        self.memory_size
59    }
60
61    #[inline]
62    pub fn all_size(&self) -> usize {
63        self.stack_size + self.memory_size
64    }
65
66    #[inline]
67    pub fn stack_pos(&self) -> usize {
68        self.stack_pos
69    }
70
71    #[inline]
72    pub fn stack_free(&self) -> usize {
73        self.stack_size - self.stack_pos
74    }
75
76    #[inline]
77    pub fn memory_free(&self) -> usize {
78        self.memory_free.iter().map(|(_, c)| c).sum()
79    }
80
81    #[inline]
82    pub fn all_free(&self) -> usize {
83        self.stack_free() + self.memory_free()
84    }
85
86    pub fn stack_push_data<T>(&mut self, value: &T) -> SimpleResult<Value> {
87        let size = size_of::<T>();
88        if self.stack_pos + size > self.stack_size {
89            Err(SimpleError::new(format!(
90                "Stack overflow while trying to push {} bytes",
91                size
92            )))
93        } else {
94            unsafe {
95                let dp = self.bytes.as_mut_ptr().add(self.stack_pos);
96                let sp = value as *const T as *const u8;
97                copy_nonoverlapping(sp, dp, size);
98            }
99            self.stack_pos += size;
100            Ok(Value::new(self.stack_pos - size, size))
101        }
102    }
103
104    pub fn stack_push_bytes(&mut self, source: &[u8]) -> SimpleResult<Value> {
105        if self.stack_pos + source.len() > self.stack_size {
106            Err(SimpleError::new(format!(
107                "Stack overflow while trying to push {} bytes",
108                source.len()
109            )))
110        } else {
111            unsafe {
112                let dp = self.bytes.as_mut_ptr().add(self.stack_pos);
113                let sp = source.as_ptr();
114                copy_nonoverlapping(sp, dp, source.len());
115            }
116            self.stack_pos += source.len();
117            Ok(Value::new(self.stack_pos - source.len(), source.len()))
118        }
119    }
120
121    pub fn stack_push_move(&mut self, source: usize, size: usize) -> SimpleResult<Value> {
122        if source + size > self.stack_size + self.memory_size {
123            Err(SimpleError::new(format!(
124                "Trying to push move {} bytes from outside of memory",
125                size
126            )))
127        } else if self.stack_pos + size > self.stack_size {
128            Err(SimpleError::new(format!(
129                "Stack overflow while trying to push {} bytes",
130                size
131            )))
132        } else if source + size > self.stack_pos && source < self.stack_pos + size {
133            Err(SimpleError::new(format!(
134                "Trying to push {} bytes in same memory fragment",
135                size
136            )))
137        } else {
138            unsafe {
139                let dp = self.bytes.as_mut_ptr().add(self.stack_pos);
140                let sp = self.bytes.as_ptr().add(source);
141                copy_nonoverlapping(sp, dp, size);
142            }
143            self.stack_pos += size;
144            Ok(Value::new(self.stack_pos - size, size))
145        }
146    }
147
148    pub fn stack_pop_bytes(&mut self, size: usize) -> SimpleResult<Vec<u8>> {
149        if size > self.stack_pos {
150            Err(SimpleError::new(format!(
151                "Stack underflow while trying to pop {} bytes",
152                size
153            )))
154        } else {
155            self.stack_pos -= size;
156            Ok(self.bytes[self.stack_pos..self.stack_pos + size].to_vec())
157        }
158    }
159
160    pub fn stack_pop_data<T: Default>(&mut self) -> SimpleResult<T> {
161        let size = size_of::<T>();
162        if size > self.stack_pos {
163            Err(SimpleError::new(format!(
164                "Stack underflow while trying to pop {} bytes",
165                size
166            )))
167        } else {
168            unsafe {
169                self.stack_pos -= size;
170                let sp = self.bytes.as_ptr().add(self.stack_pos);
171                let mut value = T::default();
172                let dp = &mut value as *mut T as *mut u8;
173                copy_nonoverlapping(sp, dp, size);
174                Ok(value)
175            }
176        }
177    }
178
179    pub fn stack_pop_move(&mut self, destination: usize, size: usize) -> SimpleResult<()> {
180        if destination + size > self.stack_size + self.memory_size {
181            Err(SimpleError::new(format!(
182                "Trying to pop move {} bytes to outside of memory",
183                size
184            )))
185        } else if size > self.stack_pos {
186            Err(SimpleError::new(format!(
187                "Stack overflow while trying to pop {} bytes",
188                size
189            )))
190        } else if destination + size > self.stack_pos - size && destination < self.stack_pos {
191            Err(SimpleError::new(format!(
192                "Trying to pop {} bytes in same memory fragment",
193                size
194            )))
195        } else {
196            self.stack_pos -= size;
197            unsafe {
198                let dp = self.bytes.as_mut_ptr().add(destination);
199                let sp = self.bytes.as_ptr().add(self.stack_pos);
200                copy_nonoverlapping(sp, dp, size);
201            }
202            Ok(())
203        }
204    }
205
206    pub fn stack_reset(&mut self, position: usize) -> SimpleResult<()> {
207        if position >= self.stack_size {
208            Err(SimpleError::new(format!(
209                "Stack overflow while trying to reset to position {}",
210                position
211            )))
212        } else {
213            self.stack_pos = position;
214            Ok(())
215        }
216    }
217
218    pub fn memory_move(
219        &mut self,
220        source: usize,
221        size: usize,
222        destination: usize,
223    ) -> SimpleResult<()> {
224        if source + size > self.stack_size + self.memory_size {
225            Err(SimpleError::new(format!(
226                "Trying to move {} bytes from outside of memory",
227                size
228            )))
229        } else if destination + size > self.stack_size + self.memory_size {
230            Err(SimpleError::new(format!(
231                "Trying to move {} bytes to outside of memory",
232                size
233            )))
234        } else {
235            unsafe {
236                let dp = self.bytes.as_mut_ptr().add(destination);
237                let sp = self.bytes.as_ptr().add(source);
238                copy_nonoverlapping(sp, dp, size);
239            }
240            Ok(())
241        }
242    }
243
244    pub fn store_data<T>(&mut self, destination: usize, value: &T) -> SimpleResult<()> {
245        let size = size_of::<T>();
246        if destination + size > self.stack_size + self.memory_size {
247            Err(SimpleError::new(format!(
248                "Trying to store {} bytes to outside of memory",
249                size
250            )))
251        } else {
252            unsafe {
253                let dp = self.bytes.as_mut_ptr().add(destination);
254                let sp = value as *const T as *const u8;
255                copy_nonoverlapping(sp, dp, size);
256            }
257            Ok(())
258        }
259    }
260
261    pub fn store_bytes(&mut self, destination: usize, value: &[u8]) -> SimpleResult<()> {
262        let size = value.len();
263        if destination + size > self.stack_size + self.memory_size {
264            Err(SimpleError::new(format!(
265                "Trying to store {} bytes to outside of memory",
266                size
267            )))
268        } else {
269            unsafe {
270                let dp = self.bytes.as_mut_ptr().add(destination);
271                let sp = value.as_ptr();
272                copy_nonoverlapping(sp, dp, size);
273            }
274            Ok(())
275        }
276    }
277
278    pub fn load_data<T: Default>(&self, source: usize) -> SimpleResult<T> {
279        let size = size_of::<T>();
280        if source + size > self.stack_size + self.memory_size {
281            Err(SimpleError::new(format!(
282                "Trying to load {} bytes from outside of memory",
283                size
284            )))
285        } else {
286            unsafe {
287                let sp = self.bytes.as_ptr().add(source);
288                let mut value = T::default();
289                let dp = &mut value as *mut T as *mut u8;
290                copy_nonoverlapping(sp, dp, size);
291                Ok(value)
292            }
293        }
294    }
295
296    pub fn load_bytes(&self, source: usize, size: usize) -> SimpleResult<Vec<u8>> {
297        if source + size > self.stack_size + self.memory_size {
298            Err(SimpleError::new(format!(
299                "Trying to load {} bytes from outside of memory",
300                size
301            )))
302        } else {
303            Ok(self.bytes[source..source + size].to_vec())
304        }
305    }
306
307    pub fn load_bytes_while<P>(&self, source: usize, mut predicate: P) -> Vec<u8>
308    where
309        P: FnMut(u8) -> bool,
310    {
311        self.bytes
312            .iter()
313            .skip(source)
314            .take_while(|b| predicate(**b))
315            .cloned()
316            .collect()
317    }
318
319    pub fn load_bytes_while_non_zero(&self, source: usize) -> Vec<u8> {
320        self.load_bytes_while(source, |b| b != 0)
321    }
322
323    pub fn map(&self, value: Value) -> SimpleResult<&[u8]> {
324        if value.address + value.size > self.stack_size + self.memory_size {
325            Err(SimpleError::new(format!(
326                "Trying to map {} bytes from outside of memory",
327                value.size
328            )))
329        } else {
330            Ok(&self.bytes[value.address..value.address + value.size])
331        }
332    }
333
334    pub fn map_mut(&mut self, value: Value) -> SimpleResult<&mut [u8]> {
335        if value.address + value.size > self.stack_size + self.memory_size {
336            Err(SimpleError::new(format!(
337                "Trying to map {} bytes from outside of memory",
338                value.size
339            )))
340        } else {
341            Ok(&mut self.bytes[value.address..value.address + value.size])
342        }
343    }
344
345    pub fn map_stack(&self) -> &[u8] {
346        &self.bytes[0..self.stack_size]
347    }
348
349    pub fn map_stack_mut(&mut self) -> &mut [u8] {
350        &mut self.bytes[0..self.stack_size]
351    }
352
353    pub fn map_memory(&self) -> &[u8] {
354        &self.bytes[self.stack_size..]
355    }
356
357    pub fn map_memory_mut(&mut self) -> &mut [u8] {
358        &mut self.bytes[self.stack_size..]
359    }
360
361    pub fn map_all(&self) -> &[u8] {
362        &self.bytes
363    }
364
365    pub fn map_all_mut(&mut self) -> &mut [u8] {
366        &mut self.bytes
367    }
368
369    pub fn alloc_stack_value(&mut self, size: usize) -> SimpleResult<Value> {
370        let address = self.stack_pos;
371        self.stack_push_bytes(&vec![0; size])?;
372        Ok(Value { address, size })
373    }
374
375    pub fn alloc_memory_value(&mut self, size: usize) -> SimpleResult<Value> {
376        let (index, address, s) = self.find_free_memory(size)?;
377        if self.memory_free[index].1 == size {
378            self.memory_free.remove(index);
379        } else {
380            self.memory_free[index] = (address + size, s - size);
381        }
382        Ok(Value {
383            address: address + self.stack_size,
384            size,
385        })
386    }
387
388    pub fn dealloc_memory_value(&mut self, value: &Value) -> SimpleResult<()> {
389        self.ensure_taken_memory(value)?;
390        self.memory_free
391            .push((value.address - self.stack_size, value.size));
392        self.defragment_free_memory();
393        Ok(())
394    }
395
396    fn find_free_memory(&self, size: usize) -> SimpleResult<(usize, usize, usize)> {
397        if let Some((i, (a, s))) = self
398            .memory_free
399            .iter()
400            .enumerate()
401            .find(|(_, (_, s))| size <= *s)
402        {
403            Ok((i, *a, *s))
404        } else {
405            Err(SimpleError::new(format!(
406                "Could not find free {} bytes in memory",
407                size
408            )))
409        }
410    }
411
412    fn ensure_taken_memory(&self, value: &Value) -> SimpleResult<()> {
413        let address = value.address - self.stack_size;
414        if !self
415            .memory_free
416            .iter()
417            .any(|(a, s)| address >= *a && address + value.size <= *a + *s)
418        {
419            Ok(())
420        } else {
421            Err(SimpleError::new(format!(
422                "Memory block at {} is free",
423                value.address
424            )))
425        }
426    }
427
428    fn defragment_free_memory(&mut self) {
429        self.memory_free.sort_by(|a, b| a.0.cmp(&b.0));
430        self.memory_free = self
431            .memory_free
432            .iter()
433            .cloned()
434            .coalesce(|a, b| {
435                if a.0 + a.1 == b.0 {
436                    Ok((a.0, a.1 + b.1))
437                } else {
438                    Err((a, b))
439                }
440            })
441            .collect();
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use crate::state::Value;
449
450    #[test]
451    fn test_stack() {
452        let mut state = State::new(8, 0);
453        assert_eq!(state.stack_pos(), 0);
454        assert_eq!(state.stack_size(), 8);
455
456        assert_eq!(
457            state.alloc_stack_value(4).unwrap(),
458            Value {
459                address: 0,
460                size: 4
461            }
462        );
463        assert_eq!(
464            state.alloc_stack_value(4).unwrap(),
465            Value {
466                address: 4,
467                size: 4
468            }
469        );
470        assert!(state.alloc_stack_value(4).is_err());
471
472        assert_eq!(state.stack_pos(), 8);
473        state.stack_pop_bytes(4).unwrap();
474        assert_eq!(state.stack_pos(), 4);
475        state.stack_pop_bytes(4).unwrap();
476        assert_eq!(state.stack_pos(), 0);
477    }
478
479    #[test]
480    fn test_memory() {
481        let mut state = State::new(8, 8);
482        assert_eq!(state.memory_free, vec![(0, 8)]);
483
484        assert_eq!(
485            state.alloc_memory_value(4).unwrap(),
486            Value {
487                address: 8,
488                size: 4
489            }
490        );
491        assert_eq!(state.memory_free, vec![(4, 4)]);
492        assert_eq!(
493            state.alloc_memory_value(4).unwrap(),
494            Value {
495                address: 12,
496                size: 4
497            }
498        );
499        assert_eq!(state.memory_free, vec![]);
500        assert!(state.alloc_memory_value(4).is_err());
501
502        let mut state = State::new(8, 8);
503        let a = state.alloc_memory_value(4).unwrap();
504        let b = state.alloc_memory_value(4).unwrap();
505        assert_eq!(state.memory_free, vec![]);
506        state.dealloc_memory_value(&b).unwrap();
507        assert_eq!(state.memory_free, vec![(4, 4)]);
508        state.dealloc_memory_value(&a).unwrap();
509        assert_eq!(state.memory_free, vec![(0, 8)]);
510    }
511}