Skip to main content

cljrs_compiler/
aot.rs

1//! AOT compilation driver.
2//!
3//! Orchestrates the full pipeline from source file to native binary:
4//!
5//! 1. Parse source → `Vec<Form>`
6//! 2. Boot a standard environment (for macro expansion)
7//! 3. Macro-expand each top-level form
8//! 4. ANF-lower all forms as a zero-arg `__cljrs_main` function
9//! 5. Cranelift codegen → `.o` object bytes
10//! 6. Generate a Cargo harness project that links the object + runtime
11//! 7. `cargo build --release` the harness → standalone binary
12
13use std::path::{Path, PathBuf};
14use std::sync::Arc;
15
16use cljrs_reader::Parser;
17
18use crate::codegen::Compiler;
19use crate::ir::IrFunction;
20use cljrs_eval::ir_convert;
21
22// ── Error type ──────────────────────────────────────────────────────────────
23
24#[derive(Debug)]
25pub enum AotError {
26    Io(std::io::Error),
27    Parse(cljrs_types::error::CljxError),
28    Codegen(crate::codegen::CodegenError),
29    Eval(String),
30    Link(String),
31    /// One or more no-gc memory-safety violations were found by the blacklist
32    /// analysis.  Only emitted when the `no-gc` Cargo feature is active.
33    #[cfg(feature = "no-gc")]
34    NoGcBlacklist(Vec<crate::escape::BlacklistViolation>),
35}
36
37impl std::fmt::Display for AotError {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            AotError::Io(e) => write!(f, "I/O error: {e}"),
41            AotError::Parse(e) => write!(f, "parse error: {e}"),
42            AotError::Codegen(e) => write!(f, "codegen error: {e:?}"),
43            AotError::Eval(e) => write!(f, "eval/lowering error: {e}"),
44            AotError::Link(e) => write!(f, "link error: {e}"),
45            #[cfg(feature = "no-gc")]
46            AotError::NoGcBlacklist(vs) => {
47                writeln!(f, "no-gc blacklist violations:")?;
48                for v in vs {
49                    writeln!(f, "  • {v}")?;
50                }
51                Ok(())
52            }
53        }
54    }
55}
56
57impl std::error::Error for AotError {}
58
59impl From<std::io::Error> for AotError {
60    fn from(e: std::io::Error) -> Self {
61        AotError::Io(e)
62    }
63}
64impl From<cljrs_types::error::CljxError> for AotError {
65    fn from(e: cljrs_types::error::CljxError) -> Self {
66        AotError::Parse(e)
67    }
68}
69impl From<crate::codegen::CodegenError> for AotError {
70    fn from(e: crate::codegen::CodegenError) -> Self {
71        AotError::Codegen(e)
72    }
73}
74
75pub type AotResult<T> = Result<T, AotError>;
76
77// ── Clojure-based lowering ──────────────────────────────────────────────────
78
79/// Lower forms via the Clojure compiler front-end.
80///
81/// 1. Ensures compiler namespaces are loaded
82/// 2. Converts Forms → Values via `form_to_value`
83/// 3. Calls `cljrs.compiler.anf/lower-fn-body` via `callback::invoke`
84/// 4. Converts returned Value → `IrFunction` via `ir_convert`
85pub fn lower_via_clojure(
86    name: Option<&str>,
87    ns: &str,
88    params: &[Arc<str>],
89    compilable_forms: &[cljrs_reader::Form],
90    env: &mut cljrs_eval::Env,
91) -> AotResult<IrFunction> {
92    // Register compiler sources so require can find them.
93    crate::register_compiler_sources(&env.globals);
94
95    // Push eval context so callback::invoke can work.
96    cljrs_env::callback::push_eval_context(env);
97
98    // Set IR_LOWERING_ACTIVE to prevent eager lowering of closures
99    // created inside the Clojure compiler during this lowering call.
100    use cljrs_eval::apply::IR_LOWERING_ACTIVE;
101    let was_active = IR_LOWERING_ACTIVE.with(|c| c.get());
102    IR_LOWERING_ACTIVE.with(|c| c.set(true));
103
104    let result = lower_via_clojure_inner(name, ns, params, compilable_forms, env);
105
106    // Restore lowering flag and pop eval context.
107    IR_LOWERING_ACTIVE.with(|c| c.set(was_active));
108    cljrs_env::callback::pop_eval_context();
109
110    result
111}
112
113fn lower_via_clojure_inner(
114    name: Option<&str>,
115    ns: &str,
116    params: &[Arc<str>],
117    compilable_forms: &[cljrs_reader::Form],
118    env: &mut cljrs_eval::Env,
119) -> AotResult<IrFunction> {
120    use cljrs_gc::GcPtr;
121    use cljrs_value::Value;
122    use cljrs_value::collections::vector::PersistentVector;
123
124    // Ensure the ANF and optimize namespaces are loaded.
125    let span = || cljrs_types::span::Span::new(Arc::new("<aot>".to_string()), 0, 0, 1, 1);
126    for ns_name in &["cljrs.compiler.anf", "cljrs.compiler.optimize"] {
127        let require_form = cljrs_reader::Form::new(
128            cljrs_reader::form::FormKind::List(vec![
129                cljrs_reader::Form::new(
130                    cljrs_reader::form::FormKind::Symbol("require".into()),
131                    span(),
132                ),
133                cljrs_reader::Form::new(
134                    cljrs_reader::form::FormKind::Quote(Box::new(cljrs_reader::Form::new(
135                        cljrs_reader::form::FormKind::Symbol((*ns_name).into()),
136                        span(),
137                    ))),
138                    span(),
139                ),
140            ]),
141            span(),
142        );
143        cljrs_eval::eval(&require_form, env).map_err(|e| AotError::Eval(format!("{e:?}")))?;
144    }
145
146    // Look up the lower-fn-body function.
147    let lower_fn = env
148        .globals
149        .lookup_var_in_ns("cljrs.compiler.anf", "lower-fn-body")
150        .ok_or_else(|| AotError::Eval("cljrs.compiler.anf/lower-fn-body not found".to_string()))?;
151    let lower_fn_val = lower_fn.get().deref().unwrap_or(Value::Nil);
152
153    // Build arguments:
154    // 1. fname (string or nil)
155    let fname_val = match name {
156        Some(n) => Value::string(n.to_string()),
157        None => Value::Nil,
158    };
159
160    // 2. ns (string)
161    let ns_val = Value::string(ns.to_string());
162
163    // 3. params (vector of strings)
164    let params_val = Value::Vector(GcPtr::new(PersistentVector::from_iter(
165        params.iter().map(|p| Value::string(p.to_string())),
166    )));
167
168    // 4. body-forms (vector of form values)
169    let body_forms_val = Value::Vector(GcPtr::new(PersistentVector::from_iter(
170        compilable_forms
171            .iter()
172            .map(cljrs_builtins::form::form_to_value),
173    )));
174
175    // Call the Clojure lowering function.
176    let ir_data = cljrs_env::callback::invoke(
177        &lower_fn_val,
178        vec![fname_val, ns_val, params_val, body_forms_val],
179    )
180    .map_err(|e| AotError::Eval(format!("Clojure lowering failed: {e:?}")))?;
181
182    // Run the optimization pass (escape analysis → region allocation).
183    let optimize_fn = env
184        .globals
185        .lookup_var_in_ns("cljrs.compiler.optimize", "optimize")
186        .ok_or_else(|| AotError::Eval("cljrs.compiler.optimize/optimize not found".to_string()))?;
187    let optimize_fn_val = optimize_fn.get().deref().unwrap_or(Value::Nil);
188
189    let optimized = cljrs_env::callback::invoke(&optimize_fn_val, vec![ir_data])
190        .map_err(|e| AotError::Eval(format!("Optimization failed: {e:?}")))?;
191
192    // Convert the result Value → IrFunction.
193    let ir_func = ir_convert::value_to_ir_function(&optimized)
194        .map_err(|e| AotError::Eval(format!("IR conversion failed: {e}")))?;
195
196    // Under no-gc: run the blacklist analysis and reject programs that contain
197    // patterns the AOT code generator cannot safely compile.
198    #[cfg(feature = "no-gc")]
199    {
200        let violations = crate::escape::check(&ir_func);
201        if !violations.is_empty() {
202            return Err(AotError::NoGcBlacklist(violations));
203        }
204    }
205
206    Ok(ir_func)
207}
208
209// ── Direct call optimization ────────────────────────────────────────────────
210
211/// Information about a compiled function arity.
212#[derive(Debug, Clone)]
213struct ArityInfo {
214    fn_name: Arc<str>,
215    param_count: usize,
216    is_variadic: bool,
217}
218
219/// Collect top-level function definitions from an IR function.
220///
221/// Scans for `AllocClosure(...) + DefVar(_, ns, name, closure_var)` patterns
222/// and returns a map from `(ns, name)` → list of arity infos.
223fn collect_defn_arities(
224    ir_func: &IrFunction,
225) -> std::collections::HashMap<(Arc<str>, Arc<str>), Vec<ArityInfo>> {
226    use crate::ir::{ClosureTemplate, Inst, VarId};
227    use std::collections::HashMap;
228
229    let mut closure_templates: HashMap<VarId, ClosureTemplate> = HashMap::new();
230    let mut defns: HashMap<(Arc<str>, Arc<str>), Vec<ArityInfo>> = HashMap::new();
231
232    for block in &ir_func.blocks {
233        for inst in &block.insts {
234            match inst {
235                Inst::AllocClosure(dst, template, captures) if captures.is_empty() => {
236                    // Only consider zero-capture closures (top-level defns).
237                    closure_templates.insert(*dst, template.clone());
238                }
239                Inst::DefVar(_, ns, name, val) => {
240                    if let Some(template) = closure_templates.get(val) {
241                        let arities: Vec<ArityInfo> = template
242                            .arity_fn_names
243                            .iter()
244                            .zip(template.param_counts.iter())
245                            .zip(template.is_variadic.iter())
246                            .map(|((fn_name, &param_count), &is_variadic)| ArityInfo {
247                                fn_name: fn_name.clone(),
248                                param_count,
249                                is_variadic,
250                            })
251                            .collect();
252                        defns.insert((ns.clone(), name.clone()), arities);
253                    }
254                }
255                _ => {}
256            }
257        }
258    }
259
260    defns
261}
262
263/// Find the arity function name that matches a given argument count.
264///
265/// Only matches fixed arities — variadic functions cannot be called directly
266/// because the runtime needs to pack extra args into a rest list.
267fn find_matching_arity(arities: &[ArityInfo], arg_count: usize) -> Option<&ArityInfo> {
268    arities
269        .iter()
270        .find(|arity| !arity.is_variadic && arity.param_count == arg_count)
271}
272
273/// Rewrite `LoadGlobal + Call` sequences to `CallDirect` for functions
274/// defined in the same compilation unit.
275///
276/// This is a perf optimization: instead of going through `rt_call` (which
277/// looks up the var in the interpreter and dispatches dynamically), we call
278/// the compiled function pointer directly.
279fn optimize_direct_calls(ir_func: &mut IrFunction) {
280    // Collect defn arities from this function AND all subfunctions (recursively).
281    let mut all_defns = collect_defn_arities(ir_func);
282    for sub in &ir_func.subfunctions {
283        // Subfunctions don't typically DefVar, but recurse just in case.
284        all_defns.extend(collect_defn_arities(sub));
285    }
286
287    if all_defns.is_empty() {
288        return;
289    }
290
291    let rewrites = rewrite_calls_to_direct(ir_func, &all_defns);
292    if rewrites > 0 {
293        eprintln!("[aot] optimized {rewrites} call(s) to direct function calls");
294    }
295
296    // Recursively optimize subfunctions too.
297    // Subfunctions can call top-level defns, so pass the defn map down.
298    for sub in &mut ir_func.subfunctions {
299        optimize_direct_calls_with_defns(sub, &all_defns);
300    }
301}
302
303/// Like `optimize_direct_calls` but uses a pre-built defn map (for subfunctions).
304fn optimize_direct_calls_with_defns(
305    ir_func: &mut IrFunction,
306    defns: &std::collections::HashMap<(Arc<str>, Arc<str>), Vec<ArityInfo>>,
307) {
308    // Merge parent defns with any defns from this function.
309    let mut all_defns = defns.clone();
310    all_defns.extend(collect_defn_arities(ir_func));
311
312    if all_defns.is_empty() {
313        return;
314    }
315
316    let rewrites = rewrite_calls_to_direct(ir_func, &all_defns);
317    if rewrites > 0 {
318        eprintln!("[aot] optimized {rewrites} direct call(s) in subfunction");
319    }
320
321    for sub in &mut ir_func.subfunctions {
322        optimize_direct_calls_with_defns(sub, &all_defns);
323    }
324}
325
326/// Rewrite `LoadGlobal + Call` → `CallDirect` in a single IR function.
327/// Returns the number of rewrites performed.
328fn rewrite_calls_to_direct(
329    ir_func: &mut IrFunction,
330    defns: &std::collections::HashMap<(Arc<str>, Arc<str>), Vec<ArityInfo>>,
331) -> usize {
332    use crate::ir::{Inst, VarId};
333    use std::collections::HashMap;
334
335    // Build a map of VarId → (ns, name) for LoadGlobal instructions that load known defns.
336    let mut loadglobal_targets: HashMap<VarId, (Arc<str>, Arc<str>)> = HashMap::new();
337    for block in &ir_func.blocks {
338        for inst in &block.insts {
339            if let Inst::LoadGlobal(dst, ns, name) = inst
340                && defns.contains_key(&(ns.clone(), name.clone()))
341            {
342                loadglobal_targets.insert(*dst, (ns.clone(), name.clone()));
343            }
344        }
345    }
346
347    let mut rewrites = 0;
348    for block in &mut ir_func.blocks {
349        for inst in &mut block.insts {
350            if let Inst::Call(dst, callee, args) = inst
351                && let Some((ns, name)) = loadglobal_targets.get(callee)
352                && let Some(arities) = defns.get(&(ns.clone(), name.clone()))
353                && let Some(arity_info) = find_matching_arity(arities, args.len())
354            {
355                *inst = Inst::CallDirect(*dst, arity_info.fn_name.clone(), args.clone());
356                rewrites += 1;
357            }
358        }
359    }
360
361    rewrites
362}
363
364// ── Public API ──────────────────────────────────────────────────────────────
365
366/// Compile a `.cljrs` / `.cljc` source file to a standalone native binary.
367///
368/// `src_path` is the input source file.  `out_path` is the desired output
369/// binary.  `src_dirs` are additional directories for `require` resolution
370/// during macro expansion.
371pub fn compile_file(src_path: &Path, out_path: &Path, src_dirs: &[PathBuf]) -> AotResult<()> {
372    eprintln!("[aot] reading {}", src_path.display());
373    let source = std::fs::read_to_string(src_path)?;
374    let filename = src_path.display().to_string();
375
376    // ── 1. Parse ────────────────────────────────────────────────────────
377    let mut parser = Parser::new(source.clone(), filename);
378    let forms = parser.parse_all()?;
379    eprintln!("[aot] parsed {} top-level form(s)", forms.len());
380
381    // ── 2. Macro-expand ─────────────────────────────────────────────────
382    // Boot a full environment so macros resolve correctly.
383    let globals = if src_dirs.is_empty() {
384        cljrs_stdlib::standard_env()
385    } else {
386        cljrs_stdlib::standard_env_with_paths(src_dirs.to_vec())
387    };
388    let mut env = cljrs_eval::Env::new(globals, "user");
389
390    // Snapshot loaded namespaces before expansion so we can detect
391    // which user namespaces were pulled in by require.
392    let pre_loaded: std::collections::HashSet<Arc<str>> =
393        env.globals.loaded.lock().unwrap().clone();
394
395    let mut expanded = Vec::with_capacity(forms.len());
396    for form in &forms {
397        // For forms that need the interpreter (ns, require, defmacro, etc.),
398        // evaluate them immediately so that required namespaces get loaded
399        // and macros from dependencies are available for later forms.
400        if needs_interpreter(form) {
401            match cljrs_eval::eval(form, &mut env) {
402                Ok(_) => {}
403                Err(e) => return Err(AotError::Eval(format!("{e:?}"))),
404            }
405        }
406        match cljrs_interp::macros::macroexpand_all(form, &mut env) {
407            Ok(f) => expanded.push(f),
408            Err(e) => return Err(AotError::Eval(format!("{e:?}"))),
409        }
410    }
411    eprintln!("[aot] macro-expanded {} form(s)", expanded.len());
412
413    // Discover user namespaces loaded during expansion (transitive deps).
414    let bundled_sources = discover_bundled_sources(&env.globals, &pre_loaded, src_dirs);
415    if !bundled_sources.is_empty() {
416        eprintln!(
417            "[aot] bundling {} required namespace(s): {}",
418            bundled_sources.len(),
419            bundled_sources
420                .iter()
421                .map(|(ns, _)| ns.as_ref())
422                .collect::<Vec<_>>()
423                .join(", ")
424        );
425    }
426
427    // ── 2b. Partition: interpreted preamble vs compiled body ─────────
428    // Forms that define functions (defn, defmacro) or require interpreter
429    // features (closures) are evaluated at startup via the interpreter.
430    // The rest is AOT-compiled.
431    let mut interpreted_source = String::new();
432    let mut compilable = Vec::new();
433    for (i, form) in expanded.iter().enumerate() {
434        if needs_interpreter(&forms[i]) || expanded_needs_interpreter(form) {
435            // Extract the original source text using span byte offsets.
436            let span = &forms[i].span;
437            let src_text = &source[span.start..span.end];
438            interpreted_source.push_str(src_text);
439            interpreted_source.push('\n');
440        } else {
441            compilable.push(form.clone());
442        }
443    }
444    if !interpreted_source.is_empty() {
445        eprintln!(
446            "[aot] {} form(s) will be interpreted at startup",
447            expanded.len() - compilable.len()
448        );
449    }
450
451    // ── 3. ANF-lower ────────────────────────────────────────────────────
452    // Treat compilable top-level forms as the body of a zero-arg `__cljrs_main`.
453    let params: Vec<Arc<str>> = vec![];
454    let compilable_forms = if compilable.is_empty() {
455        // If everything is interpreted, emit a simple nil-returning main.
456        let nil_form = cljrs_reader::Form::new(
457            cljrs_reader::form::FormKind::Nil,
458            cljrs_types::span::Span::new(Arc::new("<aot>".to_string()), 0, 0, 1, 1),
459        );
460        vec![nil_form]
461    } else {
462        compilable
463    };
464
465    // Use the current namespace (may have been changed by ns form in preamble).
466    let current_ns = env.current_ns.to_string();
467    let mut ir_func = lower_via_clojure(
468        Some("__cljrs_main"),
469        &current_ns,
470        &params,
471        &compilable_forms,
472        &mut env,
473    )?;
474    eprintln!(
475        "[aot] lowered to {} block(s), {} var(s)",
476        ir_func.blocks.len(),
477        ir_func.next_var
478    );
479
480    // ── 3b. Direct call optimization ────────────────────────────────────
481    // Rewrite `LoadGlobal + Call` → `CallDirect` for functions defined in
482    // the same compilation unit. This avoids going through rt_call.
483    optimize_direct_calls(&mut ir_func);
484
485    // ── 4. Cranelift codegen → .o ───────────────────────────────────────
486    let mut compiler = Compiler::new()?;
487
488    // Declare all subfunctions first (so they can reference each other).
489    declare_subfunctions(&ir_func, &mut compiler)?;
490
491    // Compile subfunctions before the main function.
492    compile_subfunctions(&ir_func, &mut compiler)?;
493
494    let func_id = compiler.declare_function("__cljrs_main", 0)?;
495    compiler.compile_function(&ir_func, func_id)?;
496    let obj_bytes = compiler.finish();
497    eprintln!("[aot] generated {} bytes of object code", obj_bytes.len());
498
499    // ── 5. Generate harness project & build ─────────────────────────────
500    let harness_dir = build_harness(out_path, &obj_bytes, &interpreted_source, &bundled_sources)?;
501    link_with_cargo(&harness_dir, out_path)?;
502
503    eprintln!("[aot] wrote {}", out_path.display());
504    Ok(())
505}
506
507/// Check if a top-level form needs the interpreter (can't be AOT-compiled yet).
508fn needs_interpreter(form: &cljrs_reader::Form) -> bool {
509    use cljrs_reader::form::FormKind;
510    match &form.kind {
511        FormKind::List(parts) => {
512            if let Some(head) = parts.first()
513                && let FormKind::Symbol(s) = &head.kind
514            {
515                // defmacro/defonce need the interpreter (macros must be
516                // available at compile time). ns/require are module-level.
517                // Protocol/multimethod forms modify global dispatch tables
518                // and are best handled by the interpreter at startup.
519                return matches!(
520                    s.as_str(),
521                    "defmacro"
522                        | "defonce"
523                        | "ns"
524                        | "require"
525                        | "defprotocol"
526                        | "extend-type"
527                        | "extend-protocol"
528                        | "defmulti"
529                        | "defmethod"
530                        | "defrecord"
531                );
532            }
533            false
534        }
535        _ => false,
536    }
537}
538
539/// Check if a symbol name (possibly namespace-qualified) refers to an
540/// interpreter-only function.
541fn is_interpreter_only_sym(s: &str) -> bool {
542    // Strip namespace prefix if present (e.g. "clojure.core/alter-meta!" → "alter-meta!")
543    let base = s.rsplit('/').next().unwrap_or(s);
544    matches!(
545        base,
546        "alter-meta!" | "vary-meta" | "reset-meta!" | "with-meta"
547    )
548}
549
550/// Check the macro-expanded form for constructs that the AOT compiler
551/// cannot handle (e.g. alter-meta!, vary-meta). This recurses
552/// into the form tree so that e.g. `(do (def x ...) (alter-meta! ...))` is caught.
553fn expanded_needs_interpreter(form: &cljrs_reader::Form) -> bool {
554    use cljrs_reader::form::FormKind;
555    match &form.kind {
556        FormKind::List(parts) => {
557            if let Some(head) = parts.first()
558                && let FormKind::Symbol(s) = &head.kind
559                && is_interpreter_only_sym(s)
560            {
561                return true;
562            }
563            parts.iter().any(expanded_needs_interpreter)
564        }
565        FormKind::Vector(elems) | FormKind::Set(elems) => {
566            elems.iter().any(expanded_needs_interpreter)
567        }
568        FormKind::Map(elems) => elems.iter().any(expanded_needs_interpreter),
569        _ => false,
570    }
571}
572
573// ── Subfunction compilation ─────────────────────────────────────────────────
574
575/// Recursively declare all subfunctions in the compiler module.
576fn declare_subfunctions(ir_func: &IrFunction, compiler: &mut Compiler) -> AotResult<()> {
577    for sub in &ir_func.subfunctions {
578        let name = sub.name.as_deref().unwrap_or("__cljrs_anon");
579        compiler.declare_function(name, sub.params.len())?;
580        declare_subfunctions(sub, compiler)?;
581    }
582    Ok(())
583}
584
585/// Recursively compile all subfunctions.
586fn compile_subfunctions(ir_func: &IrFunction, compiler: &mut Compiler) -> AotResult<()> {
587    for sub in &ir_func.subfunctions {
588        compile_subfunctions(sub, compiler)?;
589        let name = sub.name.as_deref().unwrap_or("__cljrs_anon");
590        let func_id = compiler.declare_function(name, sub.params.len())?;
591        compiler.compile_function(sub, func_id)?;
592    }
593    Ok(())
594}
595
596// ── Bundled source discovery ─────────────────────────────────────────────────
597
598/// Discover user namespaces loaded during macro expansion that need to be
599/// bundled into the compiled binary.
600///
601/// Compares the set of loaded namespaces before and after expansion. For each
602/// newly loaded namespace that isn't a builtin source, resolves and reads
603/// its source file from `src_dirs`.
604fn discover_bundled_sources(
605    globals: &Arc<cljrs_env::env::GlobalEnv>,
606    pre_loaded: &std::collections::HashSet<Arc<str>>,
607    src_dirs: &[PathBuf],
608) -> Vec<(Arc<str>, String)> {
609    let post_loaded = globals.loaded.lock().unwrap().clone();
610    let mut bundled = Vec::new();
611
612    for ns in post_loaded.difference(pre_loaded) {
613        // Skip namespaces that are already available as builtins at runtime.
614        if globals.builtin_source(ns).is_some() {
615            continue;
616        }
617        // Skip compiler-internal namespaces.
618        if ns.starts_with("cljrs.compiler.") {
619            continue;
620        }
621        // Resolve the source file from src_dirs.
622        let rel_path = ns.replace('.', "/").replace('-', "_");
623        if let Some(src) = find_user_source(&rel_path, src_dirs) {
624            bundled.push((ns.clone(), src));
625        }
626    }
627
628    bundled
629}
630
631/// Find and read a user source file from the given directories.
632fn find_user_source(rel: &str, src_dirs: &[PathBuf]) -> Option<String> {
633    for dir in src_dirs {
634        for ext in &[".cljrs", ".cljc"] {
635            let path = dir.join(format!("{rel}{ext}"));
636            if path.exists() {
637                return std::fs::read_to_string(&path).ok();
638            }
639        }
640    }
641    None
642}
643
644// ── Harness generation ──────────────────────────────────────────────────────
645
646/// Create a temporary Cargo project that links the compiled object code with
647/// the clojurust runtime and produces a binary.
648fn build_harness(
649    out_path: &Path,
650    obj_bytes: &[u8],
651    interpreted_source: &str,
652    bundled_sources: &[(Arc<str>, String)],
653) -> AotResult<PathBuf> {
654    // Place the harness in a temp dir next to the output.
655    let harness_dir = out_path
656        .parent()
657        .unwrap_or(Path::new("."))
658        .join(".cljrs-aot-harness");
659
660    // Clean any previous harness.
661    if harness_dir.exists() {
662        std::fs::remove_dir_all(&harness_dir)?;
663    }
664    std::fs::create_dir_all(harness_dir.join("src"))?;
665
666    // Write the object file.
667    let obj_path = harness_dir.join("__cljrs_main.o");
668    std::fs::write(&obj_path, obj_bytes)?;
669
670    // Find the workspace root (where the top-level Cargo.toml lives).
671    let workspace_root = find_workspace_root()?;
672
673    // Write Cargo.toml.
674    // The empty [workspace] table prevents Cargo from thinking this is
675    // part of a parent workspace.
676    let cargo_toml = format!(
677        r#"[package]
678name = "cljrs-aot-harness"
679version = "0.1.0"
680edition = "2024"
681
682[workspace]
683
684[dependencies]
685cljrs-types    = {{ path = "{ws}/crates/cljrs-types" }}
686cljrs-gc       = {{ path = "{ws}/crates/cljrs-gc" }}
687cljrs-value    = {{ path = "{ws}/crates/cljrs-value" }}
688cljrs-reader   = {{ path = "{ws}/crates/cljrs-reader" }}
689cljrs-env      = {{ path = "{ws}/crates/cljrs-env" }}
690cljrs-eval     = {{ path = "{ws}/crates/cljrs-eval" }}
691cljrs-stdlib   = {{ path = "{ws}/crates/cljrs-stdlib" }}
692cljrs-compiler = {{ path = "{ws}/crates/cljrs-compiler" }}
693
694[build-dependencies]
695cc = "1"
696"#,
697        ws = workspace_root.display()
698    );
699    std::fs::write(harness_dir.join("Cargo.toml"), cargo_toml)?;
700
701    // Write build.rs — tells Cargo to link our object file.
702    let obj_abs = std::fs::canonicalize(&obj_path)?;
703    let build_rs = format!(
704        r#"fn main() {{
705    // Link the AOT-compiled object file.
706    println!("cargo:rustc-link-arg={obj}");
707    println!("cargo:rerun-if-changed={obj}");
708}}"#,
709        obj = obj_abs.display()
710    );
711    std::fs::write(harness_dir.join("build.rs"), build_rs)?;
712
713    // Write the interpreted preamble source (if any).
714    let has_preamble = !interpreted_source.is_empty();
715    if has_preamble {
716        std::fs::write(harness_dir.join("src/preamble.cljrs"), interpreted_source)?;
717    }
718
719    // Write bundled dependency sources.
720    for (i, (ns, src)) in bundled_sources.iter().enumerate() {
721        let filename = format!("bundled_{i}.cljrs");
722        std::fs::write(harness_dir.join("src").join(&filename), src)?;
723        eprintln!("[aot] bundled {ns} → src/{filename}");
724    }
725
726    // Generate registration code for bundled sources.
727    let mut bundled_registration = String::new();
728    for (i, (ns, _)) in bundled_sources.iter().enumerate() {
729        bundled_registration.push_str(&format!(
730            "    globals.register_builtin_source(\"{ns}\", \
731             include_str!(\"bundled_{i}.cljrs\"));\n"
732        ));
733    }
734
735    // Write main.rs — calls into the compiled __cljrs_main.
736    let preamble_code = if has_preamble {
737        r#"
738    // Evaluate interpreted preamble (ns, require, defn, defmacro, etc.).
739    let preamble = include_str!("preamble.cljrs");
740    let mut parser = cljrs_reader::Parser::new(preamble.to_string(), "<preamble>".to_string());
741    let forms = parser.parse_all().expect("preamble parse error");
742    for form in &forms {
743        cljrs_eval::eval(form, &mut env).expect("preamble eval error");
744    }
745    // Re-push eval context with updated namespace (ns form may have changed it).
746    cljrs_env::callback::pop_eval_context();
747    cljrs_env::callback::push_eval_context(&env);
748"#
749    } else {
750        ""
751    };
752
753    let main_rs = format!(
754        r#"//! Auto-generated AOT harness for clojurust.
755//!
756//! Initializes the runtime, then calls the compiled `__cljrs_main`.
757
758#![allow(improper_ctypes)]
759
760use cljrs_value::Value;
761
762unsafe extern "C" {{
763    fn __cljrs_main() -> *const Value;
764}}
765
766fn main() {{
767    // Ensure all rt_* symbols are linked into the binary.
768    cljrs_compiler::rt_abi::anchor_rt_symbols();
769
770    // Initialize the standard environment so that rt_call and other
771    // runtime bridge functions can look up builtins.
772    let globals = cljrs_stdlib::standard_env();
773
774    // Register bundled dependency sources so require can find them
775    // without needing source files on disk.
776{bundled}
777    let mut env = cljrs_eval::Env::new(globals, "user");
778
779    // Push an eval context so rt_call can dispatch through the interpreter.
780    cljrs_env::callback::push_eval_context(&env);
781{preamble}
782    // Call the compiled code.
783    let _result = unsafe {{ __cljrs_main() }};
784
785    // Pop the eval context.
786    cljrs_env::callback::pop_eval_context();
787}}
788"#,
789        preamble = preamble_code,
790        bundled = bundled_registration
791    );
792    std::fs::write(harness_dir.join("src/main.rs"), main_rs)?;
793
794    Ok(harness_dir)
795}
796
797/// Build the harness with Cargo and copy the resulting binary to `out_path`.
798fn link_with_cargo(harness_dir: &Path, out_path: &Path) -> AotResult<()> {
799    eprintln!("[aot] building harness with cargo...");
800
801    let output = std::process::Command::new("cargo")
802        .arg("build")
803        .arg("--release")
804        .arg("--offline")
805        .current_dir(harness_dir)
806        .output()?;
807
808    if !output.status.success() {
809        let stderr = String::from_utf8_lossy(&output.stderr);
810        return Err(AotError::Link(format!("cargo build failed:\n{stderr}")));
811    }
812
813    // The binary is at target/release/cljrs-aot-harness.
814    let bin_name = if cfg!(target_os = "windows") {
815        "cljrs-aot-harness.exe"
816    } else {
817        "cljrs-aot-harness"
818    };
819    let built = harness_dir.join("target/release").join(bin_name);
820    std::fs::copy(&built, out_path)?;
821
822    // Clean up the harness directory.
823    let _ = std::fs::remove_dir_all(harness_dir);
824
825    Ok(())
826}
827
828/// Build the harness with Cargo and copy the resulting binary to `out_path`.
829/// Keeps the harness directory for debugging test harnesses.
830fn link_with_cargo_test_harness(harness_dir: &Path, out_path: &Path) -> AotResult<()> {
831    eprintln!("[aot] building harness with cargo...");
832
833    let output = std::process::Command::new("cargo")
834        .arg("build")
835        .arg("--release")
836        .arg("--offline")
837        .current_dir(harness_dir)
838        .output()?;
839
840    if !output.status.success() {
841        let stderr = String::from_utf8_lossy(&output.stderr);
842        return Err(AotError::Link(format!("cargo build failed:\n{stderr}")));
843    }
844
845    // The binary is at target/release/cljrs-aot-harness.
846    let bin_name = if cfg!(target_os = "windows") {
847        "cljrs-aot-harness.exe"
848    } else {
849        "cljrs-aot-harness"
850    };
851    let built = harness_dir.join("target/release").join(bin_name);
852    std::fs::copy(&built, out_path)?;
853
854    // Keep the harness directory for debugging
855    eprintln!("[aot] harness directory kept at {}", harness_dir.display());
856
857    Ok(())
858}
859
860/// Walk up from the current directory to find the workspace root
861/// (the directory containing Cargo.toml with [workspace]).
862fn find_workspace_root() -> AotResult<PathBuf> {
863    let mut dir = std::env::current_dir()?;
864    loop {
865        let cargo_toml = dir.join("Cargo.toml");
866        if cargo_toml.exists() {
867            let contents = std::fs::read_to_string(&cargo_toml)?;
868            if contents.contains("[workspace") {
869                return Ok(dir);
870            }
871        }
872        if !dir.pop() {
873            return Err(AotError::Link(
874                "could not find workspace root (no Cargo.toml with [workspace])".to_string(),
875            ));
876        }
877    }
878}
879
880// ── Test harness generation ─────────────────────────────────────────────────
881
882/// Discover test namespaces from a directory of test files.
883/// Returns a sorted list of namespace names.
884fn discover_test_namespaces(test_dir: &Path, src_dirs: &[PathBuf]) -> AotResult<Vec<String>> {
885    let mut namespaces = Vec::new();
886
887    // First, try to discover from the test directory directly
888    if test_dir.is_dir() {
889        discover_in_dir(test_dir, test_dir, &mut namespaces);
890    }
891
892    // If no tests found in test_dir, also search src_dirs
893    if namespaces.is_empty() {
894        for dir in src_dirs {
895            if dir.is_dir() {
896                discover_in_dir(dir, dir, &mut namespaces);
897            }
898        }
899    }
900
901    namespaces.sort();
902    Ok(namespaces)
903}
904
905/// Discover all namespace names from `.cljc` / `.cljrs` files in the given source paths.
906fn discover_in_dir(root: &Path, dir: &Path, out: &mut Vec<String>) {
907    let Ok(entries) = std::fs::read_dir(dir) else {
908        return;
909    };
910    let mut entries: Vec<_> = entries.filter_map(|e| e.ok()).collect();
911    entries.sort_by_key(|e| e.file_name());
912    for entry in entries {
913        let path = entry.path();
914        if path.is_dir() {
915            discover_in_dir(root, &path, out);
916        } else if let Some(ext) = path.extension()
917            && (ext == "cljc" || ext == "cljrs")
918            && let Some(ns) = file_to_namespace(root, &path)
919        {
920            out.push(ns);
921        }
922    }
923}
924
925/// Convert a file path relative to the source root into a Clojure namespace name.
926/// e.g. `test/clojure/core_test/juxt.cljc` relative to `test/` → `clojure.core-test.juxt`
927fn file_to_namespace(root: &Path, file: &Path) -> Option<String> {
928    let rel = file.strip_prefix(root).ok()?;
929    let stem = rel.with_extension(""); // remove .cljc / .cljrs
930    let ns = stem
931        .to_string_lossy()
932        .replace(std::path::MAIN_SEPARATOR, ".")
933        .replace('_', "-");
934    Some(ns)
935}
936
937/// Generate the Rust test harness code.
938fn generate_test_harness_code(namespaces: &[String], bundled_registration: &str) -> String {
939    let mut code = String::new();
940
941    // Generate the namespace strings array inline
942    let ns_strings: Vec<String> = namespaces
943        .iter()
944        .map(|s| format!("\"{}\".to_string()", s))
945        .collect();
946
947    code.push_str(
948        r#"//! Auto-generated AOT test harness for clojurust.
949//!
950//! Discovers and runs all clojure.test tests in the bundled namespaces.
951
952use cljrs_value::Value;
953
954fn main() {
955    // Initialize the standard environment.
956    let globals = cljrs_stdlib::standard_env();
957
958    // Register bundled dependency sources so require can find them
959    // without needing source files on disk.
960"#,
961    );
962
963    code.push_str(bundled_registration);
964    code.push_str(
965        r#"    let mut env = cljrs_eval::Env::new(globals, "user");
966
967    // Push an eval context so rt_call can dispatch through the interpreter.
968    cljrs_env::callback::push_eval_context(&env);
969
970    // Load clojure.test if not already loaded
971    let _ = cljrs_eval::eval(
972        &cljrs_reader::Parser::new(
973            "(require 'clojure.test)".to_string(),
974            "<test-harness>".to_string()
975        ).parse_all().unwrap()[0],
976        &mut env
977    );
978
979    // Load all test namespaces
980    (|| {
981"#,
982    );
983
984    for ns in namespaces.iter() {
985        code.push_str(&format!(
986            "        let _ = cljrs_eval::eval(&cljrs_reader::Parser::new(\n            \"(require '{})\".to_string(),\n            \"<test-harness>\".to_string()\n        ).parse_all().unwrap()[0], &mut env);\n",
987            ns
988        ));
989    }
990
991    code.push_str(
992        r#"    })();
993
994    // Run tests for each namespace separately
995    let mut total_pass = 0i64;
996    let mut total_fail = 0i64;
997    let mut total_error = 0i64;
998    let mut total_test_count = 0i64;
999
1000    for ns_str in vec![
1001"#,
1002    );
1003
1004    for ns_str in ns_strings.iter() {
1005        code.push_str(&format!("        {},\n", ns_str));
1006    }
1007
1008    code.push_str(r#"    ].iter() {
1009        let run_result = cljrs_eval::eval(
1010            &cljrs_reader::Parser::new(
1011                format!("(clojure.test/run-tests '{})", ns_str)
1012                    .to_string(),
1013                "<run-tests>".to_string()
1014            ).parse_all().unwrap()[0],
1015            &mut env
1016        );
1017        if let Ok(Value::Map(m)) = run_result {
1018            let mut pass = 0i64;
1019            let mut fail = 0i64;
1020            let mut error = 0i64;
1021            let mut test_count = 0i64;
1022            m.for_each(|k, v| {
1023                if let (Value::Keyword(kw), Value::Long(count)) = (k, v) {
1024                    match kw.get().name.as_ref() {
1025                        "pass" => pass = *count,
1026                        "fail" => fail = *count,
1027                        "error" => error = *count,
1028                        "test" => test_count = *count,
1029                        _ => {}
1030                    }
1031                }
1032            });
1033            total_pass += pass;
1034            total_fail += fail;
1035            total_error += error;
1036            total_test_count += test_count;
1037        }
1038    }
1039
1040    // Flush output before exiting
1041    std::io::Write::flush(&mut std::io::stdout()).unwrap();
1042    println!("Ran {} tests containing {} assertions.", total_test_count, total_pass + total_fail + total_error);
1043    std::io::Write::flush(&mut std::io::stdout()).unwrap();
1044    println!("{} passed, {} failed, {} errors.", total_pass, total_fail, total_error);
1045    std::io::Write::flush(&mut std::io::stdout()).unwrap();
1046    if total_fail > 0 || total_error > 0 {
1047        std::process::exit(1);
1048    }
1049
1050    // Pop the eval context.
1051    cljrs_env::callback::pop_eval_context();
1052}"#);
1053
1054    code
1055}
1056
1057/// Compile a directory of test files to a standalone native binary.
1058/// The resulting binary will discover and run all clojure.test tests found.
1059pub fn compile_test_harness(
1060    test_dir: &Path,
1061    out_path: &Path,
1062    src_dirs: &[PathBuf],
1063) -> AotResult<()> {
1064    eprintln!("[aot] discovering tests in {}", test_dir.display());
1065
1066    // Discover test namespaces
1067    let test_namespaces = discover_test_namespaces(test_dir, src_dirs)?;
1068    if test_namespaces.is_empty() {
1069        return Err(AotError::Eval(format!(
1070            "No test files found in {}",
1071            test_dir.display()
1072        )));
1073    }
1074    eprintln!(
1075        "[aot] discovered {} test namespace(s)",
1076        test_namespaces.len()
1077    );
1078
1079    // Also discover source namespaces from src_dirs so they get bundled
1080    let mut src_namespaces = Vec::new();
1081    for dir in src_dirs {
1082        if dir.is_dir() {
1083            discover_in_dir(dir, dir, &mut src_namespaces);
1084        }
1085    }
1086    src_namespaces.sort();
1087    eprintln!(
1088        "[aot] discovered {} source namespace(s)",
1089        src_namespaces.len()
1090    );
1091
1092    // Combine: source namespaces first (so they're registered before tests require them),
1093    // then test namespaces. Deduplicate in case of overlap.
1094    let mut all_namespaces = Vec::new();
1095    let mut seen = std::collections::HashSet::new();
1096    for ns in src_namespaces.iter().chain(test_namespaces.iter()) {
1097        if seen.insert(ns.clone()) {
1098            all_namespaces.push(ns.clone());
1099        }
1100    }
1101
1102    // Generate registration code for bundled sources
1103    let mut bundled_registration = String::new();
1104    for (i, ns) in all_namespaces.iter().enumerate() {
1105        bundled_registration.push_str(&format!(
1106            "    globals.register_builtin_source(\"{ns}\", include_str!(\"bundled_{i}.cljrs\"));\n"
1107        ));
1108    }
1109
1110    // Create the harness directory
1111    let harness_dir = out_path
1112        .parent()
1113        .unwrap_or(Path::new("."))
1114        .join(".cljrs-aot-test-harness");
1115
1116    // Clean any previous harness.
1117    if harness_dir.exists() {
1118        std::fs::remove_dir_all(&harness_dir)?;
1119    }
1120    std::fs::create_dir_all(harness_dir.join("src"))?;
1121
1122    // Generate the main.rs file (only test namespaces get run-tests called)
1123    let main_rs = generate_test_harness_code(&test_namespaces, &bundled_registration);
1124    std::fs::write(harness_dir.join("src/main.rs"), &main_rs)?;
1125
1126    // Write all namespace sources for bundling
1127    // Include test_dir as a search path for test sources
1128    let mut search_dirs = src_dirs.to_vec();
1129    search_dirs.push(test_dir.to_path_buf());
1130
1131    for (i, ns) in all_namespaces.iter().enumerate() {
1132        let rel_path = ns.replace('.', "/").replace('-', "_");
1133        if let Some(src) = find_user_source(&rel_path, &search_dirs) {
1134            std::fs::write(
1135                harness_dir.join("src").join(format!("bundled_{i}.cljrs")),
1136                &src,
1137            )?;
1138            eprintln!("[aot] bundled {ns} → src/bundled_{i}.cljrs");
1139        } else {
1140            return Err(AotError::Eval(format!(
1141                "Could not find source for namespace {ns}"
1142            )));
1143        }
1144    }
1145
1146    // Write Cargo.toml
1147    let workspace_root = find_workspace_root()?;
1148    let cargo_toml = format!(
1149        r#"[package]
1150name = "cljrs-aot-harness"
1151version = "0.1.0"
1152edition = "2021"
1153
1154[workspace]
1155
1156[dependencies]
1157cljrs-types    = {{ path = "{ws}/crates/cljrs-types" }}
1158cljrs-gc       = {{ path = "{ws}/crates/cljrs-gc" }}
1159cljrs-value    = {{ path = "{ws}/crates/cljrs-value" }}
1160cljrs-reader   = {{ path = "{ws}/crates/cljrs-reader" }}
1161cljrs-env      = {{ path = "{ws}/crates/cljrs-env" }}
1162cljrs-eval     = {{ path = "{ws}/crates/cljrs-eval" }}
1163cljrs-stdlib   = {{ path = "{ws}/crates/cljrs-stdlib" }}
1164cljrs-compiler = {{ path = "{ws}/crates/cljrs-compiler" }}
1165"#,
1166        ws = workspace_root.display()
1167    );
1168    std::fs::write(harness_dir.join("Cargo.toml"), cargo_toml)?;
1169
1170    // Write build.rs - minimal, no object file linking needed
1171    let build_rs = r#"fn main() {
1172    // No special linking needed for test harness
1173}"#;
1174    std::fs::write(harness_dir.join("build.rs"), build_rs)?;
1175
1176    // Build with cargo
1177    link_with_cargo_test_harness(&harness_dir, out_path)?;
1178
1179    eprintln!("[aot] wrote {}", out_path.display());
1180    Ok(())
1181}