use std::collections::HashMap;
use std::sync::Arc;
use crate::lexer::{Delimiter, Keyword, TokenKind};
use super::{
MacroError, MacroPattern, MacroResult, MetaVarKind, PatternElement, RepetitionKind, TokenTree,
};
#[derive(Debug, Clone)]
pub enum Binding {
Single(BindingValue),
Repeated(Vec<Binding>),
}
#[derive(Debug, Clone)]
pub enum BindingValue {
TokenTree(TokenTree),
TokenTrees(Vec<TokenTree>),
}
impl Binding {
pub fn as_single(&self) -> Option<&BindingValue> {
match self {
Binding::Single(v) => Some(v),
_ => None,
}
}
pub fn as_repeated(&self) -> Option<&[Binding]> {
match self {
Binding::Repeated(v) => Some(v),
_ => None,
}
}
pub fn count(&self) -> usize {
match self {
Binding::Single(_) => 1,
Binding::Repeated(v) => v.len(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct Bindings {
pub(crate) bindings: HashMap<Arc<str>, Binding>,
}
impl Bindings {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, name: Arc<str>, binding: Binding) {
self.bindings.insert(name, binding);
}
pub fn get(&self, name: &str) -> Option<&Binding> {
self.bindings.get(name)
}
pub fn merge(&mut self, other: Bindings) {
self.bindings.extend(other.bindings);
}
}
pub struct PatternMatcher<'a> {
input: &'a [TokenTree],
pos: usize,
}
impl<'a> PatternMatcher<'a> {
pub fn new(input: &'a [TokenTree]) -> Self {
Self { input, pos: 0 }
}
pub fn match_pattern(&mut self, pattern: &MacroPattern) -> MacroResult<Bindings> {
let mut bindings = Bindings::new();
for element in &pattern.elements {
self.match_element(element, &mut bindings)?;
}
if self.pos < self.input.len() {
return Err(MacroError::UnexpectedToken {
expected: "end of macro input".to_string(),
found: self.current_kind().clone(),
});
}
Ok(bindings)
}
fn match_element(
&mut self,
element: &PatternElement,
bindings: &mut Bindings,
) -> MacroResult<()> {
match element {
PatternElement::Token(kind) => {
self.expect_token(kind)?;
}
PatternElement::MetaVar { name, kind } => {
let value = self.match_metavar(*kind)?;
bindings.insert(name.clone(), Binding::Single(value));
}
PatternElement::Repetition {
elements,
separator,
repetition,
} => {
self.match_repetition(elements, separator.as_ref(), *repetition, bindings)?;
}
PatternElement::Delimited {
delimiter,
elements,
} => {
self.match_delimited(*delimiter, elements, bindings)?;
}
}
Ok(())
}
fn match_metavar(&mut self, kind: MetaVarKind) -> MacroResult<BindingValue> {
if self.pos >= self.input.len() {
return Err(MacroError::UnexpectedToken {
expected: format!("{:?}", kind),
found: TokenKind::Eof,
});
}
match kind {
MetaVarKind::TokenTree => {
let tt = self.input[self.pos].clone();
self.pos += 1;
Ok(BindingValue::TokenTree(tt))
}
MetaVarKind::Ident => {
self.expect_ident()?;
let tt = self.input[self.pos - 1].clone();
Ok(BindingValue::TokenTree(tt))
}
MetaVarKind::Literal => {
self.expect_literal()?;
let tt = self.input[self.pos - 1].clone();
Ok(BindingValue::TokenTree(tt))
}
MetaVarKind::Lifetime => {
self.expect_lifetime()?;
let tt = self.input[self.pos - 1].clone();
Ok(BindingValue::TokenTree(tt))
}
MetaVarKind::Expr
| MetaVarKind::Type
| MetaVarKind::Path
| MetaVarKind::Pat
| MetaVarKind::Stmt
| MetaVarKind::Block
| MetaVarKind::Item
| MetaVarKind::Meta
| MetaVarKind::Vis => {
let trees = self.collect_fragment(kind)?;
Ok(BindingValue::TokenTrees(trees))
}
}
}
fn match_repetition(
&mut self,
elements: &[PatternElement],
separator: Option<&TokenKind>,
repetition: RepetitionKind,
bindings: &mut Bindings,
) -> MacroResult<()> {
let mut all_bindings: HashMap<Arc<str>, Vec<Binding>> = HashMap::new();
let mut count = 0;
loop {
let can_match = self.can_match_elements(elements);
if !can_match {
break;
}
let mut iter_bindings = Bindings::new();
let start_pos = self.pos;
let matched = self.try_match_elements(elements, &mut iter_bindings);
if !matched {
self.pos = start_pos;
break;
}
for (name, binding) in iter_bindings.bindings {
all_bindings.entry(name).or_default().push(binding);
}
count += 1;
if let Some(sep) = separator {
if self.check_token(sep) {
self.pos += 1;
} else {
break;
}
}
}
match repetition {
RepetitionKind::ZeroOrMore => {}
RepetitionKind::OneOrMore => {
if count == 0 {
return Err(MacroError::UnexpectedToken {
expected: "at least one repetition".to_string(),
found: self.current_kind().clone(),
});
}
}
RepetitionKind::ZeroOrOne => {
if count > 1 {
return Err(MacroError::UnexpectedToken {
expected: "at most one repetition".to_string(),
found: self.current_kind().clone(),
});
}
}
}
for (name, values) in all_bindings {
bindings.insert(name, Binding::Repeated(values));
}
Ok(())
}
fn match_delimited(
&mut self,
delimiter: Delimiter,
elements: &[PatternElement],
bindings: &mut Bindings,
) -> MacroResult<()> {
if self.pos >= self.input.len() {
return Err(MacroError::UnexpectedToken {
expected: format!("{:?}", delimiter),
found: TokenKind::Eof,
});
}
match &self.input[self.pos] {
TokenTree::Delimited {
delimiter: d,
tokens,
..
} if *d == delimiter => {
self.pos += 1;
let mut inner_matcher = PatternMatcher::new(tokens);
for element in elements {
inner_matcher.match_element(element, bindings)?;
}
Ok(())
}
_ => Err(MacroError::UnexpectedToken {
expected: format!("{:?}", delimiter),
found: self.current_kind().clone(),
}),
}
}
fn try_match_elements(&mut self, elements: &[PatternElement], bindings: &mut Bindings) -> bool {
for element in elements {
if self.match_element(element, bindings).is_err() {
return false;
}
}
true
}
fn can_match_elements(&self, elements: &[PatternElement]) -> bool {
if self.pos >= self.input.len() {
return false;
}
if let Some(first) = elements.first() {
self.can_match_element(first)
} else {
true
}
}
fn can_match_element(&self, element: &PatternElement) -> bool {
if self.pos >= self.input.len() {
return false;
}
match element {
PatternElement::Token(kind) => self.check_token(kind),
PatternElement::MetaVar { kind, .. } => self.can_match_metavar(*kind),
PatternElement::Repetition { .. } => true,
PatternElement::Delimited { delimiter, .. } => {
matches!(&self.input[self.pos], TokenTree::Delimited { delimiter: d, .. } if *d == *delimiter)
}
}
}
fn can_match_metavar(&self, kind: MetaVarKind) -> bool {
if self.pos >= self.input.len() {
return false;
}
match kind {
MetaVarKind::TokenTree => true,
MetaVarKind::Ident => matches!(
self.current_kind(),
TokenKind::Ident | TokenKind::RawIdent | TokenKind::Keyword(_)
),
MetaVarKind::Literal => matches!(self.current_kind(), TokenKind::Literal { .. }),
MetaVarKind::Lifetime => matches!(self.current_kind(), TokenKind::Lifetime),
_ => true, }
}
fn collect_fragment(&mut self, kind: MetaVarKind) -> MacroResult<Vec<TokenTree>> {
let mut trees = Vec::new();
while self.pos < self.input.len() {
if self.at_fragment_end(kind) {
break;
}
trees.push(self.input[self.pos].clone());
self.pos += 1;
}
if trees.is_empty() {
return Err(MacroError::UnexpectedToken {
expected: format!("{:?}", kind),
found: self.current_kind().clone(),
});
}
Ok(trees)
}
fn at_fragment_end(&self, kind: MetaVarKind) -> bool {
if self.pos >= self.input.len() {
return true;
}
let token = self.current_kind();
match kind {
MetaVarKind::Expr | MetaVarKind::Stmt => {
matches!(
token,
TokenKind::Semi
| TokenKind::Comma
| TokenKind::FatArrow
| TokenKind::CloseDelim(_)
)
}
MetaVarKind::Type | MetaVarKind::Path => {
matches!(
token,
TokenKind::Comma
| TokenKind::Semi
| TokenKind::Eq
| TokenKind::Gt
| TokenKind::CloseDelim(_)
)
}
MetaVarKind::Pat => {
matches!(
token,
TokenKind::Eq | TokenKind::Or | TokenKind::CloseDelim(_)
) || matches!(token, TokenKind::Keyword(Keyword::If))
}
_ => {
matches!(
token,
TokenKind::Semi | TokenKind::Comma | TokenKind::CloseDelim(_)
)
}
}
}
fn expect_token(&mut self, expected: &TokenKind) -> MacroResult<()> {
if !self.check_token(expected) {
return Err(MacroError::UnexpectedToken {
expected: format!("{:?}", expected),
found: self.current_kind().clone(),
});
}
self.pos += 1;
Ok(())
}
fn expect_ident(&mut self) -> MacroResult<()> {
match self.current_kind() {
TokenKind::Ident | TokenKind::RawIdent => {
self.pos += 1;
Ok(())
}
_ => Err(MacroError::UnexpectedToken {
expected: "identifier".to_string(),
found: self.current_kind().clone(),
}),
}
}
fn expect_literal(&mut self) -> MacroResult<()> {
match self.current_kind() {
TokenKind::Literal { .. } => {
self.pos += 1;
Ok(())
}
_ => Err(MacroError::UnexpectedToken {
expected: "literal".to_string(),
found: self.current_kind().clone(),
}),
}
}
fn expect_lifetime(&mut self) -> MacroResult<()> {
match self.current_kind() {
TokenKind::Lifetime => {
self.pos += 1;
Ok(())
}
_ => Err(MacroError::UnexpectedToken {
expected: "lifetime".to_string(),
found: self.current_kind().clone(),
}),
}
}
fn check_token(&self, expected: &TokenKind) -> bool {
if self.pos >= self.input.len() {
return false;
}
match &self.input[self.pos] {
TokenTree::Token(t) => &t.kind == expected,
_ => false,
}
}
fn current_kind(&self) -> TokenKind {
if self.pos >= self.input.len() {
TokenKind::Eof
} else {
match &self.input[self.pos] {
TokenTree::Token(t) => t.kind.clone(),
TokenTree::Delimited { delimiter, .. } => TokenKind::OpenDelim(*delimiter),
}
}
}
}
pub fn match_macro_pattern(pattern: &MacroPattern, input: &[TokenTree]) -> MacroResult<Bindings> {
let mut matcher = PatternMatcher::new(input);
matcher.match_pattern(pattern)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lexer::{Lexer, SourceFile, Token};
use crate::macro_expand::tokens_to_tree;
fn lex(source: &str) -> Vec<Token> {
let file = SourceFile::anonymous(source);
let mut lexer = Lexer::new(&file);
lexer.tokenize().unwrap()
}
fn make_pattern(elements: Vec<PatternElement>) -> MacroPattern {
MacroPattern { elements }
}
#[test]
fn test_match_literal_token() {
let tokens = lex("foo");
let trees = tokens_to_tree(&tokens);
let pattern = make_pattern(vec![PatternElement::Token(TokenKind::Ident)]);
let bindings = match_macro_pattern(&pattern, &trees).unwrap();
assert!(bindings.bindings.is_empty());
}
#[test]
fn test_match_metavar_ident() {
let tokens = lex("foo");
let trees = tokens_to_tree(&tokens);
let pattern = make_pattern(vec![PatternElement::MetaVar {
name: "x".into(),
kind: MetaVarKind::Ident,
}]);
let bindings = match_macro_pattern(&pattern, &trees).unwrap();
assert!(bindings.get("x").is_some());
}
#[test]
fn test_match_metavar_tt() {
let tokens = lex("(1 + 2)");
let trees = tokens_to_tree(&tokens);
let pattern = make_pattern(vec![PatternElement::MetaVar {
name: "e".into(),
kind: MetaVarKind::TokenTree,
}]);
let bindings = match_macro_pattern(&pattern, &trees).unwrap();
assert!(bindings.get("e").is_some());
}
}