use sim_kernel::{CodecId, Cx, Error, ReadPolicy, Result};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct DecodeLimits {
pub max_input_bytes: usize,
pub max_tokens: usize,
pub max_expr_nodes: usize,
pub max_depth: usize,
pub max_string_bytes: usize,
pub max_blob_bytes: usize,
pub max_collection_len: usize,
pub max_trivia_items: usize,
}
impl Default for DecodeLimits {
fn default() -> Self {
Self {
max_input_bytes: 8 * 1024 * 1024,
max_tokens: 1_000_000,
max_expr_nodes: 200_000,
max_depth: 512,
max_string_bytes: 256 * 1024,
max_blob_bytes: 8 * 1024 * 1024,
max_collection_len: 65_536,
max_trivia_items: 16_384,
}
}
}
pub struct DecodeBudget {
limits: DecodeLimits,
nodes: usize,
trivia: usize,
}
impl DecodeBudget {
pub fn new(limits: DecodeLimits) -> Self {
Self {
limits,
nodes: 0,
trivia: 0,
}
}
pub fn limits(&self) -> DecodeLimits {
self.limits
}
pub fn check_input_bytes(&self, codec: CodecId, len: usize) -> Result<()> {
self.check(codec, "input bytes", len, self.limits.max_input_bytes)
}
pub fn check_tokens(&self, codec: CodecId, count: usize) -> Result<()> {
self.check(codec, "tokens", count, self.limits.max_tokens)
}
pub fn check_collection_len(&self, codec: CodecId, len: usize) -> Result<()> {
self.check(
codec,
"collection length",
len,
self.limits.max_collection_len,
)
}
pub fn check_string_bytes(&self, codec: CodecId, len: usize) -> Result<()> {
self.check(codec, "string bytes", len, self.limits.max_string_bytes)
}
pub fn check_blob_bytes(&self, codec: CodecId, len: usize) -> Result<()> {
self.check(codec, "blob bytes", len, self.limits.max_blob_bytes)
}
pub fn add_trivia(&mut self, codec: CodecId) -> Result<()> {
self.trivia += 1;
self.check(
codec,
"trivia items",
self.trivia,
self.limits.max_trivia_items,
)
}
pub fn enter_node(&mut self, codec: CodecId, depth: usize) -> Result<()> {
self.nodes += 1;
self.check(codec, "expr nodes", self.nodes, self.limits.max_expr_nodes)?;
self.check(codec, "recursion depth", depth, self.limits.max_depth)
}
fn check(&self, codec: CodecId, what: &str, got: usize, max: usize) -> Result<()> {
if got > max {
return Err(Error::CodecError {
codec,
message: format!("decode {what} limit exceeded: {got} > {max}"),
});
}
Ok(())
}
}
pub struct ReadCx<'a> {
pub cx: &'a mut Cx,
pub codec: CodecId,
pub read_policy: ReadPolicy,
pub limits: DecodeLimits,
}