use crate::Literal;
use super::{Expr, SpannedExpr};
#[derive(Debug, thiserror::Error)]
#[error("invalid context pattern")]
pub struct InvalidContextPattern;
#[derive(Debug, PartialEq)]
pub struct Context<'src> {
pub parts: Vec<SpannedExpr<'src>>,
}
impl<'src> Context<'src> {
pub(crate) fn new(parts: impl Into<Vec<SpannedExpr<'src>>>) -> Self {
Self {
parts: parts.into(),
}
}
pub fn parse(raw: &'src str) -> Option<Self> {
let expr = Expr::parse(raw).ok()?;
match expr.inner {
Expr::Context(ctx) => Some(ctx),
_ => None,
}
}
pub fn matches(&self, pattern: impl TryInto<ContextPattern<'src>>) -> bool {
let Ok(pattern) = pattern.try_into() else {
return false;
};
pattern.matches(self)
}
pub fn child_of(&self, parent: impl TryInto<ContextPattern<'src>>) -> bool {
let Ok(parent) = parent.try_into() else {
return false;
};
parent.parent_of(self)
}
pub fn single_tail(&self) -> Option<&str> {
if self.parts.len() != 2 || !matches!(*self.parts[0], Expr::Identifier(_)) {
return None;
}
match &self.parts[1].inner {
Expr::Identifier(ident) => Some(ident.as_str()),
Expr::Index(idx) => match &idx.inner {
Expr::Literal(Literal::String(idx)) => Some(idx),
_ => None,
},
_ => None,
}
}
pub fn as_pattern(&self) -> Option<String> {
fn push_part(part: &Expr<'_>, pattern: &mut String) {
match part {
Expr::Identifier(ident) => pattern.push_str(ident.0),
Expr::Star => pattern.push('*'),
Expr::Index(idx) => match &idx.inner {
Expr::Literal(Literal::String(idx)) => pattern.push_str(idx),
_ => pattern.push('*'),
},
_ => unreachable!("unexpected part in context pattern"),
}
}
let mut pattern = String::new();
let mut parts = self.parts.iter().peekable();
let head = parts.next()?;
if matches!(**head, Expr::Call { .. }) {
return None;
}
push_part(head, &mut pattern);
for part in parts {
pattern.push('.');
push_part(part, &mut pattern);
}
pattern.make_ascii_lowercase();
Some(pattern)
}
}
enum Comparison {
Child,
Match,
}
pub struct ContextPattern<'src>(
&'src str,
);
impl<'src> TryFrom<&'src str> for ContextPattern<'src> {
type Error = InvalidContextPattern;
fn try_from(val: &'src str) -> Result<Self, Self::Error> {
Self::try_new(val).ok_or(InvalidContextPattern)
}
}
impl<'src> ContextPattern<'src> {
pub const fn new(pattern: &'src str) -> Self {
Self::try_new(pattern).expect("invalid context pattern; use try_new to handle errors")
}
pub const fn try_new(pattern: &'src str) -> Option<Self> {
let raw_pattern = pattern.as_bytes();
if raw_pattern.is_empty() {
return None;
}
let len = raw_pattern.len();
let mut accept_reg = true;
let mut accept_dot = false;
let mut accept_star = false;
let mut idx = 0;
while idx < len {
accept_dot = accept_dot && idx != len - 1;
match raw_pattern[idx] {
b'.' => {
if !accept_dot {
return None;
}
accept_reg = true;
accept_dot = false;
accept_star = true;
}
b'*' => {
if !accept_star {
return None;
}
accept_reg = false;
accept_star = false;
accept_dot = true;
}
c if c.is_ascii_alphanumeric() || c == b'-' || c == b'_' => {
if !accept_reg {
return None;
}
accept_reg = true;
accept_dot = true;
accept_star = false;
}
_ => return None, }
idx += 1;
}
Some(Self(pattern))
}
fn compare_part(pattern: &str, part: &Expr<'src>) -> bool {
if pattern == "*" {
true
} else {
match part {
Expr::Identifier(part) => pattern.eq_ignore_ascii_case(part.0),
Expr::Index(part) => match &part.inner {
Expr::Literal(Literal::String(part)) => pattern.eq_ignore_ascii_case(part),
_ => false,
},
_ => false,
}
}
}
fn compare(&self, ctx: &Context<'src>) -> Option<Comparison> {
let mut pattern_parts = self.0.split('.').peekable();
let mut ctx_parts = ctx.parts.iter().peekable();
while let (Some(pattern), Some(part)) = (pattern_parts.peek(), ctx_parts.peek()) {
if !Self::compare_part(pattern, part) {
return None;
}
pattern_parts.next();
ctx_parts.next();
}
match (pattern_parts.next(), ctx_parts.next()) {
(None, None) => Some(Comparison::Match),
(None, Some(_)) => Some(Comparison::Child),
_ => None,
}
}
pub fn parent_of(&self, ctx: &Context<'src>) -> bool {
matches!(
self.compare(ctx),
Some(Comparison::Child | Comparison::Match)
)
}
pub fn matches(&self, ctx: &Context<'src>) -> bool {
matches!(self.compare(ctx), Some(Comparison::Match))
}
}
#[cfg(test)]
mod tests {
use super::{Context, ContextPattern};
#[test]
fn test_context_child_of() {
let ctx = Context::parse("foo.bar.baz").unwrap();
for (case, child) in &[
("foo", true),
("foo.bar", true),
("FOO", true),
("FOO.BAR", true),
("Foo", true),
("Foo.Bar", true),
("foo.bar.baz", true),
("foo.bar.baz.qux", false),
("foo.bar.qux", false),
("foo.qux", false),
("qux", false),
("foo.", false),
(".", false),
("", false),
] {
assert_eq!(ctx.child_of(*case), *child);
}
}
#[test]
fn test_single_tail() {
for (case, expected) in &[
("foo.bar", Some("bar")),
("foo['bar']", Some("bar")),
("inputs.test", Some("test")),
("foo.bar.baz", None), ("foo.bar.baz.qux", None), ("foo['bar']['baz']", None), ("fromJSON('{}').bar", None), ] {
let ctx = Context::parse(*case).unwrap();
assert_eq!(ctx.single_tail(), *expected);
}
}
#[test]
fn test_context_as_pattern() {
for (case, expected) in &[
("foo", Some("foo")),
("foo.bar", Some("foo.bar")),
("foo.bar.baz", Some("foo.bar.baz")),
("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
("foo.*", Some("foo.*")),
("foo.bar.*", Some("foo.bar.*")),
("foo.*.baz", Some("foo.*.baz")),
("foo.*.*", Some("foo.*.*")),
("FOO", Some("foo")),
("FOO.BAR", Some("foo.bar")),
("FOO.BAR.BAZ", Some("foo.bar.baz")),
("FOO.BAR.BAZ_BAZ", Some("foo.bar.baz_baz")),
("FOO.BAR.BAZ-BAZ", Some("foo.bar.baz-baz")),
("FOO.*", Some("foo.*")),
("FOO.BAR.*", Some("foo.bar.*")),
("FOO.*.BAZ", Some("foo.*.baz")),
("FOO.*.*", Some("foo.*.*")),
("foo.bar.baz[0]", Some("foo.bar.baz.*")),
("foo.bar.baz['abc']", Some("foo.bar.baz.abc")),
("foo.bar.baz[0].qux", Some("foo.bar.baz.*.qux")),
("foo.bar.baz[0].qux[1]", Some("foo.bar.baz.*.qux.*")),
("foo[1][2][3]", Some("foo.*.*.*")),
("foo.bar[abc]", Some("foo.bar.*")),
(
"foo.bar[join(github.event.issue.labels.*.name, ', ')]",
Some("foo.bar.*"),
),
("foo . bar", Some("foo.bar")),
("foo . bar . baz", Some("foo.bar.baz")),
("foo . bar . baz_baz", Some("foo.bar.baz_baz")),
("foo . bar . baz-baz", Some("foo.bar.baz-baz")),
("foo .*", Some("foo.*")),
("foo . bar .*", Some("foo.bar.*")),
("foo .* . baz", Some("foo.*.baz")),
("foo .* .*", Some("foo.*.*")),
("fromJSON('{}').bar", None),
] {
let ctx = Context::parse(*case).unwrap();
assert_eq!(ctx.as_pattern().as_deref(), *expected);
}
}
#[test]
fn test_contextpattern_new() {
for (case, expected) in &[
("foo", Some("foo")),
("foo.bar", Some("foo.bar")),
("foo.bar.baz", Some("foo.bar.baz")),
("foo.bar.baz_baz", Some("foo.bar.baz_baz")),
("foo.bar.baz-baz", Some("foo.bar.baz-baz")),
("foo.*", Some("foo.*")),
("foo.bar.*", Some("foo.bar.*")),
("foo.*.baz", Some("foo.*.baz")),
("foo.*.*", Some("foo.*.*")),
("", None),
("*", None),
("**", None),
(".**", None),
(".foo", None),
("foo.", None),
(".foo.", None),
("foo.**", None),
(".", None),
("foo.bar.", None),
("foo..bar", None),
("foo.bar.baz[0]", None),
("foo.bar.baz['abc']", None),
("foo.bar.baz[0].qux", None),
("foo.bar.baz[0].qux[1]", None),
("❤", None),
("❤.*", None),
] {
assert_eq!(ContextPattern::try_new(case).map(|p| p.0), *expected);
}
}
#[test]
fn test_contextpattern_parent_of() {
for (pattern, ctx, expected) in &[
("foo", "foo", true),
("foo.bar", "foo.bar", true),
("foo.bar", "foo['bar']", true),
("foo.bar", "foo['BAR']", true),
("foo", "foo.bar", true),
("foo.bar", "foo.bar.baz", true),
("foo.*", "foo.bar", true),
("foo.*.baz", "foo.bar.baz", true),
("foo.*.*", "foo.bar.baz.qux", true),
("foo", "foo.bar.baz.qux", true),
("foo.*", "foo.bar.baz.qux", true),
(
"secrets",
"fromJson(steps.runs.outputs.data).workflow_runs[0].id",
false,
),
] {
let pattern = ContextPattern::try_new(pattern).unwrap();
let ctx = Context::parse(*ctx).unwrap();
assert_eq!(pattern.parent_of(&ctx), *expected);
}
}
#[test]
fn test_context_pattern_matches() {
for (pattern, ctx, expected) in &[
("foo", "foo", true),
("foo.bar", "foo.bar", true),
("foo.bar.baz", "foo.bar.baz", true),
("foo.*", "foo.bar", true),
("foo.*.baz", "foo.bar.baz", true),
("foo.*.*", "foo.bar.baz", true),
("foo.*.*.*", "foo.bar.baz.qux", true),
("foo.bar", "FOO.BAR", true),
("foo.bar.baz", "Foo.Bar.Baz", true),
("foo.*", "FOO.BAR", true),
("foo.*.baz", "Foo.Bar.Baz", true),
("foo.*.*", "FOO.BAR.BAZ", true),
("FOO.BAR", "foo.bar", true),
("FOO.BAR.BAZ", "foo.bar.baz", true),
("FOO.*", "foo.bar", true),
("FOO.*.BAZ", "foo.bar.baz", true),
("FOO.*.*", "foo.bar.baz", true),
("foo.bar.baz.*", "foo.bar.baz[0]", true),
("foo.bar.baz.*", "foo.bar.baz[123]", true),
("foo.bar.baz.*", "foo.bar.baz['abc']", true),
("foo.bar.baz.*", "foo['bar']['baz']['abc']", true),
("foo.bar.baz.*", "foo['bar']['BAZ']['abc']", true),
("foo.bar.baz.*", "foo.bar.baz.*", true),
("foo.bar.*.*", "foo.bar.*.*", true),
("foo.bar.baz.qux", "foo.bar.baz.*", false), ("foo.bar.baz.qux", "foo.bar.baz[*]", false), ("foo", "bar", false), ("foo.bar", "foo.baz", false), ("foo.bar", "foo['baz']", false), ("foo.bar.baz", "foo.bar.baz.qux", false), ("foo.bar.baz", "foo.bar", false), ("foo.*.baz", "foo.bar.baz.qux", false), ("foo.*.qux", "foo.bar.baz.qux", false), ("foo.*.*", "foo.bar.baz.qux", false), ("foo.1", "foo[1]", false), ] {
let pattern = ContextPattern::try_new(pattern)
.unwrap_or_else(|| panic!("invalid pattern: {pattern}"));
let ctx = Context::parse(*ctx).unwrap();
assert_eq!(pattern.matches(&ctx), *expected);
}
}
}