use crate::ast::*;
use crate::codegen::CodegenContext;
use crate::codegen::common::{expr_to_dotted_name, is_user_type, resolve_module_call};
const DAFNY_RESERVED: &[&str] = &[
"abstract",
"allocated",
"as",
"assert",
"assume",
"bool",
"break",
"by",
"calc",
"case",
"char",
"class",
"codatatype",
"colemma",
"constructor",
"copredicate",
"datatype",
"decreases",
"default",
"else",
"ensures",
"exists",
"expect",
"export",
"extends",
"false",
"forall",
"fresh",
"function",
"ghost",
"if",
"import",
"in",
"include",
"int",
"invariant",
"is",
"iterator",
"label",
"lemma",
"map",
"match",
"method",
"modifies",
"modify",
"module",
"multiset",
"nat",
"new",
"newtype",
"null",
"object",
"old",
"opened",
"predicate",
"print",
"provides",
"reads",
"real",
"refines",
"requires",
"return",
"returns",
"reveal",
"reveals",
"seq",
"set",
"static",
"string",
"then",
"this",
"trait",
"true",
"twostate",
"type",
"unchanged",
"var",
"while",
"witness",
"yield",
"yields",
];
pub fn aver_name_to_dafny(name: &str) -> String {
crate::codegen::common::escape_reserved_word(name, DAFNY_RESERVED, "_")
}
pub fn emit_expr(expr: &Spanned<Expr>, ctx: &CodegenContext) -> String {
match &expr.node {
Expr::Literal(lit) => emit_literal(lit),
Expr::Ident(name) => aver_name_to_dafny(name),
Expr::Resolved(slot) => {
panic!(
"Dafny codegen: encountered resolver-only Expr::Resolved({slot}). \
Compile pipeline should emit source-level AST."
)
}
Expr::Attr(obj, field) => {
if let Expr::Ident(type_name) = &obj.node {
if type_name == "Option" && field == "None" {
return "Option.None".to_string();
}
if is_user_type(type_name, ctx) {
return format!("{}.{}", type_name, field);
}
}
if let Some(full_dotted) = expr_to_dotted_name_spanned(expr)
&& let Some((_, bare)) = resolve_module_call(&full_dotted, ctx)
{
if let Some(dot_pos) = bare.find('.') {
let type_name = &bare[..dot_pos];
let variant = &bare[dot_pos + 1..];
if is_user_type(type_name, ctx) {
return format!("{}.{}", type_name, variant);
}
}
return aver_name_to_dafny(bare);
}
let obj_str = emit_expr(obj, ctx);
format!("{}.{}", obj_str, aver_name_to_dafny(field))
}
Expr::FnCall(fn_expr, args) => emit_fn_call(fn_expr, args, ctx),
Expr::BinOp(op, left, right) => {
if matches!(op, BinOp::Sub) && matches!(left.node, Expr::Literal(Literal::Int(0))) {
let r = emit_expr(right, ctx);
return format!("(-{})", r);
}
let l = emit_expr(left, ctx);
let r = emit_expr(right, ctx);
let op_str = match op {
BinOp::Add => "+",
BinOp::Sub => "-",
BinOp::Mul => "*",
BinOp::Div => "/",
BinOp::Eq => "==",
BinOp::Neq => "!=",
BinOp::Lt => "<",
BinOp::Gt => ">",
BinOp::Lte => "<=",
BinOp::Gte => ">=",
};
format!("({} {} {})", l, op_str, r)
}
Expr::Match { subject, arms, .. } => emit_match(subject, arms, ctx),
Expr::Constructor(name, arg) => emit_constructor(name, arg.as_deref(), ctx),
Expr::ErrorProp(_) => {
"/* ERROR: ? operator not supported in Dafny pure functions */".to_string()
}
Expr::InterpolatedStr(parts) => emit_interpolated_str(parts, ctx),
Expr::List(elems) => {
let items: Vec<String> = elems.iter().map(|e| emit_expr(e, ctx)).collect();
format!("[{}]", items.join(", "))
}
Expr::Tuple(elems) | Expr::IndependentProduct(elems, _) => {
let items: Vec<String> = elems.iter().map(|e| emit_expr(e, ctx)).collect();
format!("({})", items.join(", "))
}
Expr::MapLiteral(entries) => {
if entries.is_empty() {
"map[]".to_string()
} else if entries
.iter()
.all(|(_, v)| crate::codegen::common::is_unit_expr_spanned(v))
{
let items: Vec<String> = entries.iter().map(|(k, _)| emit_expr(k, ctx)).collect();
format!("{{{}}}", items.join(", "))
} else {
let items: Vec<String> = entries
.iter()
.map(|(k, v)| format!("{} := {}", emit_expr(k, ctx), emit_expr(v, ctx)))
.collect();
format!("map[{}]", items.join(", "))
}
}
Expr::RecordCreate { type_name, fields } => {
let field_strs: Vec<String> = fields
.iter()
.map(|(name, expr)| {
format!("{} := {}", aver_name_to_dafny(name), emit_expr(expr, ctx))
})
.collect();
format!("{}({})", type_name, field_strs.join(", "))
}
Expr::RecordUpdate { base, updates, .. } => {
let base_str = emit_expr(base, ctx);
let update_strs: Vec<String> = updates
.iter()
.map(|(name, expr)| {
format!("{} := {}", aver_name_to_dafny(name), emit_expr(expr, ctx))
})
.collect();
format!("{}.({})", base_str, update_strs.join(", "))
}
Expr::TailCall(inner) => {
let (name, args) = inner.as_ref();
let arg_strs: Vec<String> = args.iter().map(|a| emit_expr(a, ctx)).collect();
format!("{}({})", aver_name_to_dafny(name), arg_strs.join(", "))
}
}
}
fn expr_to_dotted_name_spanned(expr: &Spanned<Expr>) -> Option<String> {
expr_to_dotted_name(&expr.node)
}
fn emit_literal(lit: &Literal) -> String {
match lit {
Literal::Int(n) => n.to_string(),
Literal::Float(f) => {
let s = f.to_string();
if s.contains('.') {
format!("{} as real", s)
} else {
format!("{}.0 as real", s)
}
}
Literal::Str(s) => {
format!(
"\"{}\"",
crate::codegen::common::escape_string_literal_unicode(s)
)
}
Literal::Bool(b) => b.to_string(),
Literal::Unit => "()".to_string(),
}
}
fn emit_fn_call(fn_expr: &Spanned<Expr>, args: &[Spanned<Expr>], ctx: &CodegenContext) -> String {
use crate::codegen::builtins::recognize_builtin;
use crate::codegen::common::is_unit_expr_spanned;
let dotted = expr_to_dotted_name_spanned(fn_expr);
if let Some(name) = dotted.as_deref()
&& name == "Map.set"
&& args.len() == 3
&& is_unit_expr_spanned(&args[2])
{
let m = emit_expr(&args[0], ctx);
let k = emit_expr(&args[1], ctx);
return format!("({} + {{{}}})", m, k);
}
if let Some(builtin) = dotted.as_deref().and_then(recognize_builtin) {
let a: Vec<String> = args.iter().map(|e| emit_expr(e, ctx)).collect();
return emit_dafny_builtin(builtin, &a);
}
let fn_name = emit_expr(fn_expr, ctx);
let arg_strs: Vec<String> = args.iter().map(|e| emit_expr(e, ctx)).collect();
format!("{}({})", fn_name, arg_strs.join(", "))
}
fn emit_dafny_builtin(b: crate::codegen::builtins::Builtin, a: &[String]) -> String {
use crate::codegen::builtins::Builtin::*;
match b {
ResultOk => format!("Result.Ok({})", a.first().map(|s| s.as_str()).unwrap_or("")),
ResultErr => format!(
"Result.Err({})",
a.first().map(|s| s.as_str()).unwrap_or("")
),
OptionSome => format!(
"Option.Some({})",
a.first().map(|s| s.as_str()).unwrap_or("")
),
ResultWithDefault => format!("ResultWithDefault({}, {})", a[0], a[1]),
OptionWithDefault => format!("OptionWithDefault({}, {})", a[0], a[1]),
OptionToResult => format!("OptionToResult({}, {})", a[0], a[1]),
IntAbs => format!("(if {} >= 0 then {} else -{})", a[0], a[0], a[0]),
IntToFloat => format!("({} as real)", a[0]),
IntToString | StringFromInt => format!("IntToString({})", a[0]),
IntFromString | IntParse => format!("IntFromString({})", a[0]),
IntMin => format!("(if {} <= {} then {} else {})", a[0], a[1], a[0], a[1]),
IntMax => format!("(if {} >= {} then {} else {})", a[0], a[1], a[0], a[1]),
IntRem | IntMod => format!("Result<int, string>.Ok(({} % {}))", a[0], a[1]),
FloatAbs => format!("(if {} >= 0.0 then {} else -{})", a[0], a[0], a[0]),
FloatSqrt => format!("FloatSqrt({})", a[0]),
FloatPow => format!("FloatPow({}, {})", a[0], a[1]),
FloatRound | FloatFloor | FloatCeil | FloatToInt => format!("FloatToInt({})", a[0]),
FloatToString | StringFromFloat => format!("FloatToString({})", a[0]),
FloatFromString | FloatParse => format!("FloatFromString({})", a[0]),
StringLen => format!("|{}|", a[0]),
StringConcat => format!("({} + {})", a[0], a[1]),
StringCharAt => format!("StringCharAt({}, {})", a[0], a[1]),
StringChars => format!("StringChars({})", a[0]),
StringSlice => format!("{}[{}..{}]", a[0], a[1], a[2]),
StringContains => format!("StringContains({}, {})", a[0], a[1]),
StringStartsWith => format!("StringStartsWith({}, {})", a[0], a[1]),
StringEndsWith => format!("StringEndsWith({}, {})", a[0], a[1]),
StringTrim => format!("StringTrim({})", a[0]),
StringSplit => format!("StringSplit({}, {})", a[0], a[1]),
StringJoin => format!("StringJoin({}, {})", a[1], a[0]), StringReplace => format!("StringReplace({}, {}, {})", a[0], a[1], a[2]),
StringRepeat => format!("StringRepeat({}, {})", a[0], a[1]),
StringIndexOf => format!("StringIndexOf({}, {})", a[0], a[1]),
StringToUpper => format!("StringToUpper({})", a[0]),
StringToLower => format!("StringToLower({})", a[0]),
StringFromBool => format!("StringFromBool({})", a[0]),
StringByteLength => format!("StringByteLength({})", a[0]),
BoolOr => format!("({} || {})", a[0], a[1]),
BoolAnd => format!("({} && {})", a[0], a[1]),
BoolNot => format!("(!{})", a[0]),
CharToCode => format!("CharToCode({})", a[0]),
CharFromCode => format!("CharFromCode({})", a[0]),
ByteToHex => format!("ByteToHex({})", a[0]),
ByteFromHex => format!("ByteFromHex({})", a[0]),
ListLen => format!("|{}|", a[0]),
ListHead => format!("ListHead({})", a[0]),
ListTail => format!("ListTail({})", a[0]),
ListPrepend => format!("[{}] + {}", a[0], a[1]),
ListTake => format!("ListTake({}, {})", a[0], a[1]),
ListDrop => format!("ListDrop({}, {})", a[0], a[1]),
ListConcat => format!("({} + {})", a[0], a[1]),
ListReverse => format!("ListReverse({})", a[0]),
ListContains => format!("({} in {})", a[1], a[0]),
ListFind => format!("ListFind({}, {})", a[0], a[1]),
ListAny => format!("ListAny({}, {})", a[0], a[1]),
ListZip => format!("ListZip({}, {})", a[0], a[1]),
VectorNew => format!("seq({}, _ => {})", a[0], a[1]),
VectorGet => format!(
"if 0 <= {} < |{}| then Some({}[{}]) else None",
a[1], a[0], a[0], a[1]
),
VectorSet => format!(
"if 0 <= {} < |{}| then Some({}[{} := {}]) else None",
a[1], a[0], a[0], a[1], a[2]
),
VectorLen => format!("|{}|", a[0]),
VectorFromList => a[0].clone(),
VectorToList => a[0].clone(),
MapEmpty => "map[]".to_string(),
MapGet => format!("MapGet({}, {})", a[0], a[1]),
MapSet => format!("{}[{} := {}]", a[0], a[1], a[2]),
MapHas => format!("({} in {})", a[1], a[0]),
MapRemove => format!("({} - {{{}}})", a[0], a[1]),
MapKeys => format!("MapKeys({})", a[0]),
MapValues => format!("MapValues({})", a[0]),
MapEntries => format!("MapEntries({})", a[0]),
MapLen => format!("|{}|", a[0]),
MapFromList => format!("MapFromList({})", a[0]),
}
}
fn emit_match(subject: &Spanned<Expr>, arms: &[MatchArm], ctx: &CodegenContext) -> String {
if has_list_patterns(arms) {
return emit_list_match(subject, arms, ctx);
}
if is_bool_match(arms) {
return emit_bool_match(subject, arms, ctx);
}
if should_emit_as_if_chain(arms) {
return emit_if_chain(subject, arms, ctx);
}
let subj = emit_expr(subject, ctx);
let mut lines = Vec::new();
lines.push(format!("match {}", subj));
for arm in arms {
let pat = emit_pattern(&arm.pattern);
let body = emit_expr(&arm.body, ctx);
lines.push(format!(" case {} => {}", pat, body));
}
format!("({})", lines.join(" "))
}
fn should_emit_as_if_chain(arms: &[MatchArm]) -> bool {
arms.iter().all(|arm| {
matches!(
arm.pattern,
Pattern::Literal(_) | Pattern::Wildcard | Pattern::Ident(_)
)
})
}
fn is_bool_match(arms: &[MatchArm]) -> bool {
if arms.len() != 2 {
return false;
}
let has_true = arms
.iter()
.any(|a| matches!(&a.pattern, Pattern::Literal(Literal::Bool(true))));
let has_false = arms
.iter()
.any(|a| matches!(&a.pattern, Pattern::Literal(Literal::Bool(false))));
has_true && has_false
}
fn emit_bool_match(subject: &Spanned<Expr>, arms: &[MatchArm], ctx: &CodegenContext) -> String {
let subj = emit_expr(subject, ctx);
let true_arm = arms
.iter()
.find(|a| matches!(&a.pattern, Pattern::Literal(Literal::Bool(true))))
.unwrap();
let false_arm = arms
.iter()
.find(|a| matches!(&a.pattern, Pattern::Literal(Literal::Bool(false))))
.unwrap();
let true_body = emit_expr(&true_arm.body, ctx);
let false_body = emit_expr(&false_arm.body, ctx);
format!("(if {} then {} else {})", subj, true_body, false_body)
}
fn emit_if_chain(subject: &Spanned<Expr>, arms: &[MatchArm], ctx: &CodegenContext) -> String {
let subj = emit_expr(subject, ctx);
emit_if_chain_inner(&subj, arms, 0, ctx)
}
fn emit_if_chain_inner(subj: &str, arms: &[MatchArm], idx: usize, ctx: &CodegenContext) -> String {
if idx >= arms.len() {
return "/* unreachable */".to_string();
}
let arm = &arms[idx];
let body = emit_expr(&arm.body, ctx);
match &arm.pattern {
Pattern::Wildcard | Pattern::Ident(_) => {
if let Pattern::Ident(name) = &arm.pattern {
format!("(var {} := {}; {})", aver_name_to_dafny(name), subj, body)
} else {
body
}
}
Pattern::Literal(lit) => {
let rest = emit_if_chain_inner(subj, arms, idx + 1, ctx);
let lit_str = emit_literal(lit);
format!("(if {} == {} then {} else {})", subj, lit_str, body, rest)
}
_ => {
let pat = emit_pattern(&arm.pattern);
format!("/* unsupported pattern: {} */ {}", pat, body)
}
}
}
fn has_list_patterns(arms: &[MatchArm]) -> bool {
arms.iter()
.any(|arm| matches!(arm.pattern, Pattern::EmptyList | Pattern::Cons(_, _)))
}
fn emit_list_match(subject: &Spanned<Expr>, arms: &[MatchArm], ctx: &CodegenContext) -> String {
let subj = emit_expr(subject, ctx);
let empty_arm = arms
.iter()
.find(|a| matches!(a.pattern, Pattern::EmptyList));
let cons_arm = arms
.iter()
.find(|a| matches!(a.pattern, Pattern::Cons(_, _)));
let wildcard_arm = arms
.iter()
.find(|a| matches!(a.pattern, Pattern::Wildcard | Pattern::Ident(_)));
let empty_body = if let Some(arm) = empty_arm {
emit_expr(&arm.body, ctx)
} else if let Some(arm) = wildcard_arm {
emit_expr(&arm.body, ctx)
} else {
"/* missing empty case */".to_string()
};
let cons_body = if let Some(arm) = cons_arm {
if let Pattern::Cons(head, tail) = &arm.pattern {
let head_name = aver_name_to_dafny(head);
let tail_name = aver_name_to_dafny(tail);
let body = emit_expr(&arm.body, ctx);
format!(
"var {} := {}[0]; var {} := {}[1..]; {}",
head_name, subj, tail_name, subj, body
)
} else {
unreachable!()
}
} else if let Some(arm) = wildcard_arm {
emit_expr(&arm.body, ctx)
} else {
"/* missing cons case */".to_string()
};
format!(
"(if |{}| == 0 then {} else {})",
subj, empty_body, cons_body
)
}
fn emit_pattern(pattern: &Pattern) -> String {
match pattern {
Pattern::Wildcard => "_".to_string(),
Pattern::Literal(lit) => emit_literal(lit),
Pattern::Ident(name) => aver_name_to_dafny(name),
Pattern::EmptyList => "Nil".to_string(),
Pattern::Cons(head, tail) => {
format!(
"Cons({}, {})",
aver_name_to_dafny(head),
aver_name_to_dafny(tail)
)
}
Pattern::Tuple(pats) => {
let subs: Vec<String> = pats.iter().map(emit_pattern).collect();
format!("({})", subs.join(", "))
}
Pattern::Constructor(name, bindings) => {
let variant = if let Some(dot_pos) = name.rfind('.') {
&name[dot_pos + 1..]
} else {
name.as_str()
};
if bindings.is_empty() {
variant.to_string()
} else {
let subs: Vec<String> = bindings.iter().map(|b| aver_name_to_dafny(b)).collect();
format!("{}({})", variant, subs.join(", "))
}
}
}
}
fn emit_constructor(name: &str, arg: Option<&Spanned<Expr>>, ctx: &CodegenContext) -> String {
let qualified = if let Some(dot_pos) = name.rfind('.') {
let type_name = &name[..dot_pos];
let variant = &name[dot_pos + 1..];
if is_user_type(type_name, ctx) {
format!("{}.{}", type_name, variant)
} else {
variant.to_string()
}
} else {
name.to_string()
};
if let Some(a) = arg {
let arg_str = emit_expr(a, ctx);
format!("{}({})", qualified, arg_str)
} else {
qualified
}
}
fn emit_interpolated_str(parts: &[StrPart], ctx: &CodegenContext) -> String {
let mut pieces = Vec::new();
for part in parts {
match part {
StrPart::Literal(s) => {
pieces.push(format!(
"\"{}\"",
crate::codegen::common::escape_string_literal_unicode(s)
));
}
StrPart::Parsed(expr) => {
pieces.push(format!("ToString({})", emit_expr(expr, ctx)));
}
}
}
if pieces.len() == 1 {
pieces.into_iter().next().unwrap()
} else {
pieces.join(" + ")
}
}