use std::ptr::NonNull;
use wasmer::{
vm::{
MemoryError, MemoryStyle, TableStyle, VMMemory, VMMemoryDefinition, VMTable,
VMTableDefinition,
},
MemoryType, Pages, TableType, Tunables,
};
pub struct LimitingTunables<T: Tunables> {
limit: Pages,
base: T,
}
impl<T: Tunables> LimitingTunables<T> {
pub fn new(base: T, limit: Pages) -> Self {
Self { limit, base }
}
fn adjust_memory(&self, requested: &MemoryType) -> MemoryType {
let mut adjusted = *requested;
if requested.maximum.is_none() {
adjusted.maximum = Some(self.limit);
}
adjusted
}
fn validate_memory(&self, ty: &MemoryType) -> Result<(), MemoryError> {
if ty.minimum > self.limit {
return Err(MemoryError::Generic(
"Minimum exceeds the allowed memory limit".to_string(),
));
}
if let Some(max) = ty.maximum {
if max > self.limit {
return Err(MemoryError::Generic(
"Maximum exceeds the allowed memory limit".to_string(),
));
}
} else {
return Err(MemoryError::Generic("Maximum unset".to_string()));
}
Ok(())
}
}
impl<T: Tunables> Tunables for LimitingTunables<T> {
fn memory_style(&self, memory: &MemoryType) -> MemoryStyle {
let adjusted = self.adjust_memory(memory);
self.base.memory_style(&adjusted)
}
fn table_style(&self, table: &TableType) -> TableStyle {
self.base.table_style(table)
}
fn create_host_memory(
&self,
ty: &MemoryType,
style: &MemoryStyle,
) -> Result<VMMemory, MemoryError> {
let adjusted = self.adjust_memory(ty);
self.validate_memory(&adjusted)?;
self.base.create_host_memory(&adjusted, style)
}
unsafe fn create_vm_memory(
&self,
ty: &MemoryType,
style: &MemoryStyle,
vm_definition_location: NonNull<VMMemoryDefinition>,
) -> Result<VMMemory, MemoryError> {
let adjusted = self.adjust_memory(ty);
self.validate_memory(&adjusted)?;
self.base
.create_vm_memory(&adjusted, style, vm_definition_location)
}
fn create_host_table(&self, ty: &TableType, style: &TableStyle) -> Result<VMTable, String> {
self.base.create_host_table(ty, style)
}
unsafe fn create_vm_table(
&self,
ty: &TableType,
style: &TableStyle,
vm_definition_location: NonNull<VMTableDefinition>,
) -> Result<VMTable, String> {
self.base.create_vm_table(ty, style, vm_definition_location)
}
}
#[cfg(test)]
mod tests {
use super::*;
use wasmer::{sys::BaseTunables, Target};
#[test]
fn adjust_memory_works() {
let limit = Pages(12);
let limiting = LimitingTunables::new(BaseTunables::for_target(&Target::default()), limit);
let requested = MemoryType::new(3, None, true);
let adjusted = limiting.adjust_memory(&requested);
assert_eq!(adjusted, MemoryType::new(3, Some(12), true));
let requested = MemoryType::new(3, Some(7), true);
let adjusted = limiting.adjust_memory(&requested);
assert_eq!(adjusted, requested);
let requested = MemoryType::new(3, Some(12), true);
let adjusted = limiting.adjust_memory(&requested);
assert_eq!(adjusted, requested);
let requested = MemoryType::new(3, Some(20), true);
let adjusted = limiting.adjust_memory(&requested);
assert_eq!(adjusted, requested);
let requested = MemoryType::new(5, Some(3), true);
let adjusted = limiting.adjust_memory(&requested);
assert_eq!(adjusted, requested);
let requested = MemoryType::new(20, Some(20), true);
let adjusted = limiting.adjust_memory(&requested);
assert_eq!(adjusted, requested);
}
#[test]
fn validate_memory_works() {
let limit = Pages(12);
let limiting = LimitingTunables::new(BaseTunables::for_target(&Target::default()), limit);
let memory = MemoryType::new(3, Some(7), true);
limiting.validate_memory(&memory).unwrap();
let memory = MemoryType::new(3, Some(12), true);
limiting.validate_memory(&memory).unwrap();
let memory = MemoryType::new(3, Some(20), true);
let result = limiting.validate_memory(&memory);
match result.unwrap_err() {
MemoryError::Generic(msg) => {
assert_eq!(msg, "Maximum exceeds the allowed memory limit")
}
err => panic!("Unexpected error: {err:?}"),
}
let memory = MemoryType::new(3, None, true);
let result = limiting.validate_memory(&memory);
match result.unwrap_err() {
MemoryError::Generic(msg) => assert_eq!(msg, "Maximum unset"),
err => panic!("Unexpected error: {err:?}"),
}
let memory = MemoryType::new(5, Some(3), true);
limiting.validate_memory(&memory).unwrap();
let memory = MemoryType::new(20, Some(20), true);
let result = limiting.validate_memory(&memory);
match result.unwrap_err() {
MemoryError::Generic(msg) => {
assert_eq!(msg, "Minimum exceeds the allowed memory limit")
}
err => panic!("Unexpected error: {err:?}"),
}
}
}