use crate::{
ebpf::{ELF_INSN_DUMP_OFFSET, MM_STACK_START, SCRATCH_REGS},
error::{EbpfError, UserDefinedError},
memory_region::MemoryRegion,
};
#[derive(Clone, Debug)]
struct CallFrame {
vm_addr: u64,
saved_reg: [u64; 4],
return_ptr: usize,
}
#[derive(Clone, Debug)]
pub struct CallFrames {
stack: Vec<u8>,
region: MemoryRegion,
frame_index: usize,
frame_index_max: usize,
frames: Vec<CallFrame>,
}
impl CallFrames {
pub fn new(depth: usize, frame_size: usize) -> Self {
let stack = vec![0u8; depth * frame_size];
let region =
MemoryRegion::new_from_slice(&stack[..], MM_STACK_START, frame_size as u64, true);
let mut frames = CallFrames {
stack,
region,
frame_index: 0,
frame_index_max: 0,
frames: vec![
CallFrame {
vm_addr: 0,
saved_reg: [0u64; SCRATCH_REGS],
return_ptr: 0
};
depth
],
};
for i in 0..depth {
frames.frames[i].vm_addr = MM_STACK_START + (i * 2 * frame_size) as u64;
}
frames
}
pub fn get_region(&self) -> &MemoryRegion {
&self.region
}
pub fn get_frame_pointers(&self) -> Vec<u64> {
self.frames.iter().map(|frame| frame.vm_addr).collect()
}
pub fn get_stack_top(&self) -> u64 {
self.frames[self.frame_index].vm_addr + (1 << self.region.vm_gap_shift)
}
pub fn get_frame_index(&self) -> usize {
self.frame_index
}
pub fn get_max_frame_index(&self) -> usize {
self.frame_index_max
}
pub fn push<E: UserDefinedError>(
&mut self,
saved_reg: &[u64],
return_ptr: usize,
) -> Result<u64, EbpfError<E>> {
if self.frame_index + 1 >= self.frames.len() {
return Err(EbpfError::CallDepthExceeded(
return_ptr + ELF_INSN_DUMP_OFFSET - 1,
self.frames.len(),
));
}
self.frames[self.frame_index].saved_reg[..].copy_from_slice(saved_reg);
self.frames[self.frame_index].return_ptr = return_ptr;
self.frame_index += 1;
self.frame_index_max = self.frame_index_max.max(self.frame_index);
Ok(self.get_stack_top())
}
pub fn pop<E: UserDefinedError>(
&mut self,
) -> Result<([u64; SCRATCH_REGS], u64, usize), EbpfError<E>> {
if self.frame_index == 0 {
return Err(EbpfError::ExitRootCallFrame);
}
self.frame_index -= 1;
Ok((
self.frames[self.frame_index].saved_reg,
self.get_stack_top(),
self.frames[self.frame_index].return_ptr,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::user_error::UserError;
#[test]
fn test_frames() {
const DEPTH: usize = 10;
const FRAME_SIZE: u64 = 8;
let mut frames = CallFrames::new(DEPTH, FRAME_SIZE as usize);
let mut ptrs: Vec<u64> = Vec::new();
for i in 0..DEPTH - 1 {
let registers = vec![i as u64; FRAME_SIZE as usize];
assert_eq!(frames.get_frame_index(), i);
ptrs.push(frames.get_frame_pointers()[i]);
let top = frames.push::<UserError>(®isters[0..4], i).unwrap();
let new_ptrs = frames.get_frame_pointers();
assert_eq!(top, new_ptrs[i + 1] + FRAME_SIZE);
assert_ne!(top, ptrs[i] + FRAME_SIZE - 1);
assert!(!(ptrs[i] <= new_ptrs[i + 1] && new_ptrs[i + 1] < ptrs[i] + FRAME_SIZE));
}
let i = DEPTH - 1;
let registers = vec![i as u64; FRAME_SIZE as usize];
assert_eq!(frames.get_frame_index(), i);
ptrs.push(frames.get_frame_pointers()[i]);
assert!(frames.push::<UserError>(®isters, DEPTH - 1).is_err());
for i in (0..DEPTH - 1).rev() {
let (saved_reg, stack_ptr, return_ptr) = frames.pop::<UserError>().unwrap();
assert_eq!(saved_reg, [i as u64, i as u64, i as u64, i as u64]);
assert_eq!(ptrs[i] + FRAME_SIZE, stack_ptr);
assert_eq!(i, return_ptr);
}
assert!(frames.pop::<UserError>().is_err());
}
}