use crate::{
BasicComposer, CommentPreservingComposer, CommentedValue, Composer, Error, Limits, Position,
Result, Value,
};
pub trait Constructor {
fn construct(&mut self) -> Result<Option<Value>>;
fn check_data(&self) -> bool;
fn reset(&mut self);
}
pub trait CommentPreservingConstructor {
fn construct_commented(&mut self) -> Result<Option<CommentedValue>>;
fn check_data(&self) -> bool;
fn reset(&mut self);
}
#[derive(Debug)]
pub struct SafeConstructor {
composer: BasicComposer,
position: Position,
limits: Limits,
}
impl SafeConstructor {
pub fn new(input: String) -> Self {
Self::with_limits(input, Limits::default())
}
pub fn with_limits(input: String, limits: Limits) -> Self {
let composer = BasicComposer::new_eager_with_limits(input, limits.clone());
let position = Position::start();
Self {
composer,
position,
limits,
}
}
pub fn from_composer(composer: BasicComposer) -> Self {
let position = Position::start();
let limits = Limits::default();
Self {
composer,
position,
limits,
}
}
pub fn from_composer_with_limits(composer: BasicComposer, limits: Limits) -> Self {
let position = Position::start();
Self {
composer,
position,
limits,
}
}
fn validate_value(&self, value: Value) -> Result<Value> {
match value {
Value::Null | Value::Bool(_) | Value::Int(_) | Value::Float(_) | Value::String(_) => {
Ok(value)
}
Value::Sequence(seq) => {
if seq.len() > self.limits.max_collection_size {
return Err(Error::limit_exceeded(format!(
"Sequence size {} exceeds max_collection_size limit of {}",
seq.len(),
self.limits.max_collection_size
)));
}
let mut safe_seq = Vec::with_capacity(seq.len());
for item in seq {
safe_seq.push(self.validate_value(item)?);
}
Ok(Value::Sequence(safe_seq))
}
Value::Mapping(map) => {
if map.len() > self.limits.max_collection_size {
return Err(Error::limit_exceeded(format!(
"Mapping size {} exceeds max_collection_size limit of {}",
map.len(),
self.limits.max_collection_size
)));
}
let mut safe_map = indexmap::IndexMap::new();
for (key, val) in map {
let safe_key = self.validate_value(key)?;
let safe_val = self.validate_value(val)?;
safe_map.insert(safe_key, safe_val);
}
Ok(Value::Mapping(safe_map))
}
}
}
fn apply_safety_rules(&self, value: Value) -> Result<Value> {
match value {
Value::String(ref s) if s.len() > self.limits.max_string_length => {
Err(Error::limit_exceeded(format!(
"String too long: {} bytes (max: {})",
s.len(),
self.limits.max_string_length
)))
}
Value::Sequence(ref seq) if seq.len() > self.limits.max_collection_size => {
Err(Error::limit_exceeded(format!(
"Sequence too long: {} elements (max: {})",
seq.len(),
self.limits.max_collection_size
)))
}
Value::Mapping(ref map) if map.len() > self.limits.max_collection_size => {
Err(Error::limit_exceeded(format!(
"Mapping too large: {} entries (max: {})",
map.len(),
self.limits.max_collection_size
)))
}
Value::Sequence(seq) => {
let mut safe_seq = Vec::with_capacity(seq.len());
for item in seq {
safe_seq.push(self.apply_safety_rules(item)?);
}
Ok(Value::Sequence(safe_seq))
}
Value::Mapping(map) => {
let mut safe_map = indexmap::IndexMap::new();
for (key, val) in map {
let safe_key = self.apply_safety_rules(key)?;
let safe_val = self.apply_safety_rules(val)?;
safe_map.insert(safe_key, safe_val);
}
Ok(Value::Mapping(safe_map))
}
_ => Ok(value),
}
}
}
impl Default for SafeConstructor {
fn default() -> Self {
Self::new(String::new())
}
}
impl Constructor for SafeConstructor {
fn construct(&mut self) -> Result<Option<Value>> {
let document = match self.composer.compose_document()? {
Some(doc) => doc,
None => return Ok(None),
};
let validated = self.validate_value(document)?;
let safe_value = self.apply_safety_rules(validated)?;
Ok(Some(safe_value))
}
fn check_data(&self) -> bool {
self.composer.check_document()
}
fn reset(&mut self) {
self.composer.reset();
self.position = Position::start();
}
}
#[derive(Debug)]
pub struct RoundTripConstructor {
composer: CommentPreservingComposer,
position: Position,
limits: Limits,
}
impl RoundTripConstructor {
pub fn new(input: String) -> Self {
Self::with_limits(input, Limits::default())
}
pub fn with_limits(input: String, limits: Limits) -> Self {
let composer = CommentPreservingComposer::with_limits(input, limits.clone());
let position = Position::start();
Self {
composer,
position,
limits,
}
}
fn parse_with_comments(&mut self) -> Result<Option<CommentedValue>> {
self.composer.compose_document()
}
}
impl CommentPreservingConstructor for RoundTripConstructor {
fn construct_commented(&mut self) -> Result<Option<CommentedValue>> {
self.parse_with_comments()
}
fn check_data(&self) -> bool {
true
}
fn reset(&mut self) {
self.position = Position::start();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_safe_scalar_construction() {
let mut constructor = SafeConstructor::new("42".to_string());
let result = constructor.construct().unwrap().unwrap();
assert_eq!(result, Value::Int(42));
}
#[test]
fn test_safe_sequence_construction() {
let mut constructor = SafeConstructor::new("[1, 2, 3]".to_string());
let result = constructor.construct().unwrap().unwrap();
let expected = Value::Sequence(vec![Value::Int(1), Value::Int(2), Value::Int(3)]);
assert_eq!(result, expected);
}
#[test]
fn test_safe_mapping_construction() {
let mut constructor = SafeConstructor::new("{'key': 'value'}".to_string());
let result = constructor.construct().unwrap().unwrap();
let mut expected_map = indexmap::IndexMap::new();
expected_map.insert(
Value::String("key".to_string()),
Value::String("value".to_string()),
);
let expected = Value::Mapping(expected_map);
assert_eq!(result, expected);
}
#[test]
fn test_nested_construction() {
let yaml_content = "{'users': [{'name': 'Alice', 'age': 30}]}";
let mut constructor = SafeConstructor::new(yaml_content.to_string());
let result = constructor.construct().unwrap().unwrap();
if let Value::Mapping(map) = result {
if let Some(Value::Sequence(users)) = map.get(&Value::String("users".to_string())) {
assert_eq!(users.len(), 1);
if let Value::Mapping(ref user) = users[0] {
assert_eq!(
user.get(&Value::String("name".to_string())),
Some(&Value::String("Alice".to_string()))
);
assert_eq!(
user.get(&Value::String("age".to_string())),
Some(&Value::Int(30))
);
}
}
} else {
panic!("Expected mapping");
}
}
#[test]
fn test_check_data() {
let constructor = SafeConstructor::new("42".to_string());
assert!(constructor.check_data());
}
#[test]
fn test_multiple_types() {
let yaml_content = "{'string': 'hello', 'int': 42, 'bool': true, 'null_key': null}";
let mut constructor = SafeConstructor::new(yaml_content.to_string());
let result = constructor.construct().unwrap().unwrap();
if let Value::Mapping(map) = result {
assert_eq!(
map.get(&Value::String("string".to_string())),
Some(&Value::String("hello".to_string()))
);
assert_eq!(
map.get(&Value::String("int".to_string())),
Some(&Value::Int(42))
);
assert_eq!(
map.get(&Value::String("bool".to_string())),
Some(&Value::Bool(true))
);
assert_eq!(
map.get(&Value::String("null_key".to_string())),
Some(&Value::Null)
);
} else {
panic!("Expected mapping");
}
}
#[test]
fn test_safety_limits() {
let large_string = "a".repeat(1000); let yaml_content = format!("value: '{}'", large_string);
let mut constructor = SafeConstructor::new(yaml_content);
let result = constructor.construct();
match result {
Ok(Some(value)) => {
if let Value::Mapping(map) = value {
if let Some(Value::String(s)) = map.get(&Value::String("value".to_string())) {
assert_eq!(s.len(), 1000);
}
}
}
Ok(None) => {
}
Err(error) => {
assert!(!error.to_string().is_empty());
}
}
}
#[test]
fn test_boolean_values() {
let test_cases = vec![
("true", true),
("false", false),
("yes", true),
("no", false),
("on", true),
("off", false),
];
for (input, expected) in test_cases {
let mut constructor = SafeConstructor::new(input.to_string());
let result = constructor.construct().unwrap().unwrap();
assert_eq!(result, Value::Bool(expected), "Failed for input: {}", input);
}
}
#[test]
fn test_null_values() {
let test_cases = vec!["null", "~"];
for input in test_cases {
let mut constructor = SafeConstructor::new(input.to_string());
let result = constructor.construct().unwrap().unwrap();
assert_eq!(result, Value::Null, "Failed for input: {}", input);
}
}
}