1use crate::Literal;
5
6use super::{Expr, SpannedExpr};
7
8#[derive(Debug, PartialEq)]
15pub struct Context<'src> {
16 pub parts: Vec<SpannedExpr<'src>>,
18}
19
20impl<'src> Context<'src> {
21 pub(crate) fn new(parts: impl Into<Vec<SpannedExpr<'src>>>) -> Self {
22 Self {
23 parts: parts.into(),
24 }
25 }
26
27 pub fn matches(&self, pattern: impl TryInto<ContextPattern<'src>>) -> bool {
29 let Ok(pattern) = pattern.try_into() else {
30 return false;
31 };
32
33 pattern.matches(self)
34 }
35
36 pub fn child_of(&self, parent: impl TryInto<ContextPattern<'src>>) -> bool {
41 let Ok(parent) = parent.try_into() else {
42 return false;
43 };
44
45 parent.parent_of(self)
46 }
47
48 pub fn single_tail(&self) -> Option<&str> {
57 if self.parts.len() != 2 || !matches!(*self.parts[0], Expr::Identifier(_)) {
58 return None;
59 }
60
61 match &self.parts[1].inner {
62 Expr::Identifier(ident) => Some(ident.as_str()),
63 Expr::Index(idx) => match &idx.inner {
64 Expr::Literal(Literal::String(idx)) => Some(idx),
65 _ => None,
66 },
67 _ => None,
68 }
69 }
70
71 pub fn as_pattern(&self) -> Option<String> {
80 fn push_part(part: &Expr<'_>, pattern: &mut String) {
81 match part {
82 Expr::Identifier(ident) => pattern.push_str(ident.0),
83 Expr::Star => pattern.push('*'),
84 Expr::Index(idx) => match &idx.inner {
85 Expr::Literal(Literal::String(idx)) => pattern.push_str(idx),
87 _ => pattern.push('*'),
90 },
91 _ => unreachable!("unexpected part in context pattern"),
92 }
93 }
94
95 let mut pattern = String::new();
101
102 let mut parts = self.parts.iter().peekable();
103
104 let head = parts.next()?;
105 if matches!(**head, Expr::Call { .. }) {
106 return None;
107 }
108
109 push_part(head, &mut pattern);
110 for part in parts {
111 pattern.push('.');
112 push_part(part, &mut pattern);
113 }
114
115 pattern.make_ascii_lowercase();
116 Some(pattern)
117 }
118}
119
120enum Comparison {
121 Child,
122 Match,
123}
124
125pub struct ContextPattern<'src>(
136 &'src str,
140);
141
142impl<'src> TryFrom<&'src str> for ContextPattern<'src> {
143 type Error = anyhow::Error;
144
145 fn try_from(val: &'src str) -> anyhow::Result<Self> {
146 Self::try_new(val).ok_or_else(|| anyhow::anyhow!("invalid context pattern"))
147 }
148}
149
150impl<'src> ContextPattern<'src> {
151 pub const fn new(pattern: &'src str) -> Self {
155 Self::try_new(pattern).expect("invalid context pattern; use try_new to handle errors")
156 }
157
158 pub const fn try_new(pattern: &'src str) -> Option<Self> {
162 let raw_pattern = pattern.as_bytes();
163 if raw_pattern.is_empty() {
164 return None;
165 }
166
167 let len = raw_pattern.len();
168
169 let mut accept_reg = true;
174 let mut accept_dot = false;
175 let mut accept_star = false;
176
177 let mut idx = 0;
178 while idx < len {
179 accept_dot = accept_dot && idx != len - 1;
180
181 match raw_pattern[idx] {
182 b'.' => {
183 if !accept_dot {
184 return None;
185 }
186
187 accept_reg = true;
188 accept_dot = false;
189 accept_star = true;
190 }
191 b'*' => {
192 if !accept_star {
193 return None;
194 }
195
196 accept_reg = false;
197 accept_star = false;
198 accept_dot = true;
199 }
200 c if c.is_ascii_alphanumeric() || c == b'-' || c == b'_' => {
201 if !accept_reg {
202 return None;
203 }
204
205 accept_reg = true;
206 accept_dot = true;
207 accept_star = false;
208 }
209 _ => return None, }
211
212 idx += 1;
213 }
214
215 Some(Self(pattern))
216 }
217
218 fn compare_part(pattern: &str, part: &Expr<'src>) -> bool {
219 if pattern == "*" {
220 true
221 } else {
222 match part {
223 Expr::Identifier(part) => pattern.eq_ignore_ascii_case(part.0),
224 Expr::Index(part) => match &part.inner {
225 Expr::Literal(Literal::String(part)) => pattern.eq_ignore_ascii_case(part),
226 _ => false,
227 },
228 _ => false,
229 }
230 }
231 }
232
233 fn compare(&self, ctx: &Context<'src>) -> Option<Comparison> {
234 let mut pattern_parts = self.0.split('.').peekable();
235 let mut ctx_parts = ctx.parts.iter().peekable();
236
237 while let (Some(pattern), Some(part)) = (pattern_parts.peek(), ctx_parts.peek()) {
238 if !Self::compare_part(pattern, part) {
239 return None;
240 }
241
242 pattern_parts.next();
243 ctx_parts.next();
244 }
245
246 match (pattern_parts.next(), ctx_parts.next()) {
247 (None, None) => Some(Comparison::Match),
249 (None, Some(_)) => Some(Comparison::Child),
252 _ => None,
253 }
254 }
255
256 pub fn parent_of(&self, ctx: &Context<'src>) -> bool {
261 matches!(
262 self.compare(ctx),
263 Some(Comparison::Child | Comparison::Match)
264 )
265 }
266
267 pub fn matches(&self, ctx: &Context<'src>) -> bool {
271 matches!(self.compare(ctx), Some(Comparison::Match))
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use crate::Expr;
278
279 use super::{Context, ContextPattern};
280
281 impl<'a> TryFrom<&'a str> for Context<'a> {
282 type Error = anyhow::Error;
283
284 fn try_from(val: &'a str) -> anyhow::Result<Self> {
285 let expr = Expr::parse(val)?;
286
287 match expr.inner {
288 Expr::Context(ctx) => Ok(ctx),
289 _ => Err(anyhow::anyhow!("expected context, found {:?}", expr)),
290 }
291 }
292 }
293
294 #[test]
295 fn test_context_child_of() {
296 let ctx = Context::try_from("foo.bar.baz").unwrap();
297
298 for (case, child) in &[
299 ("foo", true),
301 ("foo.bar", true),
302 ("FOO", true),
304 ("FOO.BAR", true),
305 ("Foo", true),
306 ("Foo.Bar", true),
307 ("foo.bar.baz", true),
309 ("foo.bar.baz.qux", false),
311 ("foo.bar.qux", false),
312 ("foo.qux", false),
313 ("qux", false),
314 ("foo.", false),
316 (".", false),
317 ("", false),
318 ] {
319 assert_eq!(ctx.child_of(*case), *child);
320 }
321 }
322
323 #[test]
324 fn test_single_tail() {
325 for (case, expected) in &[
326 ("foo.bar", Some("bar")),
328 ("foo['bar']", Some("bar")),
329 ("inputs.test", Some("test")),
330 ("foo.bar.baz", None), ("foo.bar.baz.qux", None), ("foo['bar']['baz']", None), ("foo().bar", None), ] {
336 let ctx = Context::try_from(*case).unwrap();
337 assert_eq!(ctx.single_tail(), *expected);
338 }
339 }
340
341 #[test]
342 fn test_context_as_pattern() {
343 for (case, expected) in &[
344 ("foo", Some("foo")),
346 ("foo.bar", Some("foo.bar")),
347 ("foo.bar.baz", Some("foo.bar.baz")),
348 ("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
349 ("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
350 ("foo.*", Some("foo.*")),
351 ("foo.bar.*", Some("foo.bar.*")),
352 ("foo.*.baz", Some("foo.*.baz")),
353 ("foo.*.*", Some("foo.*.*")),
354 ("FOO", Some("foo")),
356 ("FOO.BAR", Some("foo.bar")),
357 ("FOO.BAR.BAZ", Some("foo.bar.baz")),
358 ("FOO.BAR.BAZ_BAZ", Some("foo.bar.baz_baz")),
359 ("FOO.BAR.BAZ-BAZ", Some("foo.bar.baz-baz")),
360 ("FOO.*", Some("foo.*")),
361 ("FOO.BAR.*", Some("foo.bar.*")),
362 ("FOO.*.BAZ", Some("foo.*.baz")),
363 ("FOO.*.*", Some("foo.*.*")),
364 ("foo.bar.baz[0]", Some("foo.bar.baz.*")),
366 ("foo.bar.baz['abc']", Some("foo.bar.baz.abc")),
367 ("foo.bar.baz[0].qux", Some("foo.bar.baz.*.qux")),
368 ("foo.bar.baz[0].qux[1]", Some("foo.bar.baz.*.qux.*")),
369 ("foo[1][2][3]", Some("foo.*.*.*")),
370 ("foo.bar[abc]", Some("foo.bar.*")),
371 ("foo.bar[abc()]", Some("foo.bar.*")),
372 ("foo . bar", Some("foo.bar")),
374 ("foo . bar . baz", Some("foo.bar.baz")),
375 ("foo . bar . baz_baz", Some("foo.bar.baz_baz")),
376 ("foo . bar . baz-baz", Some("foo.bar.baz-baz")),
377 ("foo .*", Some("foo.*")),
378 ("foo . bar .*", Some("foo.bar.*")),
379 ("foo .* . baz", Some("foo.*.baz")),
380 ("foo .* .*", Some("foo.*.*")),
381 ("foo().bar", None),
383 ] {
384 let ctx = Context::try_from(*case).unwrap();
385 assert_eq!(ctx.as_pattern().as_deref(), *expected);
386 }
387 }
388
389 #[test]
390 fn test_contextpattern_new() {
391 for (case, expected) in &[
392 ("foo", Some("foo")),
394 ("foo.bar", Some("foo.bar")),
395 ("foo.bar.baz", Some("foo.bar.baz")),
396 ("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
397 ("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
398 ("foo.*", Some("foo.*")),
399 ("foo.bar.*", Some("foo.bar.*")),
400 ("foo.*.baz", Some("foo.*.baz")),
401 ("foo.*.*", Some("foo.*.*")),
402 ("", None),
404 ("*", None),
405 ("**", None),
406 (".**", None),
407 (".foo", None),
408 ("foo.", None),
409 (".foo.", None),
410 ("foo.**", None),
411 (".", None),
412 ("foo.bar.", None),
413 ("foo..bar", None),
414 ("foo.bar.baz[0]", None),
415 ("foo.bar.baz['abc']", None),
416 ("foo.bar.baz[0].qux", None),
417 ("foo.bar.baz[0].qux[1]", None),
418 ("❤", None),
419 ("❤.*", None),
420 ] {
421 assert_eq!(ContextPattern::try_new(case).map(|p| p.0), *expected);
422 }
423 }
424
425 #[test]
426 fn test_contextpattern_parent_of() {
427 for (pattern, ctx, expected) in &[
428 ("foo", "foo", true),
430 ("foo.bar", "foo.bar", true),
431 ("foo.bar", "foo['bar']", true),
432 ("foo.bar", "foo['BAR']", true),
433 ("foo", "foo.bar", true),
435 ("foo.bar", "foo.bar.baz", true),
436 ("foo.*", "foo.bar", true),
437 ("foo.*.baz", "foo.bar.baz", true),
438 ("foo.*.*", "foo.bar.baz.qux", true),
439 ("foo", "foo.bar.baz.qux", true),
440 ("foo.*", "foo.bar.baz.qux", true),
441 (
442 "secrets",
443 "fromJson(steps.runs.outputs.data).workflow_runs[0].id",
444 false,
445 ),
446 ] {
447 let pattern = ContextPattern::try_new(pattern).unwrap();
448 let ctx = Context::try_from(*ctx).unwrap();
449 assert_eq!(pattern.parent_of(&ctx), *expected);
450 }
451 }
452
453 #[test]
454 fn test_context_pattern_matches() {
455 for (pattern, ctx, expected) in &[
456 ("foo", "foo", true),
458 ("foo.bar", "foo.bar", true),
459 ("foo.bar.baz", "foo.bar.baz", true),
460 ("foo.*", "foo.bar", true),
461 ("foo.*.baz", "foo.bar.baz", true),
462 ("foo.*.*", "foo.bar.baz", true),
463 ("foo.*.*.*", "foo.bar.baz.qux", true),
464 ("foo.bar", "FOO.BAR", true),
466 ("foo.bar.baz", "Foo.Bar.Baz", true),
467 ("foo.*", "FOO.BAR", true),
468 ("foo.*.baz", "Foo.Bar.Baz", true),
469 ("foo.*.*", "FOO.BAR.BAZ", true),
470 ("FOO.BAR", "foo.bar", true),
471 ("FOO.BAR.BAZ", "foo.bar.baz", true),
472 ("FOO.*", "foo.bar", true),
473 ("FOO.*.BAZ", "foo.bar.baz", true),
474 ("FOO.*.*", "foo.bar.baz", true),
475 ("foo.bar.baz.*", "foo.bar.baz[0]", true),
477 ("foo.bar.baz.*", "foo.bar.baz[123]", true),
478 ("foo.bar.baz.*", "foo.bar.baz['abc']", true),
479 ("foo.bar.baz.*", "foo['bar']['baz']['abc']", true),
480 ("foo.bar.baz.*", "foo['bar']['BAZ']['abc']", true),
481 ("foo.bar.baz.*", "foo.bar.baz.*", true),
483 ("foo.bar.*.*", "foo.bar.*.*", true),
484 ("foo.bar.baz.qux", "foo.bar.baz.*", false), ("foo.bar.baz.qux", "foo.bar.baz[*]", false), ("foo", "bar", false), ("foo.bar", "foo.baz", false), ("foo.bar", "foo['baz']", false), ("foo.bar.baz", "foo.bar.baz.qux", false), ("foo.bar.baz", "foo.bar", false), ("foo.*.baz", "foo.bar.baz.qux", false), ("foo.*.qux", "foo.bar.baz.qux", false), ("foo.*.*", "foo.bar.baz.qux", false), ("foo.1", "foo[1]", false), ] {
497 let pattern = ContextPattern::try_new(pattern)
498 .unwrap_or_else(|| panic!("invalid pattern: {pattern}"));
499 let ctx = Context::try_from(*ctx).unwrap();
500 assert_eq!(pattern.matches(&ctx), *expected);
501 }
502 }
503}