use proc_macro2::TokenStream;
use quote::quote;
pub fn emit_pattern_check(pattern: &str) -> Option<TokenStream> {
let segments = parse_pattern(pattern)?;
Some(emit_segments(&segments))
}
#[derive(Debug, Clone, PartialEq)]
enum Segment {
CharClass {
ranges: Vec<CharRange>,
quantifier: Quantifier,
},
Literal(u8),
Group {
segments: Vec<Segment>,
quantifier: Quantifier,
},
}
#[derive(Debug, Clone, PartialEq)]
enum CharRange {
Range(u8, u8),
Single(u8),
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum Quantifier {
Exact(usize),
Range(usize, usize),
OneOrMore,
}
fn parse_pattern(pattern: &str) -> Option<Vec<Segment>> {
let bytes = pattern.as_bytes();
let mut pos = 0;
let mut segments = Vec::new();
while pos < bytes.len() {
match bytes[pos] {
b'[' => {
let (ranges, end) = parse_char_class(bytes, pos)?;
pos = end;
let quantifier = parse_quantifier(bytes, &mut pos);
segments.push(Segment::CharClass { ranges, quantifier });
}
b'(' => {
let (inner, end) = parse_group(bytes, pos)?;
pos = end;
let quantifier = parse_quantifier(bytes, &mut pos);
segments.push(Segment::Group {
segments: inner,
quantifier,
});
}
b'\\' => {
pos += 1;
if pos >= bytes.len() {
return None;
}
let ch = bytes[pos];
pos += 1;
let quantifier = parse_quantifier(bytes, &mut pos);
match quantifier {
Quantifier::Exact(1) => segments.push(Segment::Literal(ch)),
_ => {
segments.push(Segment::CharClass {
ranges: vec![CharRange::Single(ch)],
quantifier,
});
}
}
}
ch => {
pos += 1;
let quantifier = parse_quantifier(bytes, &mut pos);
match quantifier {
Quantifier::Exact(1) => segments.push(Segment::Literal(ch)),
_ => {
segments.push(Segment::CharClass {
ranges: vec![CharRange::Single(ch)],
quantifier,
});
}
}
}
}
}
Some(segments)
}
fn read_char_or_escape(bytes: &[u8], pos: &mut usize) -> Option<u8> {
if *pos >= bytes.len() {
return None;
}
if bytes[*pos] == b'\\' {
*pos += 1;
if *pos >= bytes.len() {
return None;
}
}
let ch = bytes[*pos];
*pos += 1;
Some(ch)
}
fn parse_char_class(bytes: &[u8], start: usize) -> Option<(Vec<CharRange>, usize)> {
debug_assert_eq!(bytes[start], b'[');
let mut pos = start + 1;
let mut ranges = Vec::new();
while pos < bytes.len() && bytes[pos] != b']' {
let ch = read_char_or_escape(bytes, &mut pos)?;
if pos < bytes.len()
&& bytes[pos] == b'-'
&& pos + 1 < bytes.len()
&& bytes[pos + 1] != b']'
{
pos += 1; let end_ch = read_char_or_escape(bytes, &mut pos)?;
ranges.push(CharRange::Range(ch, end_ch));
} else {
ranges.push(CharRange::Single(ch));
}
}
if pos >= bytes.len() {
return None; }
pos += 1; Some((ranges, pos))
}
fn parse_group(bytes: &[u8], start: usize) -> Option<(Vec<Segment>, usize)> {
debug_assert_eq!(bytes[start], b'(');
let mut depth = 1;
let mut pos = start + 1;
while pos < bytes.len() && depth > 0 {
match bytes[pos] {
b'(' => depth += 1,
b')' => depth -= 1,
b'\\' => pos += 1, _ => {}
}
if depth > 0 {
pos += 1;
}
}
if depth != 0 {
return None;
}
let inner_bytes = &bytes[start + 1..pos];
let inner_str = std::str::from_utf8(inner_bytes).ok()?;
let inner_segments = parse_pattern(inner_str)?;
pos += 1; Some((inner_segments, pos))
}
fn parse_quantifier(bytes: &[u8], pos: &mut usize) -> Quantifier {
if *pos >= bytes.len() {
return Quantifier::Exact(1);
}
match bytes[*pos] {
b'{' => {
let start = *pos + 1;
let mut end = start;
while end < bytes.len() && bytes[end] != b'}' {
end += 1;
}
if end >= bytes.len() {
return Quantifier::Exact(1);
}
let content = std::str::from_utf8(&bytes[start..end]).unwrap_or("");
*pos = end + 1;
if let Some(comma_pos) = content.find(',') {
let min_str = &content[..comma_pos];
let max_str = &content[comma_pos + 1..];
let min = min_str.parse::<usize>().unwrap_or(0);
let max = max_str.parse::<usize>().unwrap_or(min);
if min == max {
Quantifier::Exact(min)
} else {
Quantifier::Range(min, max)
}
} else {
let n = content.parse::<usize>().unwrap_or(1);
Quantifier::Exact(n)
}
}
b'+' => {
*pos += 1;
Quantifier::OneOrMore
}
b'*' => {
*pos += 1;
Quantifier::Range(0, usize::MAX)
}
b'?' => {
*pos += 1;
Quantifier::Range(0, 1)
}
_ => Quantifier::Exact(1),
}
}
fn emit_segments(segments: &[Segment]) -> TokenStream {
let (min_len, max_len) = compute_length_bounds(segments);
if min_len == max_len {
emit_fixed_length(segments, min_len)
} else {
emit_variable_length(segments, min_len, max_len)
}
}
fn compute_length_bounds(segments: &[Segment]) -> (usize, usize) {
let mut min_total = 0usize;
let mut max_total = 0usize;
for seg in segments {
let (seg_min, seg_max) = match seg {
Segment::Literal(_) => (1, 1),
Segment::CharClass { quantifier, .. } => quantifier_bounds(*quantifier),
Segment::Group {
segments: inner,
quantifier,
} => {
let (inner_min, inner_max) = compute_length_bounds(inner);
let (q_min, q_max) = quantifier_bounds(*quantifier);
(
inner_min.saturating_mul(q_min),
inner_max.saturating_mul(q_max),
)
}
};
min_total = min_total.saturating_add(seg_min);
max_total = max_total.saturating_add(seg_max);
}
(min_total, max_total)
}
fn quantifier_bounds(q: Quantifier) -> (usize, usize) {
match q {
Quantifier::Exact(n) => (n, n),
Quantifier::Range(min, max) => (min, max),
Quantifier::OneOrMore => (1, usize::MAX),
}
}
fn emit_char_predicate_negated(ranges: &[CharRange]) -> TokenStream {
let checks: Vec<TokenStream> = ranges
.iter()
.map(|r| match r {
CharRange::Range(lo, hi) => {
let lo_lit = *lo;
let hi_lit = *hi;
quote! { !(#lo_lit..=#hi_lit).contains(&b) }
}
CharRange::Single(ch) => {
let ch_lit = *ch;
quote! { b != #ch_lit }
}
})
.collect();
if checks.len() == 1 {
checks.into_iter().next().unwrap()
} else {
quote! { #(#checks)&&* }
}
}
fn emit_fixed_length(segments: &[Segment], len: usize) -> TokenStream {
let mut byte_checks = Vec::new();
let mut offset = 0usize;
emit_fixed_checks(segments, &mut offset, &mut byte_checks);
let len_lit = len;
if byte_checks.is_empty() {
quote! {
{
let bytes = value.as_bytes();
bytes.len() != #len_lit
}
}
} else {
quote! {
{
let bytes = value.as_bytes();
bytes.len() != #len_lit || #(#byte_checks)||*
}
}
}
}
fn emit_fixed_checks(segments: &[Segment], offset: &mut usize, checks: &mut Vec<TokenStream>) {
for seg in segments {
match seg {
Segment::Literal(ch) => {
let off = *offset;
let ch_lit = *ch;
checks.push(quote! { bytes[#off] != #ch_lit });
*offset += 1;
}
Segment::CharClass {
ranges, quantifier, ..
} => {
let count = match quantifier {
Quantifier::Exact(n) => *n,
_ => unreachable!("fixed-length path only for exact quantifiers"),
};
let neg_pred = emit_char_predicate_negated(ranges);
for _ in 0..count {
let off = *offset;
let check = quote! { ({ let b = bytes[#off]; #neg_pred }) };
checks.push(check);
*offset += 1;
}
}
Segment::Group {
segments: inner,
quantifier,
} => {
let count = match quantifier {
Quantifier::Exact(n) => *n,
_ => unreachable!("fixed-length path only for exact quantifiers"),
};
for _ in 0..count {
emit_fixed_checks(inner, offset, checks);
}
}
}
}
}
fn emit_variable_length(segments: &[Segment], min_len: usize, max_len: usize) -> TokenStream {
let mut stmts = Vec::new();
if max_len < usize::MAX {
stmts.push(quote! {
if !(#min_len..=#max_len).contains(&len) {
return true;
}
});
} else {
stmts.push(quote! {
if len < #min_len {
return true;
}
});
}
emit_cursor_stmts(segments, &mut stmts);
stmts.push(quote! {
if pos != len {
return true;
}
});
quote! {
{
let bytes = value.as_bytes();
let len = bytes.len();
let result: bool = (|| -> bool {
let mut pos: usize = 0;
#(#stmts)*
false
})();
result
}
}
}
fn emit_cursor_stmts(segments: &[Segment], stmts: &mut Vec<TokenStream>) {
for seg in segments {
match seg {
Segment::Literal(ch) => {
let ch_lit = *ch;
stmts.push(quote! {
if pos >= len || bytes[pos] != #ch_lit {
return true;
}
pos += 1;
});
}
Segment::CharClass {
ranges, quantifier, ..
} => {
let neg_pred = emit_char_predicate_negated(ranges);
match quantifier {
Quantifier::Exact(n) => {
let n_lit = *n;
stmts.push(quote! {
{
let end = pos + #n_lit;
if end > len {
return true;
}
for &b in &bytes[pos..end] {
if #neg_pred {
return true;
}
}
pos = end;
}
});
}
Quantifier::Range(min, max) => {
let min_lit = *min;
let max_lit = *max;
let (start_binding, min_guard) = if min_lit > 0 {
(
quote! { let start = pos; },
quote! {
let matched = pos - start;
if matched < #min_lit {
return true;
}
},
)
} else {
(TokenStream::new(), TokenStream::new())
};
if *max == usize::MAX {
stmts.push(quote! {
{
#start_binding
while pos < len {
let b = bytes[pos];
if #neg_pred {
break;
}
pos += 1;
}
#min_guard
}
});
} else {
stmts.push(quote! {
{
#start_binding
let limit = if pos + #max_lit < len { pos + #max_lit } else { len };
while pos < limit {
let b = bytes[pos];
if #neg_pred {
break;
}
pos += 1;
}
#min_guard
}
});
}
}
Quantifier::OneOrMore => {
let min_lit = 1usize;
stmts.push(quote! {
{
let start = pos;
while pos < len {
let b = bytes[pos];
if #neg_pred {
break;
}
pos += 1;
}
let matched = pos - start;
if matched < #min_lit {
return true;
}
}
});
}
}
}
Segment::Group {
segments: inner,
quantifier,
} => {
match quantifier {
Quantifier::Exact(n) => {
for _ in 0..*n {
emit_cursor_stmts(inner, stmts);
}
}
Quantifier::Range(min, max) => {
let min_lit = *min;
if *min == 0 && *max == 1 {
let mut inner_stmts = Vec::new();
emit_cursor_stmts(inner, &mut inner_stmts);
stmts.push(quote! {
{
let saved = pos;
let matched: bool = (|| -> bool {
#(#inner_stmts)*
false
})();
if matched {
pos = saved; }
}
});
} else {
let max_lit = *max;
let mut inner_stmts = Vec::new();
emit_cursor_stmts(inner, &mut inner_stmts);
let min_guard = if min_lit > 0 {
quote! {
if count < #min_lit {
return true;
}
}
} else {
TokenStream::new()
};
stmts.push(quote! {
{
let mut count = 0usize;
while count < #max_lit {
let saved = pos;
let failed: bool = (|| -> bool {
#(#inner_stmts)*
false
})();
if failed {
pos = saved;
break;
}
count += 1;
}
#min_guard
}
});
}
}
Quantifier::OneOrMore => {
let mut inner_stmts = Vec::new();
emit_cursor_stmts(inner, &mut inner_stmts);
stmts.push(quote! {
{
let mut count = 0usize;
loop {
let saved = pos;
let failed: bool = (|| -> bool {
#(#inner_stmts)*
false
})();
if failed {
pos = saved;
break;
}
count += 1;
}
if count < 1 {
return true;
}
}
});
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_parses(pattern: &str) {
let ts = emit_pattern_check(pattern)
.unwrap_or_else(|| panic!("emit_pattern_check returned None for: {pattern}"));
let full = quote! {
fn check(value: &str) -> bool {
#ts
}
};
let src = full.to_string();
syn::parse_file(&src).unwrap_or_else(|e| {
panic!("invalid Rust for pattern '{pattern}':\n{src}\nerror: {e}");
});
}
#[test]
fn parse_simple_char_class() {
let segs = parse_pattern("[A-Z]{3,3}").unwrap();
assert_eq!(segs.len(), 1);
assert!(matches!(
&segs[0],
Segment::CharClass {
ranges,
quantifier: Quantifier::Exact(3)
} if ranges.len() == 1
));
}
#[test]
fn parse_literal() {
let segs = parse_pattern("-").unwrap();
assert_eq!(segs.len(), 1);
assert!(matches!(&segs[0], Segment::Literal(b'-')));
}
#[test]
fn parse_escaped_plus() {
let segs = parse_pattern(r"[\+]{0,1}").unwrap();
assert_eq!(segs.len(), 1);
assert!(matches!(
&segs[0],
Segment::CharClass {
quantifier: Quantifier::Range(0, 1),
..
}
));
}
#[test]
fn parse_group() {
let segs = parse_pattern("([A-Z]{2}){3}").unwrap();
assert_eq!(segs.len(), 1);
assert!(matches!(
&segs[0],
Segment::Group {
quantifier: Quantifier::Exact(3),
..
}
));
}
#[test]
fn parse_one_or_more() {
let segs = parse_pattern("[0-9a-fA-F]+").unwrap();
assert_eq!(segs.len(), 1);
assert!(matches!(
&segs[0],
Segment::CharClass {
quantifier: Quantifier::OneOrMore,
..
}
));
}
#[test]
fn pattern_01_three_uppercase() {
assert_parses("[A-Z]{3,3}");
}
#[test]
fn pattern_02_two_uppercase() {
assert_parses("[A-Z]{2,2}");
}
#[test]
fn pattern_03_two_lowercase() {
assert_parses("[a-z]{2,2}");
}
#[test]
fn pattern_04_one_digit() {
assert_parses("[0-9]");
}
#[test]
fn pattern_05_two_digits() {
assert_parses("[0-9]{2}");
}
#[test]
fn pattern_06_three_digits() {
assert_parses("[0-9]{3}");
}
#[test]
fn pattern_07_one_to_three_digits() {
assert_parses("[0-9]{1,3}");
}
#[test]
fn pattern_08_one_to_five_digits() {
assert_parses("[0-9]{1,5}");
}
#[test]
fn pattern_09_two_to_three_digits() {
assert_parses("[0-9]{2,3}");
}
#[test]
fn pattern_10_three_to_four_digits() {
assert_parses("[0-9]{3,4}");
}
#[test]
fn pattern_11_one_to_fifteen_digits() {
assert_parses("[0-9]{1,15}");
}
#[test]
fn pattern_12_eight_to_twentyeight_digits() {
assert_parses("[0-9]{8,28}");
}
#[test]
fn pattern_13_four_alnum() {
assert_parses("[a-zA-Z0-9]{4}");
}
#[test]
fn pattern_14_hex_plus() {
assert_parses("[0-9a-fA-F]+");
}
#[test]
fn pattern_15_iban() {
assert_parses("[A-Z]{2,2}[0-9]{2,2}[a-zA-Z0-9]{1,30}");
}
#[test]
fn pattern_16_lei() {
assert_parses("[A-Z]{2,2}[A-Z0-9]{9,9}[0-9]{1,1}");
}
#[test]
fn pattern_17_eighteen_alnum_two_digit() {
assert_parses("[A-Z0-9]{18,18}[0-9]{2,2}");
}
#[test]
fn pattern_18_bic() {
assert_parses("[A-Z0-9]{4,4}[A-Z]{2,2}[A-Z0-9]{2,2}([A-Z0-9]{3,3}){0,1}");
}
#[test]
fn pattern_19_64_hex_upper() {
assert_parses("([0-9A-F][0-9A-F]){32}");
}
#[test]
fn pattern_20_uuid_v4() {
assert_parses("[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}");
}
#[test]
fn pattern_21_optional_plus_digits() {
assert_parses(r"[\+]{0,1}[0-9]{1,15}");
}
#[test]
fn pattern_22_phone() {
assert_parses(r"\+[0-9]{1,3}-[0-9()+\-]{1,30}");
}
#[test]
fn fixed_length_pattern_emits_exact_len_check() {
let ts = emit_pattern_check("[A-Z]{3,3}").unwrap();
let src = ts.to_string();
assert!(src.contains("!= 3"), "should check exact length 3: {src}");
}
#[test]
fn variable_length_pattern_uses_cursor() {
let ts = emit_pattern_check("[0-9]{1,3}").unwrap();
let src = ts.to_string();
assert!(src.contains("pos"), "should use cursor variable: {src}");
}
#[test]
fn unsupported_pattern_returns_none() {
assert!(emit_pattern_check("[A-Z").is_none());
}
#[test]
fn empty_pattern_returns_some() {
let ts = emit_pattern_check("").unwrap();
assert!(!ts.is_empty());
}
#[test]
fn bic_pattern_handles_optional_suffix() {
let ts =
emit_pattern_check("[A-Z0-9]{4,4}[A-Z]{2,2}[A-Z0-9]{2,2}([A-Z0-9]{3,3}){0,1}").unwrap();
let src = ts.to_string();
assert!(src.contains("pos"), "BIC pattern should use cursor: {src}");
}
fn assert_pattern_semantics(pattern: &str, cases: &[(&str, bool)]) {
let ts = emit_pattern_check(pattern)
.unwrap_or_else(|| panic!("emit_pattern_check returned None for: {pattern}"));
let check_fn = quote! {
fn check(value: &str) -> bool {
#ts
}
};
let mut main_body = String::new();
let pat_escaped = pattern.replace('{', "{{").replace('}', "}}");
for (i, (input, should_accept)) in cases.iter().enumerate() {
let escaped = input.replace('\\', "\\\\").replace('"', "\\\"");
if *should_accept {
main_body.push_str(&format!(
" assert!(!check(\"{escaped}\"), \"case {i}: '{escaped}' should be accepted by pattern '{pat_escaped}' but was rejected\");\n"
));
} else {
main_body.push_str(&format!(
" assert!(check(\"{escaped}\"), \"case {i}: '{escaped}' should be rejected by pattern '{pat_escaped}' but was accepted\");\n"
));
}
}
let program = format!("{}\nfn main() {{\n{main_body}}}\n", check_fn);
let dir = std::env::temp_dir().join(format!(
"pattern_semantic_test_{}_{}",
std::process::id(),
pattern.len()
));
std::fs::create_dir_all(&dir).unwrap();
let src_path = dir.join("test.rs");
let bin_path = dir.join("test_bin");
std::fs::write(&src_path, &program).unwrap();
let compile = std::process::Command::new("rustc")
.args(["--edition", "2021", "-o"])
.arg(&bin_path)
.arg(&src_path)
.output()
.expect("rustc not found");
assert!(
compile.status.success(),
"compilation failed for pattern '{pattern}':\n--- source ---\n{program}\n--- stderr ---\n{}",
String::from_utf8_lossy(&compile.stderr)
);
let run = std::process::Command::new(&bin_path).output().unwrap();
assert!(
run.status.success(),
"runtime assertion failed for pattern '{pattern}':\n{}",
String::from_utf8_lossy(&run.stderr)
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn semantic_three_uppercase() {
assert_pattern_semantics(
"[A-Z]{3}",
&[
("USD", true),
("EUR", true),
("ABC", true),
("US", false),
("USDD", false),
("us1", false),
("123", false),
("", false),
],
);
}
#[test]
fn semantic_iban_pattern() {
assert_pattern_semantics(
"[A-Z]{2,2}[0-9]{2,2}[a-zA-Z0-9]{1,30}",
&[
("GB82WEST12345698765432", true),
("DE89X", true),
("AB12x", true), ("gb82x", false), ("AB", false), ("AB12", false), ("", false),
("12XX1234", false), ],
);
}
#[test]
fn semantic_date_with_dashes() {
assert_pattern_semantics(
"[0-9]{4}-[0-9]{2}-[0-9]{2}",
&[
("2026-03-04", true),
("0000-00-00", true),
("9999-12-31", true),
("26-3-4", false),
("2026/03/04", false),
("YYYY-MM-DD", false),
("", false),
],
);
}
#[test]
fn semantic_bic_with_optional_suffix() {
assert_pattern_semantics(
"[A-Z0-9]{4,4}[A-Z]{2,2}[A-Z0-9]{2,2}([A-Z0-9]{3,3}){0,1}",
&[
("AAAAGB2L", true), ("AAAAGB2LXXX", true), ("DEUTDEFF500", true),
("AAAAGB2LX", false), ("aaaagb2l", false), ("", false),
("AAAAGB2LXXXX", false), ],
);
}
}