1mod builtins;
7mod expr;
8mod pattern;
9mod toplevel;
10mod types;
11
12use std::collections::HashSet;
13
14use crate::ast::{FnDef, TopLevel};
15use crate::call_graph;
16use crate::codegen::{CodegenContext, ProjectOutput};
17
18pub fn transpile(ctx: &CodegenContext) -> ProjectOutput {
20 let mut sections = Vec::new();
21
22 sections.push(generate_prelude());
24 sections.push(String::new());
25
26 let recursive_fns = call_graph::find_recursive_fns(&ctx.items);
28
29 for module in &ctx.modules {
31 for td in &module.type_defs {
32 sections.push(toplevel::emit_type_def(td));
33 if toplevel::is_recursive_type_def(td) {
35 sections.push(toplevel::emit_recursive_decidable_eq(
36 toplevel::type_def_name(td),
37 ));
38 }
39 sections.push(String::new());
40 }
41 }
42
43 for td in &ctx.type_defs {
45 sections.push(toplevel::emit_type_def(td));
46 if toplevel::is_recursive_type_def(td) {
48 sections.push(toplevel::emit_recursive_decidable_eq(
49 toplevel::type_def_name(td),
50 ));
51 }
52 sections.push(String::new());
53 }
54
55 emit_pure_functions(ctx, &recursive_fns, &mut sections);
58
59 for item in &ctx.items {
61 if let TopLevel::Decision(db) = item {
62 sections.push(toplevel::emit_decision(db));
63 sections.push(String::new());
64 }
65 }
66
67 for item in &ctx.items {
69 if let TopLevel::Verify(vb) = item {
70 sections.push(toplevel::emit_verify_block(vb, ctx));
71 sections.push(String::new());
72 }
73 }
74
75 let lean_source = sections.join("\n");
76
77 let project_name = capitalize_first(&ctx.project_name);
79 let lakefile = generate_lakefile(&project_name);
80 let toolchain = generate_toolchain();
81
82 ProjectOutput {
83 files: vec![
84 ("lakefile.lean".to_string(), lakefile),
85 ("lean-toolchain".to_string(), toolchain),
86 (format!("{}.lean", project_name), lean_source),
87 ],
88 }
89}
90
91fn emit_pure_functions(
92 ctx: &CodegenContext,
93 recursive_fns: &HashSet<String>,
94 sections: &mut Vec<String>,
95) {
96 let all_fns: Vec<&FnDef> = ctx
97 .modules
98 .iter()
99 .flat_map(|m| m.fn_defs.iter())
100 .chain(ctx.fn_defs.iter())
101 .filter(|fd| toplevel::is_pure_fn(fd))
102 .collect();
103 if all_fns.is_empty() {
104 return;
105 }
106
107 let components = call_graph::ordered_fn_components(&all_fns);
108 for fns in components {
109 if fns.is_empty() {
110 continue;
111 }
112
113 if fns.len() > 1 {
115 sections.push(toplevel::emit_mutual_group(&fns, ctx));
116 sections.push(String::new());
117 continue;
118 }
119
120 if let Some(code) = toplevel::emit_fn_def(fns[0], recursive_fns, ctx) {
123 sections.push(code);
124 sections.push(String::new());
125 }
126 }
127}
128
129fn generate_prelude() -> String {
130 let mut lines = Vec::new();
131 lines.push("-- Generated by the Aver → Lean 4 transpiler".to_string());
132 lines.push("-- Pure core logic only (effectful functions are omitted)".to_string());
133 lines.push(String::new());
134 lines.push("-- Prelude: helper definitions for Aver builtins".to_string());
135 lines.push(String::new());
136
137 lines.push("instance : Coe Int Float := ⟨fun n => Float.ofInt n⟩".to_string());
139 lines.push(String::new());
140
141 lines.push(
144 "private unsafe def Float.unsafeDecEq (a b : Float) : Decidable (a = b) :=".to_string(),
145 );
146 lines.push(" if a == b then isTrue (unsafeCast ()) else isFalse (unsafeCast ())".to_string());
147 lines.push("@[implemented_by Float.unsafeDecEq]".to_string());
148 lines.push("private opaque Float.compDecEq (a b : Float) : Decidable (a = b)".to_string());
149 lines.push("instance : DecidableEq Float := Float.compDecEq".to_string());
150 lines.push(String::new());
151
152 lines.push("instance [DecidableEq ε] [DecidableEq α] : DecidableEq (Except ε α)".to_string());
154 lines.push(" | .ok a, .ok b =>".to_string());
155 lines.push(
156 " if h : a = b then isTrue (h ▸ rfl) else isFalse (by intro h'; cases h'; exact h rfl)"
157 .to_string(),
158 );
159 lines.push(" | .error a, .error b =>".to_string());
160 lines.push(
161 " if h : a = b then isTrue (h ▸ rfl) else isFalse (by intro h'; cases h'; exact h rfl)"
162 .to_string(),
163 );
164 lines.push(" | .ok _, .error _ => isFalse (by intro h; cases h)".to_string());
165 lines.push(" | .error _, .ok _ => isFalse (by intro h; cases h)".to_string());
166 lines.push(String::new());
167
168 lines.push("namespace Except".to_string());
170 lines.push("def withDefault (r : Except ε α) (d : α) : α :=".to_string());
171 lines.push(" match r with".to_string());
172 lines.push(" | .ok v => v".to_string());
173 lines.push(" | .error _ => d".to_string());
174 lines.push("end Except".to_string());
175 lines.push(String::new());
176
177 lines.push("def Option.toExcept (o : Option α) (e : ε) : Except ε α :=".to_string());
179 lines.push(" match o with".to_string());
180 lines.push(" | some v => .ok v".to_string());
181 lines.push(" | none => .error e".to_string());
182 lines.push(String::new());
183
184 lines.push("instance : HAdd String String String := ⟨String.append⟩".to_string());
186 lines.push(String::new());
187
188 lines.push("namespace AverMap".to_string());
190 lines.push("def empty : List (α × β) := []".to_string());
191 lines.push("def get [BEq α] (m : List (α × β)) (k : α) : Option β :=".to_string());
192 lines.push(" match m with".to_string());
193 lines.push(" | [] => none".to_string());
194 lines.push(" | (k', v) :: rest => if k == k' then some v else AverMap.get rest k".to_string());
195 lines.push("def set [BEq α] (m : List (α × β)) (k : α) (v : β) : List (α × β) :=".to_string());
196 lines.push(" let rec go : List (α × β) → List (α × β)".to_string());
197 lines.push(" | [] => [(k, v)]".to_string());
198 lines.push(
199 " | (k', v') :: rest => if k == k' then (k, v) :: rest else (k', v') :: go rest"
200 .to_string(),
201 );
202 lines.push(" go m".to_string());
203 lines.push("def has [BEq α] (m : List (α × β)) (k : α) : Bool :=".to_string());
204 lines.push(" m.any (fun p => p.1 == k)".to_string());
205 lines.push("def remove [BEq α] (m : List (α × β)) (k : α) : List (α × β) :=".to_string());
206 lines.push(" m.filter (fun p => !(p.1 == k))".to_string());
207 lines.push("def keys (m : List (α × β)) : List α := m.map Prod.fst".to_string());
208 lines.push("def values (m : List (α × β)) : List β := m.map Prod.snd".to_string());
209 lines.push("def entries (m : List (α × β)) : List (α × β) := m".to_string());
210 lines.push("def len (m : List (α × β)) : Nat := m.length".to_string());
211 lines.push("def fromList (entries : List (α × β)) : List (α × β) := entries".to_string());
212 lines.push("end AverMap".to_string());
213 lines.push(String::new());
214
215 lines.push("namespace AverList".to_string());
216 lines.push("private def insertSorted [Ord α] (x : α) : List α → List α".to_string());
217 lines.push(" | [] => [x]".to_string());
218 lines.push(" | y :: ys =>".to_string());
219 lines.push(" if compare x y == Ordering.lt || compare x y == Ordering.eq then".to_string());
220 lines.push(" x :: y :: ys".to_string());
221 lines.push(" else".to_string());
222 lines.push(" y :: insertSorted x ys".to_string());
223 lines.push("def sort [Ord α] (xs : List α) : List α :=".to_string());
224 lines.push(" xs.foldl (fun acc x => insertSorted x acc) []".to_string());
225 lines.push("end AverList".to_string());
226 lines.push(String::new());
227
228 lines.push("def String.charAt (s : String) (i : Int) : Option String :=".to_string());
231 lines.push(" if i < 0 then none".to_string());
232 lines.push(" else (s.toList.get? i.toNat).map Char.toString".to_string());
233 lines.push("def String.slice (s : String) (start stop : Int) : String :=".to_string());
234 lines.push(" let startN := if start < 0 then 0 else start.toNat".to_string());
235 lines.push(" let stopN := if stop < 0 then 0 else stop.toNat".to_string());
236 lines.push(" let chars := s.toList".to_string());
237 lines.push(" String.mk ((chars.drop startN).take (stopN - startN))".to_string());
238 lines.push("def String.fromInt (n : Int) : String := toString n".to_string());
239 lines.push("def String.fromFloat (f : Float) : String := toString f".to_string());
240 lines.push(
241 "def String.chars (s : String) : List String := s.toList.map Char.toString".to_string(),
242 );
243 lines.push("namespace AverString".to_string());
244 lines.push("def split (s delim : String) : List String :=".to_string());
245 lines.push(" if delim.isEmpty then".to_string());
246 lines.push(" \"\" :: (s.toList.map Char.toString) ++ [\"\"]".to_string());
247 lines.push(" else".to_string());
248 lines.push(" s.splitOn delim".to_string());
249 lines.push("end AverString".to_string());
250 lines.push(String::new());
251
252 lines.push("def Int.fromString (s : String) : Except String Int :=".to_string());
254 lines.push(" match s.toInt? with".to_string());
255 lines.push(" | some n => .ok n".to_string());
256 lines.push(" | none => .error (\"Invalid integer: \" ++ s)".to_string());
257 lines.push(String::new());
258
259 lines.push("private def charDigitsToNat (cs : List Char) : Nat :=".to_string());
261 lines.push(" cs.foldl (fun acc c => acc * 10 + (c.toNat - '0'.toNat)) 0".to_string());
262 lines.push(String::new());
263 lines.push("private def parseExpPart : List Char → (Bool × List Char)".to_string());
264 lines.push(" | '-' :: rest => (true, rest.takeWhile Char.isDigit)".to_string());
265 lines.push(" | '+' :: rest => (false, rest.takeWhile Char.isDigit)".to_string());
266 lines.push(" | rest => (false, rest.takeWhile Char.isDigit)".to_string());
267 lines.push(String::new());
268 lines.push("def Float.fromString (s : String) : Except String Float :=".to_string());
269 lines.push(" let chars := s.toList".to_string());
270 lines.push(" let (neg, chars) := match chars with".to_string());
271 lines.push(" | '-' :: rest => (true, rest)".to_string());
272 lines.push(" | _ => (false, chars)".to_string());
273 lines.push(" let intPart := chars.takeWhile Char.isDigit".to_string());
274 lines.push(" let rest := chars.dropWhile Char.isDigit".to_string());
275 lines.push(" let (fracPart, rest) := match rest with".to_string());
276 lines.push(
277 " | '.' :: rest => (rest.takeWhile Char.isDigit, rest.dropWhile Char.isDigit)"
278 .to_string(),
279 );
280 lines.push(" | _ => ([], rest)".to_string());
281 lines.push(" let (expNeg, expDigits) := match rest with".to_string());
282 lines.push(" | 'e' :: rest => parseExpPart rest".to_string());
283 lines.push(" | 'E' :: rest => parseExpPart rest".to_string());
284 lines.push(" | _ => (false, [])".to_string());
285 lines.push(
286 " if intPart.isEmpty && fracPart.isEmpty then .error (\"Invalid float: \" ++ s)"
287 .to_string(),
288 );
289 lines.push(" else".to_string());
290 lines.push(" let mantissa := charDigitsToNat (intPart ++ fracPart)".to_string());
291 lines.push(" let fracLen : Int := fracPart.length".to_string());
292 lines.push(" let expVal : Int := charDigitsToNat expDigits".to_string());
293 lines.push(" let shift : Int := (if expNeg then -expVal else expVal) - fracLen".to_string());
294 lines.push(
295 " let f := if shift >= 0 then Float.ofScientific mantissa false shift.toNat".to_string(),
296 );
297 lines.push(" else Float.ofScientific mantissa true ((-shift).toNat)".to_string());
298 lines.push(" .ok (if neg then -f else f)".to_string());
299 lines.push(String::new());
300
301 lines.push("def Char.toCode (s : String) : Int :=".to_string());
303 lines.push(" match s.toList.head? with".to_string());
304 lines.push(" | some c => (c.toNat : Int)".to_string());
305 lines.push(" | none => panic! \"Char.toCode: string is empty\"".to_string());
306 lines.push("def Char.fromCode (n : Int) : Option String :=".to_string());
307 lines.push(" if n < 0 then none".to_string());
308 lines.push(" else some (Char.toString (Char.ofNat n.toNat))".to_string());
309 lines.push(String::new());
310
311 lines.push("def hexDigit (n : Int) : String :=".to_string());
313 lines.push(" match n with".to_string());
314 lines.push(" | 0 => \"0\" | 1 => \"1\" | 2 => \"2\" | 3 => \"3\"".to_string());
315 lines.push(" | 4 => \"4\" | 5 => \"5\" | 6 => \"6\" | 7 => \"7\"".to_string());
316 lines.push(" | 8 => \"8\" | 9 => \"9\" | 10 => \"a\" | 11 => \"b\"".to_string());
317 lines.push(" | 12 => \"c\" | 13 => \"d\" | 14 => \"e\" | 15 => \"f\"".to_string());
318 lines.push(" | _ => \"?\"".to_string());
319 lines.push(String::new());
320 lines.push("def byteToHex (code : Int) : String :=".to_string());
321 lines.push(" hexDigit (code / 16) ++ hexDigit (code % 16)".to_string());
322 lines.push(String::new());
323 lines.push("namespace AverByte".to_string());
324 lines.push("private def hexValue (c : Char) : Option Int :=".to_string());
325 lines.push(" match c with".to_string());
326 lines.push(" | '0' => some 0 | '1' => some 1 | '2' => some 2 | '3' => some 3".to_string());
327 lines.push(" | '4' => some 4 | '5' => some 5 | '6' => some 6 | '7' => some 7".to_string());
328 lines.push(" | '8' => some 8 | '9' => some 9 | 'a' => some 10 | 'b' => some 11".to_string());
329 lines.push(" | 'c' => some 12 | 'd' => some 13 | 'e' => some 14 | 'f' => some 15".to_string());
330 lines.push(" | 'A' => some 10 | 'B' => some 11 | 'C' => some 12 | 'D' => some 13".to_string());
331 lines.push(" | 'E' => some 14 | 'F' => some 15".to_string());
332 lines.push(" | _ => none".to_string());
333 lines.push("def toHex (n : Int) : Except String String :=".to_string());
334 lines.push(" if n < 0 || n > 255 then".to_string());
335 lines.push(
336 " .error (\"Byte.toHex: \" ++ toString n ++ \" is out of range 0-255\")".to_string(),
337 );
338 lines.push(" else".to_string());
339 lines.push(" .ok (byteToHex n)".to_string());
340 lines.push("def fromHex (s : String) : Except String Int :=".to_string());
341 lines.push(" match s.toList with".to_string());
342 lines.push(" | [hi, lo] =>".to_string());
343 lines.push(" match hexValue hi, hexValue lo with".to_string());
344 lines.push(" | some h, some l => .ok (h * 16 + l)".to_string());
345 lines.push(" | _, _ => .error (\"Byte.fromHex: invalid hex '\" ++ s ++ \"'\")".to_string());
346 lines.push(
347 " | _ => .error (\"Byte.fromHex: expected exactly 2 hex chars, got '\" ++ s ++ \"'\")"
348 .to_string(),
349 );
350 lines.push("end AverByte".to_string());
351
352 lines.join("\n")
353}
354
355fn generate_lakefile(project_name: &str) -> String {
356 format!(
357 r#"import Lake
358open Lake DSL
359
360package «{}» where
361 version := v!"0.1.0"
362
363@[default_target]
364lean_lib «{}» where
365 srcDir := "."
366"#,
367 project_name.to_lowercase(),
368 project_name
369 )
370}
371
372fn generate_toolchain() -> String {
373 "leanprover/lean4:v4.15.0\n".to_string()
374}
375
376fn capitalize_first(s: &str) -> String {
377 let mut chars = s.chars();
378 match chars.next() {
379 None => String::new(),
380 Some(c) => c.to_uppercase().to_string() + chars.as_str(),
381 }
382}