#include "tree/pattern/ParseTreePattern.h"
#include "tree/pattern/ParseTreeMatch.h"
#include "tree/TerminalNode.h"
#include "CommonTokenStream.h"
#include "ParserInterpreter.h"
#include "tree/pattern/TokenTagToken.h"
#include "ParserRuleContext.h"
#include "tree/pattern/RuleTagToken.h"
#include "tree/pattern/TagChunk.h"
#include "atn/ATN.h"
#include "Lexer.h"
#include "BailErrorStrategy.h"
#include "ListTokenSource.h"
#include "tree/pattern/TextChunk.h"
#include "ANTLRInputStream.h"
#include "support/Arrays.h"
#include "Exceptions.h"
#include "support/CPPUtils.h"
#include "tree/pattern/ParseTreePatternMatcher.h"
using namespace antlr4;
using namespace antlr4::tree;
using namespace antlr4::tree::pattern;
using namespace antlrcpp;
ParseTreePatternMatcher::CannotInvokeStartRule::CannotInvokeStartRule(const RuntimeException &e) : RuntimeException(e.what()) {
}
ParseTreePatternMatcher::CannotInvokeStartRule::~CannotInvokeStartRule() {
}
ParseTreePatternMatcher::StartRuleDoesNotConsumeFullPattern::~StartRuleDoesNotConsumeFullPattern() {
}
ParseTreePatternMatcher::ParseTreePatternMatcher(Lexer *lexer, Parser *parser) : _lexer(lexer), _parser(parser) {
InitializeInstanceFields();
}
ParseTreePatternMatcher::~ParseTreePatternMatcher() {
}
void ParseTreePatternMatcher::setDelimiters(const std::string &start, const std::string &stop, const std::string &escapeLeft) {
if (start.empty()) {
throw IllegalArgumentException("start cannot be null or empty");
}
if (stop.empty()) {
throw IllegalArgumentException("stop cannot be null or empty");
}
_start = start;
_stop = stop;
_escape = escapeLeft;
}
bool ParseTreePatternMatcher::matches(ParseTree *tree, const std::string &pattern, int patternRuleIndex) {
ParseTreePattern p = compile(pattern, patternRuleIndex);
return matches(tree, p);
}
bool ParseTreePatternMatcher::matches(ParseTree *tree, const ParseTreePattern &pattern) {
std::map<std::string, std::vector<ParseTree *>> labels;
ParseTree *mismatchedNode = matchImpl(tree, pattern.getPatternTree(), labels);
return mismatchedNode == nullptr;
}
ParseTreeMatch ParseTreePatternMatcher::match(ParseTree *tree, const std::string &pattern, int patternRuleIndex) {
ParseTreePattern p = compile(pattern, patternRuleIndex);
return match(tree, p);
}
ParseTreeMatch ParseTreePatternMatcher::match(ParseTree *tree, const ParseTreePattern &pattern) {
std::map<std::string, std::vector<ParseTree *>> labels;
tree::ParseTree *mismatchedNode = matchImpl(tree, pattern.getPatternTree(), labels);
return ParseTreeMatch(tree, pattern, labels, mismatchedNode);
}
ParseTreePattern ParseTreePatternMatcher::compile(const std::string &pattern, int patternRuleIndex) {
ListTokenSource tokenSrc(tokenize(pattern));
CommonTokenStream tokens(&tokenSrc);
ParserInterpreter parserInterp(_parser->getGrammarFileName(), _parser->getVocabulary(),
_parser->getRuleNames(), _parser->getATNWithBypassAlts(), &tokens);
ParserRuleContext *tree = nullptr;
try {
parserInterp.setErrorHandler(std::make_shared<BailErrorStrategy>());
tree = parserInterp.parse(patternRuleIndex);
} catch (ParseCancellationException &e) {
#if defined(_MSC_FULL_VER) && _MSC_FULL_VER < 190023026
throw e;
#else
std::rethrow_if_nested(e); #endif
} catch (RecognitionException &re) {
throw re;
#if defined(_MSC_FULL_VER) && _MSC_FULL_VER < 190023026
} catch (std::exception &e) {
throw e;
#else
} catch (std::exception & ) {
std::throw_with_nested(RuntimeException("Cannot invoke start rule")); #endif
}
if (tokens.LA(1) != Token::EOF) {
throw StartRuleDoesNotConsumeFullPattern();
}
return ParseTreePattern(this, pattern, patternRuleIndex, tree);
}
Lexer* ParseTreePatternMatcher::getLexer() {
return _lexer;
}
Parser* ParseTreePatternMatcher::getParser() {
return _parser;
}
ParseTree* ParseTreePatternMatcher::matchImpl(ParseTree *tree, ParseTree *patternTree,
std::map<std::string, std::vector<ParseTree *>> &labels) {
if (tree == nullptr) {
throw IllegalArgumentException("tree cannot be nul");
}
if (patternTree == nullptr) {
throw IllegalArgumentException("patternTree cannot be nul");
}
if (is<TerminalNode *>(tree) && is<TerminalNode *>(patternTree)) {
TerminalNode *t1 = dynamic_cast<TerminalNode *>(tree);
TerminalNode *t2 = dynamic_cast<TerminalNode *>(patternTree);
ParseTree *mismatchedNode = nullptr;
if (t1->getSymbol()->getType() == t2->getSymbol()->getType()) {
if (is<TokenTagToken *>(t2->getSymbol())) { TokenTagToken *tokenTagToken = dynamic_cast<TokenTagToken *>(t2->getSymbol());
labels[tokenTagToken->getTokenName()].push_back(tree);
if (tokenTagToken->getLabel() != "") {
labels[tokenTagToken->getLabel()].push_back(tree);
}
} else if (t1->getText() == t2->getText()) {
} else {
if (mismatchedNode == nullptr) {
mismatchedNode = t1;
}
}
} else {
if (mismatchedNode == nullptr) {
mismatchedNode = t1;
}
}
return mismatchedNode;
}
if (is<ParserRuleContext *>(tree) && is<ParserRuleContext *>(patternTree)) {
ParserRuleContext *r1 = dynamic_cast<ParserRuleContext *>(tree);
ParserRuleContext *r2 = dynamic_cast<ParserRuleContext *>(patternTree);
ParseTree *mismatchedNode = nullptr;
RuleTagToken *ruleTagToken = getRuleTagToken(r2);
if (ruleTagToken != nullptr) {
if (r1->getRuleIndex() == r2->getRuleIndex()) {
labels[ruleTagToken->getRuleName()].push_back(tree);
if (ruleTagToken->getLabel() != "") {
labels[ruleTagToken->getLabel()].push_back(tree);
}
} else {
if (!mismatchedNode) {
mismatchedNode = r1;
}
}
return mismatchedNode;
}
if (r1->children.size() != r2->children.size()) {
if (mismatchedNode == nullptr) {
mismatchedNode = r1;
}
return mismatchedNode;
}
std::size_t n = r1->children.size();
for (size_t i = 0; i < n; i++) {
ParseTree *childMatch = matchImpl(r1->children[i], patternTree->children[i], labels);
if (childMatch) {
return childMatch;
}
}
return mismatchedNode;
}
return tree;
}
RuleTagToken* ParseTreePatternMatcher::getRuleTagToken(ParseTree *t) {
if (t->children.size() == 1 && is<TerminalNode *>(t->children[0])) {
TerminalNode *c = dynamic_cast<TerminalNode *>(t->children[0]);
if (is<RuleTagToken *>(c->getSymbol())) {
return dynamic_cast<RuleTagToken *>(c->getSymbol());
}
}
return nullptr;
}
std::vector<std::unique_ptr<Token>> ParseTreePatternMatcher::tokenize(const std::string &pattern) {
std::vector<Chunk> chunks = split(pattern);
std::vector<std::unique_ptr<Token>> tokens;
for (auto chunk : chunks) {
if (is<TagChunk *>(&chunk)) {
TagChunk &tagChunk = (TagChunk&)chunk;
if (isupper(tagChunk.getTag()[0])) {
size_t ttype = _parser->getTokenType(tagChunk.getTag());
if (ttype == Token::INVALID_TYPE) {
throw IllegalArgumentException("Unknown token " + tagChunk.getTag() + " in pattern: " + pattern);
}
tokens.emplace_back(new TokenTagToken(tagChunk.getTag(), (int)ttype, tagChunk.getLabel()));
} else if (islower(tagChunk.getTag()[0])) {
size_t ruleIndex = _parser->getRuleIndex(tagChunk.getTag());
if (ruleIndex == INVALID_INDEX) {
throw IllegalArgumentException("Unknown rule " + tagChunk.getTag() + " in pattern: " + pattern);
}
size_t ruleImaginaryTokenType = _parser->getATNWithBypassAlts().ruleToTokenType[ruleIndex];
tokens.emplace_back(new RuleTagToken(tagChunk.getTag(), ruleImaginaryTokenType, tagChunk.getLabel()));
} else {
throw IllegalArgumentException("invalid tag: " + tagChunk.getTag() + " in pattern: " + pattern);
}
} else {
TextChunk &textChunk = (TextChunk&)chunk;
ANTLRInputStream input(textChunk.getText());
_lexer->setInputStream(&input);
std::unique_ptr<Token> t(_lexer->nextToken());
while (t->getType() != Token::EOF) {
tokens.push_back(std::move(t));
t = _lexer->nextToken();
}
_lexer->setInputStream(nullptr);
}
}
return tokens;
}
std::vector<Chunk> ParseTreePatternMatcher::split(const std::string &pattern) {
size_t p = 0;
size_t n = pattern.length();
std::vector<Chunk> chunks;
std::vector<size_t> starts;
std::vector<size_t> stops;
while (p < n) {
if (p == pattern.find(_escape + _start,p)) {
p += _escape.length() + _start.length();
} else if (p == pattern.find(_escape + _stop,p)) {
p += _escape.length() + _stop.length();
} else if (p == pattern.find(_start,p)) {
starts.push_back(p);
p += _start.length();
} else if (p == pattern.find(_stop,p)) {
stops.push_back(p);
p += _stop.length();
} else {
p++;
}
}
if (starts.size() > stops.size()) {
throw IllegalArgumentException("unterminated tag in pattern: " + pattern);
}
if (starts.size() < stops.size()) {
throw IllegalArgumentException("missing start tag in pattern: " + pattern);
}
size_t ntags = starts.size();
for (size_t i = 0; i < ntags; i++) {
if (starts[i] >= stops[i]) {
throw IllegalArgumentException("tag delimiters out of order in pattern: " + pattern);
}
}
if (ntags == 0) {
std::string text = pattern.substr(0, n);
chunks.push_back(TextChunk(text));
}
if (ntags > 0 && starts[0] > 0) { std::string text = pattern.substr(0, starts[0]);
chunks.push_back(TextChunk(text));
}
for (size_t i = 0; i < ntags; i++) {
std::string tag = pattern.substr(starts[i] + _start.length(), stops[i] - (starts[i] + _start.length()));
std::string ruleOrToken = tag;
std::string label = "";
size_t colon = tag.find(':');
if (colon != std::string::npos) {
label = tag.substr(0,colon);
ruleOrToken = tag.substr(colon + 1, tag.length() - (colon + 1));
}
chunks.push_back(TagChunk(label, ruleOrToken));
if (i + 1 < ntags) {
std::string text = pattern.substr(stops[i] + _stop.length(), starts[i + 1] - (stops[i] + _stop.length()));
chunks.push_back(TextChunk(text));
}
}
if (ntags > 0) {
size_t afterLastTag = stops[ntags - 1] + _stop.length();
if (afterLastTag < n) { std::string text = pattern.substr(afterLastTag, n - afterLastTag);
chunks.push_back(TextChunk(text));
}
}
for (size_t i = 0; i < chunks.size(); i++) {
Chunk &c = chunks[i];
if (is<TextChunk *>(&c)) {
TextChunk &tc = (TextChunk&)c;
std::string unescaped = tc.getText();
unescaped.erase(std::remove(unescaped.begin(), unescaped.end(), '\\'), unescaped.end());
if (unescaped.length() < tc.getText().length()) {
chunks[i] = TextChunk(unescaped);
}
}
}
return chunks;
}
void ParseTreePatternMatcher::InitializeInstanceFields() {
_start = "<";
_stop = ">";
_escape = "\\";
}