use crate::options::MergeKeyPolicy;
use ahash::RandomState;
use granit_parser::{Event, Parser, ScalarStyle, ScanError, Tag};
use smallvec::SmallVec;
use std::collections::HashSet;
const DEFAULT_MAX_SCALAR_BYTES: usize = 64 * 1024 * 1024;
const DEFAULT_MAX_TOTAL_COMMENT_BYTES: usize = 64 * 1024 * 1024;
#[cfg(feature = "serde_derived_types")]
fn default_max_total_comment_bytes() -> usize {
DEFAULT_MAX_TOTAL_COMMENT_BYTES
}
#[derive(Clone, Debug)]
#[cfg_attr(
feature = "serde_derived_types",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct Budget {
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub max_reader_input_bytes: Option<usize>,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub max_events: usize,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub max_aliases: usize,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub max_anchors: usize,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub max_depth: usize,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub max_inclusion_depth: u32,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub max_documents: usize,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub max_nodes: usize,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub max_total_scalar_bytes: usize,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
#[cfg_attr(
feature = "serde_derived_types",
serde(default = "default_max_total_comment_bytes")
)]
pub max_total_comment_bytes: usize,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub max_merge_keys: usize,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub enforce_alias_anchor_ratio: bool,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub alias_anchor_min_aliases: usize,
#[deprecated(
note = "Direct construction of `Budget` will be disabled from 1.0.0, use macro `budget!`"
)]
pub alias_anchor_ratio_multiplier: usize,
}
impl Default for Budget {
#[allow(deprecated)]
fn default() -> Self {
Self {
max_reader_input_bytes: Some(256 * 1024 * 1024), max_events: 1_000_000, max_aliases: 50_000, max_anchors: 50_000,
max_depth: 64, max_inclusion_depth: 24,
max_documents: 1_024, max_nodes: 250_000, max_total_scalar_bytes: DEFAULT_MAX_SCALAR_BYTES, max_total_comment_bytes: DEFAULT_MAX_TOTAL_COMMENT_BYTES, max_merge_keys: 10_000, enforce_alias_anchor_ratio: true,
alias_anchor_min_aliases: 100,
alias_anchor_ratio_multiplier: 10,
}
}
}
#[non_exhaustive]
#[derive(Clone, Debug)]
#[cfg_attr(
feature = "serde_derived_types",
derive(serde::Serialize, serde::Deserialize)
)]
pub enum BudgetBreach {
Events {
events: usize,
},
Aliases {
aliases: usize,
},
Anchors {
anchors: usize,
},
Depth {
depth: usize,
},
InclusionDepth {
depth: u32,
},
Documents {
documents: usize,
},
Nodes {
nodes: usize,
},
ScalarBytes {
total_scalar_bytes: usize,
},
CommentBytes {
total_comment_bytes: usize,
},
MergeKeys {
merge_keys: usize,
},
AliasAnchorRatio {
aliases: usize,
anchors: usize,
},
SequenceUnbalanced,
InputBytes {
input_bytes: usize,
},
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(
feature = "serde_derived_types",
derive(serde::Serialize, serde::Deserialize)
)]
pub struct BudgetReport {
pub breached: Option<BudgetBreach>,
pub events: usize,
pub aliases: usize,
pub anchors: usize,
pub documents: usize,
pub nodes: usize,
pub max_depth: usize,
pub total_scalar_bytes: usize,
#[cfg_attr(feature = "serde_derived_types", serde(default))]
pub total_comment_bytes: usize,
pub merge_keys: usize,
}
impl BudgetReport {
fn reset(&mut self) {
self.events = 0;
self.aliases = 0;
self.anchors = 0;
self.nodes = 0;
self.max_depth = 0;
self.total_scalar_bytes = 0;
self.total_comment_bytes = 0;
self.merge_keys = 0;
}
}
#[non_exhaustive]
#[derive(Debug, PartialEq)]
pub enum EnforcingPolicy {
AllContent,
PerDocument,
}
#[derive(Debug)]
pub(crate) struct BudgetEnforcer {
budget: Budget,
report: BudgetReport,
depth: usize,
defined_anchors: HashSet<usize, RandomState>,
containers: SmallVec<[ContainerState; 64]>,
policy: EnforcingPolicy,
merge_keys: MergeKeyPolicy,
}
#[derive(Clone, Copy, Debug)]
enum ContainerState {
Sequence {
from_mapping_value: bool,
},
Mapping {
expecting_key: bool,
from_mapping_value: bool,
},
}
#[derive(Clone, Copy)]
enum ContainerKind {
Sequence,
Mapping,
}
fn tag_display_len(tag: Option<&Tag>) -> usize {
tag.map_or(0, |tag| {
let (handle, suffix) = tag.parts();
handle.len().saturating_add(suffix.len())
})
}
impl BudgetEnforcer {
pub(crate) fn new(budget: Budget, policy: EnforcingPolicy, merge_keys: MergeKeyPolicy) -> Self {
Self {
budget,
report: BudgetReport::default(),
depth: 0,
defined_anchors: HashSet::with_capacity_and_hasher(256, RandomState::default()),
containers: SmallVec::new(),
policy,
merge_keys,
}
}
pub fn observe(&mut self, ev: &Event) -> Result<(), BudgetBreach> {
self.report.events += 1;
if self.report.events > self.budget.max_events {
return Err(BudgetBreach::Events {
events: self.report.events,
});
}
match ev {
Event::Scalar(value, style, anchor_id, tag_opt) => {
self.bump_nodes()?;
self.bump_total_scalar_bytes(
value
.len()
.saturating_add(tag_display_len(tag_opt.as_deref())),
)?;
self.record_anchor(*anchor_id)?;
self.handle_scalar(value, style, tag_opt.is_some())?;
}
Event::MappingStart(_style, anchor_id, tag_opt) => {
self.enter_container(*anchor_id, tag_opt.as_deref(), |from_mapping_value| {
ContainerState::Mapping {
expecting_key: true,
from_mapping_value,
}
})?;
}
Event::MappingEnd => {
self.leave_container(ContainerKind::Mapping)?;
}
Event::SequenceStart(_style, anchor_id, tag_opt) => {
self.enter_container(*anchor_id, tag_opt.as_deref(), |from_mapping_value| {
ContainerState::Sequence { from_mapping_value }
})?;
}
Event::SequenceEnd => {
self.leave_container(ContainerKind::Sequence)?;
}
Event::Alias(_anchor_id) => {
self.observe_alias_event(true)?;
}
Event::DocumentStart(..) => {
if self.policy == EnforcingPolicy::PerDocument {
self.report.reset();
self.defined_anchors.clear();
} else {
self.report.documents += 1;
if self.report.documents > self.budget.max_documents {
return Err(BudgetBreach::Documents {
documents: self.report.documents,
});
}
}
}
Event::DocumentEnd => {}
Event::Comment(text, _) => {
self.report.total_comment_bytes =
self.report.total_comment_bytes.saturating_add(text.len());
if self.report.total_comment_bytes > self.budget.max_total_comment_bytes {
return Err(BudgetBreach::CommentBytes {
total_comment_bytes: self.report.total_comment_bytes,
});
}
}
Event::Nothing => {}
Event::StreamStart | Event::StreamEnd => {}
}
Ok(())
}
pub(crate) fn observe_alias_reference(&mut self) -> Result<(), BudgetBreach> {
self.report.events += 1;
if self.report.events > self.budget.max_events {
return Err(BudgetBreach::Events {
events: self.report.events,
});
}
self.observe_alias_event(false)
}
fn observe_alias_event(&mut self, advance_mapping_state: bool) -> Result<(), BudgetBreach> {
self.report.aliases += 1;
if self.report.aliases > self.budget.max_aliases {
return Err(BudgetBreach::Aliases {
aliases: self.report.aliases,
});
}
if advance_mapping_state {
self.handle_alias();
}
Ok(())
}
fn bump_nodes(&mut self) -> Result<(), BudgetBreach> {
self.report.nodes += 1;
if self.report.nodes > self.budget.max_nodes {
return Err(BudgetBreach::Nodes {
nodes: self.report.nodes,
});
}
Ok(())
}
fn bump_total_scalar_bytes(&mut self, bytes: usize) -> Result<(), BudgetBreach> {
self.report.total_scalar_bytes = self.report.total_scalar_bytes.saturating_add(bytes);
if self.report.total_scalar_bytes > self.budget.max_total_scalar_bytes {
return Err(BudgetBreach::ScalarBytes {
total_scalar_bytes: self.report.total_scalar_bytes,
});
}
Ok(())
}
fn enter_depth(&mut self) -> Result<(), BudgetBreach> {
self.depth = self.depth.saturating_add(1);
if self.depth > self.report.max_depth {
self.report.max_depth = self.depth;
}
if self.report.max_depth > self.budget.max_depth {
return Err(BudgetBreach::Depth {
depth: self.report.max_depth,
});
}
Ok(())
}
fn enter_container(
&mut self,
anchor_id: usize,
tag: Option<&Tag>,
container: impl FnOnce(bool) -> ContainerState,
) -> Result<(), BudgetBreach> {
self.bump_nodes()?;
self.enter_depth()?;
self.bump_total_scalar_bytes(tag_display_len(tag))?;
let from_mapping_value = self.entering_container();
self.containers.push(container(from_mapping_value));
self.record_anchor(anchor_id)
}
fn record_anchor(&mut self, anchor_id: usize) -> Result<(), BudgetBreach> {
if anchor_id != 0 && self.defined_anchors.insert(anchor_id) {
let count = self.defined_anchors.len();
if count > self.budget.max_anchors {
self.report.anchors = count;
return Err(BudgetBreach::Anchors { anchors: count });
}
}
self.report.anchors = self.defined_anchors.len();
Ok(())
}
fn handle_scalar(
&mut self,
value: &str,
style: &ScalarStyle,
has_tag: bool,
) -> Result<(), BudgetBreach> {
if let Some(ContainerState::Mapping { expecting_key, .. }) = self.containers.last_mut() {
if *expecting_key {
if matches!(self.merge_keys, MergeKeyPolicy::Merge)
&& !has_tag
&& matches!(style, ScalarStyle::Plain)
&& value == "<<"
{
self.report.merge_keys += 1;
if self.report.merge_keys > self.budget.max_merge_keys {
return Err(BudgetBreach::MergeKeys {
merge_keys: self.report.merge_keys,
});
}
}
*expecting_key = false;
} else {
self.finish_value();
}
}
Ok(())
}
fn handle_alias(&mut self) {
if let Some(ContainerState::Mapping { expecting_key, .. }) = self.containers.last_mut() {
if *expecting_key {
*expecting_key = false;
} else {
self.finish_value();
}
}
}
fn entering_container(&mut self) -> bool {
if let Some(ContainerState::Mapping { expecting_key, .. }) = self.containers.last_mut() {
if *expecting_key {
*expecting_key = false;
false
} else {
true
}
} else {
false
}
}
fn leave_container(&mut self, expected: ContainerKind) -> Result<(), BudgetBreach> {
self.depth = self
.depth
.checked_sub(1)
.ok_or(BudgetBreach::SequenceUnbalanced)?;
let from_mapping_value = match (expected, self.containers.pop()) {
(ContainerKind::Sequence, Some(ContainerState::Sequence { from_mapping_value }))
| (
ContainerKind::Mapping,
Some(ContainerState::Mapping {
from_mapping_value, ..
}),
) => from_mapping_value,
_ => return Err(BudgetBreach::SequenceUnbalanced),
};
if from_mapping_value {
self.finish_value();
}
Ok(())
}
fn finish_value(&mut self) {
if let Some(ContainerState::Mapping { expecting_key, .. }) = self.containers.last_mut() {
*expecting_key = true;
}
}
pub fn into_report(mut self) -> BudgetReport {
self.report.anchors = self.defined_anchors.len();
self.report
}
pub fn finalize(mut self) -> BudgetReport {
self.report.anchors = self.defined_anchors.len();
if self.budget.enforce_alias_anchor_ratio
&& self.report.aliases >= self.budget.alias_anchor_min_aliases
&& (self.report.anchors == 0
|| self.report.aliases
> self
.budget
.alias_anchor_ratio_multiplier
.saturating_mul(self.report.anchors))
{
self.report.breached = Some(BudgetBreach::AliasAnchorRatio {
aliases: self.report.aliases,
anchors: self.report.anchors,
});
}
self.report
}
}
pub fn check_yaml_budget(
input: &str,
budget: Budget,
policy: EnforcingPolicy,
) -> Result<BudgetReport, ScanError> {
let parser = Parser::new_from_str(input);
let mut enforcer = BudgetEnforcer::new(budget, policy, MergeKeyPolicy::Merge);
for item in parser {
let (ev, _span) = item?;
if let Err(breach) = enforcer.observe(&ev) {
let mut report = enforcer.into_report();
report.breached = Some(breach);
return Ok(report);
}
}
Ok(enforcer.finalize())
}
pub fn parse_yaml(input: &str, budget: Budget) -> Result<bool, ScanError> {
let report = check_yaml_budget(input, budget, EnforcingPolicy::AllContent)?;
Ok(report.breached.is_some())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tiny_yaml_ok() {
let b = Budget::default();
let y = "a: [1, 2, 3]\n";
let r = check_yaml_budget(y, b, EnforcingPolicy::AllContent).unwrap();
assert!(r.breached.is_none());
assert_eq!(r.documents, 1);
assert!(r.nodes > 0);
}
#[test]
fn alias_bomb_trips_alias_limit() {
let y = r#"root: &A [1, 2]
a: *A
b: *A
c: *A
d: *A
e: *A
"#;
let b = Budget {
max_aliases: 3, ..Default::default()
};
let rep = check_yaml_budget(y, b, EnforcingPolicy::AllContent).unwrap();
assert!(matches!(rep.breached, Some(BudgetBreach::Aliases { .. })));
}
#[test]
fn deep_nesting_trips_depth() {
let mut y = String::new();
for _ in 0..200 {
y.push('[');
}
for _ in 0..200 {
y.push(']');
}
let b = Budget {
max_depth: 150,
..Default::default()
};
let rep = check_yaml_budget(&y, b, EnforcingPolicy::AllContent).unwrap();
assert!(matches!(rep.breached, Some(BudgetBreach::Depth { .. })));
}
#[test]
fn anchors_limit_trips() {
let y = "a: &A 1\nb: &B 2\nc: &C 3\n";
let b = Budget {
max_anchors: 2,
..Default::default()
};
let rep = check_yaml_budget(y, b, EnforcingPolicy::AllContent).unwrap();
assert!(matches!(
rep.breached,
Some(BudgetBreach::Anchors { anchors: 3 })
));
}
#[test]
fn merge_key_limit_trips() {
let mut y = String::from("base: &B\n k: 1\nitems:\n");
for idx in 0..3 {
y.push_str(&format!(" item{idx}:\n <<: *B\n extra: {idx}\n"));
}
let b = Budget {
max_merge_keys: 2,
..Default::default()
};
let rep = check_yaml_budget(&y, b, EnforcingPolicy::AllContent).unwrap();
assert!(matches!(
rep.breached,
Some(BudgetBreach::MergeKeys { merge_keys }) if merge_keys == 3
));
assert_eq!(rep.merge_keys, 3);
}
#[test]
fn merge_key_limit_is_ignored_when_policy_is_as_ordinary() {
let y = "base: &B\n k: 1\nroot:\n <<: *B\n";
let budget = Budget {
max_merge_keys: 0,
..Default::default()
};
let mut enforcer = BudgetEnforcer::new(
budget,
EnforcingPolicy::AllContent,
MergeKeyPolicy::AsOrdinary,
);
for item in Parser::new_from_str(y) {
let (event, _span) = item.unwrap();
enforcer.observe(&event).unwrap();
}
let report = enforcer.finalize();
assert!(report.breached.is_none());
assert_eq!(report.merge_keys, 0);
}
#[test]
fn alias_anchor_ratio_trips_when_excessive() {
let yaml = "root: &A [1]\na: *A\nb: *A\nc: *A\n";
let budget = Budget {
alias_anchor_min_aliases: 1,
alias_anchor_ratio_multiplier: 2,
..Default::default()
};
let report = check_yaml_budget(yaml, budget, EnforcingPolicy::AllContent).unwrap();
assert!(matches!(
report.breached,
Some(BudgetBreach::AliasAnchorRatio {
aliases: 3,
anchors: 1
})
));
assert_eq!(report.aliases, 3);
assert_eq!(report.anchors, 1);
}
#[test]
fn alias_anchor_ratio_respects_minimum_alias_threshold() {
let yaml = "root: &A [1]\na: *A\nb: *A\nc: *A\n";
let budget = Budget {
alias_anchor_min_aliases: 5,
alias_anchor_ratio_multiplier: 1,
..Default::default()
};
let report = check_yaml_budget(yaml, budget, EnforcingPolicy::AllContent).unwrap();
assert!(report.breached.is_none());
assert_eq!(report.aliases, 3);
assert_eq!(report.anchors, 1);
}
#[test]
fn alias_anchor_ratio_multiplier_overflow_does_not_panic() {
let budget = Budget {
alias_anchor_min_aliases: 1,
alias_anchor_ratio_multiplier: usize::MAX,
..Default::default()
};
let mut enforcer =
BudgetEnforcer::new(budget, EnforcingPolicy::AllContent, MergeKeyPolicy::Merge);
enforcer.report.aliases = usize::MAX;
enforcer.defined_anchors.insert(1);
enforcer.defined_anchors.insert(2);
let report = enforcer.finalize();
assert!(report.breached.is_none());
assert_eq!(report.aliases, usize::MAX);
assert_eq!(report.anchors, 2);
}
#[test]
fn budget_default_sets_max_inclusion_depth() {
let budget = Budget::default();
assert_eq!(budget.max_inclusion_depth, 24);
}
#[test]
fn scalar_budget_counts_tag_bytes() {
let yaml = "root: !!str tagged\n";
let budget = Budget {
max_total_scalar_bytes: 14,
..Default::default()
};
let report = check_yaml_budget(yaml, budget, EnforcingPolicy::AllContent).unwrap();
assert!(matches!(
report.breached,
Some(BudgetBreach::ScalarBytes {
total_scalar_bytes
}) if total_scalar_bytes > 14
));
}
#[test]
fn scalar_budget_counts_container_tag_bytes() {
for yaml in ["root: !!seq [a]\n", "root: !!map {a: b}\n"] {
let budget = Budget {
max_total_scalar_bytes: 24,
..Default::default()
};
let report = check_yaml_budget(yaml, budget, EnforcingPolicy::AllContent).unwrap();
assert!(
matches!(
report.breached,
Some(BudgetBreach::ScalarBytes {
total_scalar_bytes
}) if total_scalar_bytes > 24
),
"yaml: {yaml:?}, report: {report:?}"
);
}
}
#[test]
fn tag_display_len_matches_display_without_allocating() {
for tag in [
Tag::with_original_handle("tag:yaml.org,2002:", "str", "!!"),
Tag::with_original_handle("!", "local", "!"),
Tag::with_original_handle("", "tag:example.com,2000:thing", ""),
] {
assert_eq!(tag_display_len(Some(&tag)), tag.to_string().len());
}
}
#[test]
fn comment_budget_counts_comment_bytes() {
let yaml = "#abcdef\nroot: ok\n";
let budget = Budget {
max_total_comment_bytes: 5,
..Default::default()
};
let report = check_yaml_budget(yaml, budget, EnforcingPolicy::AllContent).unwrap();
assert!(matches!(
report.breached,
Some(BudgetBreach::CommentBytes {
total_comment_bytes
}) if total_comment_bytes > 5
));
assert_eq!(report.total_scalar_bytes, 0);
}
}