rust_yaml/
constructor.rs

1//! YAML constructor for building Rust objects
2
3use crate::{BasicComposer, Composer, Error, Limits, Position, Result, Value};
4
5/// Trait for YAML constructors that convert document nodes to Rust objects
6pub trait Constructor {
7    /// Construct a single value
8    fn construct(&mut self) -> Result<Option<Value>>;
9
10    /// Check if there are more values to construct
11    fn check_data(&self) -> bool;
12
13    /// Reset the constructor state
14    fn reset(&mut self);
15}
16
17/// Safe constructor that only constructs basic YAML types
18#[derive(Debug)]
19pub struct SafeConstructor {
20    composer: BasicComposer,
21    position: Position,
22    limits: Limits,
23}
24
25impl SafeConstructor {
26    /// Create a new safe constructor with input text
27    pub fn new(input: String) -> Self {
28        Self::with_limits(input, Limits::default())
29    }
30
31    /// Create a new safe constructor with custom limits
32    pub fn with_limits(input: String, limits: Limits) -> Self {
33        // Use eager composer for better anchor/alias support
34        let composer = BasicComposer::new_eager_with_limits(input, limits.clone());
35        let position = Position::start();
36
37        Self {
38            composer,
39            position,
40            limits,
41        }
42    }
43
44    /// Create constructor from existing composer
45    pub fn from_composer(composer: BasicComposer) -> Self {
46        let position = Position::start();
47        let limits = Limits::default();
48
49        Self {
50            composer,
51            position,
52            limits,
53        }
54    }
55
56    /// Create constructor from existing composer with custom limits
57    pub fn from_composer_with_limits(composer: BasicComposer, limits: Limits) -> Self {
58        let position = Position::start();
59
60        Self {
61            composer,
62            position,
63            limits,
64        }
65    }
66
67    /// Validate and potentially transform a value for safety
68    fn validate_value(&self, value: Value) -> Result<Value> {
69        match value {
70            // Basic scalar types are always safe
71            Value::Null | Value::Bool(_) | Value::Int(_) | Value::Float(_) | Value::String(_) => {
72                Ok(value)
73            }
74
75            // Sequences are safe if all elements are safe
76            Value::Sequence(seq) => {
77                // Check collection size limit
78                if seq.len() > self.limits.max_collection_size {
79                    return Err(Error::limit_exceeded(format!(
80                        "Sequence size {} exceeds max_collection_size limit of {}",
81                        seq.len(),
82                        self.limits.max_collection_size
83                    )));
84                }
85                let mut safe_seq = Vec::with_capacity(seq.len());
86                for item in seq {
87                    safe_seq.push(self.validate_value(item)?);
88                }
89                Ok(Value::Sequence(safe_seq))
90            }
91
92            // Mappings are safe if all keys and values are safe
93            Value::Mapping(map) => {
94                // Check collection size limit
95                if map.len() > self.limits.max_collection_size {
96                    return Err(Error::limit_exceeded(format!(
97                        "Mapping size {} exceeds max_collection_size limit of {}",
98                        map.len(),
99                        self.limits.max_collection_size
100                    )));
101                }
102                let mut safe_map = indexmap::IndexMap::new();
103                for (key, val) in map {
104                    let safe_key = self.validate_value(key)?;
105                    let safe_val = self.validate_value(val)?;
106                    safe_map.insert(safe_key, safe_val);
107                }
108                Ok(Value::Mapping(safe_map))
109            }
110        }
111    }
112
113    /// Apply additional safety checks and transformations
114    fn apply_safety_rules(&self, value: Value) -> Result<Value> {
115        match value {
116            // Limit string length to prevent memory exhaustion
117            Value::String(ref s) if s.len() > self.limits.max_string_length => {
118                Err(Error::limit_exceeded(format!(
119                    "String too long: {} bytes (max: {})",
120                    s.len(),
121                    self.limits.max_string_length
122                )))
123            }
124
125            // Limit sequence length
126            Value::Sequence(ref seq) if seq.len() > self.limits.max_collection_size => {
127                Err(Error::limit_exceeded(format!(
128                    "Sequence too long: {} elements (max: {})",
129                    seq.len(),
130                    self.limits.max_collection_size
131                )))
132            }
133
134            // Limit mapping size
135            Value::Mapping(ref map) if map.len() > self.limits.max_collection_size => {
136                Err(Error::limit_exceeded(format!(
137                    "Mapping too large: {} entries (max: {})",
138                    map.len(),
139                    self.limits.max_collection_size
140                )))
141            }
142
143            // Recursively apply rules
144            Value::Sequence(seq) => {
145                let mut safe_seq = Vec::with_capacity(seq.len());
146                for item in seq {
147                    safe_seq.push(self.apply_safety_rules(item)?);
148                }
149                Ok(Value::Sequence(safe_seq))
150            }
151
152            Value::Mapping(map) => {
153                let mut safe_map = indexmap::IndexMap::new();
154                for (key, val) in map {
155                    let safe_key = self.apply_safety_rules(key)?;
156                    let safe_val = self.apply_safety_rules(val)?;
157                    safe_map.insert(safe_key, safe_val);
158                }
159                Ok(Value::Mapping(safe_map))
160            }
161
162            // Other types are fine as-is
163            _ => Ok(value),
164        }
165    }
166}
167
168impl Default for SafeConstructor {
169    fn default() -> Self {
170        Self::new(String::new())
171    }
172}
173
174impl Constructor for SafeConstructor {
175    fn construct(&mut self) -> Result<Option<Value>> {
176        // Get a document from the composer
177        let document = match self.composer.compose_document()? {
178            Some(doc) => doc,
179            None => return Ok(None),
180        };
181
182        // Validate and apply safety rules
183        let validated = self.validate_value(document)?;
184        let safe_value = self.apply_safety_rules(validated)?;
185
186        Ok(Some(safe_value))
187    }
188
189    fn check_data(&self) -> bool {
190        self.composer.check_document()
191    }
192
193    fn reset(&mut self) {
194        self.composer.reset();
195        self.position = Position::start();
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_safe_scalar_construction() {
205        let mut constructor = SafeConstructor::new("42".to_string());
206        let result = constructor.construct().unwrap().unwrap();
207        assert_eq!(result, Value::Int(42));
208    }
209
210    #[test]
211    fn test_safe_sequence_construction() {
212        let mut constructor = SafeConstructor::new("[1, 2, 3]".to_string());
213        let result = constructor.construct().unwrap().unwrap();
214
215        let expected = Value::Sequence(vec![Value::Int(1), Value::Int(2), Value::Int(3)]);
216        assert_eq!(result, expected);
217    }
218
219    #[test]
220    fn test_safe_mapping_construction() {
221        let mut constructor = SafeConstructor::new("{'key': 'value'}".to_string());
222        let result = constructor.construct().unwrap().unwrap();
223
224        let mut expected_map = indexmap::IndexMap::new();
225        expected_map.insert(
226            Value::String("key".to_string()),
227            Value::String("value".to_string()),
228        );
229        let expected = Value::Mapping(expected_map);
230
231        assert_eq!(result, expected);
232    }
233
234    #[test]
235    fn test_nested_construction() {
236        let yaml_content = "{'users': [{'name': 'Alice', 'age': 30}]}";
237        let mut constructor = SafeConstructor::new(yaml_content.to_string());
238        let result = constructor.construct().unwrap().unwrap();
239
240        if let Value::Mapping(map) = result {
241            if let Some(Value::Sequence(users)) = map.get(&Value::String("users".to_string())) {
242                assert_eq!(users.len(), 1);
243                if let Value::Mapping(ref user) = users[0] {
244                    assert_eq!(
245                        user.get(&Value::String("name".to_string())),
246                        Some(&Value::String("Alice".to_string()))
247                    );
248                    assert_eq!(
249                        user.get(&Value::String("age".to_string())),
250                        Some(&Value::Int(30))
251                    );
252                }
253            }
254        } else {
255            panic!("Expected mapping");
256        }
257    }
258
259    #[test]
260    fn test_check_data() {
261        let constructor = SafeConstructor::new("42".to_string());
262        assert!(constructor.check_data());
263    }
264
265    #[test]
266    fn test_multiple_types() {
267        let yaml_content = "{'string': 'hello', 'int': 42, 'bool': true, 'null_key': null}";
268        let mut constructor = SafeConstructor::new(yaml_content.to_string());
269        let result = constructor.construct().unwrap().unwrap();
270
271        if let Value::Mapping(map) = result {
272            assert_eq!(
273                map.get(&Value::String("string".to_string())),
274                Some(&Value::String("hello".to_string()))
275            );
276            assert_eq!(
277                map.get(&Value::String("int".to_string())),
278                Some(&Value::Int(42))
279            );
280            assert_eq!(
281                map.get(&Value::String("bool".to_string())),
282                Some(&Value::Bool(true))
283            );
284            // The key is "null_key" (string) and the value should be null (Null type)
285            assert_eq!(
286                map.get(&Value::String("null_key".to_string())),
287                Some(&Value::Null)
288            );
289        } else {
290            panic!("Expected mapping");
291        }
292    }
293
294    #[test]
295    fn test_safety_limits() {
296        // Test with a reasonable size that shouldn't cause timeouts
297        let large_string = "a".repeat(1000); // Much smaller size for testing
298        let yaml_content = format!("value: '{}'", large_string);
299        let mut constructor = SafeConstructor::new(yaml_content);
300
301        let result = constructor.construct();
302        // This should succeed with a reasonable size
303        match result {
304            Ok(Some(value)) => {
305                // Should get a mapping with a string value
306                if let Value::Mapping(map) = value {
307                    if let Some(Value::String(s)) = map.get(&Value::String("value".to_string())) {
308                        assert_eq!(s.len(), 1000);
309                    }
310                }
311            }
312            Ok(None) => {
313                // Empty document is also acceptable
314            }
315            Err(error) => {
316                // If it fails, just ensure we have a meaningful error
317                assert!(!error.to_string().is_empty());
318            }
319        }
320    }
321
322    #[test]
323    fn test_boolean_values() {
324        let test_cases = vec![
325            ("true", true),
326            ("false", false),
327            ("yes", true),
328            ("no", false),
329            ("on", true),
330            ("off", false),
331        ];
332
333        for (input, expected) in test_cases {
334            let mut constructor = SafeConstructor::new(input.to_string());
335            let result = constructor.construct().unwrap().unwrap();
336            assert_eq!(result, Value::Bool(expected), "Failed for input: {}", input);
337        }
338    }
339
340    #[test]
341    fn test_null_values() {
342        let test_cases = vec!["null", "~"];
343
344        for input in test_cases {
345            let mut constructor = SafeConstructor::new(input.to_string());
346            let result = constructor.construct().unwrap().unwrap();
347            assert_eq!(result, Value::Null, "Failed for input: {}", input);
348        }
349    }
350}