1use crate::Literal;
5
6use super::{Expr, SpannedExpr};
7
8#[derive(Debug, thiserror::Error)]
10#[error("invalid context pattern")]
11pub struct InvalidContextPattern;
12
13#[derive(Debug, PartialEq)]
20pub struct Context<'src> {
21 pub parts: Vec<SpannedExpr<'src>>,
23}
24
25impl<'src> Context<'src> {
26 pub(crate) fn new(parts: impl Into<Vec<SpannedExpr<'src>>>) -> Self {
27 Self {
28 parts: parts.into(),
29 }
30 }
31
32 pub fn parse(raw: &'src str) -> Option<Self> {
36 let expr = Expr::parse(raw).ok()?;
37
38 match expr.inner {
39 Expr::Context(ctx) => Some(ctx),
40 _ => None,
41 }
42 }
43
44 pub fn matches(&self, pattern: impl TryInto<ContextPattern<'src>>) -> bool {
46 let Ok(pattern) = pattern.try_into() else {
47 return false;
48 };
49
50 pattern.matches(self)
51 }
52
53 pub fn child_of(&self, parent: impl TryInto<ContextPattern<'src>>) -> bool {
58 let Ok(parent) = parent.try_into() else {
59 return false;
60 };
61
62 parent.parent_of(self)
63 }
64
65 pub fn single_tail(&self) -> Option<&str> {
74 if self.parts.len() != 2 || !matches!(*self.parts[0], Expr::Identifier(_)) {
75 return None;
76 }
77
78 match &self.parts[1].inner {
79 Expr::Identifier(ident) => Some(ident.as_str()),
80 Expr::Index(idx) => match &idx.inner {
81 Expr::Literal(Literal::String(idx)) => Some(idx),
82 _ => None,
83 },
84 _ => None,
85 }
86 }
87
88 pub fn as_pattern(&self) -> Option<String> {
97 fn push_part(part: &Expr<'_>, pattern: &mut String) {
98 match part {
99 Expr::Identifier(ident) => pattern.push_str(ident.0),
100 Expr::Star => pattern.push('*'),
101 Expr::Index(idx) => match &idx.inner {
102 Expr::Literal(Literal::String(idx)) => pattern.push_str(idx),
104 _ => pattern.push('*'),
107 },
108 _ => unreachable!("unexpected part in context pattern"),
109 }
110 }
111
112 let mut pattern = String::new();
118
119 let mut parts = self.parts.iter().peekable();
120
121 let head = parts.next()?;
122 if matches!(**head, Expr::Call { .. }) {
123 return None;
124 }
125
126 push_part(head, &mut pattern);
127 for part in parts {
128 pattern.push('.');
129 push_part(part, &mut pattern);
130 }
131
132 pattern.make_ascii_lowercase();
133 Some(pattern)
134 }
135}
136
137enum Comparison {
138 Child,
139 Match,
140}
141
142pub struct ContextPattern<'src>(
153 &'src str,
157);
158
159impl<'src> TryFrom<&'src str> for ContextPattern<'src> {
160 type Error = InvalidContextPattern;
161
162 fn try_from(val: &'src str) -> Result<Self, Self::Error> {
163 Self::try_new(val).ok_or(InvalidContextPattern)
164 }
165}
166
167impl<'src> ContextPattern<'src> {
168 pub const fn new(pattern: &'src str) -> Self {
172 Self::try_new(pattern).expect("invalid context pattern; use try_new to handle errors")
173 }
174
175 pub const fn try_new(pattern: &'src str) -> Option<Self> {
179 let raw_pattern = pattern.as_bytes();
180 if raw_pattern.is_empty() {
181 return None;
182 }
183
184 let len = raw_pattern.len();
185
186 let mut accept_reg = true;
191 let mut accept_dot = false;
192 let mut accept_star = false;
193
194 let mut idx = 0;
195 while idx < len {
196 accept_dot = accept_dot && idx != len - 1;
197
198 match raw_pattern[idx] {
199 b'.' => {
200 if !accept_dot {
201 return None;
202 }
203
204 accept_reg = true;
205 accept_dot = false;
206 accept_star = true;
207 }
208 b'*' => {
209 if !accept_star {
210 return None;
211 }
212
213 accept_reg = false;
214 accept_star = false;
215 accept_dot = true;
216 }
217 c if c.is_ascii_alphanumeric() || c == b'-' || c == b'_' => {
218 if !accept_reg {
219 return None;
220 }
221
222 accept_reg = true;
223 accept_dot = true;
224 accept_star = false;
225 }
226 _ => return None, }
228
229 idx += 1;
230 }
231
232 Some(Self(pattern))
233 }
234
235 fn compare_part(pattern: &str, part: &Expr<'src>) -> bool {
236 if pattern == "*" {
237 true
238 } else {
239 match part {
240 Expr::Identifier(part) => pattern.eq_ignore_ascii_case(part.0),
241 Expr::Index(part) => match &part.inner {
242 Expr::Literal(Literal::String(part)) => pattern.eq_ignore_ascii_case(part),
243 _ => false,
244 },
245 _ => false,
246 }
247 }
248 }
249
250 fn compare(&self, ctx: &Context<'src>) -> Option<Comparison> {
251 let mut pattern_parts = self.0.split('.').peekable();
252 let mut ctx_parts = ctx.parts.iter().peekable();
253
254 while let (Some(pattern), Some(part)) = (pattern_parts.peek(), ctx_parts.peek()) {
255 if !Self::compare_part(pattern, part) {
256 return None;
257 }
258
259 pattern_parts.next();
260 ctx_parts.next();
261 }
262
263 match (pattern_parts.next(), ctx_parts.next()) {
264 (None, None) => Some(Comparison::Match),
266 (None, Some(_)) => Some(Comparison::Child),
269 _ => None,
270 }
271 }
272
273 pub fn parent_of(&self, ctx: &Context<'src>) -> bool {
278 matches!(
279 self.compare(ctx),
280 Some(Comparison::Child | Comparison::Match)
281 )
282 }
283
284 pub fn matches(&self, ctx: &Context<'src>) -> bool {
288 matches!(self.compare(ctx), Some(Comparison::Match))
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::{Context, ContextPattern};
295
296 #[test]
297 fn test_context_child_of() {
298 let ctx = Context::parse("foo.bar.baz").unwrap();
299
300 for (case, child) in &[
301 ("foo", true),
303 ("foo.bar", true),
304 ("FOO", true),
306 ("FOO.BAR", true),
307 ("Foo", true),
308 ("Foo.Bar", true),
309 ("foo.bar.baz", true),
311 ("foo.bar.baz.qux", false),
313 ("foo.bar.qux", false),
314 ("foo.qux", false),
315 ("qux", false),
316 ("foo.", false),
318 (".", false),
319 ("", false),
320 ] {
321 assert_eq!(ctx.child_of(*case), *child);
322 }
323 }
324
325 #[test]
326 fn test_single_tail() {
327 for (case, expected) in &[
328 ("foo.bar", Some("bar")),
330 ("foo['bar']", Some("bar")),
331 ("inputs.test", Some("test")),
332 ("foo.bar.baz", None), ("foo.bar.baz.qux", None), ("foo['bar']['baz']", None), ("fromJSON('{}').bar", None), ] {
338 let ctx = Context::parse(*case).unwrap();
339 assert_eq!(ctx.single_tail(), *expected);
340 }
341 }
342
343 #[test]
344 fn test_context_as_pattern() {
345 for (case, expected) in &[
346 ("foo", Some("foo")),
348 ("foo.bar", Some("foo.bar")),
349 ("foo.bar.baz", Some("foo.bar.baz")),
350 ("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
351 ("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
352 ("foo.*", Some("foo.*")),
353 ("foo.bar.*", Some("foo.bar.*")),
354 ("foo.*.baz", Some("foo.*.baz")),
355 ("foo.*.*", Some("foo.*.*")),
356 ("FOO", Some("foo")),
358 ("FOO.BAR", Some("foo.bar")),
359 ("FOO.BAR.BAZ", Some("foo.bar.baz")),
360 ("FOO.BAR.BAZ_BAZ", Some("foo.bar.baz_baz")),
361 ("FOO.BAR.BAZ-BAZ", Some("foo.bar.baz-baz")),
362 ("FOO.*", Some("foo.*")),
363 ("FOO.BAR.*", Some("foo.bar.*")),
364 ("FOO.*.BAZ", Some("foo.*.baz")),
365 ("FOO.*.*", Some("foo.*.*")),
366 ("foo.bar.baz[0]", Some("foo.bar.baz.*")),
368 ("foo.bar.baz['abc']", Some("foo.bar.baz.abc")),
369 ("foo.bar.baz[0].qux", Some("foo.bar.baz.*.qux")),
370 ("foo.bar.baz[0].qux[1]", Some("foo.bar.baz.*.qux.*")),
371 ("foo[1][2][3]", Some("foo.*.*.*")),
372 ("foo.bar[abc]", Some("foo.bar.*")),
373 (
374 "foo.bar[join(github.event.issue.labels.*.name, ', ')]",
375 Some("foo.bar.*"),
376 ),
377 ("foo . bar", Some("foo.bar")),
379 ("foo . bar . baz", Some("foo.bar.baz")),
380 ("foo . bar . baz_baz", Some("foo.bar.baz_baz")),
381 ("foo . bar . baz-baz", Some("foo.bar.baz-baz")),
382 ("foo .*", Some("foo.*")),
383 ("foo . bar .*", Some("foo.bar.*")),
384 ("foo .* . baz", Some("foo.*.baz")),
385 ("foo .* .*", Some("foo.*.*")),
386 ("fromJSON('{}').bar", None),
388 ] {
389 let ctx = Context::parse(*case).unwrap();
390 assert_eq!(ctx.as_pattern().as_deref(), *expected);
391 }
392 }
393
394 #[test]
395 fn test_contextpattern_new() {
396 for (case, expected) in &[
397 ("foo", Some("foo")),
399 ("foo.bar", Some("foo.bar")),
400 ("foo.bar.baz", Some("foo.bar.baz")),
401 ("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
402 ("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
403 ("foo.*", Some("foo.*")),
404 ("foo.bar.*", Some("foo.bar.*")),
405 ("foo.*.baz", Some("foo.*.baz")),
406 ("foo.*.*", Some("foo.*.*")),
407 ("", None),
409 ("*", None),
410 ("**", None),
411 (".**", None),
412 (".foo", None),
413 ("foo.", None),
414 (".foo.", None),
415 ("foo.**", None),
416 (".", None),
417 ("foo.bar.", None),
418 ("foo..bar", None),
419 ("foo.bar.baz[0]", None),
420 ("foo.bar.baz['abc']", None),
421 ("foo.bar.baz[0].qux", None),
422 ("foo.bar.baz[0].qux[1]", None),
423 ("❤", None),
424 ("❤.*", None),
425 ] {
426 assert_eq!(ContextPattern::try_new(case).map(|p| p.0), *expected);
427 }
428 }
429
430 #[test]
431 fn test_contextpattern_parent_of() {
432 for (pattern, ctx, expected) in &[
433 ("foo", "foo", true),
435 ("foo.bar", "foo.bar", true),
436 ("foo.bar", "foo['bar']", true),
437 ("foo.bar", "foo['BAR']", true),
438 ("foo", "foo.bar", true),
440 ("foo.bar", "foo.bar.baz", true),
441 ("foo.*", "foo.bar", true),
442 ("foo.*.baz", "foo.bar.baz", true),
443 ("foo.*.*", "foo.bar.baz.qux", true),
444 ("foo", "foo.bar.baz.qux", true),
445 ("foo.*", "foo.bar.baz.qux", true),
446 (
447 "secrets",
448 "fromJson(steps.runs.outputs.data).workflow_runs[0].id",
449 false,
450 ),
451 ] {
452 let pattern = ContextPattern::try_new(pattern).unwrap();
453 let ctx = Context::parse(*ctx).unwrap();
454 assert_eq!(pattern.parent_of(&ctx), *expected);
455 }
456 }
457
458 #[test]
459 fn test_context_pattern_matches() {
460 for (pattern, ctx, expected) in &[
461 ("foo", "foo", true),
463 ("foo.bar", "foo.bar", true),
464 ("foo.bar.baz", "foo.bar.baz", true),
465 ("foo.*", "foo.bar", true),
466 ("foo.*.baz", "foo.bar.baz", true),
467 ("foo.*.*", "foo.bar.baz", true),
468 ("foo.*.*.*", "foo.bar.baz.qux", true),
469 ("foo.bar", "FOO.BAR", true),
471 ("foo.bar.baz", "Foo.Bar.Baz", true),
472 ("foo.*", "FOO.BAR", true),
473 ("foo.*.baz", "Foo.Bar.Baz", true),
474 ("foo.*.*", "FOO.BAR.BAZ", true),
475 ("FOO.BAR", "foo.bar", true),
476 ("FOO.BAR.BAZ", "foo.bar.baz", true),
477 ("FOO.*", "foo.bar", true),
478 ("FOO.*.BAZ", "foo.bar.baz", true),
479 ("FOO.*.*", "foo.bar.baz", true),
480 ("foo.bar.baz.*", "foo.bar.baz[0]", true),
482 ("foo.bar.baz.*", "foo.bar.baz[123]", true),
483 ("foo.bar.baz.*", "foo.bar.baz['abc']", true),
484 ("foo.bar.baz.*", "foo['bar']['baz']['abc']", true),
485 ("foo.bar.baz.*", "foo['bar']['BAZ']['abc']", true),
486 ("foo.bar.baz.*", "foo.bar.baz.*", true),
488 ("foo.bar.*.*", "foo.bar.*.*", true),
489 ("foo.bar.baz.qux", "foo.bar.baz.*", false), ("foo.bar.baz.qux", "foo.bar.baz[*]", false), ("foo", "bar", false), ("foo.bar", "foo.baz", false), ("foo.bar", "foo['baz']", false), ("foo.bar.baz", "foo.bar.baz.qux", false), ("foo.bar.baz", "foo.bar", false), ("foo.*.baz", "foo.bar.baz.qux", false), ("foo.*.qux", "foo.bar.baz.qux", false), ("foo.*.*", "foo.bar.baz.qux", false), ("foo.1", "foo[1]", false), ] {
502 let pattern = ContextPattern::try_new(pattern)
503 .unwrap_or_else(|| panic!("invalid pattern: {pattern}"));
504 let ctx = Context::parse(*ctx).unwrap();
505 assert_eq!(pattern.matches(&ctx), *expected);
506 }
507 }
508}