use ahash::HashSetExt;
use saphyr_parser::{Event, Parser, ScalarStyle, ScanError};
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use std::collections::HashSet;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Budget {
pub max_reader_input_bytes: Option<usize>,
pub max_events: usize,
pub max_aliases: usize,
pub max_anchors: usize,
pub max_depth: usize,
pub max_documents: usize,
pub max_nodes: usize,
pub max_total_scalar_bytes: usize,
pub max_merge_keys: usize,
pub enforce_alias_anchor_ratio: bool,
pub alias_anchor_min_aliases: usize,
pub alias_anchor_ratio_multiplier: usize,
}
impl Default for Budget {
#[allow(deprecated)]
fn default() -> Self {
Self {
max_reader_input_bytes: Some(256 * 1024 * 1024), max_events: 1_000_000, max_aliases: 50_000, max_anchors: 50_000,
max_depth: 2_000, max_documents: 1_024, max_nodes: 250_000, max_total_scalar_bytes: 64 * 1024 * 1024, max_merge_keys: 10_000, enforce_alias_anchor_ratio: true,
alias_anchor_min_aliases: 100,
alias_anchor_ratio_multiplier: 10,
}
}
}
#[non_exhaustive]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum BudgetBreach {
Events {
events: usize,
},
Aliases {
aliases: usize,
},
Anchors {
anchors: usize,
},
Depth {
depth: usize,
},
Documents {
documents: usize,
},
Nodes {
nodes: usize,
},
ScalarBytes {
total_scalar_bytes: usize,
},
MergeKeys {
merge_keys: usize,
},
AliasAnchorRatio {
aliases: usize,
anchors: usize,
},
SequenceUnbalanced,
InputBytes {
input_bytes: usize,
},
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct BudgetReport {
pub breached: Option<BudgetBreach>,
pub events: usize,
pub aliases: usize,
pub anchors: usize,
pub documents: usize,
pub nodes: usize,
pub max_depth: usize,
pub total_scalar_bytes: usize,
pub merge_keys: usize,
}
impl BudgetReport {
fn reset(&mut self) {
self.events = 0;
self.aliases = 0;
self.anchors = 0;
self.nodes = 0;
self.max_depth = 0;
self.total_scalar_bytes = 0;
self.merge_keys = 0;
}
}
#[non_exhaustive]
#[derive(Debug, PartialEq)]
pub enum EnforcingPolicy {
AllContent,
PerDocument,
}
type FastHashSet<T> = HashSet<T, ahash::RandomState>;
#[derive(Debug)]
pub(crate) struct BudgetEnforcer {
budget: Budget,
report: BudgetReport,
depth: usize,
defined_anchors: FastHashSet<usize>,
containers: SmallVec<[ContainerState; 64]>,
policy: EnforcingPolicy,
}
#[derive(Clone, Copy, Debug)]
enum ContainerState {
Sequence {
from_mapping_value: bool,
},
Mapping {
expecting_key: bool,
from_mapping_value: bool,
},
}
impl BudgetEnforcer {
pub fn new(budget: Budget, policy: EnforcingPolicy) -> Self {
Self {
budget,
report: BudgetReport::default(),
depth: 0,
defined_anchors: FastHashSet::with_capacity(256),
containers: SmallVec::new(),
policy,
}
}
pub fn observe(&mut self, ev: &Event) -> Result<(), BudgetBreach> {
self.report.events += 1;
if self.report.events > self.budget.max_events {
return Err(BudgetBreach::Events {
events: self.report.events,
});
}
match ev {
Event::Scalar(value, style, anchor_id, tag_opt) => {
self.bump_nodes()?;
self.report.total_scalar_bytes =
self.report.total_scalar_bytes.saturating_add(value.len());
if self.report.total_scalar_bytes > self.budget.max_total_scalar_bytes {
return Err(BudgetBreach::ScalarBytes {
total_scalar_bytes: self.report.total_scalar_bytes,
});
}
self.record_anchor(*anchor_id)?;
self.handle_scalar(value, style, tag_opt.is_some())?;
}
Event::MappingStart(anchor_id, _tag_opt) => {
self.bump_nodes()?;
self.depth = self.depth.saturating_add(1);
if self.depth > self.report.max_depth {
self.report.max_depth = self.depth;
}
if self.report.max_depth > self.budget.max_depth {
return Err(BudgetBreach::Depth {
depth: self.report.max_depth,
});
}
let from_mapping_value = self.entering_container();
self.containers.push(ContainerState::Mapping {
expecting_key: true,
from_mapping_value,
});
self.record_anchor(*anchor_id)?;
}
Event::MappingEnd => {
if let Some(new_depth) = self.depth.checked_sub(1) {
self.depth = new_depth;
} else {
return Err(BudgetBreach::SequenceUnbalanced);
}
self.leave_mapping()?;
}
Event::SequenceStart(anchor_id, _tag_opt) => {
self.bump_nodes()?;
self.depth = self.depth.saturating_add(1);
if self.depth > self.report.max_depth {
self.report.max_depth = self.depth;
}
if self.report.max_depth > self.budget.max_depth {
return Err(BudgetBreach::Depth {
depth: self.report.max_depth,
});
}
let from_mapping_value = self.entering_container();
self.containers
.push(ContainerState::Sequence { from_mapping_value });
self.record_anchor(*anchor_id)?;
}
Event::SequenceEnd => {
if let Some(new_depth) = self.depth.checked_sub(1) {
self.depth = new_depth;
} else {
return Err(BudgetBreach::SequenceUnbalanced);
}
self.leave_sequence()?;
}
Event::Alias(_anchor_id) => {
self.report.aliases += 1;
if self.report.aliases > self.budget.max_aliases {
return Err(BudgetBreach::Aliases {
aliases: self.report.aliases,
});
}
self.handle_alias();
}
Event::DocumentStart(_explicit) => {
if self.policy == EnforcingPolicy::PerDocument {
self.report.reset();
} else {
self.report.documents += 1;
if self.report.documents > self.budget.max_documents {
return Err(BudgetBreach::Documents {
documents: self.report.documents,
});
}
}
}
Event::DocumentEnd => {}
Event::Nothing => {}
Event::StreamStart | Event::StreamEnd => {}
}
Ok(())
}
fn bump_nodes(&mut self) -> Result<(), BudgetBreach> {
self.report.nodes += 1;
if self.report.nodes > self.budget.max_nodes {
return Err(BudgetBreach::Nodes {
nodes: self.report.nodes,
});
}
Ok(())
}
fn record_anchor(&mut self, anchor_id: usize) -> Result<(), BudgetBreach> {
if anchor_id != 0 && self.defined_anchors.insert(anchor_id) {
let count = self.defined_anchors.len();
if count > self.budget.max_anchors {
self.report.anchors = count;
return Err(BudgetBreach::Anchors { anchors: count });
}
}
self.report.anchors = self.defined_anchors.len();
Ok(())
}
fn handle_scalar(
&mut self,
value: &str,
style: &ScalarStyle,
has_tag: bool,
) -> Result<(), BudgetBreach> {
if let Some(ContainerState::Mapping { expecting_key, .. }) = self.containers.last_mut() {
if *expecting_key {
if !has_tag && matches!(style, ScalarStyle::Plain) && value == "<<" {
self.report.merge_keys += 1;
if self.report.merge_keys > self.budget.max_merge_keys {
return Err(BudgetBreach::MergeKeys {
merge_keys: self.report.merge_keys,
});
}
}
*expecting_key = false;
} else {
self.finish_value();
}
}
Ok(())
}
fn handle_alias(&mut self) {
if let Some(ContainerState::Mapping { expecting_key, .. }) = self.containers.last_mut() {
if *expecting_key {
*expecting_key = false;
} else {
self.finish_value();
}
}
}
fn entering_container(&mut self) -> bool {
if let Some(ContainerState::Mapping { expecting_key, .. }) = self.containers.last_mut() {
if *expecting_key {
*expecting_key = false;
false
} else {
true
}
} else {
false
}
}
fn leave_sequence(&mut self) -> Result<(), BudgetBreach> {
match self.containers.pop() {
Some(ContainerState::Sequence { from_mapping_value }) => {
if from_mapping_value {
self.finish_value();
}
Ok(())
}
_ => Err(BudgetBreach::SequenceUnbalanced),
}
}
fn leave_mapping(&mut self) -> Result<(), BudgetBreach> {
match self.containers.pop() {
Some(ContainerState::Mapping {
from_mapping_value, ..
}) => {
if from_mapping_value {
self.finish_value();
}
Ok(())
}
_ => Err(BudgetBreach::SequenceUnbalanced),
}
}
fn finish_value(&mut self) {
if let Some(ContainerState::Mapping { expecting_key, .. }) = self.containers.last_mut() {
*expecting_key = true;
}
}
pub fn into_report(mut self) -> BudgetReport {
self.report.anchors = self.defined_anchors.len();
self.report
}
pub fn finalize(mut self) -> BudgetReport {
self.report.anchors = self.defined_anchors.len();
if self.budget.enforce_alias_anchor_ratio
&& self.report.aliases >= self.budget.alias_anchor_min_aliases
&& (self.report.anchors == 0
|| self.report.aliases
> self.budget.alias_anchor_ratio_multiplier * self.report.anchors)
{
self.report.breached = Some(BudgetBreach::AliasAnchorRatio {
aliases: self.report.aliases,
anchors: self.report.anchors,
});
}
self.report
}
}
pub fn check_yaml_budget(
input: &str,
budget: Budget,
policy: EnforcingPolicy,
) -> Result<BudgetReport, ScanError> {
let parser = Parser::new_from_str(input);
let mut enforcer = BudgetEnforcer::new(budget, policy);
for item in parser {
let (ev, _span) = item?;
if let Err(breach) = enforcer.observe(&ev) {
let mut report = enforcer.into_report();
report.breached = Some(breach);
return Ok(report);
}
}
Ok(enforcer.finalize())
}
pub fn parse_yaml(input: &str, budget: Budget) -> Result<bool, ScanError> {
let report = check_yaml_budget(input, budget, EnforcingPolicy::AllContent)?;
Ok(report.breached.is_some())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tiny_yaml_ok() {
let b = Budget::default();
let y = "a: [1, 2, 3]\n";
let r = check_yaml_budget(y, b, EnforcingPolicy::AllContent).unwrap();
assert!(r.breached.is_none());
assert_eq!(r.documents, 1);
assert!(r.nodes > 0);
}
#[test]
fn alias_bomb_trips_alias_limit() {
let y = r#"root: &A [1, 2]
a: *A
b: *A
c: *A
d: *A
e: *A
"#;
let b = Budget {
max_aliases: 3, ..Default::default()
};
let rep = check_yaml_budget(y, b, EnforcingPolicy::AllContent).unwrap();
assert!(matches!(rep.breached, Some(BudgetBreach::Aliases { .. })));
}
#[test]
fn deep_nesting_trips_depth() {
let mut y = String::new();
for _ in 0..200 {
y.push('[');
}
for _ in 0..200 {
y.push(']');
}
let b = Budget {
max_depth: 150,
..Default::default()
};
let rep = check_yaml_budget(&y, b, EnforcingPolicy::AllContent).unwrap();
assert!(matches!(rep.breached, Some(BudgetBreach::Depth { .. })));
}
#[test]
fn anchors_limit_trips() {
let y = "a: &A 1\nb: &B 2\nc: &C 3\n";
let b = Budget {
max_anchors: 2,
..Default::default()
};
let rep = check_yaml_budget(y, b, EnforcingPolicy::AllContent).unwrap();
assert!(matches!(
rep.breached,
Some(BudgetBreach::Anchors { anchors: 3 })
));
}
#[test]
fn merge_key_limit_trips() {
let mut y = String::from("base: &B\n k: 1\nitems:\n");
for idx in 0..3 {
y.push_str(&format!(" item{idx}:\n <<: *B\n extra: {idx}\n"));
}
let b = Budget {
max_merge_keys: 2,
..Default::default()
};
let rep = check_yaml_budget(&y, b, EnforcingPolicy::AllContent).unwrap();
assert!(matches!(
rep.breached,
Some(BudgetBreach::MergeKeys { merge_keys }) if merge_keys == 3
));
assert_eq!(rep.merge_keys, 3);
}
#[test]
fn alias_anchor_ratio_trips_when_excessive() {
let yaml = "root: &A [1]\na: *A\nb: *A\nc: *A\n";
let budget = Budget {
alias_anchor_min_aliases: 1,
alias_anchor_ratio_multiplier: 2,
..Default::default()
};
let report = check_yaml_budget(yaml, budget, EnforcingPolicy::AllContent).unwrap();
assert!(matches!(
report.breached,
Some(BudgetBreach::AliasAnchorRatio {
aliases: 3,
anchors: 1
})
));
assert_eq!(report.aliases, 3);
assert_eq!(report.anchors, 1);
}
#[test]
fn alias_anchor_ratio_respects_minimum_alias_threshold() {
let yaml = "root: &A [1]\na: *A\nb: *A\nc: *A\n";
let budget = Budget {
alias_anchor_min_aliases: 5,
alias_anchor_ratio_multiplier: 1,
..Default::default()
};
let report = check_yaml_budget(yaml, budget, EnforcingPolicy::AllContent).unwrap();
assert!(report.breached.is_none());
assert_eq!(report.aliases, 3);
assert_eq!(report.anchors, 1);
}
}