use super::ast_lexer::{AstLexer, Token, TokenMatch};
use super::ast_parser::AstParser;
use crate::ast::{Ast, Node, VariableValue};
use crate::{ArcSourceCode, ModulePath, Result, Scope};
pub(crate) struct CodeSectionParser<'a> {
lexer: AstLexer<'a>,
source_code: ArcSourceCode,
}
impl<'a> CodeSectionParser<'a> {
pub(crate) fn parse(
input: &'a str,
source_code: ArcSourceCode,
) -> Result<(Scope, Vec<ModulePath>)> {
CodeSectionParser {
lexer: AstLexer::new(input, None, source_code.clone())
.map_err(|e| source_code.code_syntax_error(e))?,
source_code,
}
.parse_scope()
}
fn parse_scope(&'a self) -> Result<(Scope, Vec<ModulePath>)> {
let mut scope = Scope::default();
let mut required_modules = vec![];
loop {
let next = self.lexer.next();
match next.token {
Token::Eof => break,
Token::FunctionDefinition => {
let (fn_name, function) = self.parse_fn_definition()?;
scope.define_function(fn_name, function);
}
Token::UseModule => {
required_modules.push(self.parse_use_module()?);
}
Token::Reference => {
let var_assign = self.parse_variable_assign()?;
scope.define_variable(
next.str_match,
Node::var(next.str_match, VariableValue::Ast(var_assign)),
);
}
_ => {
return Err(
next.into_error("Expected an `fn` or variable definition (`:=`) operator")
)
}
}
}
Ok((scope, required_modules))
}
fn parse_use_module(&self) -> Result<ModulePath> {
self.lexer.next().try_into()
}
fn parse_variable_assign(&'a self) -> Result<Ast> {
let next = self.lexer.next();
match next.token {
Token::VarAssign => Ok(self.parse_expr()?),
_ => Err(next.into_error("Expected a variable definition operator (`:=`)")),
}
}
fn parse_fn_definition(&'a self) -> Result<(String, Node)> {
let name = match self.lexer.next() {
TokenMatch {
token: Token::Reference,
str_match: r,
..
} => r,
token => return Err(token.into_error("Expected a function name")),
};
match self.lexer.next() {
TokenMatch {
token: Token::OpenParen,
..
} => (),
token => return Err(token.into_error("Expected `(` for a function definition")),
};
let mut fn_args = vec![];
loop {
let next = self.lexer.next();
match next.token {
Token::CloseParen => break,
Token::Comma => (),
Token::Reference => {
fn_args.push(next.str_match.to_string());
}
_ => {
return Err(
next.into_error("Expected comma-separated function arguments or `)`")
)
}
}
}
let function = Node::Function {
name: name.to_owned(),
args: fn_args,
body: self.parse_expr()?,
};
Ok((name.to_owned(), function))
}
fn parse_expr(&'a self) -> Result<Ast> {
AstParser::new(&self.lexer, true)
.expr_bp(0)
.map_err(|e| self.source_code.code_syntax_error(e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::Ast;
use crate::test_utils::*;
use crate::*;
fn test(input: &str) -> (Scope, Vec<ModulePath>) {
let source_code: SourceCode = (&TestSourceCode::new("csv", input)).into();
CodeSectionParser::parse(input, ArcSourceCode::new(source_code)).unwrap()
}
#[test]
fn parse_function() {
let (scope, _) = test("fn foo(a, b) a + b");
let foo = scope.functions.get("foo").unwrap();
let expected: Ast = Ast::new(Node::fn_def(
"foo",
&["a", "b"],
Node::infix_fn_call(Node::reference("a"), "+", Node::reference("b")),
));
assert_eq!(foo, &expected);
}
#[test]
fn parse_function_without_args() {
let (scope, _) = test("fn foo() 1 * 2");
let foo = scope.functions.get("foo").unwrap();
let expected: Ast = Ast::new(Node::fn_def("foo", &[], Node::infix_fn_call(1, "*", 2)));
assert_eq!(foo, &expected);
}
#[test]
fn parse_multiple_functions() {
let (scope, _) = test(
r#"
fn foo()
1 * 2
fn bar(a, b)
a + b
"#,
);
assert_eq!(scope.functions.len(), 2);
}
#[test]
fn parse_variables() {
let (scope, _) = test("foo := \"bar\"");
assert!(scope.variables.get("foo").is_some());
}
#[test]
fn parse_variables_and_functions() {
let (scope, _) = test(
r#"
fn foo_fn() 1 * 2
foo_var := 3 * 4 + 5
fn bar_fn(a, b) a + b
bar_var := D1
"#,
);
assert!(scope.functions.get("foo_fn").is_some());
assert!(scope.functions.get("bar_fn").is_some());
assert!(scope.variables.get("foo_var").is_some());
assert!(scope.variables.get("bar_var").is_some());
}
#[test]
fn parse_use_module() {
let (_, required_modules) = test(
r#"
use foo
"#,
);
assert_eq!(required_modules.len(), 1);
assert_eq!(required_modules[0], ModulePath::new("foo"));
}
#[test]
fn parse_use_module_multiple() {
let (_, required_modules) = test(
r#"
use foo
use bar
"#,
);
assert_eq!(required_modules.len(), 2);
assert_eq!(required_modules[0], ModulePath::new("foo"));
assert_eq!(required_modules[1], ModulePath::new("bar"));
}
}