Skip to main content

aver/codegen/lean/
mod.rs

1/// Lean 4 backend for the Aver transpiler.
2///
3/// Transpiles only pure core logic: functions without effects, type definitions,
4/// verify blocks (as `example ... := by sorry`), and decision blocks (as comments).
5/// Effectful functions and `main` are skipped.
6mod builtins;
7mod expr;
8mod pattern;
9mod toplevel;
10mod types;
11
12use std::collections::HashSet;
13
14use crate::ast::{FnDef, TopLevel};
15use crate::call_graph;
16use crate::codegen::{CodegenContext, ProjectOutput};
17
18/// Transpile an Aver program to a Lean 4 project.
19pub fn transpile(ctx: &CodegenContext) -> ProjectOutput {
20    let mut sections = Vec::new();
21
22    // Prelude
23    sections.push(generate_prelude());
24    sections.push(String::new());
25
26    // Detect recursive functions for `partial` annotation
27    let recursive_fns = call_graph::find_recursive_fns(&ctx.items);
28
29    // Module type definitions (from depends)
30    for module in &ctx.modules {
31        for td in &module.type_defs {
32            sections.push(toplevel::emit_type_def(td));
33            // #18: Recursive types need unsafe DecidableEq instance
34            if toplevel::is_recursive_type_def(td) {
35                sections.push(toplevel::emit_recursive_decidable_eq(
36                    toplevel::type_def_name(td),
37                ));
38            }
39            sections.push(String::new());
40        }
41    }
42
43    // Type definitions
44    for td in &ctx.type_defs {
45        sections.push(toplevel::emit_type_def(td));
46        // #18: Recursive types need unsafe DecidableEq instance
47        if toplevel::is_recursive_type_def(td) {
48            sections.push(toplevel::emit_recursive_decidable_eq(
49                toplevel::type_def_name(td),
50            ));
51        }
52        sections.push(String::new());
53    }
54
55    // Emit pure functions in SCC-topological order:
56    // callees first, then callers; SCCs as `mutual` blocks.
57    emit_pure_functions(ctx, &recursive_fns, &mut sections);
58
59    // Decision blocks (as comments)
60    for item in &ctx.items {
61        if let TopLevel::Decision(db) = item {
62            sections.push(toplevel::emit_decision(db));
63            sections.push(String::new());
64        }
65    }
66
67    // Verify blocks → example ... := by sorry
68    for item in &ctx.items {
69        if let TopLevel::Verify(vb) = item {
70            sections.push(toplevel::emit_verify_block(vb, ctx));
71            sections.push(String::new());
72        }
73    }
74
75    let lean_source = sections.join("\n");
76
77    // Project files
78    let project_name = capitalize_first(&ctx.project_name);
79    let lakefile = generate_lakefile(&project_name);
80    let toolchain = generate_toolchain();
81
82    ProjectOutput {
83        files: vec![
84            ("lakefile.lean".to_string(), lakefile),
85            ("lean-toolchain".to_string(), toolchain),
86            (format!("{}.lean", project_name), lean_source),
87        ],
88    }
89}
90
91fn emit_pure_functions(
92    ctx: &CodegenContext,
93    recursive_fns: &HashSet<String>,
94    sections: &mut Vec<String>,
95) {
96    let all_fns: Vec<&FnDef> = ctx
97        .modules
98        .iter()
99        .flat_map(|m| m.fn_defs.iter())
100        .chain(ctx.fn_defs.iter())
101        .filter(|fd| toplevel::is_pure_fn(fd))
102        .collect();
103    if all_fns.is_empty() {
104        return;
105    }
106
107    let components = call_graph::ordered_fn_components(&all_fns);
108    for fns in components {
109        if fns.is_empty() {
110            continue;
111        }
112
113        // Multi-node SCC => emit `mutual ... end`.
114        if fns.len() > 1 {
115            sections.push(toplevel::emit_mutual_group(&fns, ctx));
116            sections.push(String::new());
117            continue;
118        }
119
120        // Singleton SCC => regular `def` (recursive singletons still get `partial`
121        // via `recursive_fns` in emit_fn_def).
122        if let Some(code) = toplevel::emit_fn_def(fns[0], recursive_fns, ctx) {
123            sections.push(code);
124            sections.push(String::new());
125        }
126    }
127}
128
129fn generate_prelude() -> String {
130    let mut lines = Vec::new();
131    lines.push("-- Generated by the Aver → Lean 4 transpiler".to_string());
132    lines.push("-- Pure core logic only (effectful functions are omitted)".to_string());
133    lines.push(String::new());
134    lines.push("-- Prelude: helper definitions for Aver builtins".to_string());
135    lines.push(String::new());
136
137    // #2: Int → Float coercion
138    lines.push("instance : Coe Int Float := ⟨fun n => Float.ofInt n⟩".to_string());
139    lines.push(String::new());
140
141    // #6: Float DecidableEq (Float is opaque in Lean kernel)
142    // unsafe + @[implemented_by] pattern: computable, native_decide works, deriving works
143    lines.push(
144        "private unsafe def Float.unsafeDecEq (a b : Float) : Decidable (a = b) :=".to_string(),
145    );
146    lines.push("  if a == b then isTrue (unsafeCast ()) else isFalse (unsafeCast ())".to_string());
147    lines.push("@[implemented_by Float.unsafeDecEq]".to_string());
148    lines.push("private opaque Float.compDecEq (a b : Float) : Decidable (a = b)".to_string());
149    lines.push("instance : DecidableEq Float := Float.compDecEq".to_string());
150    lines.push(String::new());
151
152    // #7: Except DecidableEq
153    lines.push("instance [DecidableEq ε] [DecidableEq α] : DecidableEq (Except ε α)".to_string());
154    lines.push("  | .ok a, .ok b =>".to_string());
155    lines.push(
156        "    if h : a = b then isTrue (h ▸ rfl) else isFalse (by intro h'; cases h'; exact h rfl)"
157            .to_string(),
158    );
159    lines.push("  | .error a, .error b =>".to_string());
160    lines.push(
161        "    if h : a = b then isTrue (h ▸ rfl) else isFalse (by intro h'; cases h'; exact h rfl)"
162            .to_string(),
163    );
164    lines.push("  | .ok _, .error _ => isFalse (by intro h; cases h)".to_string());
165    lines.push("  | .error _, .ok _ => isFalse (by intro h; cases h)".to_string());
166    lines.push(String::new());
167
168    // Except.withDefault
169    lines.push("namespace Except".to_string());
170    lines.push("def withDefault (r : Except ε α) (d : α) : α :=".to_string());
171    lines.push("  match r with".to_string());
172    lines.push("  | .ok v => v".to_string());
173    lines.push("  | .error _ => d".to_string());
174    lines.push("end Except".to_string());
175    lines.push(String::new());
176
177    // Option.toExcept
178    lines.push("def Option.toExcept (o : Option α) (e : ε) : Except ε α :=".to_string());
179    lines.push("  match o with".to_string());
180    lines.push("  | some v => .ok v".to_string());
181    lines.push("  | none => .error e".to_string());
182    lines.push(String::new());
183
184    // #13: String + → String.append (Lean uses ++ not +)
185    lines.push("instance : HAdd String String String := ⟨String.append⟩".to_string());
186    lines.push(String::new());
187
188    // Map helpers (Map<K,V> = List (K × V))
189    lines.push("namespace AverMap".to_string());
190    lines.push("def empty : List (α × β) := []".to_string());
191    lines.push("def get [BEq α] (m : List (α × β)) (k : α) : Option β :=".to_string());
192    lines.push("  match m with".to_string());
193    lines.push("  | [] => none".to_string());
194    lines.push("  | (k', v) :: rest => if k == k' then some v else AverMap.get rest k".to_string());
195    lines.push("def set [BEq α] (m : List (α × β)) (k : α) (v : β) : List (α × β) :=".to_string());
196    lines.push("  let rec go : List (α × β) → List (α × β)".to_string());
197    lines.push("    | [] => [(k, v)]".to_string());
198    lines.push(
199        "    | (k', v') :: rest => if k == k' then (k, v) :: rest else (k', v') :: go rest"
200            .to_string(),
201    );
202    lines.push("  go m".to_string());
203    lines.push("def has [BEq α] (m : List (α × β)) (k : α) : Bool :=".to_string());
204    lines.push("  m.any (fun p => p.1 == k)".to_string());
205    lines.push("def remove [BEq α] (m : List (α × β)) (k : α) : List (α × β) :=".to_string());
206    lines.push("  m.filter (fun p => !(p.1 == k))".to_string());
207    lines.push("def keys (m : List (α × β)) : List α := m.map Prod.fst".to_string());
208    lines.push("def values (m : List (α × β)) : List β := m.map Prod.snd".to_string());
209    lines.push("def entries (m : List (α × β)) : List (α × β) := m".to_string());
210    lines.push("def len (m : List (α × β)) : Nat := m.length".to_string());
211    lines.push("def fromList (entries : List (α × β)) : List (α × β) := entries".to_string());
212    lines.push("end AverMap".to_string());
213    lines.push(String::new());
214
215    lines.push("namespace AverList".to_string());
216    lines.push("private def insertSorted [Ord α] (x : α) : List α → List α".to_string());
217    lines.push("  | [] => [x]".to_string());
218    lines.push("  | y :: ys =>".to_string());
219    lines.push("    if compare x y == Ordering.lt || compare x y == Ordering.eq then".to_string());
220    lines.push("      x :: y :: ys".to_string());
221    lines.push("    else".to_string());
222    lines.push("      y :: insertSorted x ys".to_string());
223    lines.push("def sort [Ord α] (xs : List α) : List α :=".to_string());
224    lines.push("  xs.foldl (fun acc x => insertSorted x acc) []".to_string());
225    lines.push("end AverList".to_string());
226    lines.push(String::new());
227
228    // String helpers (Aver has no Char type — single chars are strings)
229    // charAt/slice are code-point based to match interpreter semantics.
230    lines.push("def String.charAt (s : String) (i : Int) : Option String :=".to_string());
231    lines.push("  if i < 0 then none".to_string());
232    lines.push("  else (s.toList.get? i.toNat).map Char.toString".to_string());
233    lines.push("def String.slice (s : String) (start stop : Int) : String :=".to_string());
234    lines.push("  let startN := if start < 0 then 0 else start.toNat".to_string());
235    lines.push("  let stopN := if stop < 0 then 0 else stop.toNat".to_string());
236    lines.push("  let chars := s.toList".to_string());
237    lines.push("  String.mk ((chars.drop startN).take (stopN - startN))".to_string());
238    lines.push("def String.fromInt (n : Int) : String := toString n".to_string());
239    lines.push("def String.fromFloat (f : Float) : String := toString f".to_string());
240    lines.push(
241        "def String.chars (s : String) : List String := s.toList.map Char.toString".to_string(),
242    );
243    lines.push("namespace AverString".to_string());
244    lines.push("def split (s delim : String) : List String :=".to_string());
245    lines.push("  if delim.isEmpty then".to_string());
246    lines.push("    \"\" :: (s.toList.map Char.toString) ++ [\"\"]".to_string());
247    lines.push("  else".to_string());
248    lines.push("    s.splitOn delim".to_string());
249    lines.push("end AverString".to_string());
250    lines.push(String::new());
251
252    // Int/Float parsing (Aver returns Result<T, String>)
253    lines.push("def Int.fromString (s : String) : Except String Int :=".to_string());
254    lines.push("  match s.toInt? with".to_string());
255    lines.push("  | some n => .ok n".to_string());
256    lines.push("  | none => .error (\"Invalid integer: \" ++ s)".to_string());
257    lines.push(String::new());
258
259    // #19: Float.fromString — real parser using Float.ofScientific
260    lines.push("private def charDigitsToNat (cs : List Char) : Nat :=".to_string());
261    lines.push("  cs.foldl (fun acc c => acc * 10 + (c.toNat - '0'.toNat)) 0".to_string());
262    lines.push(String::new());
263    lines.push("private def parseExpPart : List Char → (Bool × List Char)".to_string());
264    lines.push("  | '-' :: rest => (true, rest.takeWhile Char.isDigit)".to_string());
265    lines.push("  | '+' :: rest => (false, rest.takeWhile Char.isDigit)".to_string());
266    lines.push("  | rest => (false, rest.takeWhile Char.isDigit)".to_string());
267    lines.push(String::new());
268    lines.push("def Float.fromString (s : String) : Except String Float :=".to_string());
269    lines.push("  let chars := s.toList".to_string());
270    lines.push("  let (neg, chars) := match chars with".to_string());
271    lines.push("    | '-' :: rest => (true, rest)".to_string());
272    lines.push("    | _ => (false, chars)".to_string());
273    lines.push("  let intPart := chars.takeWhile Char.isDigit".to_string());
274    lines.push("  let rest := chars.dropWhile Char.isDigit".to_string());
275    lines.push("  let (fracPart, rest) := match rest with".to_string());
276    lines.push(
277        "    | '.' :: rest => (rest.takeWhile Char.isDigit, rest.dropWhile Char.isDigit)"
278            .to_string(),
279    );
280    lines.push("    | _ => ([], rest)".to_string());
281    lines.push("  let (expNeg, expDigits) := match rest with".to_string());
282    lines.push("    | 'e' :: rest => parseExpPart rest".to_string());
283    lines.push("    | 'E' :: rest => parseExpPart rest".to_string());
284    lines.push("    | _ => (false, [])".to_string());
285    lines.push(
286        "  if intPart.isEmpty && fracPart.isEmpty then .error (\"Invalid float: \" ++ s)"
287            .to_string(),
288    );
289    lines.push("  else".to_string());
290    lines.push("    let mantissa := charDigitsToNat (intPart ++ fracPart)".to_string());
291    lines.push("    let fracLen : Int := fracPart.length".to_string());
292    lines.push("    let expVal : Int := charDigitsToNat expDigits".to_string());
293    lines.push("    let shift : Int := (if expNeg then -expVal else expVal) - fracLen".to_string());
294    lines.push(
295        "    let f := if shift >= 0 then Float.ofScientific mantissa false shift.toNat".to_string(),
296    );
297    lines.push("             else Float.ofScientific mantissa true ((-shift).toNat)".to_string());
298    lines.push("    .ok (if neg then -f else f)".to_string());
299    lines.push(String::new());
300
301    // Char helpers (Aver Char namespace operates on strings)
302    lines.push("def Char.toCode (s : String) : Int :=".to_string());
303    lines.push("  match s.toList.head? with".to_string());
304    lines.push("  | some c => (c.toNat : Int)".to_string());
305    lines.push("  | none => panic! \"Char.toCode: string is empty\"".to_string());
306    lines.push("def Char.fromCode (n : Int) : Option String :=".to_string());
307    lines.push("  if n < 0 then none".to_string());
308    lines.push("  else some (Char.toString (Char.ofNat n.toNat))".to_string());
309    lines.push(String::new());
310
311    // #20: Hex conversion helpers (for Byte.toHex / Byte.fromHex)
312    lines.push("def hexDigit (n : Int) : String :=".to_string());
313    lines.push("  match n with".to_string());
314    lines.push("  | 0 => \"0\" | 1 => \"1\" | 2 => \"2\" | 3 => \"3\"".to_string());
315    lines.push("  | 4 => \"4\" | 5 => \"5\" | 6 => \"6\" | 7 => \"7\"".to_string());
316    lines.push("  | 8 => \"8\" | 9 => \"9\" | 10 => \"a\" | 11 => \"b\"".to_string());
317    lines.push("  | 12 => \"c\" | 13 => \"d\" | 14 => \"e\" | 15 => \"f\"".to_string());
318    lines.push("  | _ => \"?\"".to_string());
319    lines.push(String::new());
320    lines.push("def byteToHex (code : Int) : String :=".to_string());
321    lines.push("  hexDigit (code / 16) ++ hexDigit (code % 16)".to_string());
322    lines.push(String::new());
323    lines.push("namespace AverByte".to_string());
324    lines.push("private def hexValue (c : Char) : Option Int :=".to_string());
325    lines.push("  match c with".to_string());
326    lines.push("  | '0' => some 0  | '1' => some 1  | '2' => some 2  | '3' => some 3".to_string());
327    lines.push("  | '4' => some 4  | '5' => some 5  | '6' => some 6  | '7' => some 7".to_string());
328    lines.push("  | '8' => some 8  | '9' => some 9  | 'a' => some 10 | 'b' => some 11".to_string());
329    lines.push("  | 'c' => some 12 | 'd' => some 13 | 'e' => some 14 | 'f' => some 15".to_string());
330    lines.push("  | 'A' => some 10 | 'B' => some 11 | 'C' => some 12 | 'D' => some 13".to_string());
331    lines.push("  | 'E' => some 14 | 'F' => some 15".to_string());
332    lines.push("  | _ => none".to_string());
333    lines.push("def toHex (n : Int) : Except String String :=".to_string());
334    lines.push("  if n < 0 || n > 255 then".to_string());
335    lines.push(
336        "    .error (\"Byte.toHex: \" ++ toString n ++ \" is out of range 0-255\")".to_string(),
337    );
338    lines.push("  else".to_string());
339    lines.push("    .ok (byteToHex n)".to_string());
340    lines.push("def fromHex (s : String) : Except String Int :=".to_string());
341    lines.push("  match s.toList with".to_string());
342    lines.push("  | [hi, lo] =>".to_string());
343    lines.push("    match hexValue hi, hexValue lo with".to_string());
344    lines.push("    | some h, some l => .ok (h * 16 + l)".to_string());
345    lines.push("    | _, _ => .error (\"Byte.fromHex: invalid hex '\" ++ s ++ \"'\")".to_string());
346    lines.push(
347        "  | _ => .error (\"Byte.fromHex: expected exactly 2 hex chars, got '\" ++ s ++ \"'\")"
348            .to_string(),
349    );
350    lines.push("end AverByte".to_string());
351
352    lines.join("\n")
353}
354
355fn generate_lakefile(project_name: &str) -> String {
356    format!(
357        r#"import Lake
358open Lake DSL
359
360package «{}» where
361  version := v!"0.1.0"
362
363@[default_target]
364lean_lib «{}» where
365  srcDir := "."
366"#,
367        project_name.to_lowercase(),
368        project_name
369    )
370}
371
372fn generate_toolchain() -> String {
373    "leanprover/lean4:v4.15.0\n".to_string()
374}
375
376fn capitalize_first(s: &str) -> String {
377    let mut chars = s.chars();
378    match chars.next() {
379        None => String::new(),
380        Some(c) => c.to_uppercase().to_string() + chars.as_str(),
381    }
382}