use anyhow::{anyhow, bail, ensure, Result};
use crate::btrfs::definitions::STRUCTS;
use crate::lang::ast::*;
use crate::lang::variables::Variables;
pub struct SemanticAnalyzer {
variables: Variables<()>,
loop_depth: u32,
}
impl SemanticAnalyzer {
pub fn new() -> Self {
Self {
variables: Variables::new(|_| (), |_| ()),
loop_depth: 0,
}
}
fn analyze_primary_expr(&mut self, expr: &PrimaryExpression) -> Result<()> {
match expr {
PrimaryExpression::Identifier(ident) => {
let _ = self
.variables
.get(ident)
.ok_or_else(|| anyhow!("Unknown variable: {}", ident))?;
Ok(())
}
PrimaryExpression::Constant(_) => Ok(()),
PrimaryExpression::Str(_) => Ok(()),
PrimaryExpression::Paren(expr) => self.analyze_expr(expr),
}
}
fn analyze_binary_expr(&mut self, binop: &BinaryExpression) -> Result<()> {
match binop {
BinaryExpression::Plus(lhs, rhs)
| BinaryExpression::Minus(lhs, rhs)
| BinaryExpression::Multiply(lhs, rhs)
| BinaryExpression::Divide(lhs, rhs)
| BinaryExpression::Modulo(lhs, rhs)
| BinaryExpression::BitOr(lhs, rhs)
| BinaryExpression::BitAnd(lhs, rhs)
| BinaryExpression::BitXor(lhs, rhs)
| BinaryExpression::LeftShift(lhs, rhs)
| BinaryExpression::RightShift(lhs, rhs)
| BinaryExpression::LessThan(lhs, rhs)
| BinaryExpression::LessThanEquals(lhs, rhs)
| BinaryExpression::GreaterThan(lhs, rhs)
| BinaryExpression::GreaterThanEquals(lhs, rhs)
| BinaryExpression::Equals(lhs, rhs)
| BinaryExpression::NotEquals(lhs, rhs)
| BinaryExpression::LogicalOr(lhs, rhs)
| BinaryExpression::LogicalAnd(lhs, rhs) => {
self.analyze_expr(lhs)?;
self.analyze_expr(rhs)?;
Ok(())
}
}
}
fn analyze_unary_expr(&mut self, unop: &UnaryExpression) -> Result<()> {
match unop {
UnaryExpression::BitNot(e) | UnaryExpression::Not(e) | UnaryExpression::Minus(e) => {
self.analyze_expr(&*e)
}
UnaryExpression::Cast(t, e) => match t {
TypeSpecifier::Struct(ty) => match &**ty {
Expression::PrimaryExpression(PrimaryExpression::Identifier(Identifier(
ident,
))) => {
ensure!(
STRUCTS.iter().any(|s| s.name == ident),
"Unknown type: struct {}",
ident,
);
self.analyze_expr(e)?;
Ok(())
}
_ => bail!("Invalid type in type cast"),
},
},
}
}
fn analyze_expr(&mut self, expr: &Expression) -> Result<()> {
match expr {
Expression::PrimaryExpression(p) => self.analyze_primary_expr(p),
Expression::FieldAccess(expr, field) => {
match &**field {
Expression::PrimaryExpression(PrimaryExpression::Identifier(Identifier(_))) => {
}
_ => bail!("Field in a field access must be an identifier"),
};
self.analyze_expr(expr)
}
Expression::ArrayIndex(expr, index) => {
self.analyze_expr(expr)?;
self.analyze_expr(index)?;
Ok(())
}
Expression::FunctionCall(expr, args) => {
self.analyze_expr(expr)?;
for arg in args {
self.analyze_expr(arg)?;
}
Ok(())
}
Expression::BinaryExpression(b) => self.analyze_binary_expr(b),
Expression::UnaryExpression(u) => self.analyze_unary_expr(u),
}
}
fn analyze_stmts(&mut self, stmts: &[Statement]) -> Result<()> {
for stmt in stmts {
self.analyze_stmt(stmt)?;
}
Ok(())
}
fn analyze_if(
&mut self,
cond: &Expression,
true_body: &[Statement],
false_body: &[Statement],
) -> Result<()> {
self.analyze_expr(cond)?;
self.analyze_stmts(true_body)?;
self.analyze_stmts(false_body)?;
Ok(())
}
fn analyze_for(
&mut self,
ident: &Expression,
range: &Expression,
stmts: &[Statement],
) -> Result<()> {
let ident = match ident {
Expression::PrimaryExpression(PrimaryExpression::Identifier(i)) => i,
_ => bail!("Temporary variable in a for-loop must be an identifier"),
};
self.variables.insert(ident.clone(), ());
self.analyze_expr(range)?;
self.analyze_stmts(stmts)?;
Ok(())
}
fn analyze_while(&mut self, cond: &Expression, stmts: &[Statement]) -> Result<()> {
self.analyze_expr(cond)?;
self.analyze_stmts(stmts)?;
Ok(())
}
fn analyze_block_stmt(&mut self, block: &BlockStatement) -> Result<()> {
match block {
BlockStatement::If(cond, true_body, false_body) => {
self.variables.push_scope();
let ret = self.analyze_if(cond, true_body, false_body);
self.variables.pop_scope();
ret
}
BlockStatement::While(cond, stmts) => {
self.loop_depth += 1;
self.variables.push_scope();
let ret = self.analyze_while(cond, stmts);
self.loop_depth -= 1;
self.variables.pop_scope();
ret
}
BlockStatement::For(ident, range, stmts) => {
self.loop_depth += 1;
self.variables.push_scope();
let ret = self.analyze_for(ident, range, stmts);
self.loop_depth -= 1;
self.variables.pop_scope();
ret
}
}
}
fn analyze_stmt(&mut self, stmt: &Statement) -> Result<()> {
match stmt {
Statement::AssignStatement(lhs, rhs) => {
if let Expression::PrimaryExpression(PrimaryExpression::Identifier(ident)) = lhs {
self.variables.insert(ident.clone(), ());
} else {
self.analyze_expr(lhs)?;
}
self.analyze_expr(rhs)?;
Ok(())
}
Statement::BlockStatement(block) => self.analyze_block_stmt(block),
Statement::JumpStatement(jump) => {
ensure!(self.loop_depth > 0, "Cannot '{}' outside of a loop", jump);
Ok(())
}
Statement::BuiltinStatement(builtin) => match builtin {
BuiltinStatement::Help => Ok(()),
BuiltinStatement::Quit => Ok(()),
BuiltinStatement::Filesystem(expr) => match expr {
Expression::PrimaryExpression(PrimaryExpression::Str(_)) => Ok(()),
_ => bail!("'filesystem' builtin only takes string literal arguments"),
},
BuiltinStatement::Print(expr) => {
self.analyze_expr(expr)?;
Ok(())
}
},
Statement::ExpressionStatement(expr) => {
self.analyze_expr(expr)?;
Ok(())
}
}
}
pub fn analyze(&mut self, stmts: &[Statement]) -> Result<()> {
for stmt in stmts {
self.analyze_stmt(stmt)?;
}
Ok(())
}
}
#[cfg(test)]
fn analyze(stmts: &[Statement]) -> Result<()> {
SemanticAnalyzer::new().analyze(stmts)
}
#[cfg(test)]
fn parse(prog: &str) -> Vec<Statement> {
use crate::lang::parse::parse;
parse(prog).unwrap()
}
#[test]
fn test_jump_outside_loop() {
{
let prog = r#"
break;
"#;
let stmts = parse(prog);
assert!(analyze(&stmts).is_err());
}
{
let prog = r#"
continue;
"#;
let stmts = parse(prog);
assert!(analyze(&stmts).is_err());
}
{
let prog = r#"
while true {
break;
continue;
}
"#;
let stmts = parse(prog);
analyze(&stmts).expect("Semantic analysis failed")
}
{
let prog = r#"
for i in true {
break;
continue;
}
"#;
let stmts = parse(prog);
analyze(&stmts).expect("Semantic analysis failed");
}
}
#[test]
fn test_variable_scope() {
{
let prog = r#"
while true {
inner_var = 3;
}
outer_var = inner_var;
"#;
let stmts = parse(prog);
assert!(analyze(&stmts).is_err());
}
{
let prog = r#"
for i in true {
inner_var = 3;
}
outer_var = inner_var;
"#;
let stmts = parse(prog);
assert!(analyze(&stmts).is_err());
}
{
let prog = r#"
while true {
inner_var = 3;
inner_var = 4;
}
"#;
let stmts = parse(prog);
analyze(&stmts).expect("Semantic analysis failed")
}
{
let prog = r#"
for i in true {
inner_var = 3;
inner_var2 = inner_var;
}
"#;
let stmts = parse(prog);
analyze(&stmts).expect("Semantic analysis failed");
}
}
#[test]
fn test_builtin() {
{
let prog = r#"
help;
quit;
filesystem "/mnt/asdf/waa";
print 3;
print true;
print "asdf";
"#;
let stmts = parse(prog);
analyze(&stmts).expect("Semantic analysis failed")
}
}
#[test]
fn test_type_cast() {
{
let prog = r#"
(struct btrfs_key) 1;
"#;
let stmts = parse(prog);
analyze(&stmts).expect("Semantic analysis failed");
}
{
let prog = r#"
(struct btrfs_keyz) 1;
"#;
let stmts = parse(prog);
analyze(&stmts).expect_err("Semantic analysis passed");
}
}