use std::collections::HashSet;
use std::fs;
use colored::Colorize;
use aver::ast::TopLevel;
use aver::call_graph::{find_recursive_fns, recursive_callsite_counts};
use aver::interpreter::{Interpreter, Value};
use aver::resolver;
use aver::source::{parse_source, require_module_declaration};
use aver::tco;
use aver::types;
use aver::types::checker::{TypeCheckResult, TypeError, run_type_check_full};
use aver::vm;
pub(super) fn read_file(path: &str) -> Result<String, String> {
fs::read_to_string(path).map_err(|e| format!("Cannot open file '{}': {}", path, e))
}
pub(super) fn parse_file(source: &str) -> Result<Vec<TopLevel>, String> {
parse_source(source)
}
pub(super) fn resolve_module_root(module_root: Option<&str>) -> String {
if let Some(root) = module_root {
return root.to_string();
}
std::env::current_dir()
.ok()
.and_then(|p| p.into_os_string().into_string().ok())
.unwrap_or_else(|| ".".to_string())
}
pub(super) fn load_runtime_policy(
module_root: &str,
) -> Result<Option<aver::config::ProjectConfig>, String> {
aver::config::ProjectConfig::load_from_dir(std::path::Path::new(module_root))
.map_err(|e| format!("aver.toml: {}", e))
}
pub(super) fn apply_runtime_policy_to_vm(
machine: &mut vm::VM,
module_root: &str,
) -> Result<(), String> {
if let Some(config) = load_runtime_policy(module_root)? {
machine.set_runtime_policy(config);
}
Ok(())
}
pub(super) fn load_dep_modules(
interp: &mut Interpreter,
items: &[TopLevel],
module_root: &str,
) -> Result<(), String> {
let mut loading = Vec::new();
let mut loading_set = std::collections::HashSet::new();
if let Some(module) = items.iter().find_map(|i| {
if let TopLevel::Module(m) = i {
Some(m)
} else {
None
}
}) {
for dep_name in &module.depends {
let ns = interp
.load_module(dep_name, module_root, &mut loading, &mut loading_set)
.map_err(|e| e.to_string())?;
interp
.define_module_path(dep_name, ns)
.map_err(|e| e.to_string())?;
}
}
Ok(())
}
pub(super) fn print_type_errors(errors: &[TypeError]) {
for te in errors {
eprintln!(
"{}",
format!("error[{}:{}]: {}", te.line, te.col, te.message).red()
);
}
}
pub(super) fn compute_memo_fns(items: &[TopLevel], tc_result: &TypeCheckResult) -> HashSet<String> {
let recursive = find_recursive_fns(items);
let recursive_calls = recursive_callsite_counts(items);
let mut memo = HashSet::new();
for fn_name in &recursive {
if let Some((params, _ret, effects)) = tc_result.fn_sigs.get(fn_name) {
if !effects.is_empty() {
continue;
}
if recursive_calls.get(fn_name).copied().unwrap_or(0) < 2 {
continue;
}
let all_safe = params
.iter()
.all(|ty| is_memo_safe_type(ty, &tc_result.memo_safe_types));
if all_safe {
memo.insert(fn_name.clone());
}
}
}
memo
}
pub(super) fn is_memo_safe_type(ty: &types::Type, safe_named: &HashSet<String>) -> bool {
use aver::types::Type;
match ty {
Type::Int | Type::Float | Type::Bool | Type::Unit => true,
Type::Str => false,
Type::Tuple(items) => items.iter().all(|item| is_memo_safe_type(item, safe_named)),
Type::List(_) | Type::Vector(_) | Type::Map(_, _) | Type::Fn(_, _, _) | Type::Unknown => {
false
}
Type::Result(_, _) | Type::Option(_) => false,
Type::Named(name) => safe_named.contains(name),
}
}
pub(super) fn format_type_errors(errors: &[TypeError]) -> String {
let mut out = Vec::new();
for te in errors {
out.push(format!("error[{}:{}]: {}", te.line, te.col, te.message));
}
out.join("\n")
}
pub(super) fn compile_program_for_exec(
file: &str,
module_root_override: Option<&str>,
) -> Result<(Interpreter, Vec<TopLevel>, String), String> {
let module_root = resolve_module_root(module_root_override);
let source = read_file(file)?;
let mut items = parse_file(&source)?;
require_module_declaration(&items, file)?;
tco::transform_program(&mut items);
let tc_result = run_type_check_full(&items, Some(&module_root));
if !tc_result.errors.is_empty() {
return Err(format_type_errors(&tc_result.errors));
}
resolver::resolve_program(&mut items);
let memo_fns = compute_memo_fns(&items, &tc_result);
let mut interp = Interpreter::new();
interp.enable_memo(memo_fns);
if let Some(config) = load_runtime_policy(&module_root)? {
interp.set_runtime_policy(config);
}
load_dep_modules(&mut interp, &items, &module_root)?;
for item in &items {
if let TopLevel::TypeDef(td) = item {
interp.register_type_def(td);
}
}
for item in &items {
if let TopLevel::FnDef(fd) = item {
interp.exec_fn_def(fd).map_err(|e| e.to_string())?;
}
}
Ok((interp, items, module_root))
}
pub(super) fn run_top_level_statements(
interp: &mut Interpreter,
items: &[TopLevel],
) -> Result<(), String> {
for item in items {
if let TopLevel::Stmt(stmt) = item {
interp.exec_stmt(stmt).map_err(|e| e.to_string())?;
}
}
Ok(())
}
pub(super) fn run_entry_function(
interp: &mut Interpreter,
entry_fn: &str,
args: Vec<Value>,
) -> Result<Value, String> {
let fn_val = interp
.lookup(entry_fn)
.map_err(|_| format!("Entry function '{}' not found", entry_fn))?;
let allowed = Interpreter::callable_declared_effects(&fn_val);
interp
.call_value_with_effects_pub(fn_val, args, &format!("<{}>", entry_fn), allowed)
.map_err(|e| e.to_string())
}