Skip to main content

cfn_guard/rules/
values.rs

1use std::{
2    convert::TryFrom,
3    fmt,
4    fmt::Display,
5    hash::{Hash, Hasher},
6};
7
8use indexmap::map::IndexMap;
9use nom::lib::std::fmt::Formatter;
10
11use crate::rules::{
12    errors::{Error, InternalError},
13    libyaml::loader::Loader,
14    parser::Span,
15    path_value::Location,
16    short_form_to_long, SEQUENCE_VALUE_FUNC_REF, SINGLE_VALUE_FUNC_REF,
17};
18
19use serde::{Deserialize, Serialize};
20
21#[derive(Eq, PartialEq, Debug, Clone, Serialize, Deserialize, Hash, Copy)]
22pub enum CmpOperator {
23    Eq,
24    In,
25    Gt,
26    Lt,
27    Le,
28    Ge,
29    Exists,
30    Empty,
31
32    IsString,
33    IsList,
34    IsMap,
35    IsBool,
36    IsInt,
37    IsFloat,
38    IsNull,
39}
40
41impl CmpOperator {
42    pub(crate) fn is_unary(&self) -> bool {
43        matches!(
44            self,
45            CmpOperator::Exists
46                | CmpOperator::Empty
47                | CmpOperator::IsString
48                | CmpOperator::IsBool
49                | CmpOperator::IsList
50                | CmpOperator::IsInt
51                | CmpOperator::IsMap
52                | CmpOperator::IsFloat
53                | CmpOperator::IsNull
54        )
55    }
56}
57
58impl Display for CmpOperator {
59    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
60        match self {
61            CmpOperator::Eq => f.write_str("EQUALS")?,
62            CmpOperator::In => f.write_str("IN")?,
63            CmpOperator::Gt => f.write_str("GREATER THAN")?,
64            CmpOperator::Lt => f.write_str("LESS THAN")?,
65            CmpOperator::Ge => f.write_str("GREATER THAN EQUALS")?,
66            CmpOperator::Le => f.write_str("LESS THAN EQUALS")?,
67            CmpOperator::Exists => f.write_str("EXISTS")?,
68            CmpOperator::Empty => f.write_str("EMPTY")?,
69            CmpOperator::IsString => f.write_str("IS STRING")?,
70            CmpOperator::IsBool => f.write_str("IS BOOL")?,
71            CmpOperator::IsInt => f.write_str("IS INT")?,
72            CmpOperator::IsList => f.write_str("IS LIST")?,
73            CmpOperator::IsMap => f.write_str("IS MAP")?,
74            CmpOperator::IsNull => f.write_str("IS NULL")?,
75            CmpOperator::IsFloat => f.write_str("IS FLOAT")?,
76        }
77        Ok(())
78    }
79}
80
81#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
82pub enum Value {
83    Null,
84    String(String),
85    Regex(String),
86    Bool(bool),
87    Int(i64),
88    Float(f64),
89    Char(char),
90    List(Vec<Value>),
91    Map(indexmap::IndexMap<String, Value>),
92    RangeInt(RangeType<i64>),
93    RangeFloat(RangeType<f64>),
94    RangeChar(RangeType<char>),
95}
96
97impl Hash for Value {
98    fn hash<H: Hasher>(&self, state: &mut H) {
99        match self {
100            Value::String(s) | Value::Regex(s) => {
101                s.hash(state);
102            }
103
104            Value::Char(c) => {
105                c.hash(state);
106            }
107            Value::Int(i) => {
108                i.hash(state);
109            }
110            Value::Null => {
111                "NULL".hash(state);
112            }
113            Value::Float(f) => {
114                (*f as u64).hash(state);
115            }
116
117            Value::RangeChar(r) => {
118                r.lower.hash(state);
119                r.upper.hash(state);
120                r.inclusive.hash(state);
121            }
122
123            Value::RangeInt(r) => {
124                r.lower.hash(state);
125                r.upper.hash(state);
126                r.inclusive.hash(state);
127            }
128
129            Value::RangeFloat(r) => {
130                (r.lower as u64).hash(state);
131                (r.upper as u64).hash(state);
132                r.inclusive.hash(state);
133            }
134
135            Value::Bool(b) => {
136                b.hash(state);
137            }
138
139            Value::List(l) => {
140                for each in l {
141                    each.hash(state);
142                }
143            }
144
145            Value::Map(map) => {
146                for (key, value) in map.iter() {
147                    key.hash(state);
148                    value.hash(state);
149                }
150            }
151        }
152    }
153}
154
155impl Display for Value {
156    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
157        match self {
158            Value::String(s) => write!(f, "\"{}\"", s),
159            Value::Regex(s) => write!(f, "/{}/", s),
160            Value::Int(int) => write!(f, "{}", int),
161            Value::Float(float) => write!(f, "{}", float),
162            Value::Bool(bool) => write!(f, "{}", bool),
163            Value::List(list) => {
164                let result: Vec<String> = list.iter().map(|item| format!("{}", item)).collect();
165                write!(f, "[{}]", result.join(", "))
166            }
167            Value::Map(map) => {
168                let key_values: Vec<String> = map
169                    .into_iter()
170                    .map(|(key, value)| format!("\"{}\": {}", key, value))
171                    .collect();
172                write!(f, "{{{}}}", key_values.join(", "))
173            }
174            Value::Null => {
175                write!(f, "null")
176            }
177            Value::RangeChar(range) => {
178                if (range.inclusive & LOWER_INCLUSIVE) == LOWER_INCLUSIVE {
179                    write!(f, "[")?;
180                } else {
181                    write!(f, "(")?;
182                }
183                write!(f, "{},{}", range.lower, range.upper)?;
184
185                if (range.inclusive & UPPER_INCLUSIVE) == UPPER_INCLUSIVE {
186                    write!(f, "]")
187                } else {
188                    write!(f, ")")
189                }
190            }
191            Value::RangeFloat(range) => {
192                if (range.inclusive & LOWER_INCLUSIVE) == LOWER_INCLUSIVE {
193                    write!(f, "[")?;
194                } else {
195                    write!(f, "(")?;
196                }
197                write!(f, "{},{}", range.lower, range.upper)?;
198
199                if (range.inclusive & UPPER_INCLUSIVE) == UPPER_INCLUSIVE {
200                    write!(f, "]")
201                } else {
202                    write!(f, ")")
203                }
204            }
205            Value::RangeInt(range) => {
206                if (range.inclusive & LOWER_INCLUSIVE) == LOWER_INCLUSIVE {
207                    write!(f, "[")?;
208                } else {
209                    write!(f, "(")?;
210                }
211                write!(f, "{},{}", range.lower, range.upper)?;
212
213                if (range.inclusive & UPPER_INCLUSIVE) == UPPER_INCLUSIVE {
214                    write!(f, "]")
215                } else {
216                    write!(f, ")")
217                }
218            }
219            Value::Char(c) => {
220                write!(f, "\"{}\"", c)
221            }
222        }
223    }
224}
225
226//
227//    .X > 10
228//    .X <= 20
229//
230//    .X in r(10, 20]
231//    .X in r(10, 20)
232#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
233pub struct RangeType<T: PartialOrd> {
234    pub upper: T,
235    pub lower: T,
236    pub inclusive: u8,
237}
238
239pub const LOWER_INCLUSIVE: u8 = 0x01;
240pub const UPPER_INCLUSIVE: u8 = 0x01 << 1;
241
242pub(crate) trait WithinRange<RHS: PartialOrd = Self> {
243    fn is_within(&self, range: &RangeType<RHS>) -> bool;
244}
245
246impl WithinRange for i64 {
247    fn is_within(&self, range: &RangeType<i64>) -> bool {
248        is_within(range, self)
249    }
250}
251
252impl WithinRange for f64 {
253    fn is_within(&self, range: &RangeType<f64>) -> bool {
254        is_within(range, self)
255    }
256}
257
258impl WithinRange for char {
259    fn is_within(&self, range: &RangeType<char>) -> bool {
260        is_within(range, self)
261    }
262}
263
264//impl WithinRange for
265
266fn is_within<T: PartialOrd>(range: &RangeType<T>, other: &T) -> bool {
267    let lower = if (range.inclusive & LOWER_INCLUSIVE) > 0 {
268        range.lower.le(other)
269    } else {
270        range.lower.lt(other)
271    };
272    let upper = if (range.inclusive & UPPER_INCLUSIVE) > 0 {
273        range.upper.ge(other)
274    } else {
275        range.upper.gt(other)
276    };
277    lower && upper
278}
279
280impl<'a> TryFrom<&'a serde_yaml::Value> for Value {
281    type Error = Error;
282
283    fn try_from(value: &'a serde_yaml::Value) -> Result<Self, Self::Error> {
284        match value {
285            serde_yaml::Value::String(s) => Ok(Value::String(s.to_owned())),
286            serde_yaml::Value::Number(num) => {
287                if num.is_i64() {
288                    Ok(Value::Int(num.as_i64().unwrap()))
289                } else if num.is_u64() {
290                    //
291                    // Yes we are losing precision here. TODO fix this
292                    //
293                    Ok(Value::Int(num.as_u64().unwrap() as i64))
294                } else {
295                    Ok(Value::Float(num.as_f64().unwrap()))
296                }
297            }
298            serde_yaml::Value::Bool(b) => Ok(Value::Bool(*b)),
299            serde_yaml::Value::Sequence(sequence) => Ok(Value::List(sequence.iter().try_fold(
300                vec![],
301                |mut res, val| -> Result<Vec<Self>, Self::Error> {
302                    res.push(Value::try_from(val)?);
303                    Ok(res)
304                },
305            )?)),
306            serde_yaml::Value::Mapping(mapping) => Ok(Value::Map(mapping.iter().try_fold(
307                IndexMap::with_capacity(mapping.len()),
308                |mut res, (key, val)| -> Result<IndexMap<String, Self>, Self::Error> {
309                    match key {
310                        serde_yaml::Value::String(key) => {
311                            res.insert(key.to_string(), Value::try_from(val)?);
312                        }
313                        _ => {
314                            // NOTE: can't provide a location for our error here since serde_yaml
315                            // doesn't provide that for us
316                            return Err(Error::InternalError(InternalError::InvalidKeyType(
317                                String::default(),
318                            )));
319                        }
320                    }
321                    Ok(res)
322                },
323            )?)),
324            serde_yaml::Value::Tagged(tag) => {
325                let prefix = tag.tag.to_string();
326                let value = tag.value.clone();
327
328                match prefix.matches('!').count() {
329                    1 => {
330                        let stripped_prefix = prefix.strip_prefix('!').unwrap();
331                        Ok(handle_tagged_value(value, stripped_prefix)?)
332                    }
333                    _ => Ok(Value::try_from(value)?),
334                }
335            }
336            serde_yaml::Value::Null => Ok(Value::Null),
337        }
338    }
339}
340
341impl<'a> TryFrom<&'a serde_json::Value> for Value {
342    type Error = Error;
343
344    fn try_from(value: &'a serde_json::Value) -> Result<Self, Self::Error> {
345        match value {
346            serde_json::Value::String(s) => Ok(Value::String(s.to_owned())),
347            serde_json::Value::Number(num) => {
348                if num.is_i64() {
349                    Ok(Value::Int(num.as_i64().unwrap()))
350                } else if num.is_u64() {
351                    //
352                    // Yes we are losing precision here. TODO fix this
353                    //
354                    Ok(Value::Int(num.as_u64().unwrap() as i64))
355                } else {
356                    Ok(Value::Float(num.as_f64().unwrap()))
357                }
358            }
359            serde_json::Value::Bool(b) => Ok(Value::Bool(*b)),
360            serde_json::Value::Null => Ok(Value::Null),
361            serde_json::Value::Array(v) => {
362                let mut result: Vec<Value> = Vec::with_capacity(v.len());
363                for each in v {
364                    result.push(Value::try_from(each)?)
365                }
366                Ok(Value::List(result))
367            }
368            serde_json::Value::Object(map) => {
369                let mut result = IndexMap::with_capacity(map.len());
370                for (key, value) in map.iter() {
371                    result.insert(key.to_owned(), Value::try_from(value)?);
372                }
373                Ok(Value::Map(result))
374            }
375        }
376    }
377}
378
379impl TryFrom<serde_json::Value> for Value {
380    type Error = Error;
381
382    fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
383        Value::try_from(&value)
384    }
385}
386
387impl TryFrom<serde_yaml::Value> for Value {
388    type Error = Error;
389
390    fn try_from(value: serde_yaml::Value) -> Result<Self, Self::Error> {
391        Value::try_from(&value)
392    }
393}
394
395impl<'a> TryFrom<&'a str> for Value {
396    type Error = Error;
397
398    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
399        Ok(super::parser::parse_value(Span::new_extra(value, ""))?.1)
400    }
401}
402
403#[derive(PartialEq, Debug, Clone)]
404#[allow(dead_code)]
405pub(crate) enum MarkedValue {
406    Null(Location),
407    BadValue(String, Location),
408    String(String, Location),
409    Regex(String, Location),
410    Bool(bool, Location),
411    Int(i64, Location),
412    Float(f64, Location),
413    Char(char, Location),
414    List(Vec<MarkedValue>, Location),
415    Map(
416        indexmap::IndexMap<(String, Location), MarkedValue>,
417        Location,
418    ),
419    RangeInt(RangeType<i64>, Location),
420    RangeFloat(RangeType<f64>, Location),
421    RangeChar(RangeType<char>, Location),
422}
423
424impl MarkedValue {
425    pub(crate) fn location(&self) -> &Location {
426        match self {
427            Self::Null(loc)
428            | Self::BadValue(_, loc)
429            | Self::String(_, loc)
430            | Self::Regex(_, loc)
431            | Self::Bool(_, loc)
432            | Self::Int(_, loc)
433            | Self::Float(_, loc)
434            | Self::Char(_, loc)
435            | Self::List(_, loc)
436            | Self::Map(_, loc)
437            | Self::RangeInt(_, loc)
438            | Self::RangeFloat(_, loc)
439            | Self::RangeChar(_, loc) => loc,
440        }
441    }
442}
443
444pub(crate) fn read_from(from_reader: &str) -> crate::rules::Result<MarkedValue> {
445    let mut loader = Loader::new();
446    match loader.load(from_reader.to_string()) {
447        Ok(doc) => Ok(doc),
448        Err(e) => match e {
449            Error::InternalError(..) => Err(e),
450            _ => Err(Error::ParseError(format!("{}", e))),
451        },
452    }
453}
454
455#[cfg(test)]
456pub(super) fn make_linked_hashmap<'a, I>(values: I) -> IndexMap<String, Value>
457where
458    I: IntoIterator<Item = (&'a str, Value)>,
459{
460    values.into_iter().map(|(s, v)| (s.to_owned(), v)).collect()
461}
462
463fn handle_tagged_value(val: serde_yaml::Value, fn_ref: &str) -> crate::rules::Result<Value> {
464    if SINGLE_VALUE_FUNC_REF.contains(fn_ref) || SEQUENCE_VALUE_FUNC_REF.contains(fn_ref) {
465        let mut map = indexmap::IndexMap::new();
466        let fn_ref = short_form_to_long(fn_ref);
467        map.insert(fn_ref.to_string(), Value::try_from(val)?);
468
469        return Ok(Value::Map(map));
470    }
471
472    Value::try_from(val)
473}
474
475#[cfg(test)]
476#[path = "values_tests.rs"]
477mod values_tests;