use super::{
index_allocator::{SimpleIndexAllocator, SlotId},
round_up_to_pow2,
};
use crate::sys::vm::{commit_stack_pages, reset_stack_pages_to_zero};
use crate::{Mmap, PoolingInstanceAllocatorConfig};
use anyhow::{anyhow, bail, Context, Result};
#[derive(Debug)]
pub struct StackPool {
mapping: Mmap,
stack_size: usize,
max_stacks: usize,
page_size: usize,
index_allocator: SimpleIndexAllocator,
async_stack_zeroing: bool,
async_stack_keep_resident: usize,
}
impl StackPool {
pub fn new(config: &PoolingInstanceAllocatorConfig) -> Result<Self> {
use rustix::mm::{mprotect, MprotectFlags};
let page_size = crate::page_size();
let stack_size = if config.stack_size == 0 {
0
} else {
round_up_to_pow2(config.stack_size, page_size)
.checked_add(page_size)
.ok_or_else(|| anyhow!("stack size exceeds addressable memory"))?
};
let max_stacks = usize::try_from(config.limits.total_stacks).unwrap();
let allocation_size = stack_size
.checked_mul(max_stacks)
.ok_or_else(|| anyhow!("total size of execution stacks exceeds addressable memory"))?;
let mapping = Mmap::accessible_reserved(allocation_size, allocation_size)
.context("failed to create stack pool mapping")?;
if allocation_size > 0 {
unsafe {
for i in 0..max_stacks {
let bottom_of_stack = mapping.as_ptr().add(i * stack_size).cast_mut();
mprotect(bottom_of_stack.cast(), page_size, MprotectFlags::empty())
.context("failed to protect stack guard page")?;
}
}
}
Ok(Self {
mapping,
stack_size,
max_stacks,
page_size,
async_stack_zeroing: config.async_stack_zeroing,
async_stack_keep_resident: config.async_stack_keep_resident,
index_allocator: SimpleIndexAllocator::new(config.limits.total_stacks),
})
}
pub fn is_empty(&self) -> bool {
self.index_allocator.is_empty()
}
pub fn allocate(&self) -> Result<wasmtime_fiber::FiberStack> {
if self.stack_size == 0 {
bail!("pooling allocator not configured to enable fiber stack allocation");
}
let index = self
.index_allocator
.alloc()
.ok_or_else(|| {
anyhow!(
"maximum concurrent fiber limit of {} reached",
self.max_stacks
)
})?
.index();
assert!(index < self.max_stacks);
unsafe {
let size_without_guard = self.stack_size - self.page_size;
let bottom_of_stack = self
.mapping
.as_ptr()
.add((index * self.stack_size) + self.page_size)
.cast_mut();
commit_stack_pages(bottom_of_stack, size_without_guard)?;
let stack =
wasmtime_fiber::FiberStack::from_raw_parts(bottom_of_stack, size_without_guard)?;
Ok(stack)
}
}
pub unsafe fn deallocate(&self, stack: &wasmtime_fiber::FiberStack) {
let top = stack
.top()
.expect("fiber stack not allocated from the pool") as usize;
let base = self.mapping.as_ptr() as usize;
let len = self.mapping.len();
assert!(
top > base && top <= (base + len),
"fiber stack top pointer not in range"
);
let stack_size = self.stack_size - self.page_size;
let bottom_of_stack = top - stack_size;
let start_of_stack = bottom_of_stack - self.page_size;
assert!(start_of_stack >= base && start_of_stack < (base + len));
assert!((start_of_stack - base) % self.stack_size == 0);
let index = (start_of_stack - base) / self.stack_size;
assert!(index < self.max_stacks);
if self.async_stack_zeroing {
self.zero_stack(bottom_of_stack, stack_size);
}
self.index_allocator.free(SlotId(index as u32));
}
fn zero_stack(&self, bottom: usize, size: usize) {
let size_to_memset = size.min(self.async_stack_keep_resident);
unsafe {
std::ptr::write_bytes(
(bottom + size - size_to_memset) as *mut u8,
0,
size_to_memset,
);
reset_stack_pages_to_zero(bottom as _, size - size_to_memset).unwrap();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::InstanceLimits;
#[cfg(all(unix, target_pointer_width = "64", feature = "async", not(miri)))]
#[test]
fn test_stack_pool() -> Result<()> {
let config = PoolingInstanceAllocatorConfig {
limits: InstanceLimits {
total_stacks: 10,
..Default::default()
},
stack_size: 1,
async_stack_zeroing: true,
..PoolingInstanceAllocatorConfig::default()
};
let pool = StackPool::new(&config)?;
let native_page_size = crate::page_size();
assert_eq!(pool.stack_size, 2 * native_page_size);
assert_eq!(pool.max_stacks, 10);
assert_eq!(pool.page_size, native_page_size);
assert_eq!(pool.index_allocator.testing_freelist(), []);
let base = pool.mapping.as_ptr() as usize;
let mut stacks = Vec::new();
for i in 0..10 {
let stack = pool.allocate().expect("allocation should succeed");
assert_eq!(
((stack.top().unwrap() as usize - base) / pool.stack_size) - 1,
i
);
stacks.push(stack);
}
assert_eq!(pool.index_allocator.testing_freelist(), []);
assert!(pool.allocate().is_err(), "allocation should fail");
for stack in stacks {
unsafe {
pool.deallocate(&stack);
}
}
assert_eq!(
pool.index_allocator.testing_freelist(),
[
SlotId(0),
SlotId(1),
SlotId(2),
SlotId(3),
SlotId(4),
SlotId(5),
SlotId(6),
SlotId(7),
SlotId(8),
SlotId(9)
],
);
Ok(())
}
}