github_actions_expressions/
context.rs

1//! Parsing and matching APIs for GitHub Actions expressions
2//! contexts (e.g. `github.event.name`).
3
4use crate::Literal;
5
6use super::{Expr, SpannedExpr};
7
8/// Represents a context in a GitHub Actions expression.
9///
10/// These typically look something like `github.actor` or `inputs.foo`,
11/// although they can also be a "call" context like `fromJSON(...).foo.bar`,
12/// i.e. where the head of the context is a function call rather than an
13/// identifier.
14#[derive(Debug, PartialEq)]
15pub struct Context<'src> {
16    /// The individual parts of the context.
17    pub parts: Vec<SpannedExpr<'src>>,
18}
19
20impl<'src> Context<'src> {
21    pub(crate) fn new(parts: impl Into<Vec<SpannedExpr<'src>>>) -> Self {
22        Self {
23            parts: parts.into(),
24        }
25    }
26
27    /// Parse a context from the given string.
28    pub fn parse(raw: &'src str) -> anyhow::Result<Self> {
29        let expr = Expr::parse(raw)?;
30
31        match expr.inner {
32            Expr::Context(ctx) => Ok(ctx),
33            _ => Err(anyhow::anyhow!("expected context, found {:?}", expr)),
34        }
35    }
36
37    /// Returns whether the context matches the given pattern exactly.
38    pub fn matches(&self, pattern: impl TryInto<ContextPattern<'src>>) -> bool {
39        let Ok(pattern) = pattern.try_into() else {
40            return false;
41        };
42
43        pattern.matches(self)
44    }
45
46    /// Returns whether the context is a child of the given pattern.
47    ///
48    /// A context is considered its own child, i.e. `foo.bar` is a child of
49    /// `foo.bar`.
50    pub fn child_of(&self, parent: impl TryInto<ContextPattern<'src>>) -> bool {
51        let Ok(parent) = parent.try_into() else {
52            return false;
53        };
54
55        parent.parent_of(self)
56    }
57
58    /// Return this context's "single tail," if it has one.
59    ///
60    /// This is useful primarily for contexts under `env` and `inputs`,
61    /// where we expect only a single tail part, e.g. `env.FOO` or
62    /// `inputs['bar']`.
63    ///
64    /// Returns `None` if the context has more than one tail part,
65    /// or if the context's head part is not an identifier.
66    pub fn single_tail(&self) -> Option<&str> {
67        if self.parts.len() != 2 || !matches!(*self.parts[0], Expr::Identifier(_)) {
68            return None;
69        }
70
71        match &self.parts[1].inner {
72            Expr::Identifier(ident) => Some(ident.as_str()),
73            Expr::Index(idx) => match &idx.inner {
74                Expr::Literal(Literal::String(idx)) => Some(idx),
75                _ => None,
76            },
77            _ => None,
78        }
79    }
80
81    /// Returns the "pattern equivalent" of this context.
82    ///
83    /// This is a string that can be used to efficiently match the context,
84    /// such as is done in `zizmor`'s template-injection audit via a
85    /// finite state transducer.
86    ///
87    /// Returns None if the context doesn't have a sensible pattern
88    /// equivalent, e.g. if it starts with a call.
89    pub fn as_pattern(&self) -> Option<String> {
90        fn push_part(part: &Expr<'_>, pattern: &mut String) {
91            match part {
92                Expr::Identifier(ident) => pattern.push_str(ident.0),
93                Expr::Star => pattern.push('*'),
94                Expr::Index(idx) => match &idx.inner {
95                    // foo['bar'] -> foo.bar
96                    Expr::Literal(Literal::String(idx)) => pattern.push_str(idx),
97                    // any kind of numeric or computed index, e.g.:
98                    // foo[0], foo[1 + 2], foo[bar]
99                    _ => pattern.push('*'),
100                },
101                _ => unreachable!("unexpected part in context pattern"),
102            }
103        }
104
105        // TODO: Optimization ideas:
106        // 1. Add a happy path for contexts that contain only
107        //    identifiers? Problem: case normalization.
108        // 2. Use `regex-automata` to return a case insensitive
109        //    automation here?
110        let mut pattern = String::new();
111
112        let mut parts = self.parts.iter().peekable();
113
114        let head = parts.next()?;
115        if matches!(**head, Expr::Call { .. }) {
116            return None;
117        }
118
119        push_part(head, &mut pattern);
120        for part in parts {
121            pattern.push('.');
122            push_part(part, &mut pattern);
123        }
124
125        pattern.make_ascii_lowercase();
126        Some(pattern)
127    }
128}
129
130enum Comparison {
131    Child,
132    Match,
133}
134
135/// A `ContextPattern` is a pattern that matches one or more contexts.
136///
137/// It uses a restricted subset of the syntax used by contexts themselves:
138/// a pattern is always in dotted form and can only contain identifiers
139/// and wildcards.
140///
141/// Indices are not allowed in patterns themselves, although contexts
142/// that contain indices can be matched against patterns. For example,
143/// `github.event.pull_request.assignees.*.name` will match the context
144/// `github.event.pull_request.assignees[0].name`.
145pub struct ContextPattern<'src>(
146    // NOTE: Kept as a string as a potentially premature optimization;
147    // re-parsing should be faster in terms of locality.
148    // TODO: Vec instead?
149    &'src str,
150);
151
152impl<'src> TryFrom<&'src str> for ContextPattern<'src> {
153    type Error = anyhow::Error;
154
155    fn try_from(val: &'src str) -> anyhow::Result<Self> {
156        Self::try_new(val).ok_or_else(|| anyhow::anyhow!("invalid context pattern"))
157    }
158}
159
160impl<'src> ContextPattern<'src> {
161    /// Creates a new [`ContextPattern`] from the given string.
162    ///
163    /// Panics if the pattern is invalid.
164    pub const fn new(pattern: &'src str) -> Self {
165        Self::try_new(pattern).expect("invalid context pattern; use try_new to handle errors")
166    }
167
168    /// Creates a new [`ContextPattern`] from the given string.
169    ///
170    /// Returns `None` if the pattern is invalid.
171    pub const fn try_new(pattern: &'src str) -> Option<Self> {
172        let raw_pattern = pattern.as_bytes();
173        if raw_pattern.is_empty() {
174            return None;
175        }
176
177        let len = raw_pattern.len();
178
179        // State machine:
180        // - accept_reg: whether the next character can be a regular identifier character
181        // - accept_dot: whether the next character can be a dot
182        // - accept_star: whether the next character can be a star
183        let mut accept_reg = true;
184        let mut accept_dot = false;
185        let mut accept_star = false;
186
187        let mut idx = 0;
188        while idx < len {
189            accept_dot = accept_dot && idx != len - 1;
190
191            match raw_pattern[idx] {
192                b'.' => {
193                    if !accept_dot {
194                        return None;
195                    }
196
197                    accept_reg = true;
198                    accept_dot = false;
199                    accept_star = true;
200                }
201                b'*' => {
202                    if !accept_star {
203                        return None;
204                    }
205
206                    accept_reg = false;
207                    accept_star = false;
208                    accept_dot = true;
209                }
210                c if c.is_ascii_alphanumeric() || c == b'-' || c == b'_' => {
211                    if !accept_reg {
212                        return None;
213                    }
214
215                    accept_reg = true;
216                    accept_dot = true;
217                    accept_star = false;
218                }
219                _ => return None, // invalid character
220            }
221
222            idx += 1;
223        }
224
225        Some(Self(pattern))
226    }
227
228    fn compare_part(pattern: &str, part: &Expr<'src>) -> bool {
229        if pattern == "*" {
230            true
231        } else {
232            match part {
233                Expr::Identifier(part) => pattern.eq_ignore_ascii_case(part.0),
234                Expr::Index(part) => match &part.inner {
235                    Expr::Literal(Literal::String(part)) => pattern.eq_ignore_ascii_case(part),
236                    _ => false,
237                },
238                _ => false,
239            }
240        }
241    }
242
243    fn compare(&self, ctx: &Context<'src>) -> Option<Comparison> {
244        let mut pattern_parts = self.0.split('.').peekable();
245        let mut ctx_parts = ctx.parts.iter().peekable();
246
247        while let (Some(pattern), Some(part)) = (pattern_parts.peek(), ctx_parts.peek()) {
248            if !Self::compare_part(pattern, part) {
249                return None;
250            }
251
252            pattern_parts.next();
253            ctx_parts.next();
254        }
255
256        match (pattern_parts.next(), ctx_parts.next()) {
257            // If both are exhausted, we have an exact match.
258            (None, None) => Some(Comparison::Match),
259            // If the pattern is exhausted but the context isn't, then
260            // the context is a child of the pattern.
261            (None, Some(_)) => Some(Comparison::Child),
262            _ => None,
263        }
264    }
265
266    /// Returns true if the given context is a child of the pattern.
267    ///
268    /// This is a loose parent-child relationship; for example, `foo` is its
269    /// own parent, as well as the parent of `foo.bar` and `foo.bar.baz`.
270    pub fn parent_of(&self, ctx: &Context<'src>) -> bool {
271        matches!(
272            self.compare(ctx),
273            Some(Comparison::Child | Comparison::Match)
274        )
275    }
276
277    /// Returns true if the given context exactly matches the pattern.
278    ///
279    /// See [`ContextPattern`] for a description of the matching rules.
280    pub fn matches(&self, ctx: &Context<'src>) -> bool {
281        matches!(self.compare(ctx), Some(Comparison::Match))
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use crate::Expr;
288
289    use super::{Context, ContextPattern};
290
291    impl<'a> TryFrom<&'a str> for Context<'a> {
292        type Error = anyhow::Error;
293
294        fn try_from(val: &'a str) -> anyhow::Result<Self> {
295            let expr = Expr::parse(val)?;
296
297            match expr.inner {
298                Expr::Context(ctx) => Ok(ctx),
299                _ => Err(anyhow::anyhow!("expected context, found {:?}", expr)),
300            }
301        }
302    }
303
304    #[test]
305    fn test_context_child_of() {
306        let ctx = Context::try_from("foo.bar.baz").unwrap();
307
308        for (case, child) in &[
309            // Trivial child cases.
310            ("foo", true),
311            ("foo.bar", true),
312            // Case-insensitive cases.
313            ("FOO", true),
314            ("FOO.BAR", true),
315            ("Foo", true),
316            ("Foo.Bar", true),
317            // We consider a context to be a child of itself.
318            ("foo.bar.baz", true),
319            // Trivial non-child cases.
320            ("foo.bar.baz.qux", false),
321            ("foo.bar.qux", false),
322            ("foo.qux", false),
323            ("qux", false),
324            // Invalid cases.
325            ("foo.", false),
326            (".", false),
327            ("", false),
328        ] {
329            assert_eq!(ctx.child_of(*case), *child);
330        }
331    }
332
333    #[test]
334    fn test_single_tail() {
335        for (case, expected) in &[
336            // Valid cases.
337            ("foo.bar", Some("bar")),
338            ("foo['bar']", Some("bar")),
339            ("inputs.test", Some("test")),
340            // Invalid cases.
341            ("foo.bar.baz", None),       // too many parts
342            ("foo.bar.baz.qux", None),   // too many parts
343            ("foo['bar']['baz']", None), // too many parts
344            ("foo().bar", None),         // head is a call, not an identifier
345        ] {
346            let ctx = Context::try_from(*case).unwrap();
347            assert_eq!(ctx.single_tail(), *expected);
348        }
349    }
350
351    #[test]
352    fn test_context_as_pattern() {
353        for (case, expected) in &[
354            // Basic cases.
355            ("foo", Some("foo")),
356            ("foo.bar", Some("foo.bar")),
357            ("foo.bar.baz", Some("foo.bar.baz")),
358            ("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
359            ("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
360            ("foo.*", Some("foo.*")),
361            ("foo.bar.*", Some("foo.bar.*")),
362            ("foo.*.baz", Some("foo.*.baz")),
363            ("foo.*.*", Some("foo.*.*")),
364            // Case sensitivity.
365            ("FOO", Some("foo")),
366            ("FOO.BAR", Some("foo.bar")),
367            ("FOO.BAR.BAZ", Some("foo.bar.baz")),
368            ("FOO.BAR.BAZ_BAZ", Some("foo.bar.baz_baz")),
369            ("FOO.BAR.BAZ-BAZ", Some("foo.bar.baz-baz")),
370            ("FOO.*", Some("foo.*")),
371            ("FOO.BAR.*", Some("foo.bar.*")),
372            ("FOO.*.BAZ", Some("foo.*.baz")),
373            ("FOO.*.*", Some("foo.*.*")),
374            // Indexes.
375            ("foo.bar.baz[0]", Some("foo.bar.baz.*")),
376            ("foo.bar.baz['abc']", Some("foo.bar.baz.abc")),
377            ("foo.bar.baz[0].qux", Some("foo.bar.baz.*.qux")),
378            ("foo.bar.baz[0].qux[1]", Some("foo.bar.baz.*.qux.*")),
379            ("foo[1][2][3]", Some("foo.*.*.*")),
380            ("foo.bar[abc]", Some("foo.bar.*")),
381            ("foo.bar[abc()]", Some("foo.bar.*")),
382            // Whitespace.
383            ("foo . bar", Some("foo.bar")),
384            ("foo . bar . baz", Some("foo.bar.baz")),
385            ("foo . bar . baz_baz", Some("foo.bar.baz_baz")),
386            ("foo . bar . baz-baz", Some("foo.bar.baz-baz")),
387            ("foo .*", Some("foo.*")),
388            ("foo . bar .*", Some("foo.bar.*")),
389            ("foo .* . baz", Some("foo.*.baz")),
390            ("foo .* .*", Some("foo.*.*")),
391            // Invalid cases
392            ("foo().bar", None),
393        ] {
394            let ctx = Context::try_from(*case).unwrap();
395            assert_eq!(ctx.as_pattern().as_deref(), *expected);
396        }
397    }
398
399    #[test]
400    fn test_contextpattern_new() {
401        for (case, expected) in &[
402            // Well-formed patterns.
403            ("foo", Some("foo")),
404            ("foo.bar", Some("foo.bar")),
405            ("foo.bar.baz", Some("foo.bar.baz")),
406            ("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
407            ("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
408            ("foo.*", Some("foo.*")),
409            ("foo.bar.*", Some("foo.bar.*")),
410            ("foo.*.baz", Some("foo.*.baz")),
411            ("foo.*.*", Some("foo.*.*")),
412            // Invalid patterns.
413            ("", None),
414            ("*", None),
415            ("**", None),
416            (".**", None),
417            (".foo", None),
418            ("foo.", None),
419            (".foo.", None),
420            ("foo.**", None),
421            (".", None),
422            ("foo.bar.", None),
423            ("foo..bar", None),
424            ("foo.bar.baz[0]", None),
425            ("foo.bar.baz['abc']", None),
426            ("foo.bar.baz[0].qux", None),
427            ("foo.bar.baz[0].qux[1]", None),
428            ("❤", None),
429            ("❤.*", None),
430        ] {
431            assert_eq!(ContextPattern::try_new(case).map(|p| p.0), *expected);
432        }
433    }
434
435    #[test]
436    fn test_contextpattern_parent_of() {
437        for (pattern, ctx, expected) in &[
438            // Exact contains.
439            ("foo", "foo", true),
440            ("foo.bar", "foo.bar", true),
441            ("foo.bar", "foo['bar']", true),
442            ("foo.bar", "foo['BAR']", true),
443            // Parent relationships
444            ("foo", "foo.bar", true),
445            ("foo.bar", "foo.bar.baz", true),
446            ("foo.*", "foo.bar", true),
447            ("foo.*.baz", "foo.bar.baz", true),
448            ("foo.*.*", "foo.bar.baz.qux", true),
449            ("foo", "foo.bar.baz.qux", true),
450            ("foo.*", "foo.bar.baz.qux", true),
451            (
452                "secrets",
453                "fromJson(steps.runs.outputs.data).workflow_runs[0].id",
454                false,
455            ),
456        ] {
457            let pattern = ContextPattern::try_new(pattern).unwrap();
458            let ctx = Context::try_from(*ctx).unwrap();
459            assert_eq!(pattern.parent_of(&ctx), *expected);
460        }
461    }
462
463    #[test]
464    fn test_context_pattern_matches() {
465        for (pattern, ctx, expected) in &[
466            // Normal matches.
467            ("foo", "foo", true),
468            ("foo.bar", "foo.bar", true),
469            ("foo.bar.baz", "foo.bar.baz", true),
470            ("foo.*", "foo.bar", true),
471            ("foo.*.baz", "foo.bar.baz", true),
472            ("foo.*.*", "foo.bar.baz", true),
473            ("foo.*.*.*", "foo.bar.baz.qux", true),
474            // Case-insensitive matches.
475            ("foo.bar", "FOO.BAR", true),
476            ("foo.bar.baz", "Foo.Bar.Baz", true),
477            ("foo.*", "FOO.BAR", true),
478            ("foo.*.baz", "Foo.Bar.Baz", true),
479            ("foo.*.*", "FOO.BAR.BAZ", true),
480            ("FOO.BAR", "foo.bar", true),
481            ("FOO.BAR.BAZ", "foo.bar.baz", true),
482            ("FOO.*", "foo.bar", true),
483            ("FOO.*.BAZ", "foo.bar.baz", true),
484            ("FOO.*.*", "foo.bar.baz", true),
485            // Indices also match correctly.
486            ("foo.bar.baz.*", "foo.bar.baz[0]", true),
487            ("foo.bar.baz.*", "foo.bar.baz[123]", true),
488            ("foo.bar.baz.*", "foo.bar.baz['abc']", true),
489            ("foo.bar.baz.*", "foo['bar']['baz']['abc']", true),
490            ("foo.bar.baz.*", "foo['bar']['BAZ']['abc']", true),
491            // Contexts containing stars match correctly.
492            ("foo.bar.baz.*", "foo.bar.baz.*", true),
493            ("foo.bar.*.*", "foo.bar.*.*", true),
494            ("foo.bar.baz.qux", "foo.bar.baz.*", false), // patterns are one way
495            ("foo.bar.baz.qux", "foo.bar.baz[*]", false), // patterns are one way
496            // False normal matches.
497            ("foo", "bar", false),                     // different identifier
498            ("foo.bar", "foo.baz", false),             // different identifier
499            ("foo.bar", "foo['baz']", false),          // different index
500            ("foo.bar.baz", "foo.bar.baz.qux", false), // pattern too short
501            ("foo.bar.baz", "foo.bar", false),         // context too short
502            ("foo.*.baz", "foo.bar.baz.qux", false),   // pattern too short
503            ("foo.*.qux", "foo.bar.baz.qux", false),   // * does not match multiple parts
504            ("foo.*.*", "foo.bar.baz.qux", false),     // pattern too short
505            ("foo.1", "foo[1]", false),                // .1 means a string key, not an index
506        ] {
507            let pattern = ContextPattern::try_new(pattern)
508                .unwrap_or_else(|| panic!("invalid pattern: {pattern}"));
509            let ctx = Context::try_from(*ctx).unwrap();
510            assert_eq!(pattern.matches(&ctx), *expected);
511        }
512    }
513}