github_actions_expressions/
context.rs

1//! Parsing and matching APIs for GitHub Actions expressions
2//! contexts (e.g. `github.event.name`).
3
4use crate::Literal;
5
6use super::{Expr, SpannedExpr};
7
8/// Represents a context in a GitHub Actions expression.
9///
10/// These typically look something like `github.actor` or `inputs.foo`,
11/// although they can also be a "call" context like `fromJSON(...).foo.bar`,
12/// i.e. where the head of the context is a function call rather than an
13/// identifier.
14#[derive(Debug, PartialEq)]
15pub struct Context<'src> {
16    /// The individual parts of the context.
17    pub parts: Vec<SpannedExpr<'src>>,
18}
19
20impl<'src> Context<'src> {
21    pub(crate) fn new(parts: impl Into<Vec<SpannedExpr<'src>>>) -> Self {
22        Self {
23            parts: parts.into(),
24        }
25    }
26
27    /// Returns whether the context matches the given pattern exactly.
28    pub fn matches(&self, pattern: impl TryInto<ContextPattern<'src>>) -> bool {
29        let Ok(pattern) = pattern.try_into() else {
30            return false;
31        };
32
33        pattern.matches(self)
34    }
35
36    /// Returns whether the context is a child of the given pattern.
37    ///
38    /// A context is considered its own child, i.e. `foo.bar` is a child of
39    /// `foo.bar`.
40    pub fn child_of(&self, parent: impl TryInto<ContextPattern<'src>>) -> bool {
41        let Ok(parent) = parent.try_into() else {
42            return false;
43        };
44
45        parent.parent_of(self)
46    }
47
48    /// Return this context's "single tail," if it has one.
49    ///
50    /// This is useful primarily for contexts under `env` and `inputs`,
51    /// where we expect only a single tail part, e.g. `env.FOO` or
52    /// `inputs['bar']`.
53    ///
54    /// Returns `None` if the context has more than one tail part,
55    /// or if the context's head part is not an identifier.
56    pub fn single_tail(&self) -> Option<&str> {
57        if self.parts.len() != 2 || !matches!(*self.parts[0], Expr::Identifier(_)) {
58            return None;
59        }
60
61        match &self.parts[1].inner {
62            Expr::Identifier(ident) => Some(ident.as_str()),
63            Expr::Index(idx) => match &idx.inner {
64                Expr::Literal(Literal::String(idx)) => Some(idx),
65                _ => None,
66            },
67            _ => None,
68        }
69    }
70
71    /// Returns the "pattern equivalent" of this context.
72    ///
73    /// This is a string that can be used to efficiently match the context,
74    /// such as is done in `zizmor`'s template-injection audit via a
75    /// finite state transducer.
76    ///
77    /// Returns None if the context doesn't have a sensible pattern
78    /// equivalent, e.g. if it starts with a call.
79    pub fn as_pattern(&self) -> Option<String> {
80        fn push_part(part: &Expr<'_>, pattern: &mut String) {
81            match part {
82                Expr::Identifier(ident) => pattern.push_str(ident.0),
83                Expr::Star => pattern.push('*'),
84                Expr::Index(idx) => match &idx.inner {
85                    // foo['bar'] -> foo.bar
86                    Expr::Literal(Literal::String(idx)) => pattern.push_str(idx),
87                    // any kind of numeric or computed index, e.g.:
88                    // foo[0], foo[1 + 2], foo[bar]
89                    _ => pattern.push('*'),
90                },
91                _ => unreachable!("unexpected part in context pattern"),
92            }
93        }
94
95        // TODO: Optimization ideas:
96        // 1. Add a happy path for contexts that contain only
97        //    identifiers? Problem: case normalization.
98        // 2. Use `regex-automata` to return a case insensitive
99        //    automation here?
100        let mut pattern = String::new();
101
102        let mut parts = self.parts.iter().peekable();
103
104        let head = parts.next()?;
105        if matches!(**head, Expr::Call { .. }) {
106            return None;
107        }
108
109        push_part(head, &mut pattern);
110        for part in parts {
111            pattern.push('.');
112            push_part(part, &mut pattern);
113        }
114
115        pattern.make_ascii_lowercase();
116        Some(pattern)
117    }
118}
119
120enum Comparison {
121    Child,
122    Match,
123}
124
125/// A `ContextPattern` is a pattern that matches one or more contexts.
126///
127/// It uses a restricted subset of the syntax used by contexts themselves:
128/// a pattern is always in dotted form and can only contain identifiers
129/// and wildcards.
130///
131/// Indices are not allowed in patterns themselves, although contexts
132/// that contain indices can be matched against patterns. For example,
133/// `github.event.pull_request.assignees.*.name` will match the context
134/// `github.event.pull_request.assignees[0].name`.
135pub struct ContextPattern<'src>(
136    // NOTE: Kept as a string as a potentially premature optimization;
137    // re-parsing should be faster in terms of locality.
138    // TODO: Vec instead?
139    &'src str,
140);
141
142impl<'src> TryFrom<&'src str> for ContextPattern<'src> {
143    type Error = anyhow::Error;
144
145    fn try_from(val: &'src str) -> anyhow::Result<Self> {
146        Self::try_new(val).ok_or_else(|| anyhow::anyhow!("invalid context pattern"))
147    }
148}
149
150impl<'src> ContextPattern<'src> {
151    /// Creates a new [`ContextPattern`] from the given string.
152    ///
153    /// Panics if the pattern is invalid.
154    pub const fn new(pattern: &'src str) -> Self {
155        Self::try_new(pattern).expect("invalid context pattern; use try_new to handle errors")
156    }
157
158    /// Creates a new [`ContextPattern`] from the given string.
159    ///
160    /// Returns `None` if the pattern is invalid.
161    pub const fn try_new(pattern: &'src str) -> Option<Self> {
162        let raw_pattern = pattern.as_bytes();
163        if raw_pattern.is_empty() {
164            return None;
165        }
166
167        let len = raw_pattern.len();
168
169        // State machine:
170        // - accept_reg: whether the next character can be a regular identifier character
171        // - accept_dot: whether the next character can be a dot
172        // - accept_star: whether the next character can be a star
173        let mut accept_reg = true;
174        let mut accept_dot = false;
175        let mut accept_star = false;
176
177        let mut idx = 0;
178        while idx < len {
179            accept_dot = accept_dot && idx != len - 1;
180
181            match raw_pattern[idx] {
182                b'.' => {
183                    if !accept_dot {
184                        return None;
185                    }
186
187                    accept_reg = true;
188                    accept_dot = false;
189                    accept_star = true;
190                }
191                b'*' => {
192                    if !accept_star {
193                        return None;
194                    }
195
196                    accept_reg = false;
197                    accept_star = false;
198                    accept_dot = true;
199                }
200                c if c.is_ascii_alphanumeric() || c == b'-' || c == b'_' => {
201                    if !accept_reg {
202                        return None;
203                    }
204
205                    accept_reg = true;
206                    accept_dot = true;
207                    accept_star = false;
208                }
209                _ => return None, // invalid character
210            }
211
212            idx += 1;
213        }
214
215        Some(Self(pattern))
216    }
217
218    fn compare_part(pattern: &str, part: &Expr<'src>) -> bool {
219        if pattern == "*" {
220            true
221        } else {
222            match part {
223                Expr::Identifier(part) => pattern.eq_ignore_ascii_case(part.0),
224                Expr::Index(part) => match &part.inner {
225                    Expr::Literal(Literal::String(part)) => pattern.eq_ignore_ascii_case(part),
226                    _ => false,
227                },
228                _ => false,
229            }
230        }
231    }
232
233    fn compare(&self, ctx: &Context<'src>) -> Option<Comparison> {
234        let mut pattern_parts = self.0.split('.').peekable();
235        let mut ctx_parts = ctx.parts.iter().peekable();
236
237        while let (Some(pattern), Some(part)) = (pattern_parts.peek(), ctx_parts.peek()) {
238            if !Self::compare_part(pattern, part) {
239                return None;
240            }
241
242            pattern_parts.next();
243            ctx_parts.next();
244        }
245
246        match (pattern_parts.next(), ctx_parts.next()) {
247            // If both are exhausted, we have an exact match.
248            (None, None) => Some(Comparison::Match),
249            // If the pattern is exhausted but the context isn't, then
250            // the context is a child of the pattern.
251            (None, Some(_)) => Some(Comparison::Child),
252            _ => None,
253        }
254    }
255
256    /// Returns true if the given context is a child of the pattern.
257    ///
258    /// This is a loose parent-child relationship; for example, `foo` is its
259    /// own parent, as well as the parent of `foo.bar` and `foo.bar.baz`.
260    pub fn parent_of(&self, ctx: &Context<'src>) -> bool {
261        matches!(
262            self.compare(ctx),
263            Some(Comparison::Child | Comparison::Match)
264        )
265    }
266
267    /// Returns true if the given context exactly matches the pattern.
268    ///
269    /// See [`ContextPattern`] for a description of the matching rules.
270    pub fn matches(&self, ctx: &Context<'src>) -> bool {
271        matches!(self.compare(ctx), Some(Comparison::Match))
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use crate::Expr;
278
279    use super::{Context, ContextPattern};
280
281    impl<'a> TryFrom<&'a str> for Context<'a> {
282        type Error = anyhow::Error;
283
284        fn try_from(val: &'a str) -> anyhow::Result<Self> {
285            let expr = Expr::parse(val)?;
286
287            match expr.inner {
288                Expr::Context(ctx) => Ok(ctx),
289                _ => Err(anyhow::anyhow!("expected context, found {:?}", expr)),
290            }
291        }
292    }
293
294    #[test]
295    fn test_context_child_of() {
296        let ctx = Context::try_from("foo.bar.baz").unwrap();
297
298        for (case, child) in &[
299            // Trivial child cases.
300            ("foo", true),
301            ("foo.bar", true),
302            // Case-insensitive cases.
303            ("FOO", true),
304            ("FOO.BAR", true),
305            ("Foo", true),
306            ("Foo.Bar", true),
307            // We consider a context to be a child of itself.
308            ("foo.bar.baz", true),
309            // Trivial non-child cases.
310            ("foo.bar.baz.qux", false),
311            ("foo.bar.qux", false),
312            ("foo.qux", false),
313            ("qux", false),
314            // Invalid cases.
315            ("foo.", false),
316            (".", false),
317            ("", false),
318        ] {
319            assert_eq!(ctx.child_of(*case), *child);
320        }
321    }
322
323    #[test]
324    fn test_single_tail() {
325        for (case, expected) in &[
326            // Valid cases.
327            ("foo.bar", Some("bar")),
328            ("foo['bar']", Some("bar")),
329            ("inputs.test", Some("test")),
330            // Invalid cases.
331            ("foo.bar.baz", None),       // too many parts
332            ("foo.bar.baz.qux", None),   // too many parts
333            ("foo['bar']['baz']", None), // too many parts
334            ("foo().bar", None),         // head is a call, not an identifier
335        ] {
336            let ctx = Context::try_from(*case).unwrap();
337            assert_eq!(ctx.single_tail(), *expected);
338        }
339    }
340
341    #[test]
342    fn test_context_as_pattern() {
343        for (case, expected) in &[
344            // Basic cases.
345            ("foo", Some("foo")),
346            ("foo.bar", Some("foo.bar")),
347            ("foo.bar.baz", Some("foo.bar.baz")),
348            ("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
349            ("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
350            ("foo.*", Some("foo.*")),
351            ("foo.bar.*", Some("foo.bar.*")),
352            ("foo.*.baz", Some("foo.*.baz")),
353            ("foo.*.*", Some("foo.*.*")),
354            // Case sensitivity.
355            ("FOO", Some("foo")),
356            ("FOO.BAR", Some("foo.bar")),
357            ("FOO.BAR.BAZ", Some("foo.bar.baz")),
358            ("FOO.BAR.BAZ_BAZ", Some("foo.bar.baz_baz")),
359            ("FOO.BAR.BAZ-BAZ", Some("foo.bar.baz-baz")),
360            ("FOO.*", Some("foo.*")),
361            ("FOO.BAR.*", Some("foo.bar.*")),
362            ("FOO.*.BAZ", Some("foo.*.baz")),
363            ("FOO.*.*", Some("foo.*.*")),
364            // Indexes.
365            ("foo.bar.baz[0]", Some("foo.bar.baz.*")),
366            ("foo.bar.baz['abc']", Some("foo.bar.baz.abc")),
367            ("foo.bar.baz[0].qux", Some("foo.bar.baz.*.qux")),
368            ("foo.bar.baz[0].qux[1]", Some("foo.bar.baz.*.qux.*")),
369            ("foo[1][2][3]", Some("foo.*.*.*")),
370            ("foo.bar[abc]", Some("foo.bar.*")),
371            ("foo.bar[abc()]", Some("foo.bar.*")),
372            // Whitespace.
373            ("foo . bar", Some("foo.bar")),
374            ("foo . bar . baz", Some("foo.bar.baz")),
375            ("foo . bar . baz_baz", Some("foo.bar.baz_baz")),
376            ("foo . bar . baz-baz", Some("foo.bar.baz-baz")),
377            ("foo .*", Some("foo.*")),
378            ("foo . bar .*", Some("foo.bar.*")),
379            ("foo .* . baz", Some("foo.*.baz")),
380            ("foo .* .*", Some("foo.*.*")),
381            // Invalid cases
382            ("foo().bar", None),
383        ] {
384            let ctx = Context::try_from(*case).unwrap();
385            assert_eq!(ctx.as_pattern().as_deref(), *expected);
386        }
387    }
388
389    #[test]
390    fn test_contextpattern_new() {
391        for (case, expected) in &[
392            // Well-formed patterns.
393            ("foo", Some("foo")),
394            ("foo.bar", Some("foo.bar")),
395            ("foo.bar.baz", Some("foo.bar.baz")),
396            ("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
397            ("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
398            ("foo.*", Some("foo.*")),
399            ("foo.bar.*", Some("foo.bar.*")),
400            ("foo.*.baz", Some("foo.*.baz")),
401            ("foo.*.*", Some("foo.*.*")),
402            // Invalid patterns.
403            ("", None),
404            ("*", None),
405            ("**", None),
406            (".**", None),
407            (".foo", None),
408            ("foo.", None),
409            (".foo.", None),
410            ("foo.**", None),
411            (".", None),
412            ("foo.bar.", None),
413            ("foo..bar", None),
414            ("foo.bar.baz[0]", None),
415            ("foo.bar.baz['abc']", None),
416            ("foo.bar.baz[0].qux", None),
417            ("foo.bar.baz[0].qux[1]", None),
418            ("❤", None),
419            ("❤.*", None),
420        ] {
421            assert_eq!(ContextPattern::try_new(case).map(|p| p.0), *expected);
422        }
423    }
424
425    #[test]
426    fn test_contextpattern_parent_of() {
427        for (pattern, ctx, expected) in &[
428            // Exact contains.
429            ("foo", "foo", true),
430            ("foo.bar", "foo.bar", true),
431            ("foo.bar", "foo['bar']", true),
432            ("foo.bar", "foo['BAR']", true),
433            // Parent relationships
434            ("foo", "foo.bar", true),
435            ("foo.bar", "foo.bar.baz", true),
436            ("foo.*", "foo.bar", true),
437            ("foo.*.baz", "foo.bar.baz", true),
438            ("foo.*.*", "foo.bar.baz.qux", true),
439            ("foo", "foo.bar.baz.qux", true),
440            ("foo.*", "foo.bar.baz.qux", true),
441            (
442                "secrets",
443                "fromJson(steps.runs.outputs.data).workflow_runs[0].id",
444                false,
445            ),
446        ] {
447            let pattern = ContextPattern::try_new(pattern).unwrap();
448            let ctx = Context::try_from(*ctx).unwrap();
449            assert_eq!(pattern.parent_of(&ctx), *expected);
450        }
451    }
452
453    #[test]
454    fn test_context_pattern_matches() {
455        for (pattern, ctx, expected) in &[
456            // Normal matches.
457            ("foo", "foo", true),
458            ("foo.bar", "foo.bar", true),
459            ("foo.bar.baz", "foo.bar.baz", true),
460            ("foo.*", "foo.bar", true),
461            ("foo.*.baz", "foo.bar.baz", true),
462            ("foo.*.*", "foo.bar.baz", true),
463            ("foo.*.*.*", "foo.bar.baz.qux", true),
464            // Case-insensitive matches.
465            ("foo.bar", "FOO.BAR", true),
466            ("foo.bar.baz", "Foo.Bar.Baz", true),
467            ("foo.*", "FOO.BAR", true),
468            ("foo.*.baz", "Foo.Bar.Baz", true),
469            ("foo.*.*", "FOO.BAR.BAZ", true),
470            ("FOO.BAR", "foo.bar", true),
471            ("FOO.BAR.BAZ", "foo.bar.baz", true),
472            ("FOO.*", "foo.bar", true),
473            ("FOO.*.BAZ", "foo.bar.baz", true),
474            ("FOO.*.*", "foo.bar.baz", true),
475            // Indices also match correctly.
476            ("foo.bar.baz.*", "foo.bar.baz[0]", true),
477            ("foo.bar.baz.*", "foo.bar.baz[123]", true),
478            ("foo.bar.baz.*", "foo.bar.baz['abc']", true),
479            ("foo.bar.baz.*", "foo['bar']['baz']['abc']", true),
480            ("foo.bar.baz.*", "foo['bar']['BAZ']['abc']", true),
481            // Contexts containing stars match correctly.
482            ("foo.bar.baz.*", "foo.bar.baz.*", true),
483            ("foo.bar.*.*", "foo.bar.*.*", true),
484            ("foo.bar.baz.qux", "foo.bar.baz.*", false), // patterns are one way
485            ("foo.bar.baz.qux", "foo.bar.baz[*]", false), // patterns are one way
486            // False normal matches.
487            ("foo", "bar", false),                     // different identifier
488            ("foo.bar", "foo.baz", false),             // different identifier
489            ("foo.bar", "foo['baz']", false),          // different index
490            ("foo.bar.baz", "foo.bar.baz.qux", false), // pattern too short
491            ("foo.bar.baz", "foo.bar", false),         // context too short
492            ("foo.*.baz", "foo.bar.baz.qux", false),   // pattern too short
493            ("foo.*.qux", "foo.bar.baz.qux", false),   // * does not match multiple parts
494            ("foo.*.*", "foo.bar.baz.qux", false),     // pattern too short
495            ("foo.1", "foo[1]", false),                // .1 means a string key, not an index
496        ] {
497            let pattern = ContextPattern::try_new(pattern)
498                .unwrap_or_else(|| panic!("invalid pattern: {pattern}"));
499            let ctx = Context::try_from(*ctx).unwrap();
500            assert_eq!(pattern.matches(&ctx), *expected);
501        }
502    }
503}