Skip to main content

github_actions_expressions/
context.rs

1//! Parsing and matching APIs for GitHub Actions expressions
2//! contexts (e.g. `github.event.name`).
3
4use crate::Literal;
5
6use super::{Expr, SpannedExpr};
7
8/// Errors that can occur when parsing a context pattern.
9#[derive(Debug, thiserror::Error)]
10#[error("invalid context pattern")]
11pub struct InvalidContextPattern;
12
13/// Represents a context in a GitHub Actions expression.
14///
15/// These typically look something like `github.actor` or `inputs.foo`,
16/// although they can also be a "call" context like `fromJSON(...).foo.bar`,
17/// i.e. where the head of the context is a function call rather than an
18/// identifier.
19#[derive(Debug, PartialEq)]
20pub struct Context<'src> {
21    /// The individual parts of the context.
22    pub parts: Vec<SpannedExpr<'src>>,
23}
24
25impl<'src> Context<'src> {
26    pub(crate) fn new(parts: impl Into<Vec<SpannedExpr<'src>>>) -> Self {
27        Self {
28            parts: parts.into(),
29        }
30    }
31
32    /// Parse a context from the given string.
33    ///
34    /// Returns `None` if the string is not a valid context.
35    pub fn parse(raw: &'src str) -> Option<Self> {
36        let expr = Expr::parse(raw).ok()?;
37
38        match expr.inner {
39            Expr::Context(ctx) => Some(ctx),
40            _ => None,
41        }
42    }
43
44    /// Returns whether the context matches the given pattern exactly.
45    pub fn matches(&self, pattern: impl TryInto<ContextPattern<'src>>) -> bool {
46        let Ok(pattern) = pattern.try_into() else {
47            return false;
48        };
49
50        pattern.matches(self)
51    }
52
53    /// Returns whether the context is a child of the given pattern.
54    ///
55    /// A context is considered its own child, i.e. `foo.bar` is a child of
56    /// `foo.bar`.
57    pub fn child_of(&self, parent: impl TryInto<ContextPattern<'src>>) -> bool {
58        let Ok(parent) = parent.try_into() else {
59            return false;
60        };
61
62        parent.parent_of(self)
63    }
64
65    /// Return this context's "single tail," if it has one.
66    ///
67    /// This is useful primarily for contexts under `env` and `inputs`,
68    /// where we expect only a single tail part, e.g. `env.FOO` or
69    /// `inputs['bar']`.
70    ///
71    /// Returns `None` if the context has more than one tail part,
72    /// or if the context's head part is not an identifier.
73    pub fn single_tail(&self) -> Option<&str> {
74        if self.parts.len() != 2 || !matches!(*self.parts[0], Expr::Identifier(_)) {
75            return None;
76        }
77
78        match &self.parts[1].inner {
79            Expr::Identifier(ident) => Some(ident.as_str()),
80            Expr::Index(idx) => match &idx.inner {
81                Expr::Literal(Literal::String(idx)) => Some(idx),
82                _ => None,
83            },
84            _ => None,
85        }
86    }
87
88    /// Returns the "pattern equivalent" of this context.
89    ///
90    /// This is a string that can be used to efficiently match the context,
91    /// such as is done in `zizmor`'s template-injection audit via a
92    /// finite state transducer.
93    ///
94    /// Returns None if the context doesn't have a sensible pattern
95    /// equivalent, e.g. if it starts with a call.
96    pub fn as_pattern(&self) -> Option<String> {
97        fn push_part(part: &Expr<'_>, pattern: &mut String) {
98            match part {
99                Expr::Identifier(ident) => pattern.push_str(ident.0),
100                Expr::Star => pattern.push('*'),
101                Expr::Index(idx) => match &idx.inner {
102                    // foo['bar'] -> foo.bar
103                    Expr::Literal(Literal::String(idx)) => pattern.push_str(idx),
104                    // any kind of numeric or computed index, e.g.:
105                    // foo[0], foo[1 + 2], foo[bar]
106                    _ => pattern.push('*'),
107                },
108                _ => unreachable!("unexpected part in context pattern"),
109            }
110        }
111
112        // TODO: Optimization ideas:
113        // 1. Add a happy path for contexts that contain only
114        //    identifiers? Problem: case normalization.
115        // 2. Use `regex-automata` to return a case insensitive
116        //    automation here?
117        let mut pattern = String::new();
118
119        let mut parts = self.parts.iter().peekable();
120
121        let head = parts.next()?;
122        if matches!(**head, Expr::Call { .. }) {
123            return None;
124        }
125
126        push_part(head, &mut pattern);
127        for part in parts {
128            pattern.push('.');
129            push_part(part, &mut pattern);
130        }
131
132        pattern.make_ascii_lowercase();
133        Some(pattern)
134    }
135}
136
137enum Comparison {
138    Child,
139    Match,
140}
141
142/// A `ContextPattern` is a pattern that matches one or more contexts.
143///
144/// It uses a restricted subset of the syntax used by contexts themselves:
145/// a pattern is always in dotted form and can only contain identifiers
146/// and wildcards.
147///
148/// Indices are not allowed in patterns themselves, although contexts
149/// that contain indices can be matched against patterns. For example,
150/// `github.event.pull_request.assignees.*.name` will match the context
151/// `github.event.pull_request.assignees[0].name`.
152pub struct ContextPattern<'src>(
153    // NOTE: Kept as a string as a potentially premature optimization;
154    // re-parsing should be faster in terms of locality.
155    // TODO: Vec instead?
156    &'src str,
157);
158
159impl<'src> TryFrom<&'src str> for ContextPattern<'src> {
160    type Error = InvalidContextPattern;
161
162    fn try_from(val: &'src str) -> Result<Self, Self::Error> {
163        Self::try_new(val).ok_or(InvalidContextPattern)
164    }
165}
166
167impl<'src> ContextPattern<'src> {
168    /// Creates a new [`ContextPattern`] from the given string.
169    ///
170    /// Panics if the pattern is invalid.
171    pub const fn new(pattern: &'src str) -> Self {
172        Self::try_new(pattern).expect("invalid context pattern; use try_new to handle errors")
173    }
174
175    /// Creates a new [`ContextPattern`] from the given string.
176    ///
177    /// Returns `None` if the pattern is invalid.
178    pub const fn try_new(pattern: &'src str) -> Option<Self> {
179        let raw_pattern = pattern.as_bytes();
180        if raw_pattern.is_empty() {
181            return None;
182        }
183
184        let len = raw_pattern.len();
185
186        // State machine:
187        // - accept_reg: whether the next character can be a regular identifier character
188        // - accept_dot: whether the next character can be a dot
189        // - accept_star: whether the next character can be a star
190        let mut accept_reg = true;
191        let mut accept_dot = false;
192        let mut accept_star = false;
193
194        let mut idx = 0;
195        while idx < len {
196            accept_dot = accept_dot && idx != len - 1;
197
198            match raw_pattern[idx] {
199                b'.' => {
200                    if !accept_dot {
201                        return None;
202                    }
203
204                    accept_reg = true;
205                    accept_dot = false;
206                    accept_star = true;
207                }
208                b'*' => {
209                    if !accept_star {
210                        return None;
211                    }
212
213                    accept_reg = false;
214                    accept_star = false;
215                    accept_dot = true;
216                }
217                c if c.is_ascii_alphanumeric() || c == b'-' || c == b'_' => {
218                    if !accept_reg {
219                        return None;
220                    }
221
222                    accept_reg = true;
223                    accept_dot = true;
224                    accept_star = false;
225                }
226                _ => return None, // invalid character
227            }
228
229            idx += 1;
230        }
231
232        Some(Self(pattern))
233    }
234
235    fn compare_part(pattern: &str, part: &Expr<'src>) -> bool {
236        if pattern == "*" {
237            true
238        } else {
239            match part {
240                Expr::Identifier(part) => pattern.eq_ignore_ascii_case(part.0),
241                Expr::Index(part) => match &part.inner {
242                    Expr::Literal(Literal::String(part)) => pattern.eq_ignore_ascii_case(part),
243                    _ => false,
244                },
245                _ => false,
246            }
247        }
248    }
249
250    fn compare(&self, ctx: &Context<'src>) -> Option<Comparison> {
251        let mut pattern_parts = self.0.split('.').peekable();
252        let mut ctx_parts = ctx.parts.iter().peekable();
253
254        while let (Some(pattern), Some(part)) = (pattern_parts.peek(), ctx_parts.peek()) {
255            if !Self::compare_part(pattern, part) {
256                return None;
257            }
258
259            pattern_parts.next();
260            ctx_parts.next();
261        }
262
263        match (pattern_parts.next(), ctx_parts.next()) {
264            // If both are exhausted, we have an exact match.
265            (None, None) => Some(Comparison::Match),
266            // If the pattern is exhausted but the context isn't, then
267            // the context is a child of the pattern.
268            (None, Some(_)) => Some(Comparison::Child),
269            _ => None,
270        }
271    }
272
273    /// Returns true if the given context is a child of the pattern.
274    ///
275    /// This is a loose parent-child relationship; for example, `foo` is its
276    /// own parent, as well as the parent of `foo.bar` and `foo.bar.baz`.
277    pub fn parent_of(&self, ctx: &Context<'src>) -> bool {
278        matches!(
279            self.compare(ctx),
280            Some(Comparison::Child | Comparison::Match)
281        )
282    }
283
284    /// Returns true if the given context exactly matches the pattern.
285    ///
286    /// See [`ContextPattern`] for a description of the matching rules.
287    pub fn matches(&self, ctx: &Context<'src>) -> bool {
288        matches!(self.compare(ctx), Some(Comparison::Match))
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::{Context, ContextPattern};
295
296    #[test]
297    fn test_context_child_of() {
298        let ctx = Context::parse("foo.bar.baz").unwrap();
299
300        for (case, child) in &[
301            // Trivial child cases.
302            ("foo", true),
303            ("foo.bar", true),
304            // Case-insensitive cases.
305            ("FOO", true),
306            ("FOO.BAR", true),
307            ("Foo", true),
308            ("Foo.Bar", true),
309            // We consider a context to be a child of itself.
310            ("foo.bar.baz", true),
311            // Trivial non-child cases.
312            ("foo.bar.baz.qux", false),
313            ("foo.bar.qux", false),
314            ("foo.qux", false),
315            ("qux", false),
316            // Invalid cases.
317            ("foo.", false),
318            (".", false),
319            ("", false),
320        ] {
321            assert_eq!(ctx.child_of(*case), *child);
322        }
323    }
324
325    #[test]
326    fn test_single_tail() {
327        for (case, expected) in &[
328            // Valid cases.
329            ("foo.bar", Some("bar")),
330            ("foo['bar']", Some("bar")),
331            ("inputs.test", Some("test")),
332            // Invalid cases.
333            ("foo.bar.baz", None),        // too many parts
334            ("foo.bar.baz.qux", None),    // too many parts
335            ("foo['bar']['baz']", None),  // too many parts
336            ("fromJSON('{}').bar", None), // head is a call, not an identifier
337        ] {
338            let ctx = Context::parse(*case).unwrap();
339            assert_eq!(ctx.single_tail(), *expected);
340        }
341    }
342
343    #[test]
344    fn test_context_as_pattern() {
345        for (case, expected) in &[
346            // Basic cases.
347            ("foo", Some("foo")),
348            ("foo.bar", Some("foo.bar")),
349            ("foo.bar.baz", Some("foo.bar.baz")),
350            ("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
351            ("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
352            ("foo.*", Some("foo.*")),
353            ("foo.bar.*", Some("foo.bar.*")),
354            ("foo.*.baz", Some("foo.*.baz")),
355            ("foo.*.*", Some("foo.*.*")),
356            // Case sensitivity.
357            ("FOO", Some("foo")),
358            ("FOO.BAR", Some("foo.bar")),
359            ("FOO.BAR.BAZ", Some("foo.bar.baz")),
360            ("FOO.BAR.BAZ_BAZ", Some("foo.bar.baz_baz")),
361            ("FOO.BAR.BAZ-BAZ", Some("foo.bar.baz-baz")),
362            ("FOO.*", Some("foo.*")),
363            ("FOO.BAR.*", Some("foo.bar.*")),
364            ("FOO.*.BAZ", Some("foo.*.baz")),
365            ("FOO.*.*", Some("foo.*.*")),
366            // Indexes.
367            ("foo.bar.baz[0]", Some("foo.bar.baz.*")),
368            ("foo.bar.baz['abc']", Some("foo.bar.baz.abc")),
369            ("foo.bar.baz[0].qux", Some("foo.bar.baz.*.qux")),
370            ("foo.bar.baz[0].qux[1]", Some("foo.bar.baz.*.qux.*")),
371            ("foo[1][2][3]", Some("foo.*.*.*")),
372            ("foo.bar[abc]", Some("foo.bar.*")),
373            (
374                "foo.bar[join(github.event.issue.labels.*.name, ', ')]",
375                Some("foo.bar.*"),
376            ),
377            // Whitespace.
378            ("foo . bar", Some("foo.bar")),
379            ("foo . bar . baz", Some("foo.bar.baz")),
380            ("foo . bar . baz_baz", Some("foo.bar.baz_baz")),
381            ("foo . bar . baz-baz", Some("foo.bar.baz-baz")),
382            ("foo .*", Some("foo.*")),
383            ("foo . bar .*", Some("foo.bar.*")),
384            ("foo .* . baz", Some("foo.*.baz")),
385            ("foo .* .*", Some("foo.*.*")),
386            // Invalid cases
387            ("fromJSON('{}').bar", None),
388        ] {
389            let ctx = Context::parse(*case).unwrap();
390            assert_eq!(ctx.as_pattern().as_deref(), *expected);
391        }
392    }
393
394    #[test]
395    fn test_contextpattern_new() {
396        for (case, expected) in &[
397            // Well-formed patterns.
398            ("foo", Some("foo")),
399            ("foo.bar", Some("foo.bar")),
400            ("foo.bar.baz", Some("foo.bar.baz")),
401            ("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
402            ("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
403            ("foo.*", Some("foo.*")),
404            ("foo.bar.*", Some("foo.bar.*")),
405            ("foo.*.baz", Some("foo.*.baz")),
406            ("foo.*.*", Some("foo.*.*")),
407            // Invalid patterns.
408            ("", None),
409            ("*", None),
410            ("**", None),
411            (".**", None),
412            (".foo", None),
413            ("foo.", None),
414            (".foo.", None),
415            ("foo.**", None),
416            (".", None),
417            ("foo.bar.", None),
418            ("foo..bar", None),
419            ("foo.bar.baz[0]", None),
420            ("foo.bar.baz['abc']", None),
421            ("foo.bar.baz[0].qux", None),
422            ("foo.bar.baz[0].qux[1]", None),
423            ("❤", None),
424            ("❤.*", None),
425        ] {
426            assert_eq!(ContextPattern::try_new(case).map(|p| p.0), *expected);
427        }
428    }
429
430    #[test]
431    fn test_contextpattern_parent_of() {
432        for (pattern, ctx, expected) in &[
433            // Exact contains.
434            ("foo", "foo", true),
435            ("foo.bar", "foo.bar", true),
436            ("foo.bar", "foo['bar']", true),
437            ("foo.bar", "foo['BAR']", true),
438            // Parent relationships
439            ("foo", "foo.bar", true),
440            ("foo.bar", "foo.bar.baz", true),
441            ("foo.*", "foo.bar", true),
442            ("foo.*.baz", "foo.bar.baz", true),
443            ("foo.*.*", "foo.bar.baz.qux", true),
444            ("foo", "foo.bar.baz.qux", true),
445            ("foo.*", "foo.bar.baz.qux", true),
446            (
447                "secrets",
448                "fromJson(steps.runs.outputs.data).workflow_runs[0].id",
449                false,
450            ),
451        ] {
452            let pattern = ContextPattern::try_new(pattern).unwrap();
453            let ctx = Context::parse(*ctx).unwrap();
454            assert_eq!(pattern.parent_of(&ctx), *expected);
455        }
456    }
457
458    #[test]
459    fn test_context_pattern_matches() {
460        for (pattern, ctx, expected) in &[
461            // Normal matches.
462            ("foo", "foo", true),
463            ("foo.bar", "foo.bar", true),
464            ("foo.bar.baz", "foo.bar.baz", true),
465            ("foo.*", "foo.bar", true),
466            ("foo.*.baz", "foo.bar.baz", true),
467            ("foo.*.*", "foo.bar.baz", true),
468            ("foo.*.*.*", "foo.bar.baz.qux", true),
469            // Case-insensitive matches.
470            ("foo.bar", "FOO.BAR", true),
471            ("foo.bar.baz", "Foo.Bar.Baz", true),
472            ("foo.*", "FOO.BAR", true),
473            ("foo.*.baz", "Foo.Bar.Baz", true),
474            ("foo.*.*", "FOO.BAR.BAZ", true),
475            ("FOO.BAR", "foo.bar", true),
476            ("FOO.BAR.BAZ", "foo.bar.baz", true),
477            ("FOO.*", "foo.bar", true),
478            ("FOO.*.BAZ", "foo.bar.baz", true),
479            ("FOO.*.*", "foo.bar.baz", true),
480            // Indices also match correctly.
481            ("foo.bar.baz.*", "foo.bar.baz[0]", true),
482            ("foo.bar.baz.*", "foo.bar.baz[123]", true),
483            ("foo.bar.baz.*", "foo.bar.baz['abc']", true),
484            ("foo.bar.baz.*", "foo['bar']['baz']['abc']", true),
485            ("foo.bar.baz.*", "foo['bar']['BAZ']['abc']", true),
486            // Contexts containing stars match correctly.
487            ("foo.bar.baz.*", "foo.bar.baz.*", true),
488            ("foo.bar.*.*", "foo.bar.*.*", true),
489            ("foo.bar.baz.qux", "foo.bar.baz.*", false), // patterns are one way
490            ("foo.bar.baz.qux", "foo.bar.baz[*]", false), // patterns are one way
491            // False normal matches.
492            ("foo", "bar", false),                     // different identifier
493            ("foo.bar", "foo.baz", false),             // different identifier
494            ("foo.bar", "foo['baz']", false),          // different index
495            ("foo.bar.baz", "foo.bar.baz.qux", false), // pattern too short
496            ("foo.bar.baz", "foo.bar", false),         // context too short
497            ("foo.*.baz", "foo.bar.baz.qux", false),   // pattern too short
498            ("foo.*.qux", "foo.bar.baz.qux", false),   // * does not match multiple parts
499            ("foo.*.*", "foo.bar.baz.qux", false),     // pattern too short
500            ("foo.1", "foo[1]", false),                // .1 means a string key, not an index
501        ] {
502            let pattern = ContextPattern::try_new(pattern)
503                .unwrap_or_else(|| panic!("invalid pattern: {pattern}"));
504            let ctx = Context::parse(*ctx).unwrap();
505            assert_eq!(pattern.matches(&ctx), *expected);
506        }
507    }
508}