Skip to main content

cljrs_compiler/
aot.rs

1//! AOT compilation driver.
2//!
3//! Orchestrates the full pipeline from source file to native binary:
4//!
5//! 1. Parse source → `Vec<Form>`
6//! 2. Boot a standard environment (for macro expansion)
7//! 3. Macro-expand each top-level form
8//! 4. ANF-lower all forms as a zero-arg `__cljrs_main` function
9//! 5. Cranelift codegen → `.o` object bytes
10//! 6. Generate a Cargo harness project that links the object + runtime
11//! 7. `cargo build --release` the harness → standalone binary
12
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15
16use cljrs_reader::Parser;
17
18use crate::codegen::Compiler;
19use crate::ir::IrFunction;
20
21// ── Error type ──────────────────────────────────────────────────────────────
22
23#[derive(Debug)]
24pub enum AotError {
25    Io(std::io::Error),
26    Parse(cljrs_types::error::CljxError),
27    Codegen(crate::codegen::CodegenError),
28    Eval(String),
29    Link(String),
30    /// One or more no-gc memory-safety violations were found by the blacklist
31    /// analysis.  Only emitted when the `no-gc` Cargo feature is active.
32    #[cfg(feature = "no-gc")]
33    NoGcBlacklist(Vec<crate::escape::BlacklistViolation>),
34}
35
36impl std::fmt::Display for AotError {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            AotError::Io(e) => write!(f, "I/O error: {e}"),
40            AotError::Parse(e) => write!(f, "parse error: {e}"),
41            AotError::Codegen(e) => write!(f, "codegen error: {e:?}"),
42            AotError::Eval(e) => write!(f, "eval/lowering error: {e}"),
43            AotError::Link(e) => write!(f, "link error: {e}"),
44            #[cfg(feature = "no-gc")]
45            AotError::NoGcBlacklist(vs) => {
46                writeln!(f, "no-gc blacklist violations:")?;
47                for v in vs {
48                    writeln!(f, "  • {v}")?;
49                }
50                Ok(())
51            }
52        }
53    }
54}
55
56impl std::error::Error for AotError {}
57
58impl From<std::io::Error> for AotError {
59    fn from(e: std::io::Error) -> Self {
60        AotError::Io(e)
61    }
62}
63impl From<cljrs_types::error::CljxError> for AotError {
64    fn from(e: cljrs_types::error::CljxError) -> Self {
65        AotError::Parse(e)
66    }
67}
68impl From<crate::codegen::CodegenError> for AotError {
69    fn from(e: crate::codegen::CodegenError) -> Self {
70        AotError::Codegen(e)
71    }
72}
73
74pub type AotResult<T> = Result<T, AotError>;
75
76// ── Rust-native lowering ────────────────────────────────────────────────────
77
78/// Lower forms directly via the native Rust compiler pipeline.
79///
80/// Replaces the old `lower_via_clojure` path: no interpreter round-trip,
81/// no `callback::invoke`, no `ir_convert`.
82pub fn lower_via_rust(
83    name: Option<&str>,
84    ns: &str,
85    params: &[Arc<str>],
86    compilable_forms: &[cljrs_reader::Form],
87    _env: &mut cljrs_eval::Env,
88) -> AotResult<IrFunction> {
89    let ns_arc: Arc<str> = Arc::from(ns);
90    let ir = cljrs_ir::lower::lower_fn_body(name, &ns_arc, params, compilable_forms)
91        .map_err(|e| AotError::Eval(format!("lowering: {e:?}")))?;
92    let ir = cljrs_ir::lower::optimize(ir);
93
94    #[cfg(feature = "no-gc")]
95    {
96        let violations = crate::escape::check(&ir);
97        if !violations.is_empty() {
98            return Err(AotError::NoGcBlacklist(violations));
99        }
100    }
101
102    Ok(ir)
103}
104
105// ── Direct call optimization ────────────────────────────────────────────────
106
107/// Information about a compiled function arity.
108#[derive(Debug, Clone)]
109struct ArityInfo {
110    fn_name: Arc<str>,
111    param_count: usize,
112    is_variadic: bool,
113}
114
115/// Collect top-level function definitions from an IR function.
116///
117/// Scans for `AllocClosure(...) + DefVar(_, ns, name, closure_var)` patterns
118/// and returns a map from `(ns, name)` → list of arity infos.
119fn collect_defn_arities(
120    ir_func: &IrFunction,
121) -> std::collections::HashMap<(Arc<str>, Arc<str>), Vec<ArityInfo>> {
122    use crate::ir::{ClosureTemplate, Inst, VarId};
123    use std::collections::HashMap;
124
125    let mut closure_templates: HashMap<VarId, ClosureTemplate> = HashMap::new();
126    let mut defns: HashMap<(Arc<str>, Arc<str>), Vec<ArityInfo>> = HashMap::new();
127
128    for block in &ir_func.blocks {
129        for inst in &block.insts {
130            match inst {
131                Inst::AllocClosure(dst, template, captures) if captures.is_empty() => {
132                    // Only consider zero-capture closures (top-level defns).
133                    closure_templates.insert(*dst, template.clone());
134                }
135                Inst::DefVar(_, ns, name, val) => {
136                    if let Some(template) = closure_templates.get(val) {
137                        let arities: Vec<ArityInfo> = template
138                            .arity_fn_names
139                            .iter()
140                            .zip(template.param_counts.iter())
141                            .zip(template.is_variadic.iter())
142                            .map(|((fn_name, &param_count), &is_variadic)| ArityInfo {
143                                fn_name: fn_name.clone(),
144                                param_count,
145                                is_variadic,
146                            })
147                            .collect();
148                        defns.insert((ns.clone(), name.clone()), arities);
149                    }
150                }
151                _ => {}
152            }
153        }
154    }
155
156    defns
157}
158
159/// Find the arity function name that matches a given argument count.
160///
161/// Only matches fixed arities — variadic functions cannot be called directly
162/// because the runtime needs to pack extra args into a rest list.
163fn find_matching_arity(arities: &[ArityInfo], arg_count: usize) -> Option<&ArityInfo> {
164    arities
165        .iter()
166        .find(|arity| !arity.is_variadic && arity.param_count == arg_count)
167}
168
169/// Rewrite `LoadGlobal + Call` sequences to `CallDirect` for functions
170/// defined in the same compilation unit.
171///
172/// This is a perf optimization: instead of going through `rt_call` (which
173/// looks up the var in the interpreter and dispatches dynamically), we call
174/// the compiled function pointer directly.
175fn optimize_direct_calls(ir_func: &mut IrFunction) {
176    // Collect defn arities from this function AND all subfunctions (recursively).
177    let mut all_defns = collect_defn_arities(ir_func);
178    for sub in &ir_func.subfunctions {
179        // Subfunctions don't typically DefVar, but recurse just in case.
180        all_defns.extend(collect_defn_arities(sub));
181    }
182
183    if all_defns.is_empty() {
184        return;
185    }
186
187    let rewrites = rewrite_calls_to_direct(ir_func, &all_defns);
188    if rewrites > 0 {
189        eprintln!("[aot] optimized {rewrites} call(s) to direct function calls");
190    }
191
192    // Recursively optimize subfunctions too.
193    // Subfunctions can call top-level defns, so pass the defn map down.
194    for sub in &mut ir_func.subfunctions {
195        optimize_direct_calls_with_defns(sub, &all_defns);
196    }
197}
198
199/// Like `optimize_direct_calls` but uses a pre-built defn map (for subfunctions).
200fn optimize_direct_calls_with_defns(
201    ir_func: &mut IrFunction,
202    defns: &std::collections::HashMap<(Arc<str>, Arc<str>), Vec<ArityInfo>>,
203) {
204    // Merge parent defns with any defns from this function.
205    let mut all_defns = defns.clone();
206    all_defns.extend(collect_defn_arities(ir_func));
207
208    if all_defns.is_empty() {
209        return;
210    }
211
212    let rewrites = rewrite_calls_to_direct(ir_func, &all_defns);
213    if rewrites > 0 {
214        eprintln!("[aot] optimized {rewrites} direct call(s) in subfunction");
215    }
216
217    for sub in &mut ir_func.subfunctions {
218        optimize_direct_calls_with_defns(sub, &all_defns);
219    }
220}
221
222/// Rewrite `LoadGlobal + Call` → `CallDirect` in a single IR function.
223/// Returns the number of rewrites performed.
224fn rewrite_calls_to_direct(
225    ir_func: &mut IrFunction,
226    defns: &std::collections::HashMap<(Arc<str>, Arc<str>), Vec<ArityInfo>>,
227) -> usize {
228    use crate::ir::{Inst, VarId};
229    use std::collections::HashMap;
230
231    // Build a map of VarId → (ns, name) for LoadGlobal instructions that load known defns.
232    let mut loadglobal_targets: HashMap<VarId, (Arc<str>, Arc<str>)> = HashMap::new();
233    for block in &ir_func.blocks {
234        for inst in &block.insts {
235            if let Inst::LoadGlobal(dst, ns, name) = inst
236                && defns.contains_key(&(ns.clone(), name.clone()))
237            {
238                loadglobal_targets.insert(*dst, (ns.clone(), name.clone()));
239            }
240        }
241    }
242
243    let mut rewrites = 0;
244    for block in &mut ir_func.blocks {
245        for inst in &mut block.insts {
246            if let Inst::Call(dst, callee, args) = inst
247                && let Some((ns, name)) = loadglobal_targets.get(callee)
248                && let Some(arities) = defns.get(&(ns.clone(), name.clone()))
249                && let Some(arity_info) = find_matching_arity(arities, args.len())
250            {
251                *inst = Inst::CallDirect(*dst, arity_info.fn_name.clone(), args.clone());
252                rewrites += 1;
253            }
254        }
255    }
256
257    rewrites
258}
259
260/// Tally region vs heap allocations across an IR function tree, so the
261/// AOT pipeline can report the impact of escape analysis at a glance.
262#[derive(Default)]
263struct AllocStats {
264    region: usize,
265    heap: usize,
266    closures: usize,
267    functions: usize,
268}
269
270fn count_alloc_stats(ir_func: &IrFunction) -> AllocStats {
271    use crate::ir::Inst;
272    let mut stats = AllocStats {
273        functions: 1,
274        ..Default::default()
275    };
276    for block in &ir_func.blocks {
277        for inst in &block.insts {
278            match inst {
279                Inst::AllocVector(..)
280                | Inst::AllocMap(..)
281                | Inst::AllocSet(..)
282                | Inst::AllocList(..)
283                | Inst::AllocCons(..) => stats.heap += 1,
284                Inst::AllocClosure(..) => stats.closures += 1,
285                Inst::RegionAlloc(..) => stats.region += 1,
286                _ => {}
287            }
288        }
289    }
290    for sub in &ir_func.subfunctions {
291        let s = count_alloc_stats(sub);
292        stats.region += s.region;
293        stats.heap += s.heap;
294        stats.closures += s.closures;
295        stats.functions += s.functions;
296    }
297    stats
298}
299
300// ── Public API ──────────────────────────────────────────────────────────────
301
302/// Run the AOT pipeline up to (and including) ANF lowering + region
303/// optimization, but stop before code generation.  Returns the source text
304/// and the optimized `IrFunction` so tools like `cljrs-ir-viz` can inspect
305/// exactly what the AOT compiler would lower.
306///
307/// The `silent` flag suppresses the usual `[aot] ...` progress output.
308pub fn lower_file_to_ir(
309    src_path: &Path,
310    src_dirs: &[PathBuf],
311    silent: bool,
312) -> AotResult<(String, IrFunction)> {
313    macro_rules! note {
314        ($($arg:tt)*) => { if !silent { eprintln!($($arg)*); } };
315    }
316
317    note!("[aot] reading {}", src_path.display());
318    let source = std::fs::read_to_string(src_path)?;
319    let filename = src_path.display().to_string();
320
321    let mut parser = Parser::new(source.clone(), filename);
322    let forms = parser.parse_all()?;
323    note!("[aot] parsed {} top-level form(s)", forms.len());
324
325    let globals = if src_dirs.is_empty() {
326        cljrs_stdlib::standard_env()
327    } else {
328        cljrs_stdlib::standard_env_with_paths(src_dirs.to_vec())
329    };
330    let mut env = cljrs_eval::Env::new(globals, "user");
331
332    let mut expanded = Vec::with_capacity(forms.len());
333    for form in &forms {
334        if needs_interpreter(form) {
335            match cljrs_eval::eval(form, &mut env) {
336                Ok(_) => {}
337                Err(e) => return Err(AotError::Eval(format!("{e:?}"))),
338            }
339        }
340        match cljrs_interp::macros::macroexpand_all(form, &mut env) {
341            Ok(f) => expanded.push(f),
342            Err(e) => return Err(AotError::Eval(format!("{e:?}"))),
343        }
344    }
345    note!("[aot] macro-expanded {} form(s)", expanded.len());
346
347    let mut compilable = Vec::new();
348    for (i, form) in expanded.iter().enumerate() {
349        if needs_interpreter(&forms[i]) || expanded_needs_interpreter(form) {
350            continue;
351        }
352        compilable.push(form.clone());
353    }
354
355    let params: Vec<Arc<str>> = vec![];
356    let compilable_forms = if compilable.is_empty() {
357        let nil_form = cljrs_reader::Form::new(
358            cljrs_reader::form::FormKind::Nil,
359            cljrs_types::span::Span::new(Arc::new("<aot>".to_string()), 0, 0, 1, 1),
360        );
361        vec![nil_form]
362    } else {
363        compilable
364    };
365
366    let current_ns = env.current_ns.to_string();
367    let ir_func = lower_via_rust(
368        Some("__cljrs_main"),
369        &current_ns,
370        &params,
371        &compilable_forms,
372        &mut env,
373    )?;
374    note!(
375        "[aot] lowered to {} block(s), {} var(s)",
376        ir_func.blocks.len(),
377        ir_func.next_var
378    );
379    Ok((source, ir_func))
380}
381
382/// Compile a `.cljrs` / `.cljc` source file to a standalone native binary.
383///
384/// `src_path` is the input source file.  `out_path` is the desired output
385/// binary.  `src_dirs` are additional directories for `require` resolution
386/// during macro expansion.
387pub fn compile_file(src_path: &Path, out_path: &Path, src_dirs: &[PathBuf]) -> AotResult<()> {
388    eprintln!("[aot] reading {}", src_path.display());
389    let source = std::fs::read_to_string(src_path)?;
390    let filename = src_path.display().to_string();
391
392    // ── 1. Parse ────────────────────────────────────────────────────────
393    let mut parser = Parser::new(source.clone(), filename);
394    let forms = parser.parse_all()?;
395    eprintln!("[aot] parsed {} top-level form(s)", forms.len());
396
397    // ── 2. Macro-expand ─────────────────────────────────────────────────
398    // Boot a full environment so macros resolve correctly.
399    let globals = if src_dirs.is_empty() {
400        cljrs_stdlib::standard_env()
401    } else {
402        cljrs_stdlib::standard_env_with_paths(src_dirs.to_vec())
403    };
404    let mut env = cljrs_eval::Env::new(globals, "user");
405
406    // Snapshot loaded namespaces before expansion so we can detect
407    // which user namespaces were pulled in by require.
408    let pre_loaded: std::collections::HashSet<Arc<str>> =
409        env.globals.loaded.lock().unwrap().clone();
410
411    let mut expanded = Vec::with_capacity(forms.len());
412    for form in &forms {
413        // For forms that need the interpreter (ns, require, defmacro, etc.),
414        // evaluate them immediately so that required namespaces get loaded
415        // and macros from dependencies are available for later forms.
416        if needs_interpreter(form) {
417            match cljrs_eval::eval(form, &mut env) {
418                Ok(_) => {}
419                Err(e) => return Err(AotError::Eval(format!("{e:?}"))),
420            }
421        }
422        match cljrs_interp::macros::macroexpand_all(form, &mut env) {
423            Ok(f) => expanded.push(f),
424            Err(e) => return Err(AotError::Eval(format!("{e:?}"))),
425        }
426    }
427    eprintln!("[aot] macro-expanded {} form(s)", expanded.len());
428
429    // Discover user namespaces loaded during expansion (transitive deps).
430    let bundled_sources = discover_bundled_sources(&env.globals, &pre_loaded, src_dirs);
431    if !bundled_sources.is_empty() {
432        eprintln!(
433            "[aot] bundling {} required namespace(s): {}",
434            bundled_sources.len(),
435            bundled_sources
436                .iter()
437                .map(|(ns, _)| ns.as_ref())
438                .collect::<Vec<_>>()
439                .join(", ")
440        );
441    }
442
443    // ── 2b. Partition: interpreted preamble vs compiled body ─────────
444    // Forms that define functions (defn, defmacro) or require interpreter
445    // features (closures) are evaluated at startup via the interpreter.
446    // The rest is AOT-compiled.
447    let mut interpreted_source = String::new();
448    let mut compilable = Vec::new();
449    for (i, form) in expanded.iter().enumerate() {
450        if needs_interpreter(&forms[i]) || expanded_needs_interpreter(form) {
451            // Extract the original source text using span byte offsets.
452            let span = &forms[i].span;
453            let src_text = &source[span.start..span.end];
454            interpreted_source.push_str(src_text);
455            interpreted_source.push('\n');
456        } else {
457            compilable.push(form.clone());
458        }
459    }
460    if !interpreted_source.is_empty() {
461        eprintln!(
462            "[aot] {} form(s) will be interpreted at startup",
463            expanded.len() - compilable.len()
464        );
465    }
466
467    // ── 3. ANF-lower ────────────────────────────────────────────────────
468    // Treat compilable top-level forms as the body of a zero-arg `__cljrs_main`.
469    let params: Vec<Arc<str>> = vec![];
470    let compilable_forms = if compilable.is_empty() {
471        // If everything is interpreted, emit a simple nil-returning main.
472        let nil_form = cljrs_reader::Form::new(
473            cljrs_reader::form::FormKind::Nil,
474            cljrs_types::span::Span::new(Arc::new("<aot>".to_string()), 0, 0, 1, 1),
475        );
476        vec![nil_form]
477    } else {
478        compilable
479    };
480
481    // Use the current namespace (may have been changed by ns form in preamble).
482    let current_ns = env.current_ns.to_string();
483    let mut ir_func = lower_via_rust(
484        Some("__cljrs_main"),
485        &current_ns,
486        &params,
487        &compilable_forms,
488        &mut env,
489    )?;
490    eprintln!(
491        "[aot] lowered to {} block(s), {} var(s)",
492        ir_func.blocks.len(),
493        ir_func.next_var
494    );
495
496    // Region allocation stats — show how many heap allocs the optimizer
497    // managed to lift onto the bump arena.
498    let stats = count_alloc_stats(&ir_func);
499    eprintln!(
500        "[aot] allocation stats: {} region-allocated, {} heap, {} closures (across {} functions)",
501        stats.region, stats.heap, stats.closures, stats.functions,
502    );
503
504    // ── 3b. Direct call optimization ────────────────────────────────────
505    // Rewrite `LoadGlobal + Call` → `CallDirect` for functions defined in
506    // the same compilation unit. This avoids going through rt_call.
507    optimize_direct_calls(&mut ir_func);
508
509    // ── 4. Cranelift codegen → .o ───────────────────────────────────────
510    let mut compiler = Compiler::new()?;
511
512    // Declare all subfunctions first (so they can reference each other).
513    declare_subfunctions(&ir_func, &mut compiler)?;
514
515    // Compile subfunctions before the main function.
516    compile_subfunctions(&ir_func, &mut compiler)?;
517
518    let func_id = compiler.declare_function("__cljrs_main", 0)?;
519    compiler.compile_function(&ir_func, func_id)?;
520    let obj_bytes = compiler.finish();
521    eprintln!("[aot] generated {} bytes of object code", obj_bytes.len());
522
523    // ── 5. Generate harness project & build ─────────────────────────────
524    let harness_dir = build_harness(out_path, &obj_bytes, &interpreted_source, &bundled_sources)?;
525    link_with_cargo(&harness_dir, out_path)?;
526
527    eprintln!("[aot] wrote {}", out_path.display());
528    Ok(())
529}
530
531/// Check if a top-level form needs the interpreter (can't be AOT-compiled yet).
532fn needs_interpreter(form: &cljrs_reader::Form) -> bool {
533    use cljrs_reader::form::FormKind;
534    match &form.kind {
535        FormKind::List(parts) => {
536            if let Some(head) = parts.first()
537                && let FormKind::Symbol(s) = &head.kind
538            {
539                // defmacro/defonce need the interpreter (macros must be
540                // available at compile time). ns/require are module-level.
541                // Protocol/multimethod forms modify global dispatch tables
542                // and are best handled by the interpreter at startup.
543                return matches!(
544                    s.as_str(),
545                    "defmacro"
546                        | "defonce"
547                        | "ns"
548                        | "require"
549                        | "defprotocol"
550                        | "extend-type"
551                        | "extend-protocol"
552                        | "defmulti"
553                        | "defmethod"
554                        | "defrecord"
555                );
556            }
557            false
558        }
559        _ => false,
560    }
561}
562
563/// Check if a symbol name (possibly namespace-qualified) refers to an
564/// interpreter-only function.
565fn is_interpreter_only_sym(s: &str) -> bool {
566    // Strip namespace prefix if present (e.g. "clojure.core/alter-meta!" → "alter-meta!")
567    let base = s.rsplit('/').next().unwrap_or(s);
568    matches!(
569        base,
570        "alter-meta!" | "vary-meta" | "reset-meta!" | "with-meta"
571    )
572}
573
574/// Check the macro-expanded form for constructs that the AOT compiler
575/// cannot handle (e.g. alter-meta!, vary-meta). This recurses
576/// into the form tree so that e.g. `(do (def x ...) (alter-meta! ...))` is caught.
577fn expanded_needs_interpreter(form: &cljrs_reader::Form) -> bool {
578    use cljrs_reader::form::FormKind;
579    match &form.kind {
580        FormKind::List(parts) => {
581            if let Some(head) = parts.first()
582                && let FormKind::Symbol(s) = &head.kind
583                && is_interpreter_only_sym(s)
584            {
585                return true;
586            }
587            parts.iter().any(expanded_needs_interpreter)
588        }
589        FormKind::Vector(elems) | FormKind::Set(elems) => {
590            elems.iter().any(expanded_needs_interpreter)
591        }
592        FormKind::Map(elems) => elems.iter().any(expanded_needs_interpreter),
593        _ => false,
594    }
595}
596
597// ── Subfunction compilation ─────────────────────────────────────────────────
598
599/// Recursively declare all subfunctions in the compiler module.
600fn declare_subfunctions(ir_func: &IrFunction, compiler: &mut Compiler) -> AotResult<()> {
601    for sub in &ir_func.subfunctions {
602        let name = sub.name.as_deref().unwrap_or("__cljrs_anon");
603        compiler.declare_function(name, sub.params.len())?;
604        declare_subfunctions(sub, compiler)?;
605    }
606    Ok(())
607}
608
609/// Recursively compile all subfunctions.
610fn compile_subfunctions(ir_func: &IrFunction, compiler: &mut Compiler) -> AotResult<()> {
611    for sub in &ir_func.subfunctions {
612        compile_subfunctions(sub, compiler)?;
613        let name = sub.name.as_deref().unwrap_or("__cljrs_anon");
614        let func_id = compiler.declare_function(name, sub.params.len())?;
615        compiler.compile_function(sub, func_id)?;
616    }
617    Ok(())
618}
619
620// ── Bundled source discovery ─────────────────────────────────────────────────
621
622/// Discover user namespaces loaded during macro expansion that need to be
623/// bundled into the compiled binary.
624///
625/// Compares the set of loaded namespaces before and after expansion. For each
626/// newly loaded namespace that isn't a builtin source, resolves and reads
627/// its source file from `src_dirs`.
628fn discover_bundled_sources(
629    globals: &Arc<cljrs_env::env::GlobalEnv>,
630    pre_loaded: &std::collections::HashSet<Arc<str>>,
631    src_dirs: &[PathBuf],
632) -> Vec<(Arc<str>, String)> {
633    let post_loaded = globals.loaded.lock().unwrap().clone();
634    let mut bundled = Vec::new();
635
636    for ns in post_loaded.difference(pre_loaded) {
637        // Skip namespaces that are already available as builtins at runtime.
638        if globals.builtin_source(ns).is_some() {
639            continue;
640        }
641        // Skip compiler-internal namespaces.
642        if ns.starts_with("cljrs.compiler.") {
643            continue;
644        }
645        // Resolve the source file from src_dirs.
646        let rel_path = ns.replace('.', "/").replace('-', "_");
647        if let Some(src) = find_user_source(&rel_path, src_dirs) {
648            bundled.push((ns.clone(), src));
649        }
650    }
651
652    bundled
653}
654
655/// Find and read a user source file from the given directories.
656fn find_user_source(rel: &str, src_dirs: &[PathBuf]) -> Option<String> {
657    for dir in src_dirs {
658        for ext in &[".cljrs", ".cljc"] {
659            let path = dir.join(format!("{rel}{ext}"));
660            if path.exists() {
661                return std::fs::read_to_string(&path).ok();
662            }
663        }
664    }
665    None
666}
667
668// ── Harness generation ──────────────────────────────────────────────────────
669
670/// Create a temporary Cargo project that links the compiled object code with
671/// the clojurust runtime and produces a binary.
672fn build_harness(
673    out_path: &Path,
674    obj_bytes: &[u8],
675    interpreted_source: &str,
676    bundled_sources: &[(Arc<str>, String)],
677) -> AotResult<PathBuf> {
678    // Place the harness in a temp dir next to the output.
679    let harness_dir = out_path
680        .parent()
681        .unwrap_or(Path::new("."))
682        .join(".cljrs-aot-harness");
683
684    // Clean any previous harness.
685    if harness_dir.exists() {
686        std::fs::remove_dir_all(&harness_dir)?;
687    }
688    std::fs::create_dir_all(harness_dir.join("src"))?;
689
690    // Write the object file.
691    let obj_path = harness_dir.join("__cljrs_main.o");
692    std::fs::write(&obj_path, obj_bytes)?;
693
694    // Find the workspace root (where the top-level Cargo.toml lives).
695    let workspace_root = find_workspace_root()?;
696
697    // Write Cargo.toml.
698    // The empty [workspace] table prevents Cargo from thinking this is
699    // part of a parent workspace.
700    let cargo_toml = format!(
701        r#"[package]
702name = "cljrs-aot-harness"
703version = "0.1.0"
704edition = "2024"
705
706[workspace]
707
708[dependencies]
709cljrs-types    = {{ path = "{ws}/crates/cljrs-types" }}
710cljrs-gc       = {{ path = "{ws}/crates/cljrs-gc" }}
711cljrs-value    = {{ path = "{ws}/crates/cljrs-value" }}
712cljrs-reader   = {{ path = "{ws}/crates/cljrs-reader" }}
713cljrs-env      = {{ path = "{ws}/crates/cljrs-env" }}
714cljrs-eval     = {{ path = "{ws}/crates/cljrs-eval" }}
715cljrs-stdlib   = {{ path = "{ws}/crates/cljrs-stdlib" }}
716cljrs-compiler = {{ path = "{ws}/crates/cljrs-compiler" }}
717"#,
718        ws = workspace_root.display()
719    );
720    std::fs::write(harness_dir.join("Cargo.toml"), cargo_toml)?;
721
722    // Write build.rs — tells Cargo to link our object file.
723    let obj_abs = std::fs::canonicalize(&obj_path)?;
724    let build_rs = format!(
725        r#"fn main() {{
726    // Link the AOT-compiled object file.
727    println!("cargo:rustc-link-arg={obj}");
728    println!("cargo:rerun-if-changed={obj}");
729}}"#,
730        obj = obj_abs.display()
731    );
732    std::fs::write(harness_dir.join("build.rs"), build_rs)?;
733
734    // Write the interpreted preamble source (if any).
735    let has_preamble = !interpreted_source.is_empty();
736    if has_preamble {
737        std::fs::write(harness_dir.join("src/preamble.cljrs"), interpreted_source)?;
738    }
739
740    // Write bundled dependency sources.
741    for (i, (ns, src)) in bundled_sources.iter().enumerate() {
742        let filename = format!("bundled_{i}.cljrs");
743        std::fs::write(harness_dir.join("src").join(&filename), src)?;
744        eprintln!("[aot] bundled {ns} → src/{filename}");
745    }
746
747    // Generate registration code for bundled sources.
748    let mut bundled_registration = String::new();
749    for (i, (ns, _)) in bundled_sources.iter().enumerate() {
750        bundled_registration.push_str(&format!(
751            "    globals.register_builtin_source(\"{ns}\", \
752             include_str!(\"bundled_{i}.cljrs\"));\n"
753        ));
754    }
755
756    // Write main.rs — calls into the compiled __cljrs_main.
757    let preamble_code = if has_preamble {
758        r#"
759    // Evaluate interpreted preamble (ns, require, defn, defmacro, etc.).
760    let preamble = include_str!("preamble.cljrs");
761    let mut parser = cljrs_reader::Parser::new(preamble.to_string(), "<preamble>".to_string());
762    let forms = parser.parse_all().expect("preamble parse error");
763    for form in &forms {
764        cljrs_eval::eval(form, &mut env).expect("preamble eval error");
765    }
766    // Re-push eval context with updated namespace (ns form may have changed it).
767    cljrs_env::callback::pop_eval_context();
768    cljrs_env::callback::push_eval_context(&env);
769"#
770    } else {
771        ""
772    };
773
774    let main_rs = format!(
775        r#"//! Auto-generated AOT harness for clojurust.
776//!
777//! Initializes the runtime, then calls the compiled `__cljrs_main`.
778
779#![allow(improper_ctypes)]
780
781use cljrs_value::Value;
782
783unsafe extern "C" {{
784    fn __cljrs_main() -> *const Value;
785}}
786
787fn main() {{
788    // Ensure all rt_* symbols are linked into the binary.
789    cljrs_compiler::rt_abi::anchor_rt_symbols();
790
791    // Initialize the standard environment so that rt_call and other
792    // runtime bridge functions can look up builtins.
793    let globals = cljrs_stdlib::standard_env();
794
795    // Register bundled dependency sources so require can find them
796    // without needing source files on disk.
797{bundled}
798    let mut env = cljrs_eval::Env::new(globals, "user");
799
800    // Push an eval context so rt_call can dispatch through the interpreter.
801    cljrs_env::callback::push_eval_context(&env);
802{preamble}
803    // Call the compiled code.
804    let _result = unsafe {{ __cljrs_main() }};
805
806    // Pop the eval context.
807    cljrs_env::callback::pop_eval_context();
808
809    // If CLJRS_GC_STATS is set, dump GC stats to its target (stdout/file).
810    cljrs_gc::dump_stats_from_env();
811}}
812"#,
813        preamble = preamble_code,
814        bundled = bundled_registration
815    );
816    std::fs::write(harness_dir.join("src/main.rs"), main_rs)?;
817
818    Ok(harness_dir)
819}
820
821/// Build the harness with Cargo and copy the resulting binary to `out_path`.
822fn link_with_cargo(harness_dir: &Path, out_path: &Path) -> AotResult<()> {
823    eprintln!("[aot] building harness with cargo...");
824
825    let output = std::process::Command::new("cargo")
826        .arg("build")
827        .arg("--release")
828        .arg("--offline")
829        .current_dir(harness_dir)
830        .output()?;
831
832    if !output.status.success() {
833        let stderr = String::from_utf8_lossy(&output.stderr);
834        return Err(AotError::Link(format!("cargo build failed:\n{stderr}")));
835    }
836
837    // The binary is at target/release/cljrs-aot-harness.
838    let bin_name = if cfg!(target_os = "windows") {
839        "cljrs-aot-harness.exe"
840    } else {
841        "cljrs-aot-harness"
842    };
843    let built = harness_dir.join("target/release").join(bin_name);
844    std::fs::copy(&built, out_path)?;
845
846    // Clean up the harness directory.
847    let _ = std::fs::remove_dir_all(harness_dir);
848
849    Ok(())
850}
851
852/// Build the harness with Cargo and copy the resulting binary to `out_path`.
853/// Keeps the harness directory for debugging test harnesses.
854fn link_with_cargo_test_harness(harness_dir: &Path, out_path: &Path) -> AotResult<()> {
855    eprintln!("[aot] building harness with cargo...");
856
857    let output = std::process::Command::new("cargo")
858        .arg("build")
859        .arg("--release")
860        .arg("--offline")
861        .current_dir(harness_dir)
862        .output()?;
863
864    if !output.status.success() {
865        let stderr = String::from_utf8_lossy(&output.stderr);
866        return Err(AotError::Link(format!("cargo build failed:\n{stderr}")));
867    }
868
869    // The binary is at target/release/cljrs-aot-harness.
870    let bin_name = if cfg!(target_os = "windows") {
871        "cljrs-aot-harness.exe"
872    } else {
873        "cljrs-aot-harness"
874    };
875    let built = harness_dir.join("target/release").join(bin_name);
876    std::fs::copy(&built, out_path)?;
877
878    // Keep the harness directory for debugging
879    eprintln!("[aot] harness directory kept at {}", harness_dir.display());
880
881    Ok(())
882}
883
884/// Walk up from the current directory to find the workspace root
885/// (the directory containing Cargo.toml with [workspace]).
886fn find_workspace_root() -> AotResult<PathBuf> {
887    let mut dir = std::env::current_dir()?;
888    loop {
889        let cargo_toml = dir.join("Cargo.toml");
890        if cargo_toml.exists() {
891            let contents = std::fs::read_to_string(&cargo_toml)?;
892            if contents.contains("[workspace") {
893                return Ok(dir);
894            }
895        }
896        if !dir.pop() {
897            return Err(AotError::Link(
898                "could not find workspace root (no Cargo.toml with [workspace])".to_string(),
899            ));
900        }
901    }
902}
903
904// ── Test harness generation ─────────────────────────────────────────────────
905
906/// Discover test namespaces from a directory of test files.
907/// Returns a sorted list of namespace names.
908fn discover_test_namespaces(test_dir: &Path, src_dirs: &[PathBuf]) -> AotResult<Vec<String>> {
909    let mut namespaces = Vec::new();
910
911    // First, try to discover from the test directory directly
912    if test_dir.is_dir() {
913        discover_in_dir(test_dir, test_dir, &mut namespaces);
914    }
915
916    // If no tests found in test_dir, also search src_dirs
917    if namespaces.is_empty() {
918        for dir in src_dirs {
919            if dir.is_dir() {
920                discover_in_dir(dir, dir, &mut namespaces);
921            }
922        }
923    }
924
925    namespaces.sort();
926    Ok(namespaces)
927}
928
929/// Discover all namespace names from `.cljc` / `.cljrs` files in the given source paths.
930fn discover_in_dir(root: &Path, dir: &Path, out: &mut Vec<String>) {
931    let Ok(entries) = std::fs::read_dir(dir) else {
932        return;
933    };
934    let mut entries: Vec<_> = entries.filter_map(|e| e.ok()).collect();
935    entries.sort_by_key(|e| e.file_name());
936    for entry in entries {
937        let path = entry.path();
938        if path.is_dir() {
939            discover_in_dir(root, &path, out);
940        } else if let Some(ext) = path.extension()
941            && (ext == "cljc" || ext == "cljrs")
942            && let Some(ns) = file_to_namespace(root, &path)
943        {
944            out.push(ns);
945        }
946    }
947}
948
949/// Convert a file path relative to the source root into a Clojure namespace name.
950/// e.g. `test/clojure/core_test/juxt.cljc` relative to `test/` → `clojure.core-test.juxt`
951fn file_to_namespace(root: &Path, file: &Path) -> Option<String> {
952    let rel = file.strip_prefix(root).ok()?;
953    let stem = rel.with_extension(""); // remove .cljc / .cljrs
954    let ns = stem
955        .to_string_lossy()
956        .replace(std::path::MAIN_SEPARATOR, ".")
957        .replace('_', "-");
958    Some(ns)
959}
960
961/// Generate the Rust test harness code.
962fn generate_test_harness_code(namespaces: &[String], bundled_registration: &str) -> String {
963    let mut code = String::new();
964
965    // Generate the namespace strings array inline
966    let ns_strings: Vec<String> = namespaces
967        .iter()
968        .map(|s| format!("\"{}\".to_string()", s))
969        .collect();
970
971    code.push_str(
972        r#"//! Auto-generated AOT test harness for clojurust.
973//!
974//! Discovers and runs all clojure.test tests in the bundled namespaces.
975
976use cljrs_value::Value;
977
978fn main() {
979    // Initialize the standard environment.
980    let globals = cljrs_stdlib::standard_env();
981
982    // Register bundled dependency sources so require can find them
983    // without needing source files on disk.
984"#,
985    );
986
987    code.push_str(bundled_registration);
988    code.push_str(
989        r#"    let mut env = cljrs_eval::Env::new(globals, "user");
990
991    // Push an eval context so rt_call can dispatch through the interpreter.
992    cljrs_env::callback::push_eval_context(&env);
993
994    // Load clojure.test if not already loaded
995    let _ = cljrs_eval::eval(
996        &cljrs_reader::Parser::new(
997            "(require 'clojure.test)".to_string(),
998            "<test-harness>".to_string()
999        ).parse_all().unwrap()[0],
1000        &mut env
1001    );
1002
1003    // Load all test namespaces
1004    (|| {
1005"#,
1006    );
1007
1008    for ns in namespaces.iter() {
1009        code.push_str(&format!(
1010            "        let _ = cljrs_eval::eval(&cljrs_reader::Parser::new(\n            \"(require '{})\".to_string(),\n            \"<test-harness>\".to_string()\n        ).parse_all().unwrap()[0], &mut env);\n",
1011            ns
1012        ));
1013    }
1014
1015    code.push_str(
1016        r#"    })();
1017
1018    // Run tests for each namespace separately
1019    let mut total_pass = 0i64;
1020    let mut total_fail = 0i64;
1021    let mut total_error = 0i64;
1022    let mut total_test_count = 0i64;
1023
1024    for ns_str in vec![
1025"#,
1026    );
1027
1028    for ns_str in ns_strings.iter() {
1029        code.push_str(&format!("        {},\n", ns_str));
1030    }
1031
1032    code.push_str(r#"    ].iter() {
1033        let run_result = cljrs_eval::eval(
1034            &cljrs_reader::Parser::new(
1035                format!("(clojure.test/run-tests '{})", ns_str)
1036                    .to_string(),
1037                "<run-tests>".to_string()
1038            ).parse_all().unwrap()[0],
1039            &mut env
1040        );
1041        if let Ok(Value::Map(m)) = run_result {
1042            let mut pass = 0i64;
1043            let mut fail = 0i64;
1044            let mut error = 0i64;
1045            let mut test_count = 0i64;
1046            m.for_each(|k, v| {
1047                if let (Value::Keyword(kw), Value::Long(count)) = (k, v) {
1048                    match kw.get().name.as_ref() {
1049                        "pass" => pass = *count,
1050                        "fail" => fail = *count,
1051                        "error" => error = *count,
1052                        "test" => test_count = *count,
1053                        _ => {}
1054                    }
1055                }
1056            });
1057            total_pass += pass;
1058            total_fail += fail;
1059            total_error += error;
1060            total_test_count += test_count;
1061        }
1062    }
1063
1064    // Flush output before exiting
1065    std::io::Write::flush(&mut std::io::stdout()).unwrap();
1066    println!("Ran {} tests containing {} assertions.", total_test_count, total_pass + total_fail + total_error);
1067    std::io::Write::flush(&mut std::io::stdout()).unwrap();
1068    println!("{} passed, {} failed, {} errors.", total_pass, total_fail, total_error);
1069    std::io::Write::flush(&mut std::io::stdout()).unwrap();
1070
1071    // Pop the eval context.
1072    cljrs_env::callback::pop_eval_context();
1073
1074    // If CLJRS_GC_STATS is set, dump GC stats to its target (stdout/file).
1075    cljrs_gc::dump_stats_from_env();
1076
1077    if total_fail > 0 || total_error > 0 {
1078        std::process::exit(1);
1079    }
1080}"#);
1081
1082    code
1083}
1084
1085/// Compile a directory of test files to a standalone native binary.
1086/// The resulting binary will discover and run all clojure.test tests found.
1087pub fn compile_test_harness(
1088    test_dir: &Path,
1089    out_path: &Path,
1090    src_dirs: &[PathBuf],
1091) -> AotResult<()> {
1092    eprintln!("[aot] discovering tests in {}", test_dir.display());
1093
1094    // Discover test namespaces
1095    let test_namespaces = discover_test_namespaces(test_dir, src_dirs)?;
1096    if test_namespaces.is_empty() {
1097        return Err(AotError::Eval(format!(
1098            "No test files found in {}",
1099            test_dir.display()
1100        )));
1101    }
1102    eprintln!(
1103        "[aot] discovered {} test namespace(s)",
1104        test_namespaces.len()
1105    );
1106
1107    // Also discover source namespaces from src_dirs so they get bundled
1108    let mut src_namespaces = Vec::new();
1109    for dir in src_dirs {
1110        if dir.is_dir() {
1111            discover_in_dir(dir, dir, &mut src_namespaces);
1112        }
1113    }
1114    src_namespaces.sort();
1115    eprintln!(
1116        "[aot] discovered {} source namespace(s)",
1117        src_namespaces.len()
1118    );
1119
1120    // Combine: source namespaces first (so they're registered before tests require them),
1121    // then test namespaces. Deduplicate in case of overlap.
1122    let mut all_namespaces = Vec::new();
1123    let mut seen = std::collections::HashSet::new();
1124    for ns in src_namespaces.iter().chain(test_namespaces.iter()) {
1125        if seen.insert(ns.clone()) {
1126            all_namespaces.push(ns.clone());
1127        }
1128    }
1129
1130    // Generate registration code for bundled sources
1131    let mut bundled_registration = String::new();
1132    for (i, ns) in all_namespaces.iter().enumerate() {
1133        bundled_registration.push_str(&format!(
1134            "    globals.register_builtin_source(\"{ns}\", include_str!(\"bundled_{i}.cljrs\"));\n"
1135        ));
1136    }
1137
1138    // Create the harness directory
1139    let harness_dir = out_path
1140        .parent()
1141        .unwrap_or(Path::new("."))
1142        .join(".cljrs-aot-test-harness");
1143
1144    // Clean any previous harness.
1145    if harness_dir.exists() {
1146        std::fs::remove_dir_all(&harness_dir)?;
1147    }
1148    std::fs::create_dir_all(harness_dir.join("src"))?;
1149
1150    // Generate the main.rs file (only test namespaces get run-tests called)
1151    let main_rs = generate_test_harness_code(&test_namespaces, &bundled_registration);
1152    std::fs::write(harness_dir.join("src/main.rs"), &main_rs)?;
1153
1154    // Write all namespace sources for bundling
1155    // Include test_dir as a search path for test sources
1156    let mut search_dirs = src_dirs.to_vec();
1157    search_dirs.push(test_dir.to_path_buf());
1158
1159    for (i, ns) in all_namespaces.iter().enumerate() {
1160        let rel_path = ns.replace('.', "/").replace('-', "_");
1161        if let Some(src) = find_user_source(&rel_path, &search_dirs) {
1162            std::fs::write(
1163                harness_dir.join("src").join(format!("bundled_{i}.cljrs")),
1164                &src,
1165            )?;
1166            eprintln!("[aot] bundled {ns} → src/bundled_{i}.cljrs");
1167        } else {
1168            return Err(AotError::Eval(format!(
1169                "Could not find source for namespace {ns}"
1170            )));
1171        }
1172    }
1173
1174    // Write Cargo.toml
1175    let workspace_root = find_workspace_root()?;
1176    let cargo_toml = format!(
1177        r#"[package]
1178name = "cljrs-aot-harness"
1179version = "0.1.0"
1180edition = "2021"
1181
1182[workspace]
1183
1184[dependencies]
1185cljrs-types    = {{ path = "{ws}/crates/cljrs-types" }}
1186cljrs-gc       = {{ path = "{ws}/crates/cljrs-gc" }}
1187cljrs-value    = {{ path = "{ws}/crates/cljrs-value" }}
1188cljrs-reader   = {{ path = "{ws}/crates/cljrs-reader" }}
1189cljrs-env      = {{ path = "{ws}/crates/cljrs-env" }}
1190cljrs-eval     = {{ path = "{ws}/crates/cljrs-eval" }}
1191cljrs-stdlib   = {{ path = "{ws}/crates/cljrs-stdlib" }}
1192cljrs-compiler = {{ path = "{ws}/crates/cljrs-compiler" }}
1193"#,
1194        ws = workspace_root.display()
1195    );
1196    std::fs::write(harness_dir.join("Cargo.toml"), cargo_toml)?;
1197
1198    // Write build.rs - minimal, no object file linking needed
1199    let build_rs = r#"fn main() {
1200    // No special linking needed for test harness
1201}"#;
1202    std::fs::write(harness_dir.join("build.rs"), build_rs)?;
1203
1204    // Build with cargo
1205    link_with_cargo_test_harness(&harness_dir, out_path)?;
1206
1207    eprintln!("[aot] wrote {}", out_path.display());
1208    Ok(())
1209}