github_actions_expressions/
context.rs

1//! Parsing and matching APIs for GitHub Actions expressions
2//! contexts (e.g. `github.event.name`).
3
4use crate::Literal;
5
6use super::Expr;
7
8/// Represents a context in a GitHub Actions expression.
9///
10/// These typically look something like `github.actor` or `inputs.foo`,
11/// although they can also be a "call" context like `fromJSON(...).foo.bar`,
12/// i.e. where the head of the context is a function call rather than an
13/// identifier.
14#[derive(Debug)]
15pub struct Context<'src> {
16    raw: &'src str,
17    /// The individual parts of the context.
18    pub parts: Vec<Expr<'src>>,
19}
20
21impl<'src> Context<'src> {
22    pub(crate) fn new(raw: &'src str, parts: impl Into<Vec<Expr<'src>>>) -> Self {
23        Self {
24            raw,
25            parts: parts.into(),
26        }
27    }
28
29    /// Returns the raw string representation of the context.
30    pub fn as_str(&self) -> &str {
31        self.raw
32    }
33
34    /// Returns whether the context is a child of the given pattern.
35    ///
36    /// A context is considered its own child, i.e. `foo.bar` is a child of
37    /// `foo.bar`.
38    pub fn child_of(&self, parent: impl TryInto<ContextPattern<'src>>) -> bool {
39        let Ok(parent) = parent.try_into() else {
40            return false;
41        };
42
43        parent.parent_of(self)
44    }
45
46    /// Returns the tail of the context if the head matches the given string.
47    pub fn pop_if(&self, head: &str) -> Option<&str> {
48        match self.parts.first()? {
49            Expr::Identifier(ident) if ident == head => Some(self.raw.split_once('.')?.1),
50            _ => None,
51        }
52    }
53
54    /// Returns the "pattern equivalent" of this context.
55    ///
56    /// This is a string that can be used to efficiently match the context,
57    /// such as is done in `zizmor`'s template-injection audit via a
58    /// finite state transducer.
59    ///
60    /// Returns None if the context doesn't have a sensible pattern
61    /// equivalent, e.g. if it starts with a call.
62    pub fn as_pattern(&self) -> Option<String> {
63        fn push_part(part: &Expr<'_>, pattern: &mut String) {
64            match part {
65                Expr::Identifier(ident) => pattern.push_str(ident.0),
66                Expr::Star => pattern.push('*'),
67                Expr::Index(idx) => match idx.as_ref() {
68                    // foo['bar'] -> foo.bar
69                    Expr::Literal(Literal::String(idx)) => pattern.push_str(idx),
70                    // any kind of numeric or computed index, e.g.:
71                    // foo[0], foo[1 + 2], foo[bar]
72                    _ => pattern.push('*'),
73                },
74                _ => unreachable!("unexpected part in context pattern"),
75            }
76        }
77
78        // TODO: Optimization ideas:
79        // 1. Add a happy path for contexts that contain only
80        //    identifiers? Problem: case normalization.
81        // 2. Use `regex-automata` to return a case insensitive
82        //    automation here?
83        let mut pattern = String::with_capacity(self.raw.len());
84
85        let mut parts = self.parts.iter().peekable();
86
87        let head = parts.next()?;
88        if matches!(head, Expr::Call { .. }) {
89            return None;
90        }
91
92        push_part(head, &mut pattern);
93        for part in parts {
94            pattern.push('.');
95            push_part(part, &mut pattern);
96        }
97
98        pattern.make_ascii_lowercase();
99        Some(pattern)
100    }
101}
102
103impl PartialEq for Context<'_> {
104    fn eq(&self, other: &Self) -> bool {
105        self.raw.eq_ignore_ascii_case(other.raw)
106    }
107}
108
109impl PartialEq<str> for Context<'_> {
110    fn eq(&self, other: &str) -> bool {
111        self.raw.eq_ignore_ascii_case(other)
112    }
113}
114
115enum Comparison {
116    Child,
117    Match,
118}
119
120/// A `ContextPattern` is a pattern that matches one or more contexts.
121///
122/// It uses a restricted subset of the syntax used by contexts themselves:
123/// a pattern is always in dotted form and can only contain identifiers
124/// and wildcards.
125///
126/// Indices are not allowed in patterns themselves, although contexts
127/// that contain indices can be matched against patterns. For example,
128/// `github.event.pull_request.assignees.*.name` will match the context
129/// `github.event.pull_request.assignees[0].name`.
130pub struct ContextPattern<'src>(
131    // NOTE: Kept as a string as a potentially premature optimization;
132    // re-parsing should be faster in terms of locality.
133    // TODO: Vec instead?
134    &'src str,
135);
136
137impl<'src> TryFrom<&'src str> for ContextPattern<'src> {
138    type Error = anyhow::Error;
139
140    fn try_from(val: &'src str) -> anyhow::Result<Self> {
141        Self::new(val).ok_or_else(|| anyhow::anyhow!("invalid context pattern"))
142    }
143}
144
145impl<'src> ContextPattern<'src> {
146    /// Creates a new `ContextPattern` from the given string.
147    ///
148    /// Returns `None` if the pattern is invalid.
149    pub fn new(pattern: &'src str) -> Option<Self> {
150        let parts = pattern.split('.');
151        let mut count = 0;
152        for part in parts {
153            if part.is_empty() {
154                return None;
155            }
156
157            match part {
158                "*" => {}
159                // TODO: `bytes()` is probably a little faster.
160                _ if part
161                    .chars()
162                    .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') => {}
163                _ => return None,
164            }
165            count += 1;
166        }
167
168        match count {
169            0 => None,
170            _ => Some(Self(pattern)),
171        }
172    }
173
174    fn compare_part(pattern: &str, part: &Expr<'src>) -> bool {
175        if pattern == "*" {
176            true
177        } else {
178            match part {
179                Expr::Identifier(part) => pattern.eq_ignore_ascii_case(part.0),
180                Expr::Index(part) => match part.as_ref() {
181                    Expr::Literal(Literal::String(part)) => pattern.eq_ignore_ascii_case(part),
182                    _ => false,
183                },
184                _ => false,
185            }
186        }
187    }
188
189    fn compare(&self, ctx: &Context<'src>) -> Option<Comparison> {
190        let mut pattern_parts = self.0.split('.').peekable();
191        let mut ctx_parts = ctx.parts.iter().peekable();
192
193        while let (Some(pattern), Some(part)) = (pattern_parts.peek(), ctx_parts.peek()) {
194            if !Self::compare_part(pattern, part) {
195                return None;
196            }
197
198            pattern_parts.next();
199            ctx_parts.next();
200        }
201
202        match (pattern_parts.next(), ctx_parts.next()) {
203            // If both are exhausted, we have an exact match.
204            (None, None) => Some(Comparison::Match),
205            // If the pattern is exhausted but the context isn't, then
206            // the context is a child of the pattern.
207            (None, Some(_)) => Some(Comparison::Child),
208            _ => None,
209        }
210    }
211
212    /// Returns true if the given context is a child of the pattern.
213    ///
214    /// This is a loose parent-child relationship; for example, `foo` is its
215    /// own parent, as well as the parent of `foo.bar` and `foo.bar.baz`.
216    pub fn parent_of(&self, ctx: &Context<'src>) -> bool {
217        matches!(
218            self.compare(ctx),
219            Some(Comparison::Child | Comparison::Match)
220        )
221    }
222
223    /// Returns true if the given context exactly matches the pattern.
224    ///
225    /// See [`ContextPattern`] for a description of the matching rules.
226    pub fn matches(&self, ctx: &Context<'src>) -> bool {
227        matches!(self.compare(ctx), Some(Comparison::Match))
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use crate::Expr;
234
235    use super::{Context, ContextPattern};
236
237    impl<'a> TryFrom<&'a str> for Context<'a> {
238        type Error = anyhow::Error;
239
240        fn try_from(val: &'a str) -> anyhow::Result<Self> {
241            let expr = Expr::parse(val)?;
242
243            match expr {
244                Expr::Context(ctx) => Ok(ctx),
245                _ => Err(anyhow::anyhow!("expected context, found {:?}", expr)),
246            }
247        }
248    }
249
250    #[test]
251    fn test_context_eq() {
252        let ctx = Context::try_from("foo.bar.baz").unwrap();
253        assert_eq!(&ctx, "foo.bar.baz");
254        assert_eq!(&ctx, "FOO.BAR.BAZ");
255        assert_eq!(&ctx, "Foo.Bar.Baz");
256    }
257
258    #[test]
259    fn test_context_child_of() {
260        let ctx = Context::try_from("foo.bar.baz").unwrap();
261
262        for (case, child) in &[
263            // Trivial child cases.
264            ("foo", true),
265            ("foo.bar", true),
266            // Case-insensitive cases.
267            ("FOO", true),
268            ("FOO.BAR", true),
269            ("Foo", true),
270            ("Foo.Bar", true),
271            // We consider a context to be a child of itself.
272            ("foo.bar.baz", true),
273            // Trivial non-child cases.
274            ("foo.bar.baz.qux", false),
275            ("foo.bar.qux", false),
276            ("foo.qux", false),
277            ("qux", false),
278            // Invalid cases.
279            ("foo.", false),
280            (".", false),
281            ("", false),
282        ] {
283            assert_eq!(ctx.child_of(*case), *child);
284        }
285    }
286
287    #[test]
288    fn test_context_pop_if() {
289        let ctx = Context::try_from("foo.bar.baz").unwrap();
290
291        for (case, expected) in &[
292            ("foo", Some("bar.baz")),
293            ("Foo", Some("bar.baz")),
294            ("FOO", Some("bar.baz")),
295            ("foo.", None),
296            ("bar", None),
297        ] {
298            assert_eq!(ctx.pop_if(case), *expected);
299        }
300    }
301
302    #[test]
303    fn test_context_as_pattern() {
304        for (case, expected) in &[
305            // Basic cases.
306            ("foo", Some("foo")),
307            ("foo.bar", Some("foo.bar")),
308            ("foo.bar.baz", Some("foo.bar.baz")),
309            ("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
310            ("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
311            ("foo.*", Some("foo.*")),
312            ("foo.bar.*", Some("foo.bar.*")),
313            ("foo.*.baz", Some("foo.*.baz")),
314            ("foo.*.*", Some("foo.*.*")),
315            // Case sensitivity.
316            ("FOO", Some("foo")),
317            ("FOO.BAR", Some("foo.bar")),
318            ("FOO.BAR.BAZ", Some("foo.bar.baz")),
319            ("FOO.BAR.BAZ_BAZ", Some("foo.bar.baz_baz")),
320            ("FOO.BAR.BAZ-BAZ", Some("foo.bar.baz-baz")),
321            ("FOO.*", Some("foo.*")),
322            ("FOO.BAR.*", Some("foo.bar.*")),
323            ("FOO.*.BAZ", Some("foo.*.baz")),
324            ("FOO.*.*", Some("foo.*.*")),
325            // Indexes.
326            ("foo.bar.baz[0]", Some("foo.bar.baz.*")),
327            ("foo.bar.baz['abc']", Some("foo.bar.baz.abc")),
328            ("foo.bar.baz[0].qux", Some("foo.bar.baz.*.qux")),
329            ("foo.bar.baz[0].qux[1]", Some("foo.bar.baz.*.qux.*")),
330            ("foo[1][2][3]", Some("foo.*.*.*")),
331            ("foo.bar[abc]", Some("foo.bar.*")),
332            ("foo.bar[abc()]", Some("foo.bar.*")),
333            // Invalid cases
334            ("foo().bar", None),
335        ] {
336            let ctx = Context::try_from(*case).unwrap();
337            assert_eq!(ctx.as_pattern().as_deref(), *expected);
338        }
339    }
340
341    #[test]
342    fn test_contextpattern_new() {
343        for (case, expected) in &[
344            // Well-formed patterns.
345            ("foo", Some("foo")),
346            ("foo.bar", Some("foo.bar")),
347            ("foo.bar.baz", Some("foo.bar.baz")),
348            ("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
349            ("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
350            ("foo.*", Some("foo.*")),
351            ("foo.bar.*", Some("foo.bar.*")),
352            ("foo.*.baz", Some("foo.*.baz")),
353            ("foo.*.*", Some("foo.*.*")),
354            // Invalid patterns.
355            ("", None),
356            ("foo.", None),
357            ("foo.**", None),
358            (".", None),
359            ("foo.bar.", None),
360            ("foo..bar", None),
361            ("foo.bar.baz[0]", None),
362            ("foo.bar.baz['abc']", None),
363            ("foo.bar.baz[0].qux", None),
364            ("foo.bar.baz[0].qux[1]", None),
365            ("❤", None),
366            ("❤.*", None),
367        ] {
368            assert_eq!(ContextPattern::new(case).map(|p| p.0), *expected);
369        }
370    }
371
372    #[test]
373    fn test_contextpattern_parent_of() {
374        for (pattern, ctx, expected) in &[
375            // Exact contains.
376            ("foo", "foo", true),
377            ("foo.bar", "foo.bar", true),
378            ("foo.bar", "foo['bar']", true),
379            ("foo.bar", "foo['BAR']", true),
380            // Parent relationships
381            ("foo", "foo.bar", true),
382            ("foo.bar", "foo.bar.baz", true),
383            ("foo.*", "foo.bar", true),
384            ("foo.*.baz", "foo.bar.baz", true),
385            ("foo.*.*", "foo.bar.baz.qux", true),
386            ("foo", "foo.bar.baz.qux", true),
387            ("foo.*", "foo.bar.baz.qux", true),
388            (
389                "secrets",
390                "fromJson(steps.runs.outputs.data).workflow_runs[0].id",
391                false,
392            ),
393        ] {
394            let pattern = ContextPattern::new(pattern).unwrap();
395            let ctx = Context::try_from(*ctx).unwrap();
396            assert_eq!(pattern.parent_of(&ctx), *expected);
397        }
398    }
399
400    #[test]
401    fn test_context_pattern_matches() {
402        for (pattern, ctx, expected) in &[
403            // Normal matches.
404            ("foo", "foo", true),
405            ("*", "foo", true),
406            ("foo.bar", "foo.bar", true),
407            ("foo.bar.baz", "foo.bar.baz", true),
408            ("foo.*", "foo.bar", true),
409            ("foo.*.baz", "foo.bar.baz", true),
410            ("foo.*.*", "foo.bar.baz", true),
411            ("foo.*.*.*", "foo.bar.baz.qux", true),
412            // Case-insensitive matches.
413            ("foo.bar", "FOO.BAR", true),
414            ("foo.bar.baz", "Foo.Bar.Baz", true),
415            ("foo.*", "FOO.BAR", true),
416            ("foo.*.baz", "Foo.Bar.Baz", true),
417            ("foo.*.*", "FOO.BAR.BAZ", true),
418            ("FOO.BAR", "foo.bar", true),
419            ("FOO.BAR.BAZ", "foo.bar.baz", true),
420            ("FOO.*", "foo.bar", true),
421            ("FOO.*.BAZ", "foo.bar.baz", true),
422            ("FOO.*.*", "foo.bar.baz", true),
423            // Indices also match correctly.
424            ("foo.bar.baz.*", "foo.bar.baz[0]", true),
425            ("foo.bar.baz.*", "foo.bar.baz[123]", true),
426            ("foo.bar.baz.*", "foo.bar.baz['abc']", true),
427            ("foo.bar.baz.*", "foo['bar']['baz']['abc']", true),
428            ("foo.bar.baz.*", "foo['bar']['BAZ']['abc']", true),
429            // Contexts containing stars match correctly.
430            ("foo.bar.baz.*", "foo.bar.baz.*", true),
431            ("foo.bar.*.*", "foo.bar.*.*", true),
432            ("foo.bar.baz.qux", "foo.bar.baz.*", false), // patterns are one way
433            ("foo.bar.baz.qux", "foo.bar.baz[*]", false), // patterns are one way
434            // False normal matches.
435            ("foo", "bar", false),                     // different identifier
436            ("foo.bar", "foo.baz", false),             // different identifier
437            ("foo.bar", "foo['baz']", false),          // different index
438            ("foo.bar.baz", "foo.bar.baz.qux", false), // pattern too short
439            ("foo.bar.baz", "foo.bar", false),         // context too short
440            ("foo.*.baz", "foo.bar.baz.qux", false),   // pattern too short
441            ("foo.*.qux", "foo.bar.baz.qux", false),   // * does not match multiple parts
442            ("foo.*.*", "foo.bar.baz.qux", false),     // pattern too short
443            ("foo.1", "foo[1]", false),                // .1 means a string key, not an index
444        ] {
445            let pattern = ContextPattern::new(pattern).unwrap();
446            let ctx = Context::try_from(*ctx).unwrap();
447            assert_eq!(pattern.matches(&ctx), *expected);
448        }
449    }
450}