use shiftkit::{extract, grammar::*, lr::InternalToken, parser::generator::generate_parser, rule};
use std::io::{self, Write};
use std::iter::Peekable;
use std::str::Chars;
#[derive(Debug, Clone, PartialEq, Eq)]
enum LexError {
UnexpectedChar(char, usize),
}
struct Lexer<'a> {
chars: Peekable<Chars<'a>>,
position: usize,
}
impl<'a> Lexer<'a> {
fn new(input: &'a str) -> Self {
Self {
chars: input.chars().peekable(),
position: 0,
}
}
fn advance(&mut self) -> Option<char> {
let c = self.chars.next()?;
self.position += c.len_utf8();
Some(c)
}
fn peek(&mut self) -> Option<&char> {
self.chars.peek()
}
fn skip_whitespace(&mut self) {
while let Some(&c) = self.peek() {
if c.is_whitespace() {
self.advance();
} else {
break;
}
}
}
fn read_number(&mut self, first: char) -> String {
let mut num = String::new();
num.push(first);
while let Some(&c) = self.peek() {
if c.is_ascii_digit() {
num.push(c);
self.advance();
} else {
break;
}
}
num
}
fn read_identifier(&mut self, first: char) -> String {
let mut ident = String::new();
ident.push(first);
while let Some(&c) = self.peek() {
if c.is_alphanumeric() || c == '_' {
ident.push(c);
self.advance();
} else {
break;
}
}
ident
}
const fn position(&self) -> usize {
self.position
}
}
trait Tokenize: Sized {
fn tokenize(lexer: &mut Lexer) -> Result<Option<Self>, LexError>;
}
fn tokenize_all<T: Tokenize>(input: &str) -> Result<Vec<T>, LexError> {
let mut lexer = Lexer::new(input);
let mut tokens = Vec::new();
while let Some(token) = T::tokenize(&mut lexer)? {
tokens.push(token);
}
Ok(tokens)
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Keyword {
Let,
Fn,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Punctuation {
Assign, Plus, Minus, Star, Slash, LParen, RParen, Semicolon, }
const TOKEN_NUMBER: TokenType = TokenType(0);
const TOKEN_IDENTIFIER: TokenType = TokenType(1);
const TOKEN_KEYWORD_LET: TokenType = TokenType(2);
const TOKEN_KEYWORD_FN: TokenType = TokenType(3);
const TOKEN_PUNCTUATION_ASSIGN: TokenType = TokenType(4);
const TOKEN_PUNCTUATION_PLUS: TokenType = TokenType(5);
const TOKEN_PUNCTUATION_MINUS: TokenType = TokenType(6);
const TOKEN_PUNCTUATION_STAR: TokenType = TokenType(7);
const TOKEN_PUNCTUATION_SLASH: TokenType = TokenType(8);
const TOKEN_PUNCTUATION_LPAREN: TokenType = TokenType(9);
const TOKEN_PUNCTUATION_RPAREN: TokenType = TokenType(10);
const TOKEN_PUNCTUATION_SEMICOLON: TokenType = TokenType(11);
#[derive(Debug, Clone, PartialEq, Eq)]
enum Token {
Number(i64),
Identifier(String),
Keyword(Keyword),
Punctuation(Punctuation),
}
impl HasTokenType for Token {
fn token_type(&self) -> TokenType {
match self {
Token::Number(_) => TOKEN_NUMBER,
Token::Identifier(_) => TOKEN_IDENTIFIER,
Token::Keyword(Keyword::Let) => TOKEN_KEYWORD_LET,
Token::Keyword(Keyword::Fn) => TOKEN_KEYWORD_FN,
Token::Punctuation(Punctuation::Assign) => TOKEN_PUNCTUATION_ASSIGN,
Token::Punctuation(Punctuation::Plus) => TOKEN_PUNCTUATION_PLUS,
Token::Punctuation(Punctuation::Minus) => TOKEN_PUNCTUATION_MINUS,
Token::Punctuation(Punctuation::Star) => TOKEN_PUNCTUATION_STAR,
Token::Punctuation(Punctuation::Slash) => TOKEN_PUNCTUATION_SLASH,
Token::Punctuation(Punctuation::LParen) => TOKEN_PUNCTUATION_LPAREN,
Token::Punctuation(Punctuation::RParen) => TOKEN_PUNCTUATION_RPAREN,
Token::Punctuation(Punctuation::Semicolon) => TOKEN_PUNCTUATION_SEMICOLON,
}
}
}
impl Tokenize for Token {
fn tokenize(lexer: &mut Lexer) -> Result<Option<Self>, LexError> {
lexer.skip_whitespace();
let Some(&c) = lexer.peek() else {
return Ok(None);
};
let token = match c {
'0'..='9' => {
let first = lexer.advance().unwrap();
let num_str = lexer.read_number(first);
let num = num_str.parse::<i64>().expect("valid number");
Token::Number(num)
}
'a'..='z' | 'A'..='Z' | '_' => {
let first = lexer.advance().unwrap();
let ident = lexer.read_identifier(first);
match ident.as_str() {
"let" => Token::Keyword(Keyword::Let),
"fn" => Token::Keyword(Keyword::Fn),
_ => Token::Identifier(ident),
}
}
'=' => {
lexer.advance();
Token::Punctuation(Punctuation::Assign)
}
'+' => {
lexer.advance();
Token::Punctuation(Punctuation::Plus)
}
'-' => {
lexer.advance();
Token::Punctuation(Punctuation::Minus)
}
'*' => {
lexer.advance();
Token::Punctuation(Punctuation::Star)
}
'/' => {
lexer.advance();
Token::Punctuation(Punctuation::Slash)
}
'(' => {
lexer.advance();
Token::Punctuation(Punctuation::LParen)
}
')' => {
lexer.advance();
Token::Punctuation(Punctuation::RParen)
}
';' => {
lexer.advance();
Token::Punctuation(Punctuation::Semicolon)
}
_ => return Err(LexError::UnexpectedChar(c, lexer.position())),
};
Ok(Some(token))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BinOp {
Add,
Sub,
Mul,
Div,
}
type ExprId = AstNodeId;
#[derive(Debug, Clone, PartialEq, Eq)]
enum Expr {
Number(i64),
Identifier(String),
BinOp(AstNodeId, BinOp, ExprId),
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Statement {
Expr(ExprId),
Assign(String, ExprId),
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum AstNode {
Expr(Expr),
Statement(Statement),
Statements(Statement, Option<AstNodeId>),
}
const NODE_STATEMENTS: AstNodeType = AstNodeType(0);
const NODE_STATEMENT: AstNodeType = AstNodeType(1);
const NODE_EXPR: AstNodeType = AstNodeType(2);
const NODE_EXPR_ATOM: AstNodeType = AstNodeType(3);
const NODE_EXPR_PRODUCT: AstNodeType = AstNodeType(4);
const NODE_EXPR_SUM: AstNodeType = AstNodeType(5);
fn display_token_type(token_type: TokenType) -> String {
match token_type {
TOKEN_NUMBER => "NUMBER".into(),
TOKEN_IDENTIFIER => "IDENTIFIER".into(),
TOKEN_KEYWORD_LET => "KEYWORD_LET".into(),
TOKEN_KEYWORD_FN => "KEYWORD_FN".into(),
TOKEN_PUNCTUATION_ASSIGN => "PUNCTUATION_ASSIGN".into(),
TOKEN_PUNCTUATION_PLUS => "PUNCTUATION_PLUS".into(),
TOKEN_PUNCTUATION_MINUS => "PUNCTUATION_MINUS".into(),
TOKEN_PUNCTUATION_STAR => "PUNCTUATION_STAR".into(),
TOKEN_PUNCTUATION_SLASH => "PUNCTUATION_SLASH".into(),
TOKEN_PUNCTUATION_LPAREN => "PUNCTUATION_LPAREN".into(),
TOKEN_PUNCTUATION_RPAREN => "PUNCTUATION_RPAREN".into(),
TOKEN_PUNCTUATION_SEMICOLON => "PUNCTUATION_SEMICOLON".into(),
TokenType(t) => format!("TOKEN({t})"),
}
}
fn display_internal_token(token: InternalToken) -> String {
match token {
InternalToken::User(token_type) => display_token_type(token_type),
InternalToken::Eof => "EOF".into(),
InternalToken::Special => "SPECIAL".into(),
}
}
fn display_node_type(node_type: AstNodeType) -> String {
match node_type {
NODE_STATEMENTS => String::from("statements"),
NODE_STATEMENT => String::from("statement"),
NODE_EXPR => String::from("expr"),
NODE_EXPR_ATOM => String::from("expr_atom"),
NODE_EXPR_PRODUCT => String::from("expr_product"),
NODE_EXPR_SUM => String::from("expr_sum"),
AstNodeType(n) => format!("node({n})"),
}
}
fn display_grammar_item(grammar_item: GrammarItem) -> String {
match grammar_item {
GrammarItem::Token(token_type) => display_token_type(token_type),
GrammarItem::AstNode(ast_node_type) => display_node_type(ast_node_type),
}
}
fn display_rule<T: HasTokenType, A>(grammar_rule: &GrammarRule<T, A>) -> String {
format!(
"{} => {}",
display_node_type(grammar_rule.result),
grammar_rule
.components
.iter()
.copied()
.map(display_grammar_item)
.collect::<Vec<_>>()
.join(" ")
)
}
fn display_lr0_item<T: HasTokenType, A>(grammar_rule: &GrammarRule<T, A>, index: usize) -> String {
format!(
"{} => {}",
display_node_type(grammar_rule.result),
grammar_rule
.components
.iter()
.copied()
.take(index)
.map(display_grammar_item)
.chain(std::iter::once(String::from(".")))
.chain(
grammar_rule
.components
.iter()
.copied()
.skip(index)
.map(display_grammar_item)
)
.collect::<Vec<_>>()
.join(" ")
)
}
fn display_bin_op(bin_op: &BinOp) -> String {
match bin_op {
BinOp::Add => "+".into(),
BinOp::Sub => "-".into(),
BinOp::Mul => "*".into(),
BinOp::Div => "/".into(),
}
}
fn display_expr(ast_nodes: &[AstNode], expr: &Expr) -> String {
match expr {
Expr::Number(n) => n.to_string(),
Expr::Identifier(i) => i.to_string(),
Expr::BinOp(lhs, bin_op, rhs) => format!(
"({} {} {})",
display_ast(ast_nodes, *lhs),
display_bin_op(bin_op),
display_ast(ast_nodes, *rhs)
),
}
}
fn display_statement(ast_nodes: &[AstNode], statement: &Statement) -> String {
match statement {
Statement::Expr(ast_node_id) => display_ast(ast_nodes, *ast_node_id),
Statement::Assign(ident, ast_node_id) => {
format!("{ident} = {}", display_ast(ast_nodes, *ast_node_id))
}
}
}
fn display_statements(
ast_nodes: &[AstNode],
statement: &Statement,
statements: &Option<AstNodeId>,
) -> String {
if let Some(statements) = statements {
let (s, ss) = extract!(&ast_nodes[statements.0], AstNode::Statements(s, ss) => (s, ss));
format!(
"{}; {}",
display_statements(ast_nodes, s, ss),
display_statement(ast_nodes, statement)
)
} else {
display_statement(ast_nodes, statement)
}
}
fn display_ast(ast_nodes: &[AstNode], node_id: AstNodeId) -> String {
match &ast_nodes[node_id.0] {
AstNode::Expr(expr) => display_expr(ast_nodes, expr),
AstNode::Statement(statement) => display_statement(ast_nodes, statement),
AstNode::Statements(statement, statements) => {
display_statements(ast_nodes, statement, statements)
}
}
}
fn main() {
let grammar = {
let mut grammar: Grammar<Token, AstNode> = Grammar::new(NODE_STATEMENTS);
grammar.add_rule(
NODE_STATEMENTS,
&[NODE_STATEMENT.into()],
rule!(|list| {
let statement =
extract!(list.ast_node(0), AstNode::Statement(statement) => statement);
AstNode::Statements(statement.clone(), None)
}),
);
grammar.add_rule(
NODE_STATEMENTS,
&[NODE_STATEMENTS.into(), TOKEN_PUNCTUATION_SEMICOLON.into()],
rule!(|list| list.node(0)),
);
grammar.add_rule(
NODE_STATEMENTS,
&[
NODE_STATEMENTS.into(),
TOKEN_PUNCTUATION_SEMICOLON.into(),
NODE_STATEMENT.into(),
],
rule!(|list| {
let statement =
extract!(list.ast_node(2), AstNode::Statement(statement) => statement);
AstNode::Statements(statement.clone(), Some(list.node(0)))
}),
);
grammar.add_rule(
NODE_STATEMENT,
&[NODE_EXPR.into()],
rule!(|list| AstNode::Statement(Statement::Expr(list.node(0)))),
);
grammar.add_rule(
NODE_STATEMENT,
&[
TOKEN_IDENTIFIER.into(),
TOKEN_PUNCTUATION_ASSIGN.into(),
NODE_EXPR.into(),
],
rule!(|list| {
let name = extract!(list.token(0), Token::Identifier(s) => s.clone());
AstNode::Statement(Statement::Assign(name, list.node(2)))
}),
);
grammar.add_rule(
NODE_EXPR,
&[NODE_EXPR_SUM.into()],
rule!(|list| list.node(0)),
);
grammar.add_rule(
NODE_EXPR_SUM,
&[NODE_EXPR_PRODUCT.into()],
rule!(|list| list.node(0)),
);
grammar.add_rule(
NODE_EXPR_PRODUCT,
&[NODE_EXPR_ATOM.into()],
rule!(|list| list.node(0)),
);
grammar.add_rule(
NODE_EXPR_ATOM,
&[TOKEN_IDENTIFIER.into()],
rule!(|list| {
let name = extract!(list.token(0), Token::Identifier(s) => s.clone());
AstNode::Expr(Expr::Identifier(name))
}),
);
grammar.add_rule(
NODE_EXPR_ATOM,
&[TOKEN_NUMBER.into()],
rule!(|list| {
let n = extract!(list.token(0), Token::Number(n) => *n);
AstNode::Expr(Expr::Number(n))
}),
);
grammar.add_rule(
NODE_EXPR_ATOM,
&[
TOKEN_PUNCTUATION_LPAREN.into(),
NODE_EXPR.into(),
TOKEN_PUNCTUATION_RPAREN.into(),
],
rule!(|list| list.node(1)),
);
grammar.add_rule(
NODE_EXPR_SUM,
&[
NODE_EXPR_SUM.into(),
TOKEN_PUNCTUATION_PLUS.into(),
NODE_EXPR_PRODUCT.into(),
],
rule!(|list| AstNode::Expr(Expr::BinOp(list.node(0), BinOp::Add, list.node(2)))),
);
grammar.add_rule(
NODE_EXPR_SUM,
&[
NODE_EXPR_SUM.into(),
TOKEN_PUNCTUATION_MINUS.into(),
NODE_EXPR_PRODUCT.into(),
],
rule!(|list| AstNode::Expr(Expr::BinOp(list.node(0), BinOp::Sub, list.node(2)))),
);
grammar.add_rule(
NODE_EXPR_PRODUCT,
&[
NODE_EXPR_PRODUCT.into(),
TOKEN_PUNCTUATION_STAR.into(),
NODE_EXPR_ATOM.into(),
],
rule!(|list| AstNode::Expr(Expr::BinOp(list.node(0), BinOp::Mul, list.node(2)))),
);
grammar.add_rule(
NODE_EXPR_PRODUCT,
&[
NODE_EXPR_PRODUCT.into(),
TOKEN_PUNCTUATION_SLASH.into(),
NODE_EXPR_ATOM.into(),
],
rule!(|list| AstNode::Expr(Expr::BinOp(list.node(0), BinOp::Div, list.node(2)))),
);
grammar
};
let (parser, ambiguities) = generate_parser(grammar);
for ambiguity in &ambiguities {
use shiftkit::parser::AmbiguityKind;
match &ambiguity.kind {
AmbiguityKind::ReduceReduce { rules, chosen } => {
eprintln!(
"Warning: REDUCE/REDUCE conflict in state {} for token {}",
ambiguity.state,
display_internal_token(ambiguity.token)
);
for &rule_id in rules {
let rule = parser.grammar.get_rule(rule_id);
eprintln!(" Rule {}: {}", rule_id.0, display_rule(rule));
}
eprintln!(" Chosen: Rule {}", chosen.0);
}
AmbiguityKind::ShiftReduce {
shift_items,
reduce_rules,
} => {
eprintln!(
"Warning: SHIFT/REDUCE conflict in state {} for token {}",
ambiguity.state,
display_internal_token(ambiguity.token)
);
for item in shift_items {
let rule = parser.grammar.get_rule(item.rule);
eprintln!(
" Rule {}: Shift {}",
item.rule.0,
display_lr0_item(rule, item.index)
);
}
for &rule_id in reduce_rules {
let rule = parser.grammar.get_rule(rule_id);
eprintln!(" Rule {}: Reduce {}", rule_id.0, display_rule(rule));
}
}
}
}
println!("Rules:");
parser
.grammar
.rules
.iter()
.enumerate()
.for_each(|(id, r)| println!(" {id}: {}", display_rule(r)));
loop {
print!("> ");
io::stdout().flush().unwrap();
let mut input = String::new();
match io::stdin().read_line(&mut input) {
Ok(0) => break, Ok(_) => {
let input = input.trim();
if input.is_empty() {
continue;
}
let tokens = match tokenize_all::<Token>(input) {
Ok(tokens) => tokens,
Err(e) => {
eprintln!("Lexer error: {:?}", e);
continue;
}
};
match parser.parse(&tokens) {
Ok((root_id, ast_nodes)) => {
println!("Ok - '{}'", display_ast(&ast_nodes, root_id));
}
Err(e) => {
eprintln!("Parse error: {:?}", e);
}
}
}
Err(e) => {
eprintln!("Error reading input: {}", e);
break;
}
}
}
}