insta/
redaction.rs

1use pest::Parser;
2use pest_derive::Parser;
3use std::borrow::Cow;
4use std::fmt;
5
6use crate::content::Content;
7
8#[derive(Debug)]
9pub struct SelectorParseError(Box<pest::error::Error<Rule>>);
10
11impl SelectorParseError {
12    /// Return the column of where the error occurred.
13    pub fn column(&self) -> usize {
14        match self.0.line_col {
15            pest::error::LineColLocation::Pos((_, col)) => col,
16            pest::error::LineColLocation::Span((_, col), _) => col,
17        }
18    }
19}
20
21/// Represents a path for a callback function.
22///
23/// This can be converted into a string with `to_string` to see a stringified
24/// path that the selector matched.
25#[derive(Clone, Debug)]
26#[cfg_attr(docsrs, doc(cfg(feature = "redactions")))]
27pub struct ContentPath<'a>(&'a [PathItem]);
28
29impl fmt::Display for ContentPath<'_> {
30    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
31        for item in self.0.iter() {
32            write!(f, ".")?;
33            match *item {
34                PathItem::Content(ref ctx) => {
35                    if let Some(s) = ctx.as_str() {
36                        write!(f, "{s}")?;
37                    } else {
38                        write!(f, "<content>")?;
39                    }
40                }
41                PathItem::Field(name) => write!(f, "{name}")?,
42                PathItem::Index(idx, _) => write!(f, "{idx}")?,
43            }
44        }
45        Ok(())
46    }
47}
48
49/// Replaces a value with another one.
50///
51/// Represents a redaction.
52#[cfg_attr(docsrs, doc(cfg(feature = "redactions")))]
53pub enum Redaction {
54    /// Static redaction with new content.
55    Static(Content),
56    /// Redaction with new content.
57    Dynamic(Box<dyn Fn(Content, ContentPath<'_>) -> Content + Sync + Send>),
58}
59
60macro_rules! impl_from {
61    ($ty:ty) => {
62        impl From<$ty> for Redaction {
63            fn from(value: $ty) -> Redaction {
64                Redaction::Static(Content::from(value))
65            }
66        }
67    };
68}
69
70impl_from!(());
71impl_from!(bool);
72impl_from!(u8);
73impl_from!(u16);
74impl_from!(u32);
75impl_from!(u64);
76impl_from!(i8);
77impl_from!(i16);
78impl_from!(i32);
79impl_from!(i64);
80impl_from!(f32);
81impl_from!(f64);
82impl_from!(char);
83impl_from!(String);
84impl_from!(Vec<u8>);
85
86impl<'a> From<&'a str> for Redaction {
87    fn from(value: &'a str) -> Redaction {
88        Redaction::Static(Content::from(value))
89    }
90}
91
92impl<'a> From<&'a [u8]> for Redaction {
93    fn from(value: &'a [u8]) -> Redaction {
94        Redaction::Static(Content::from(value))
95    }
96}
97
98/// Creates a dynamic redaction.
99///
100/// This can be used to redact a value with a different value but instead of
101/// statically declaring it a dynamic value can be computed.  This can also
102/// be used to perform assertions before replacing the value.
103///
104/// The closure is passed two arguments: the value as [`Content`]
105/// and the path that was selected (as [`ContentPath`])
106///
107/// Example:
108///
109/// ```rust
110/// # use insta::{Settings, dynamic_redaction};
111/// # let mut settings = Settings::new();
112/// settings.add_redaction(".id", dynamic_redaction(|value, path| {
113///     assert_eq!(path.to_string(), ".id");
114///     assert_eq!(
115///         value
116///             .as_str()
117///             .unwrap()
118///             .chars()
119///             .filter(|&c| c == '-')
120///             .count(),
121///         4
122///     );
123///     "[uuid]"
124/// }));
125/// ```
126#[cfg_attr(docsrs, doc(cfg(feature = "redactions")))]
127pub fn dynamic_redaction<I, F>(func: F) -> Redaction
128where
129    I: Into<Content>,
130    F: Fn(Content, ContentPath<'_>) -> I + Send + Sync + 'static,
131{
132    Redaction::Dynamic(Box::new(move |c, p| func(c, p).into()))
133}
134
135/// Creates a dynamic redaction that sorts the value at the selector.
136///
137/// This is useful to force something like a set or map to be ordered to make
138/// it deterministic.  This is necessary as insta's serialization support is
139/// based on [`serde`] which does not have native set support.  As a result vectors
140/// (which need to retain order) and sets (which should be given a stable order)
141/// look the same.
142///
143/// ```rust
144/// # use insta::{Settings, sorted_redaction};
145/// # let mut settings = Settings::new();
146/// settings.add_redaction(".flags", sorted_redaction());
147/// ```
148#[cfg_attr(docsrs, doc(cfg(feature = "redactions")))]
149pub fn sorted_redaction() -> Redaction {
150    fn sort(mut value: Content, _path: ContentPath) -> Content {
151        match value.resolve_inner_mut() {
152            Content::Seq(ref mut val) => {
153                val.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
154            }
155            Content::Map(ref mut val) => {
156                val.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
157            }
158            Content::Struct(_, ref mut fields)
159            | Content::StructVariant(_, _, _, ref mut fields) => {
160                fields.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
161            }
162            _ => {}
163        }
164        value
165    }
166    dynamic_redaction(sort)
167}
168
169/// Creates a redaction that rounds floating point numbers to a given
170/// number of decimal places.
171///
172/// ```rust
173/// # use insta::{Settings, rounded_redaction};
174/// # let mut settings = Settings::new();
175/// settings.add_redaction(".sum", rounded_redaction(2));
176/// ```
177#[cfg_attr(docsrs, doc(cfg(feature = "redactions")))]
178pub fn rounded_redaction(decimals: usize) -> Redaction {
179    dynamic_redaction(move |value: Content, _path: ContentPath| -> Content {
180        let f = match value.resolve_inner() {
181            Content::F32(f) => *f as f64,
182            Content::F64(f) => *f,
183            _ => return value,
184        };
185        let x = 10f64.powf(decimals as f64);
186        Content::F64((f * x).round() / x)
187    })
188}
189
190impl Redaction {
191    /// Performs the redaction of the value at the given path.
192    fn redact(&self, value: Content, path: &[PathItem]) -> Content {
193        match *self {
194            Redaction::Static(ref new_val) => new_val.clone(),
195            Redaction::Dynamic(ref callback) => callback(value, ContentPath(path)),
196        }
197    }
198}
199
200#[derive(Parser)]
201#[grammar = "select_grammar.pest"]
202pub struct SelectParser;
203
204#[derive(Debug)]
205pub enum PathItem {
206    Content(Content),
207    Field(&'static str),
208    Index(u64, u64),
209}
210
211impl PathItem {
212    fn as_str(&self) -> Option<&str> {
213        match *self {
214            PathItem::Content(ref content) => content.as_str(),
215            PathItem::Field(s) => Some(s),
216            PathItem::Index(..) => None,
217        }
218    }
219
220    fn as_u64(&self) -> Option<u64> {
221        match *self {
222            PathItem::Content(ref content) => content.as_u64(),
223            PathItem::Field(_) => None,
224            PathItem::Index(idx, _) => Some(idx),
225        }
226    }
227
228    fn range_check(&self, start: Option<i64>, end: Option<i64>) -> bool {
229        fn expand_range(sel: i64, len: i64) -> i64 {
230            if sel < 0 {
231                (len + sel).max(0)
232            } else {
233                sel
234            }
235        }
236        let (idx, len) = match *self {
237            PathItem::Index(idx, len) => (idx as i64, len as i64),
238            _ => return false,
239        };
240        match (start, end) {
241            (None, None) => true,
242            (None, Some(end)) => idx < expand_range(end, len),
243            (Some(start), None) => idx >= expand_range(start, len),
244            (Some(start), Some(end)) => {
245                idx >= expand_range(start, len) && idx < expand_range(end, len)
246            }
247        }
248    }
249}
250
251#[derive(Debug, Clone, PartialEq, Eq)]
252pub enum Segment<'a> {
253    DeepWildcard,
254    Wildcard,
255    Key(Cow<'a, str>),
256    Index(u64),
257    Range(Option<i64>, Option<i64>),
258}
259
260#[derive(Debug, Clone)]
261pub struct Selector<'a> {
262    selectors: Vec<Vec<Segment<'a>>>,
263}
264
265impl<'a> Selector<'a> {
266    pub fn parse(selector: &'a str) -> Result<Selector<'a>, SelectorParseError> {
267        let pair = SelectParser::parse(Rule::selectors, selector)
268            .map_err(Box::new)
269            .map_err(SelectorParseError)?
270            .next()
271            .unwrap();
272        let mut rv = vec![];
273
274        for selector_pair in pair.into_inner() {
275            match selector_pair.as_rule() {
276                Rule::EOI => break,
277                other => assert_eq!(other, Rule::selector),
278            }
279            let mut segments = vec![];
280            let mut have_deep_wildcard = false;
281            for segment_pair in selector_pair.into_inner() {
282                segments.push(match segment_pair.as_rule() {
283                    Rule::identity => continue,
284                    Rule::wildcard => Segment::Wildcard,
285                    Rule::deep_wildcard => {
286                        if have_deep_wildcard {
287                            return Err(SelectorParseError(Box::new(
288                                pest::error::Error::new_from_span(
289                                    pest::error::ErrorVariant::CustomError {
290                                        message: "deep wildcard used twice".into(),
291                                    },
292                                    segment_pair.as_span(),
293                                ),
294                            )));
295                        }
296                        have_deep_wildcard = true;
297                        Segment::DeepWildcard
298                    }
299                    Rule::key => Segment::Key(Cow::Borrowed(&segment_pair.as_str()[1..])),
300                    Rule::subscript => {
301                        let subscript_rule = segment_pair.into_inner().next().unwrap();
302                        match subscript_rule.as_rule() {
303                            Rule::int => Segment::Index(subscript_rule.as_str().parse().unwrap()),
304                            Rule::string => {
305                                let sq = subscript_rule.as_str();
306                                let s = &sq[1..sq.len() - 1];
307                                let mut was_backslash = false;
308                                Segment::Key(if s.bytes().any(|x| x == b'\\') {
309                                    Cow::Owned(
310                                        s.chars()
311                                            .filter_map(|c| {
312                                                let rv = match c {
313                                                    '\\' if !was_backslash => {
314                                                        was_backslash = true;
315                                                        return None;
316                                                    }
317                                                    other => other,
318                                                };
319                                                was_backslash = false;
320                                                Some(rv)
321                                            })
322                                            .collect(),
323                                    )
324                                } else {
325                                    Cow::Borrowed(s)
326                                })
327                            }
328                            _ => unreachable!(),
329                        }
330                    }
331                    Rule::full_range => Segment::Range(None, None),
332                    Rule::range => {
333                        let mut int_rule = segment_pair
334                            .into_inner()
335                            .map(|x| x.as_str().parse().unwrap());
336                        Segment::Range(int_rule.next(), int_rule.next())
337                    }
338                    Rule::range_to => {
339                        let int_rule = segment_pair.into_inner().next().unwrap();
340                        Segment::Range(None, int_rule.as_str().parse().ok())
341                    }
342                    Rule::range_from => {
343                        let int_rule = segment_pair.into_inner().next().unwrap();
344                        Segment::Range(int_rule.as_str().parse().ok(), None)
345                    }
346                    _ => unreachable!(),
347                });
348            }
349            rv.push(segments);
350        }
351
352        Ok(Selector { selectors: rv })
353    }
354
355    pub fn make_static(self) -> Selector<'static> {
356        Selector {
357            selectors: self
358                .selectors
359                .into_iter()
360                .map(|parts| {
361                    parts
362                        .into_iter()
363                        .map(|x| match x {
364                            Segment::Key(x) => Segment::Key(Cow::Owned(x.into_owned())),
365                            Segment::Index(x) => Segment::Index(x),
366                            Segment::Wildcard => Segment::Wildcard,
367                            Segment::DeepWildcard => Segment::DeepWildcard,
368                            Segment::Range(a, b) => Segment::Range(a, b),
369                        })
370                        .collect()
371                })
372                .collect(),
373        }
374    }
375
376    fn segment_is_match(&self, segment: &Segment, element: &PathItem) -> bool {
377        match *segment {
378            Segment::Wildcard => true,
379            Segment::DeepWildcard => true,
380            Segment::Key(ref k) => element.as_str() == Some(k),
381            Segment::Index(i) => element.as_u64() == Some(i),
382            Segment::Range(start, end) => element.range_check(start, end),
383        }
384    }
385
386    fn selector_is_match(&self, selector: &[Segment], path: &[PathItem]) -> bool {
387        if let Some(idx) = selector.iter().position(|x| *x == Segment::DeepWildcard) {
388            let forward_sel = &selector[..idx];
389            let backward_sel = &selector[idx + 1..];
390
391            if path.len() <= idx {
392                return false;
393            }
394
395            for (segment, element) in forward_sel.iter().zip(path.iter()) {
396                if !self.segment_is_match(segment, element) {
397                    return false;
398                }
399            }
400
401            for (segment, element) in backward_sel.iter().rev().zip(path.iter().rev()) {
402                if !self.segment_is_match(segment, element) {
403                    return false;
404                }
405            }
406
407            true
408        } else {
409            if selector.len() != path.len() {
410                return false;
411            }
412            for (segment, element) in selector.iter().zip(path.iter()) {
413                if !self.segment_is_match(segment, element) {
414                    return false;
415                }
416            }
417            true
418        }
419    }
420
421    pub fn is_match(&self, path: &[PathItem]) -> bool {
422        for selector in &self.selectors {
423            if self.selector_is_match(selector, path) {
424                return true;
425            }
426        }
427        false
428    }
429
430    pub fn redact(&self, value: Content, redaction: &Redaction) -> Content {
431        self.redact_impl(value, redaction, &mut vec![])
432    }
433
434    fn redact_seq(
435        &self,
436        seq: Vec<Content>,
437        redaction: &Redaction,
438        path: &mut Vec<PathItem>,
439    ) -> Vec<Content> {
440        let len = seq.len();
441        seq.into_iter()
442            .enumerate()
443            .map(|(idx, value)| {
444                path.push(PathItem::Index(idx as u64, len as u64));
445                let new_value = self.redact_impl(value, redaction, path);
446                path.pop();
447                new_value
448            })
449            .collect()
450    }
451
452    fn redact_struct(
453        &self,
454        seq: Vec<(&'static str, Content)>,
455        redaction: &Redaction,
456        path: &mut Vec<PathItem>,
457    ) -> Vec<(&'static str, Content)> {
458        seq.into_iter()
459            .map(|(key, value)| {
460                path.push(PathItem::Field(key));
461                let new_value = self.redact_impl(value, redaction, path);
462                path.pop();
463                (key, new_value)
464            })
465            .collect()
466    }
467
468    fn redact_impl(
469        &self,
470        value: Content,
471        redaction: &Redaction,
472        path: &mut Vec<PathItem>,
473    ) -> Content {
474        if self.is_match(path) {
475            redaction.redact(value, path)
476        } else {
477            match value {
478                Content::Map(map) => Content::Map(
479                    map.into_iter()
480                        .map(|(key, value)| {
481                            path.push(PathItem::Field("$key"));
482                            let new_key = self.redact_impl(key.clone(), redaction, path);
483                            path.pop();
484
485                            path.push(PathItem::Content(key));
486                            let new_value = self.redact_impl(value, redaction, path);
487                            path.pop();
488
489                            (new_key, new_value)
490                        })
491                        .collect(),
492                ),
493                Content::Seq(seq) => Content::Seq(self.redact_seq(seq, redaction, path)),
494                Content::Tuple(seq) => Content::Tuple(self.redact_seq(seq, redaction, path)),
495                Content::TupleStruct(name, seq) => {
496                    Content::TupleStruct(name, self.redact_seq(seq, redaction, path))
497                }
498                Content::TupleVariant(name, variant_index, variant, seq) => Content::TupleVariant(
499                    name,
500                    variant_index,
501                    variant,
502                    self.redact_seq(seq, redaction, path),
503                ),
504                Content::Struct(name, seq) => {
505                    Content::Struct(name, self.redact_struct(seq, redaction, path))
506                }
507                Content::StructVariant(name, variant_index, variant, seq) => {
508                    Content::StructVariant(
509                        name,
510                        variant_index,
511                        variant,
512                        self.redact_struct(seq, redaction, path),
513                    )
514                }
515                Content::NewtypeStruct(name, inner) => Content::NewtypeStruct(
516                    name,
517                    Box::new(self.redact_impl(*inner, redaction, path)),
518                ),
519                Content::NewtypeVariant(name, index, variant_name, inner) => {
520                    Content::NewtypeVariant(
521                        name,
522                        index,
523                        variant_name,
524                        Box::new(self.redact_impl(*inner, redaction, path)),
525                    )
526                }
527                Content::Some(contents) => {
528                    Content::Some(Box::new(self.redact_impl(*contents, redaction, path)))
529                }
530                other => other,
531            }
532        }
533    }
534}
535
536#[test]
537fn test_range_checks() {
538    use similar_asserts::assert_eq;
539    assert_eq!(PathItem::Index(0, 10).range_check(None, Some(-1)), true);
540    assert_eq!(PathItem::Index(9, 10).range_check(None, Some(-1)), false);
541    assert_eq!(PathItem::Index(0, 10).range_check(Some(1), Some(-1)), false);
542    assert_eq!(PathItem::Index(1, 10).range_check(Some(1), Some(-1)), true);
543    assert_eq!(PathItem::Index(9, 10).range_check(Some(1), Some(-1)), false);
544    assert_eq!(PathItem::Index(0, 10).range_check(Some(1), None), false);
545    assert_eq!(PathItem::Index(1, 10).range_check(Some(1), None), true);
546    assert_eq!(PathItem::Index(9, 10).range_check(Some(1), None), true);
547}