use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RestrictedAst {
pub functions: Vec<Function>,
pub entry_point: String,
}
impl RestrictedAst {
pub fn validate(&self) -> Result<(), String> {
if !self.functions.iter().any(|f| f.name == self.entry_point) {
return Err(format!(
"Entry point function '{}' not found",
self.entry_point
));
}
for function in &self.functions {
function.validate()?;
}
Ok(())
}
fn _check_no_recursion(&self) -> Result<(), String> {
let mut call_graph: HashMap<String, Vec<String>> = HashMap::new();
for function in &self.functions {
let mut calls = Vec::new();
function.collect_function_calls(&mut calls);
call_graph.insert(function.name.clone(), calls);
}
for function in &self.functions {
let mut visited = std::collections::HashSet::new();
let mut rec_stack = std::collections::HashSet::new();
if self.has_cycle(&call_graph, &function.name, &mut visited, &mut rec_stack) {
return Err(format!(
"Recursion detected involving function '{}'",
function.name
));
}
}
Ok(())
}
#[allow(dead_code, clippy::only_used_in_recursion)]
fn has_cycle(
&self,
graph: &HashMap<String, Vec<String>>,
node: &str,
visited: &mut std::collections::HashSet<String>,
rec_stack: &mut std::collections::HashSet<String>,
) -> bool {
if rec_stack.contains(node) {
return true;
}
if visited.contains(node) {
return false;
}
visited.insert(node.to_string());
rec_stack.insert(node.to_string());
if let Some(neighbors) = graph.get(node) {
for neighbor in neighbors {
if self.has_cycle(graph, neighbor, visited, rec_stack) {
return true;
}
}
}
rec_stack.remove(node);
false
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Function {
pub name: String,
pub params: Vec<Parameter>,
pub return_type: Type,
pub body: Vec<Stmt>,
}
impl Function {
pub fn validate(&self) -> Result<(), String> {
Self::validate_identifier(&self.name)?;
let mut param_names = std::collections::HashSet::new();
for param in &self.params {
Self::validate_identifier(¶m.name)?;
if !param_names.insert(¶m.name) {
return Err(format!("Duplicate parameter name: {}", param.name));
}
}
for stmt in &self.body {
stmt.validate()?;
}
Ok(())
}
fn validate_identifier(name: &str) -> Result<(), String> {
if name.is_empty() {
return Err("Identifiers cannot be empty".to_string());
}
if name.contains('\0') {
return Err("Null characters not allowed in identifiers".to_string());
}
if name.contains('$') || name.contains('`') || name.contains('\\') {
return Err(format!("Unsafe characters in identifier: {}", name));
}
Ok(())
}
pub fn collect_function_calls(&self, calls: &mut Vec<String>) {
for stmt in &self.body {
stmt.collect_function_calls(calls);
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Parameter {
pub name: String,
pub param_type: Type,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Type {
Void,
Bool,
U16,
U32,
Str,
Result {
ok_type: Box<Type>,
err_type: Box<Type>,
},
Option {
inner_type: Box<Type>,
},
}
impl Type {
pub fn is_allowed(&self) -> bool {
match self {
Type::Void | Type::Bool | Type::U16 | Type::U32 | Type::Str => true,
Type::Result { ok_type, err_type } => ok_type.is_allowed() && err_type.is_allowed(),
Type::Option { inner_type } => inner_type.is_allowed(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Stmt {
Let {
name: String,
value: Expr,
#[serde(default = "default_declaration")]
declaration: bool,
},
Expr(Expr),
Return(Option<Expr>),
If {
condition: Expr,
then_block: Vec<Stmt>,
else_block: Option<Vec<Stmt>>,
},
Match {
scrutinee: Expr,
arms: Vec<MatchArm>,
},
For {
pattern: Pattern,
iter: Expr,
body: Vec<Stmt>,
max_iterations: Option<u32>,
},
While {
condition: Expr,
body: Vec<Stmt>,
max_iterations: Option<u32>,
},
Break,
Continue,
}
fn default_declaration() -> bool {
true
}
impl Stmt {
pub fn validate(&self) -> Result<(), String> {
match self {
Stmt::Let { name, value, .. } => {
Self::validate_identifier(name)?;
value.validate()
}
Stmt::Expr(expr) => expr.validate(),
Stmt::Return(Some(expr)) => expr.validate(),
Stmt::Return(None) => Ok(()),
Stmt::If {
condition,
then_block,
else_block,
} => self.validate_if_stmt(condition, then_block, else_block.as_ref()),
Stmt::Match { scrutinee, arms } => self.validate_match_stmt(scrutinee, arms),
Stmt::For {
pattern,
iter,
body,
max_iterations,
} => self.validate_for_stmt(pattern, iter, body, *max_iterations),
Stmt::While {
condition,
body,
max_iterations,
} => self.validate_while_stmt(condition, body, *max_iterations),
Stmt::Break | Stmt::Continue => Ok(()),
}
}
fn validate_identifier(name: &str) -> Result<(), String> {
if name.is_empty() {
return Err("Identifiers cannot be empty".to_string());
}
if name.contains('\0') {
return Err("Null characters not allowed in identifiers".to_string());
}
if name.contains('$') || name.contains('`') || name.contains('\\') {
return Err(format!("Unsafe characters in identifier: {}", name));
}
Ok(())
}
fn validate_if_stmt(
&self,
condition: &Expr,
then_block: &[Stmt],
else_block: Option<&Vec<Stmt>>,
) -> Result<(), String> {
condition.validate()?;
self.validate_stmt_block(then_block)?;
if let Some(else_stmts) = else_block {
self.validate_stmt_block(else_stmts)?;
}
Ok(())
}
fn validate_match_stmt(&self, scrutinee: &Expr, arms: &[MatchArm]) -> Result<(), String> {
scrutinee.validate()?;
for arm in arms {
arm.pattern.validate()?;
if let Some(guard) = &arm.guard {
guard.validate()?;
}
self.validate_stmt_block(&arm.body)?;
}
Ok(())
}
fn validate_for_stmt(
&self,
pattern: &Pattern,
iter: &Expr,
body: &[Stmt],
max_iterations: Option<u32>,
) -> Result<(), String> {
self.validate_bounded_iteration(max_iterations, "For")?;
pattern.validate()?;
iter.validate()?;
self.validate_stmt_block(body)
}
fn validate_while_stmt(
&self,
condition: &Expr,
body: &[Stmt],
max_iterations: Option<u32>,
) -> Result<(), String> {
self.validate_bounded_iteration(max_iterations, "While")?;
condition.validate()?;
self.validate_stmt_block(body)
}
fn validate_bounded_iteration(
&self,
max_iterations: Option<u32>,
loop_type: &str,
) -> Result<(), String> {
if max_iterations.is_none() {
return Err(format!(
"{loop_type} loops must have bounded iterations for verification"
));
}
Ok(())
}
fn validate_stmt_block(&self, stmts: &[Stmt]) -> Result<(), String> {
for stmt in stmts {
stmt.validate()?;
}
Ok(())
}
pub fn collect_function_calls(&self, calls: &mut Vec<String>) {
match self {
Stmt::Let { value, .. } => value.collect_function_calls(calls),
Stmt::Expr(expr) => expr.collect_function_calls(calls),
Stmt::Return(Some(expr)) => expr.collect_function_calls(calls),
Stmt::Return(None) | Stmt::Break | Stmt::Continue => {}
Stmt::If {
condition,
then_block,
else_block,
} => {
condition.collect_function_calls(calls);
collect_calls_from_block(then_block, calls);
if let Some(else_stmts) = else_block {
collect_calls_from_block(else_stmts, calls);
}
}
Stmt::Match { scrutinee, arms } => {
scrutinee.collect_function_calls(calls);
collect_calls_from_match_arms(arms, calls);
}
Stmt::For { iter, body, .. } => {
iter.collect_function_calls(calls);
collect_calls_from_block(body, calls);
}
Stmt::While {
condition, body, ..
} => {
condition.collect_function_calls(calls);
collect_calls_from_block(body, calls);
}
}
}
}
include!("restricted_expr.rs");