use std::{
collections::{BTreeMap, BTreeSet},
rc::Rc,
};
use inc_complete::{
Db, DbHandle, define_input, define_intermediate, impl_storage,
storage::{HashMapStorage, SingletonStorage},
};
#[derive(Default)]
struct Compiler {
input: SingletonStorage<Input>,
parse: SingletonStorage<Parse>,
check: HashMapStorage<Check>,
execute: HashMapStorage<Execute>,
execute_all: SingletonStorage<ExecuteAll>,
}
impl_storage!(Compiler,
input: Input,
parse: Parse,
check: Check,
execute: Execute,
execute_all: ExecuteAll,
);
#[derive(Debug, Clone)]
struct Input;
define_input!(0, Input -> String, Compiler);
#[derive(Debug, Clone)]
struct Parse;
define_intermediate!(1, Parse -> Result<Ast, Error>, Compiler, parse_program);
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct Check(Rc<Ast>, Rc<CheckEnv>);
define_intermediate!(2, Check -> Result<(), Error>, Compiler, check_impl);
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
struct Execute(Rc<Ast>, Rc<ExecEnv>);
define_intermediate!(3, Execute -> Result<i64, Error>, Compiler, execute_impl);
#[derive(Debug, Clone)]
struct ExecuteAll;
define_intermediate!(4, ExecuteAll -> Result<i64, Error>, Compiler, execute_all_impl);
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
enum Error {
IncorrectArgumentCount,
UnterminatedLParen(Ast),
InvalidOperation(String),
InvalidIntegerLiteral(String),
LetVarIsNotAnIdent(Ast),
InputEmptyOrUnparsedOutput(String),
NameNotDefined(String),
}
#[derive(Debug, Hash, PartialEq, Eq, Clone, PartialOrd, Ord)]
enum Ast {
Var {
name: String,
},
Int(i64),
Add(Rc<Ast>, Rc<Ast>),
Let {
name: String,
rhs: Rc<Ast>,
body: Rc<Ast>,
},
}
type CheckEnv = BTreeSet<String>;
type ExecEnv = BTreeMap<String, i64>;
fn parse_program(_: &Parse, db: &DbHandle<Compiler>) -> Result<Ast, Error> {
let program: String = Input.get(db);
let (ast, rest) = parse_value(&program)?;
if !rest.trim().is_empty() {
Err(Error::InputEmptyOrUnparsedOutput(rest.to_string()))
} else {
Ok(ast)
}
}
fn parse_value(mut text: &str) -> Result<(Ast, &str), Error> {
text = text.trim();
let Some(first_char) = text.chars().next() else {
return Err(Error::InputEmptyOrUnparsedOutput(text.to_string()));
};
let ast = match first_char {
'(' => {
text = &text[1..];
let mut args = Vec::new();
let (operation, new_text) = parse_word(text);
text = new_text;
while let Ok((ast, rest)) = parse_value(text) {
args.push(ast);
text = rest;
}
let ast = match operation {
"+" => {
if args.len() != 2 {
return Err(Error::IncorrectArgumentCount);
}
let rhs = args.pop().unwrap();
let lhs = args.pop().unwrap();
Ast::Add(Rc::new(lhs), Rc::new(rhs))
}
"let" => {
if args.len() != 3 {
return Err(Error::IncorrectArgumentCount);
}
let body = Rc::new(args.pop().unwrap());
let rhs = Rc::new(args.pop().unwrap());
let let_var = args.pop().unwrap();
let Ast::Var { name } = let_var else {
return Err(Error::LetVarIsNotAnIdent(let_var));
};
Ast::Let { name, rhs, body }
}
other => return Err(Error::InvalidOperation(other.to_string())),
};
text = text.trim();
if !text.starts_with(')') {
return Err(Error::UnterminatedLParen(ast));
} else {
text = &text[1..];
ast
}
}
s if s.is_ascii_alphabetic() => {
let (word, rest) = parse_word(text);
text = rest;
Ast::Var {
name: word.to_string(),
}
}
s if s.is_numeric() => {
let (word, rest) = parse_word(text);
text = rest;
let int = word
.parse::<i64>()
.map_err(|_| Error::InvalidIntegerLiteral(word.to_string()))?;
Ast::Int(int)
}
_ => return Err(Error::InputEmptyOrUnparsedOutput(text.to_string())),
};
Ok((ast, text))
}
fn parse_word(text: &str) -> (&str, &str) {
let end = next_whitespace_or_rparen_index(text);
if end == text.len() {
(&text[..end], "")
} else {
(&text[..end], &text[end..])
}
}
fn next_whitespace_or_rparen_index(text: &str) -> usize {
for (i, char) in text.char_indices() {
if char.is_whitespace() || char == ')' {
return i;
}
}
text.len()
}
fn check_impl(check: &Check, db: &DbHandle<Compiler>) -> Result<(), Error> {
let ast = check.0.as_ref();
let env = &check.1;
match ast {
Ast::Var { name } => {
if env.contains(name) {
Ok(())
} else {
Err(Error::NameNotDefined(name.clone()))
}
}
Ast::Int(_) => Ok(()),
Ast::Add(lhs, rhs) => {
Check(lhs.clone(), env.clone()).get(db)?;
Check(rhs.clone(), env.clone()).get(db)
}
Ast::Let { name, rhs, body } => {
Check(rhs.clone(), env.clone()).get(db)?;
let mut new_env = env.as_ref().clone();
new_env.insert(name.clone());
Check(body.clone(), Rc::new(new_env)).get(db)
}
}
}
fn execute_impl(execute: &Execute, db: &DbHandle<Compiler>) -> Result<i64, Error> {
let ast = execute.0.as_ref();
let env = &execute.1;
match ast {
Ast::Var { name } => {
Ok(env[name])
}
Ast::Int(x) => Ok(*x),
Ast::Add(lhs, rhs) => {
let lhs = Execute(lhs.clone(), env.clone()).get(db)?;
let rhs = Execute(rhs.clone(), env.clone()).get(db)?;
Ok(lhs + rhs)
}
Ast::Let { name, rhs, body } => {
let rhs = Execute(rhs.clone(), env.clone()).get(db)?;
let mut new_env = env.as_ref().clone();
new_env.insert(name.clone(), rhs);
Execute(body.clone(), Rc::new(new_env)).get(db)
}
}
}
fn execute_all_impl(_: &ExecuteAll, db: &DbHandle<Compiler>) -> Result<i64, Error> {
let ast = Parse.get(db)?;
let ast = Rc::new(ast);
Check(ast.clone(), Rc::new(CheckEnv::new())).get(db)?;
Execute(ast.clone(), Rc::new(ExecEnv::new())).get(db)
}
fn set_input(db: &mut Db<Compiler>, source_program: &str) {
Input.set(db, source_program.to_string());
}
mod compiler {
use crate::*;
use inc_complete::Db;
#[test]
fn basic_programs() {
let db = &mut Db::<Compiler>::new();
set_input(db, "42");
let result = ExecuteAll.get(db);
assert_eq!(result, Ok(42));
set_input(db, "(+ 42 58)");
let result = ExecuteAll.get(db);
assert_eq!(result, Ok(100));
set_input(db, "(let foo 42 (+ 58 foo))");
let result = ExecuteAll.get(db);
assert_eq!(result, Ok(100));
set_input(db, "(let foo 42 (+ 58 foo)) foo");
let result = ExecuteAll.get(db);
assert_eq!(
result,
Err(Error::InputEmptyOrUnparsedOutput(" foo".to_string()))
);
set_input(db, "(let foo 42 (+ foo bar))");
let result = db.get(ExecuteAll);
assert_eq!(result, Err(Error::NameNotDefined("bar".to_string())));
set_input(db, "(let foo 42 (let bar 8 (+ foo bar)))");
let result = db.get(ExecuteAll);
assert_eq!(result, Ok(50));
}
#[test]
fn cached() {
let db = &mut Db::<Compiler>::new();
set_input(db, "(+ 42 58)");
let result = ExecuteAll.get(db);
assert_eq!(result, Ok(100));
set_input(db, "42");
let ast = Parse.get(db).unwrap();
assert_eq!(ast, Ast::Int(42));
let ast = Rc::new(ast);
assert!(!db.is_stale(&Check(ast.clone(), Default::default())));
assert!(!db.is_stale(&Execute(ast, Default::default())));
}
}