use std::collections::HashSet;
use crate::ast::{
Expr, FnBody, FnDef, Spanned, Stmt, StrPart, TailCallData, TopLevel, TypeDef, TypeVariant,
VerifyBlock, VerifyGivenDomain, VerifyKind,
};
use crate::codegen::CodegenContext;
use crate::types::Type;
pub fn is_pure_fn(fd: &FnDef) -> bool {
fd.effects.is_empty() && fd.name != "main"
}
pub fn is_recursive_type_def(td: &TypeDef) -> bool {
match td {
TypeDef::Sum { name, variants, .. } => is_recursive_sum(name, variants),
TypeDef::Product { name, fields, .. } => is_recursive_product(name, fields),
}
}
pub fn type_def_name(td: &TypeDef) -> &str {
match td {
TypeDef::Sum { name, .. } | TypeDef::Product { name, .. } => name,
}
}
pub fn is_recursive_sum(name: &str, variants: &[TypeVariant]) -> bool {
variants
.iter()
.any(|v| v.fields.iter().any(|f| type_ref_contains(f, name)))
}
pub fn is_recursive_product(name: &str, fields: &[(String, String)]) -> bool {
fields.iter().any(|(_, ty)| type_ref_contains(ty, name))
}
fn type_ref_contains(annotation: &str, type_name: &str) -> bool {
annotation == type_name
|| annotation.contains(&format!("<{}", type_name))
|| annotation.contains(&format!("{}>", type_name))
|| annotation.contains(&format!(", {}", type_name))
|| annotation.contains(&format!("{},", type_name))
}
pub(crate) fn is_user_type(name: &str, ctx: &CodegenContext) -> bool {
let check_td = |td: &TypeDef| match td {
TypeDef::Sum { name: n, .. } => n == name,
TypeDef::Product { name: n, .. } => n == name,
};
ctx.type_defs.iter().any(check_td)
|| ctx.modules.iter().any(|m| m.type_defs.iter().any(check_td))
}
pub(crate) fn resolve_module_call<'a>(
dotted_name: &'a str,
ctx: &'a CodegenContext,
) -> Option<(&'a str, &'a str)> {
let mut best: Option<&str> = None;
for prefix in &ctx.module_prefixes {
let dotted_prefix = format!("{}.", prefix);
if dotted_name.starts_with(&dotted_prefix) && best.is_none_or(|b| prefix.len() > b.len()) {
best = Some(prefix.as_str());
}
}
best.map(|prefix| (prefix, &dotted_name[prefix.len() + 1..]))
}
pub(crate) fn module_prefix_to_rust_segments(prefix: &str) -> Vec<String> {
prefix.split('.').map(module_segment_to_rust).collect()
}
pub(crate) fn module_prefix_to_filename(prefix: &str) -> String {
prefix.replace('.', "/")
}
pub(crate) struct DeclaredEffects {
pub bare_namespaces: HashSet<String>,
pub methods: HashSet<String>,
}
impl DeclaredEffects {
pub fn includes(&self, c_method: &str) -> bool {
if self.methods.contains(c_method) {
return true;
}
if let Some((ns, _)) = c_method.split_once('.') {
return self.bare_namespaces.contains(ns);
}
false
}
}
pub(crate) fn collect_declared_effects(ctx: &CodegenContext) -> DeclaredEffects {
let mut bare_namespaces: HashSet<String> = HashSet::new();
let mut methods: HashSet<String> = HashSet::new();
let mut record = |effect: &str| {
if effect.contains('.') {
methods.insert(effect.to_string());
} else {
bare_namespaces.insert(effect.to_string());
}
};
for item in &ctx.items {
if let TopLevel::FnDef(fd) = item {
for eff in &fd.effects {
record(&eff.node);
}
}
}
for module in &ctx.modules {
for fd in &module.fn_defs {
for eff in &fd.effects {
record(&eff.node);
}
}
}
DeclaredEffects {
bare_namespaces,
methods,
}
}
pub fn entry_basename(ctx: &CodegenContext) -> String {
ctx.items
.iter()
.find_map(|item| match item {
TopLevel::Module(m) => Some(m.name.clone()),
_ => None,
})
.unwrap_or_else(|| {
let mut chars = ctx.project_name.chars();
match chars.next() {
None => String::new(),
Some(c) => c.to_uppercase().chain(chars).collect(),
}
})
}
pub(crate) fn fn_owning_scope(ctx: &CodegenContext) -> std::collections::HashMap<String, String> {
let mut scope = std::collections::HashMap::new();
for m in &ctx.modules {
for fd in &m.fn_defs {
scope.insert(fd.name.clone(), m.prefix.clone());
}
}
for fd in &ctx.fn_defs {
scope.insert(fd.name.clone(), String::new());
}
scope
}
pub(crate) fn module_prefix_to_rust_path(prefix: &str) -> String {
format!(
"crate::aver_generated::{}",
module_prefix_to_rust_segments(prefix).join("::")
)
}
fn module_segment_to_rust(segment: &str) -> String {
let chars = segment.chars().collect::<Vec<_>>();
let mut out = String::new();
for (idx, ch) in chars.iter().enumerate() {
if ch.is_ascii_alphanumeric() {
if ch.is_ascii_uppercase() {
let prev_is_lower_or_digit = idx > 0
&& (chars[idx - 1].is_ascii_lowercase() || chars[idx - 1].is_ascii_digit());
let next_is_lower = chars
.get(idx + 1)
.is_some_and(|next| next.is_ascii_lowercase());
if idx > 0 && (prev_is_lower_or_digit || next_is_lower) && !out.ends_with('_') {
out.push('_');
}
out.push(ch.to_ascii_lowercase());
} else {
out.push(ch.to_ascii_lowercase());
}
} else if !out.ends_with('_') {
out.push('_');
}
}
let trimmed = out.trim_matches('_');
let mut normalized = if trimmed.is_empty() {
"module".to_string()
} else {
trimmed.to_string()
};
if matches!(
normalized.as_str(),
"as" | "break"
| "const"
| "continue"
| "crate"
| "else"
| "enum"
| "extern"
| "false"
| "fn"
| "for"
| "if"
| "impl"
| "in"
| "let"
| "loop"
| "match"
| "mod"
| "move"
| "mut"
| "pub"
| "ref"
| "return"
| "self"
| "Self"
| "static"
| "struct"
| "super"
| "trait"
| "true"
| "type"
| "unsafe"
| "use"
| "where"
| "while"
) {
normalized.push_str("_mod");
}
normalized
}
pub(crate) fn split_type_params(s: &str, delim: char) -> Vec<String> {
let mut parts = Vec::new();
let mut depth = 0usize;
let mut current = String::new();
for ch in s.chars() {
match ch {
'<' | '(' => {
depth += 1;
current.push(ch);
}
'>' | ')' => {
depth = depth.saturating_sub(1);
current.push(ch);
}
_ if ch == delim && depth == 0 => {
parts.push(current.trim().to_string());
current.clear();
}
_ => current.push(ch),
}
}
let rest = current.trim().to_string();
if !rest.is_empty() {
parts.push(rest);
}
parts
}
pub(crate) fn escape_string_literal_ext(s: &str, unicode_escapes: bool) -> String {
let mut out = String::with_capacity(s.len());
for ch in s.chars() {
match ch {
'\\' => out.push_str("\\\\"),
'"' => out.push_str("\\\""),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
'\0' => out.push_str("\\0"),
c if c.is_control() => {
if unicode_escapes {
out.push_str(&format!("\\U{{{:06x}}}", c as u32));
} else {
out.push_str(&format!("\\x{:02x}", c as u32));
}
}
c => out.push(c),
}
}
out
}
pub(crate) fn escape_string_literal(s: &str) -> String {
escape_string_literal_ext(s, false)
}
pub(crate) fn escape_string_literal_unicode(s: &str) -> String {
escape_string_literal_ext(s, true)
}
pub(crate) fn parse_type_annotation(ann: &str) -> Type {
crate::types::parse_type_str(ann)
}
pub(crate) fn is_set_type(ty: &Type) -> bool {
matches!(ty, Type::Map(_, v) if matches!(v.as_ref(), Type::Unit))
}
pub(crate) fn is_set_annotation(ann: &str) -> bool {
is_set_type(&parse_type_annotation(ann))
}
pub(crate) fn is_unit_expr(expr: &crate::ast::Expr) -> bool {
matches!(expr, crate::ast::Expr::Literal(crate::ast::Literal::Unit))
}
pub(crate) fn is_unit_expr_spanned(expr: &crate::ast::Spanned<crate::ast::Expr>) -> bool {
is_unit_expr(&expr.node)
}
pub(crate) fn escape_reserved_word(name: &str, reserved: &[&str], suffix: &str) -> String {
if reserved.contains(&name) {
format!("{}{}", name, suffix)
} else {
name.to_string()
}
}
pub(crate) fn escape_reserved_word_prefix(name: &str, reserved: &[&str], prefix: &str) -> String {
if reserved.contains(&name) {
format!("{}{}", prefix, name)
} else {
name.to_string()
}
}
pub(crate) fn to_lower_first(s: &str) -> String {
let mut chars = s.chars();
match chars.next() {
None => String::new(),
Some(c) => c.to_lowercase().to_string() + chars.as_str(),
}
}
pub(crate) fn expr_to_dotted_name(expr: &Expr) -> Option<String> {
crate::ir::expr_to_dotted_name(expr)
}
#[derive(Debug, Clone)]
pub(crate) enum OracleInjectionMode<'a> {
LemmaBinding,
LemmaBindingProjected,
#[allow(dead_code)]
SampleValue,
SampleCaseBinding(&'a [(String, crate::ast::Spanned<Expr>)]),
}
pub(crate) fn rewrite_effectful_calls_in_law(
expr: &crate::ast::Spanned<Expr>,
law: &crate::ast::VerifyLaw,
ctx: &CodegenContext,
mode: OracleInjectionMode,
) -> crate::ast::Spanned<Expr> {
use crate::ast::{Spanned, VerifyGivenDomain};
let injection_by_effect: std::collections::HashMap<String, Spanned<Expr>> = law
.givens
.iter()
.filter_map(|g| {
let arg_expr = match &mode {
OracleInjectionMode::LemmaBinding => Spanned {
node: Expr::Ident(g.name.clone()),
line: expr.line,
},
OracleInjectionMode::LemmaBindingProjected => {
Spanned {
node: Expr::Ident(g.name.clone()),
line: expr.line,
}
}
OracleInjectionMode::SampleValue => match &g.domain {
VerifyGivenDomain::Explicit(vals) => vals.first().cloned()?,
_ => return None,
},
OracleInjectionMode::SampleCaseBinding(case_bindings) => case_bindings
.iter()
.find(|(name, _)| name == &g.name)
.map(|(_, v)| v.clone())?,
};
Some((g.type_name.clone(), arg_expr))
})
.collect();
let rewritten = rewrite_effectful_call(expr, &injection_by_effect, ctx);
if matches!(mode, OracleInjectionMode::LemmaBindingProjected) {
let oracle_names: std::collections::HashSet<String> = law
.givens
.iter()
.filter(|g| {
matches!(
crate::types::checker::effect_classification::classify(&g.type_name)
.map(|c| c.dimension),
Some(crate::types::checker::effect_classification::EffectDimension::Generative)
| Some(
crate::types::checker::effect_classification::EffectDimension::GenerativeOutput
)
)
})
.map(|g| g.name.clone())
.collect();
if !oracle_names.is_empty() {
return project_oracle_direct_calls(&rewritten, &oracle_names);
}
}
rewritten
}
fn project_oracle_direct_calls(
expr: &crate::ast::Spanned<Expr>,
oracle_names: &std::collections::HashSet<String>,
) -> crate::ast::Spanned<Expr> {
use crate::ast::Spanned;
let line = expr.line;
let project_ident = |name: &str, line: usize| -> Spanned<Expr> {
Spanned {
node: Expr::Attr(
Box::new(Spanned {
node: Expr::Ident(name.to_string()),
line,
}),
"val".to_string(),
),
line,
}
};
let new_node = match &expr.node {
Expr::Ident(name) if oracle_names.contains(name) => {
return project_ident(name, line);
}
Expr::FnCall(callee, args) => {
let new_args: Vec<Spanned<Expr>> = args
.iter()
.map(|a| project_oracle_direct_calls(a, oracle_names))
.collect();
let new_callee = if let Expr::Ident(name) = &callee.node
&& oracle_names.contains(name)
{
project_ident(name, callee.line)
} else {
project_oracle_direct_calls(callee, oracle_names)
};
Expr::FnCall(Box::new(new_callee), new_args)
}
Expr::Constructor(name, Some(arg)) => Expr::Constructor(
name.clone(),
Some(Box::new(project_oracle_direct_calls(arg, oracle_names))),
),
Expr::Attr(obj, field) => Expr::Attr(
Box::new(project_oracle_direct_calls(obj, oracle_names)),
field.clone(),
),
Expr::BinOp(op, l, r) => Expr::BinOp(
*op,
Box::new(project_oracle_direct_calls(l, oracle_names)),
Box::new(project_oracle_direct_calls(r, oracle_names)),
),
other => other.clone(),
};
Spanned {
node: new_node,
line,
}
}
fn rewrite_effectful_call(
expr: &crate::ast::Spanned<Expr>,
injection_by_effect: &std::collections::HashMap<String, crate::ast::Spanned<Expr>>,
ctx: &CodegenContext,
) -> crate::ast::Spanned<Expr> {
use crate::ast::Spanned;
use crate::types::checker::effect_classification::{EffectDimension, classify};
match &expr.node {
Expr::FnCall(callee, args) => {
let rewritten_args: Vec<Spanned<Expr>> = args
.iter()
.map(|a| rewrite_effectful_call(a, injection_by_effect, ctx))
.collect();
let rewritten_callee =
Box::new(rewrite_effectful_call(callee, injection_by_effect, ctx));
let callee_name = match &callee.node {
Expr::Ident(name) => Some(name.clone()),
Expr::Resolved { name, .. } => Some(name.clone()),
_ => None,
};
if let Some(name) = callee_name
&& let Some(fd) = ctx.fn_defs.iter().find(|fd| fd.name == name)
&& !fd.effects.is_empty()
&& fd
.effects
.iter()
.all(|e| crate::types::checker::effect_classification::is_classified(&e.node))
{
let mut injected: Vec<Spanned<Expr>> = Vec::new();
let needs_path = fd.effects.iter().any(|e| {
matches!(
classify(&e.node).map(|c| c.dimension),
Some(EffectDimension::Generative | EffectDimension::GenerativeOutput)
)
});
if needs_path {
injected.push(Spanned {
node: Expr::Attr(
Box::new(Spanned {
node: Expr::Ident("BranchPath".to_string()),
line: expr.line,
}),
"Root".to_string(),
),
line: expr.line,
});
}
let mut seen = std::collections::HashSet::new();
for e in &fd.effects {
if !seen.insert(e.node.clone()) {
continue;
}
let Some(c) = classify(&e.node) else { continue };
if matches!(c.dimension, EffectDimension::Output) {
continue;
}
if let Some(inj) = injection_by_effect.get(&e.node) {
injected.push(inj.clone());
}
}
injected.extend(rewritten_args);
return Spanned {
node: Expr::FnCall(rewritten_callee, injected),
line: expr.line,
};
}
Spanned {
node: Expr::FnCall(rewritten_callee, rewritten_args),
line: expr.line,
}
}
Expr::BinOp(op, l, r) => Spanned {
node: Expr::BinOp(
*op,
Box::new(rewrite_effectful_call(l, injection_by_effect, ctx)),
Box::new(rewrite_effectful_call(r, injection_by_effect, ctx)),
),
line: expr.line,
},
Expr::Tuple(items) => Spanned {
node: Expr::Tuple(
items
.iter()
.map(|i| rewrite_effectful_call(i, injection_by_effect, ctx))
.collect(),
),
line: expr.line,
},
_ => expr.clone(),
}
}
pub(crate) fn verify_reachable_fn_names(items: &[TopLevel]) -> HashSet<String> {
let mut reachable: HashSet<String> = HashSet::new();
for item in items {
if let TopLevel::Verify(vb) = item {
collect_verify_block_refs(vb, &mut reachable);
}
}
loop {
let mut changed = false;
for item in items {
if let TopLevel::FnDef(fd) = item
&& reachable.contains(&fd.name)
{
let mut called = HashSet::new();
collect_called_idents_in_body(&fd.body, &mut called);
for name in called {
if reachable.insert(name) {
changed = true;
}
}
}
}
if !changed {
break;
}
}
reachable
}
fn collect_verify_block_refs(vb: &VerifyBlock, out: &mut HashSet<String>) {
out.insert(vb.fn_name.clone());
for (lhs, rhs) in &vb.cases {
collect_called_idents(lhs, out);
collect_called_idents(rhs, out);
}
if let VerifyKind::Law(law) = &vb.kind {
collect_called_idents(&law.lhs, out);
collect_called_idents(&law.rhs, out);
if let Some(when) = &law.when {
collect_called_idents(when, out);
}
for given in &law.givens {
if let VerifyGivenDomain::Explicit(values) = &given.domain {
for v in values {
collect_called_idents(v, out);
}
}
}
}
for given in &vb.cases_givens {
if let VerifyGivenDomain::Explicit(values) = &given.domain {
for v in values {
collect_called_idents(v, out);
}
}
}
}
fn collect_called_idents_in_body(body: &FnBody, out: &mut HashSet<String>) {
for stmt in body.stmts() {
match stmt {
Stmt::Binding(_, _, e) | Stmt::Expr(e) => collect_called_idents(e, out),
}
}
}
fn collect_called_idents(expr: &Spanned<Expr>, out: &mut HashSet<String>) {
match &expr.node {
Expr::FnCall(callee, args) => {
if let Expr::Ident(name) | Expr::Resolved { name, .. } = &callee.node {
out.insert(name.clone());
} else {
collect_called_idents(callee, out);
}
for a in args {
collect_called_idents(a, out);
}
}
Expr::TailCall(boxed) => {
let TailCallData { target, args, .. } = boxed.as_ref();
out.insert(target.clone());
for a in args {
collect_called_idents(a, out);
}
}
Expr::Ident(name) | Expr::Resolved { name, .. } => {
out.insert(name.clone());
}
Expr::BinOp(_, l, r) => {
collect_called_idents(l, out);
collect_called_idents(r, out);
}
Expr::Match { subject, arms, .. } => {
collect_called_idents(subject, out);
for arm in arms {
collect_called_idents(&arm.body, out);
}
}
Expr::ErrorProp(inner) | Expr::Attr(inner, _) => {
collect_called_idents(inner, out);
}
Expr::Constructor(_, Some(inner)) => {
collect_called_idents(inner, out);
}
Expr::InterpolatedStr(parts) => {
for part in parts {
if let StrPart::Parsed(inner) = part {
collect_called_idents(inner, out);
}
}
}
Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
for i in items {
collect_called_idents(i, out);
}
}
Expr::MapLiteral(entries) => {
for (k, v) in entries {
collect_called_idents(k, out);
collect_called_idents(v, out);
}
}
Expr::RecordCreate { fields, .. } => {
for (_, v) in fields {
collect_called_idents(v, out);
}
}
Expr::RecordUpdate { base, updates, .. } => {
collect_called_idents(base, out);
for (_, v) in updates {
collect_called_idents(v, out);
}
}
Expr::Literal(_) | Expr::Constructor(_, None) => {}
}
}
pub(crate) struct PerScopeSections {
pub by_scope: std::collections::HashMap<String, Vec<String>>,
}
impl PerScopeSections {
pub(crate) fn take(&mut self, scope: &str) -> Vec<String> {
self.by_scope.remove(scope).unwrap_or_default()
}
}
pub(crate) fn route_pure_components_per_scope<F, G>(
ctx: &CodegenContext,
is_pure: F,
mut emit: G,
) -> PerScopeSections
where
F: Fn(&FnDef) -> bool,
G: FnMut(&[&FnDef]) -> Vec<String>,
{
let mut by_scope: std::collections::HashMap<String, Vec<String>> =
std::collections::HashMap::new();
let mut process =
|fns: Vec<&FnDef>,
scope: String,
by_scope: &mut std::collections::HashMap<String, Vec<String>>| {
let comps = crate::call_graph::ordered_fn_components(&fns, &ctx.module_prefixes);
let bucket = by_scope.entry(scope).or_default();
for comp in comps {
bucket.extend(emit(&comp));
}
};
for module in &ctx.modules {
let pure: Vec<&FnDef> = module.fn_defs.iter().filter(|fd| is_pure(fd)).collect();
process(pure, module.prefix.clone(), &mut by_scope);
}
let entry_pure: Vec<&FnDef> = ctx.fn_defs.iter().filter(|fd| is_pure(fd)).collect();
process(entry_pure, String::new(), &mut by_scope);
PerScopeSections { by_scope }
}