Skip to main content

cljrs_ir/
lib.rs

1//! Intermediate representation for clojurust program analysis and optimization.
2//!
3//! The IR is a control-flow graph of basic blocks containing instructions in
4//! A-normal form (all sub-expressions bound to named temporaries). It supports
5//! SSA construction via phi nodes at join points.
6//!
7//! The IR serves multiple purposes:
8//! 1. **Escape analysis** and optimization hints
9//! 2. **IR interpreter** (Tier 1 execution)
10//! 3. **Cranelift-based JIT/AOT code generation** (Tier 2 execution)
11
12#![allow(clippy::result_large_err)]
13
14pub mod lower;
15
16use cljrs_types::error::CljxError::SerializationError;
17use cljrs_types::error::CljxResult;
18use cljrs_types::span::Span;
19use serde::{Deserialize, Serialize};
20use std::cell::RefCell;
21use std::collections::HashMap;
22use std::fmt;
23use std::fmt::Display;
24use std::sync::Arc;
25
26// Display helpers
27
28thread_local! {
29    static INDENT: RefCell<i16> = const { RefCell::new(0) }
30}
31
32fn indent() -> String {
33    INDENT.with(|indent| (0..*indent.borrow()).map(|_| " ").collect())
34}
35
36fn indent_inc() {
37    INDENT.with(|indent| *indent.borrow_mut() += 2);
38}
39
40fn indent_dec() {
41    INDENT.with(|indent| *indent.borrow_mut() -= 2);
42}
43
44// ── Variable IDs ─────────────────────────────────────────────────────────────
45
46/// A unique variable identifier within an IR function.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
48pub struct VarId(pub u32);
49
50impl Display for VarId {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        write!(f, "v{}", self.0)
53    }
54}
55
56/// A basic block identifier within an IR function.
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
58pub struct BlockId(pub u32);
59
60impl fmt::Display for BlockId {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        write!(f, "bb{}", self.0)
63    }
64}
65
66// ── Known functions ──────────────────────────────────────────────────────────
67
68/// Built-in functions the IR knows about for precise effect tracking.
69///
70/// When the IR can identify a call target as a known function, it uses this
71/// enum instead of a generic `Call` — enabling escape analysis to reason
72/// precisely about argument flow and allocation behavior.
73#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
74pub enum KnownFn {
75    // Collection constructors
76    Vector,
77    HashMap,
78    HashSet,
79    List,
80
81    // Collection operations (return new persistent collection)
82    Assoc,
83    Dissoc,
84    Conj,
85    Disj,
86    Get,
87    Nth,
88    Count,
89    Contains,
90
91    // Transient operations
92    Transient,
93    AssocBang,
94    ConjBang,
95    PersistentBang,
96
97    // Sequence operations
98    First,
99    Rest,
100    Next,
101    Cons,
102    Seq,
103    LazySeq,
104    Peek,
105    Pop,
106    Vec,
107
108    // Arithmetic (pure, no alloc for i64/f64)
109    Add,
110    Sub,
111    Mul,
112    Div,
113    Rem,
114
115    // Comparison (pure)
116    Eq,
117    Lt,
118    Gt,
119    Lte,
120    Gte,
121
122    // Type checks (pure)
123    IsNil,
124    IsSeq,
125    IsVector,
126    IsMap,
127    IsEmpty,
128
129    // String
130    Str,
131
132    // Identity / deref
133    Deref,
134    Identical,
135
136    // I/O and side effects
137    Println,
138    Pr,
139
140    // Atom operations
141    AtomDeref,
142    AtomReset,
143    AtomSwap,
144
145    // Apply
146    Apply,
147
148    // Higher-order functions
149    Reduce2,
150    Reduce3,
151    Map,
152    Filter,
153    Mapv,
154    Filterv,
155    Mapcat,
156    Some,
157    Every,
158    Into,
159    Into3,
160    Repeatedly,
161
162    // More HOFs
163    GroupBy,
164    Partition2,
165    Partition3,
166    Partition4,
167    Frequencies,
168    Keep,
169    Remove,
170    MapIndexed,
171    Zipmap,
172    Juxt,
173    Comp,
174    Partial,
175    Complement,
176
177    // Sequence operations
178    Concat,
179    Range1,
180    Range2,
181    Range3,
182    Take,
183    Drop,
184    Reverse,
185    Sort,
186    SortBy,
187
188    // Collection operations
189    Keys,
190    Vals,
191    Merge,
192    Update,
193    GetIn,
194    AssocIn,
195
196    // Type predicates
197    IsNumber,
198    IsString,
199    IsKeyword,
200    IsSymbol,
201    IsBool,
202    IsInt,
203
204    // Additional I/O
205    Prn,
206    Print,
207
208    // Atom construction
209    Atom,
210
211    // Exception handling
212    TryCatchFinally,
213
214    // Dynamic binding
215    SetBangVar,
216    WithBindings,
217
218    // Output capture
219    WithOutStr,
220}
221
222// ── Effect metadata ──────────────────────────────────────────────────────────
223
224/// Effect classification for IR instructions.
225///
226/// Used by escape analysis and optimization passes to reason about what
227/// side effects an instruction may have.
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub enum Effect {
230    /// No observable side effects; result depends only on inputs.
231    Pure,
232    /// Allocates a new heap object (GC or region).
233    Alloc,
234    /// Reads from a heap object (may observe mutations).
235    HeapRead,
236    /// Writes to a heap object (atoms, volatiles, vars).
237    HeapWrite,
238    /// Performs I/O.
239    IO,
240    /// Calls an unknown function — must assume any effect.
241    UnknownCall,
242}
243
244// ── Constant values ──────────────────────────────────────────────────────────
245
246/// A constant value in the IR. Kept separate from `Value` to avoid requiring
247/// GC allocation for IR analysis.
248#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
249pub enum Const {
250    Nil,
251    Bool(bool),
252    Long(i64),
253    Double(f64),
254    Str(Arc<str>),
255    Keyword(Arc<str>),
256    Symbol(Arc<str>),
257    Char(char),
258}
259
260// ── Instructions ─────────────────────────────────────────────────────────────
261
262/// An IR instruction. Each instruction produces at most one result (the `dst`
263/// field in variants that have one).
264///
265/// Instructions are in A-normal form: all operands are `VarId` references to
266/// previously computed values, never nested expressions.
267#[derive(Debug, Clone, Serialize, Deserialize)]
268pub enum Inst {
269    /// Load a constant value.
270    Const(VarId, Const),
271
272    /// Load a local variable by name (from the interpreter's Env).
273    LoadLocal(VarId, Arc<str>),
274
275    /// Load a global var by namespace-qualified name (returns the dereferenced value).
276    LoadGlobal(VarId, Arc<str>, Arc<str>), // dst, ns, name
277
278    /// Load a global Var object (not its value) — for `set!` and `binding`.
279    LoadVar(VarId, Arc<str>, Arc<str>), // dst, ns, name
280
281    /// Allocate a vector from elements.
282    AllocVector(VarId, Vec<VarId>),
283
284    /// Allocate a map from key-value pairs.
285    AllocMap(VarId, Vec<(VarId, VarId)>),
286
287    /// Allocate a set from elements.
288    AllocSet(VarId, Vec<VarId>),
289
290    /// Allocate a list from elements.
291    AllocList(VarId, Vec<VarId>),
292
293    /// Allocate a cons cell.
294    AllocCons(VarId, VarId, VarId), // dst, head, tail
295
296    /// Allocate a closure, capturing the given variables.
297    AllocClosure(VarId, ClosureTemplate, Vec<VarId>),
298
299    /// Call a known built-in function.
300    CallKnown(VarId, KnownFn, Vec<VarId>),
301
302    /// Call an unknown function value.
303    Call(VarId, VarId, Vec<VarId>), // dst, callee, args
304
305    /// Call a compiled function directly by name (bypasses dynamic dispatch).
306    /// Generated by the direct-call optimization pass when a defn in the same
307    /// compilation unit is called with a matching arity.
308    CallDirect(VarId, Arc<str>, Vec<VarId>), // dst, compiled_fn_name, args
309
310    /// Dereference (@ operator).
311    Deref(VarId, VarId),
312
313    /// Store to a var's root binding (`def`).
314    DefVar(VarId, Arc<str>, Arc<str>, VarId), // dst(=var), ns, name, value
315
316    /// `set!` on a var.
317    SetBang(VarId, VarId), // var, value
318
319    /// Throw an exception.
320    Throw(VarId),
321
322    /// SSA phi node — value depends on which predecessor block we came from.
323    Phi(VarId, Vec<(BlockId, VarId)>),
324
325    /// Recur with new values (in a loop context).
326    Recur(Vec<VarId>),
327
328    /// No-op marker with a source span (for debugging / source mapping).
329    SourceLoc(Span),
330
331    // ── Region allocation nodes ─────────────────────────────────────────
332    /// Begin a region scope.  The `VarId` identifies the region handle,
333    /// used by subsequent `RegionAlloc` instructions.  Paired with
334    /// `RegionEnd`.
335    RegionStart(VarId),
336
337    /// Allocate an object in a region instead of the GC heap.
338    /// `(dst, region_handle, alloc_kind, operands)`.
339    ///
340    /// `alloc_kind` mirrors the collection `Alloc*` instructions but
341    /// produces region-backed `GcPtr`s.
342    RegionAlloc(VarId, VarId, RegionAllocKind, Vec<VarId>),
343
344    /// End a region scope — all region-allocated objects are freed.
345    /// The `VarId` is the region handle from `RegionStart`.
346    RegionEnd(VarId),
347
348    /// Bind a VarId to the region handle inherited from the caller.
349    ///
350    /// Emitted at the entry block of a region-parameterised callee variant
351    /// (produced by stage 4 of the escape-optimisation pipeline).  The bound
352    /// VarId is referenced as the `region` operand of subsequent
353    /// `RegionAlloc` instructions in the callee body, but the actual
354    /// allocation target comes from the thread-local region stack — so the
355    /// runtime can treat this as a placeholder bind to nil.
356    RegionParam(VarId),
357
358    /// Direct call to a region-parameterised callee variant.
359    ///
360    /// `(dst, target_name, args)` — semantically identical to `CallDirect`,
361    /// but signals that the call site has wrapped itself in
362    /// `RegionStart`/`RegionEnd` so that the callee's `RegionAlloc`
363    /// instructions can allocate into the caller's region.  Generated by the
364    /// stage-4 cross-function region-promotion pass.
365    CallWithRegion(VarId, Arc<str>, Vec<VarId>),
366}
367
368/// The kind of object allocated in a region.
369#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
370pub enum RegionAllocKind {
371    /// `[elem ...]` — vector from elements.
372    Vector,
373    /// `{k v ...}` — map from key-value pairs.
374    Map,
375    /// `#{elem ...}` — set from elements.
376    Set,
377    /// `(elem ...)` — list from elements.
378    List,
379    /// `(cons head tail)` — cons cell.
380    Cons,
381}
382
383impl fmt::Display for RegionAllocKind {
384    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
385        match self {
386            Self::Vector => write!(f, "vector"),
387            Self::Map => write!(f, "map"),
388            Self::Set => write!(f, "set"),
389            Self::List => write!(f, "list"),
390            Self::Cons => write!(f, "cons"),
391        }
392    }
393}
394
395/// Template for a closure — the static parts of an `fn*` form.
396#[derive(Debug, Clone, Serialize, Deserialize)]
397pub struct ClosureTemplate {
398    /// Function name (if named).
399    pub name: Option<Arc<str>>,
400    /// Compiled function names for each arity (indices match `param_counts`).
401    pub arity_fn_names: Vec<Arc<str>>,
402    /// Fixed parameter count for each arity (excludes rest param for variadic arities).
403    pub param_counts: Vec<usize>,
404    /// Whether each arity is variadic (has a `& rest` parameter).
405    /// Variadic arities accept `param_counts[i]` or more arguments; extra args
406    /// are packed into a list for the rest parameter.
407    pub is_variadic: Vec<bool>,
408    /// Names of the captured variables (in order).
409    pub capture_names: Vec<Arc<str>>,
410}
411
412// ── Terminators ──────────────────────────────────────────────────────────────
413
414/// A block terminator — controls flow between basic blocks.
415#[derive(Debug, Clone, Serialize, Deserialize)]
416pub enum Terminator {
417    /// Unconditional jump.
418    Jump(BlockId),
419
420    /// Conditional branch.
421    Branch {
422        cond: VarId,
423        then_block: BlockId,
424        else_block: BlockId,
425    },
426
427    /// Return a value from the function.
428    Return(VarId),
429
430    /// Recur (tail-call back to loop header).
431    RecurJump { target: BlockId, args: Vec<VarId> },
432
433    /// Unreachable (e.g., after a `throw`).
434    Unreachable,
435}
436
437// ── Basic blocks and functions ───────────────────────────────────────────────
438
439/// A basic block: a linear sequence of instructions followed by a terminator.
440#[derive(Debug, Clone, Serialize, Deserialize)]
441pub struct Block {
442    pub id: BlockId,
443    /// Phi nodes at the top of this block (only at join points).
444    pub phis: Vec<Inst>,
445    /// Non-phi instructions, in order.
446    pub insts: Vec<Inst>,
447    /// How this block transfers control.
448    pub terminator: Terminator,
449}
450
451/// An IR function — the unit of analysis.
452#[derive(Debug, Clone, Serialize, Deserialize)]
453pub struct IrFunction {
454    /// Function name (for diagnostics).
455    pub name: Option<Arc<str>>,
456    /// Parameters (mapped to VarIds).
457    pub params: Vec<(Arc<str>, VarId)>,
458    /// All basic blocks. `blocks[0]` is the entry block.
459    pub blocks: Vec<Block>,
460    /// Next VarId to allocate.
461    pub next_var: u32,
462    /// Next BlockId to allocate.
463    pub next_block: u32,
464    /// Source span of the original function definition.
465    pub span: Option<Span>,
466    /// Nested function bodies (from `fn*` forms), each compiled separately.
467    pub subfunctions: Vec<IrFunction>,
468}
469
470impl IrFunction {
471    /// Create a new empty IR function.
472    pub fn new(name: Option<Arc<str>>, span: Option<Span>) -> Self {
473        Self {
474            name,
475            params: Vec::new(),
476            blocks: Vec::new(),
477            next_var: 0,
478            next_block: 0,
479            span,
480            subfunctions: Vec::new(),
481        }
482    }
483
484    /// Allocate a fresh variable ID.
485    pub fn fresh_var(&mut self) -> VarId {
486        let id = VarId(self.next_var);
487        self.next_var += 1;
488        id
489    }
490
491    /// Allocate a fresh block ID.
492    pub fn fresh_block(&mut self) -> BlockId {
493        let id = BlockId(self.next_block);
494        self.next_block += 1;
495        id
496    }
497
498    /// Build a block index: `block_id.0` → index in `self.blocks`.
499    ///
500    /// If block IDs are dense and match array indices (the common case from
501    /// the compiler), returns `None` — callers can use `block_id.0 as usize`
502    /// directly.  Otherwise returns a lookup table.
503    pub fn block_index(&self) -> Option<Vec<usize>> {
504        // Check if block IDs are dense and sequential (0, 1, 2, ...).
505        let is_identity = self
506            .blocks
507            .iter()
508            .enumerate()
509            .all(|(i, b)| b.id.0 as usize == i);
510        if is_identity {
511            return None; // Use block_id.0 directly as index.
512        }
513        // Sparse case: build a lookup table.
514        let max_id = self.blocks.iter().map(|b| b.id.0).max().unwrap_or(0);
515        let mut table = vec![0usize; max_id as usize + 1];
516        for (i, b) in self.blocks.iter().enumerate() {
517            table[b.id.0 as usize] = i;
518        }
519        Some(table)
520    }
521
522    pub fn serialize(&self) -> CljxResult<Vec<u8>> {
523        postcard::to_allocvec(self).map_err(|e| SerializationError {
524            message: e.to_string(),
525        })
526    }
527
528    pub fn deserialize(bytes: &[u8]) -> CljxResult<Self> {
529        postcard::from_bytes(bytes).map_err(|e| SerializationError {
530            message: e.to_string(),
531        })
532    }
533}
534
535// ── IR Bundle ───────────────────────────────────────────────────────────────
536
537/// A bundle of pre-lowered IR functions, keyed by a string identifier.
538///
539/// Used to serialize multiple functions (e.g. an entire namespace) into a
540/// single blob that can be loaded at startup without running the Clojure
541/// compiler.
542#[derive(Debug, Serialize, Deserialize)]
543pub struct IrBundle {
544    /// Bundle entries keyed by identifier (typically `"ns/name:arity"` or
545    /// the arity ID as a string).
546    pub functions: HashMap<String, IrFunction>,
547}
548
549impl Display for IrBundle {
550    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
551        f.write_str("IrBundle {\n")?;
552        f.write_str("  functions: {\n")?;
553        indent_inc();
554        self.functions.iter().try_for_each(|(name, function)| {
555            f.write_fmt(format_args!("  \"{}\": {}", name, function))?;
556            Ok(())
557        })?;
558        indent_dec();
559        f.write_str("  }\n")?;
560        f.write_str("}\n")?;
561        Ok(())
562    }
563}
564
565impl IrBundle {
566    pub fn new() -> Self {
567        Self {
568            functions: HashMap::new(),
569        }
570    }
571
572    /// Insert a function into the bundle.
573    pub fn insert(&mut self, key: String, func: IrFunction) {
574        self.functions.insert(key, func);
575    }
576
577    /// Look up a function by key.
578    pub fn get(&self, key: &str) -> Option<&IrFunction> {
579        self.functions.get(key)
580    }
581
582    /// Number of functions in the bundle.
583    pub fn len(&self) -> usize {
584        self.functions.len()
585    }
586
587    /// Whether the bundle is empty.
588    pub fn is_empty(&self) -> bool {
589        self.functions.is_empty()
590    }
591}
592
593impl Default for IrBundle {
594    fn default() -> Self {
595        Self::new()
596    }
597}
598
599/// Serialize an [`IrBundle`] to bytes.
600pub fn serialize_bundle(bundle: &IrBundle) -> CljxResult<Vec<u8>> {
601    postcard::to_allocvec(bundle).map_err(|e| SerializationError {
602        message: e.to_string(),
603    })
604}
605
606/// Deserialize an [`IrBundle`] from bytes.
607pub fn deserialize_bundle(bytes: &[u8]) -> CljxResult<IrBundle> {
608    postcard::from_bytes(bytes).map_err(|e| SerializationError {
609        message: e.to_string(),
610    })
611}
612
613// ── Effect classification ────────────────────────────────────────────────────
614
615impl Inst {
616    /// Return the primary effect of this instruction.
617    pub fn effect(&self) -> Effect {
618        match self {
619            Inst::Const(..) | Inst::LoadLocal(..) | Inst::Phi(..) | Inst::SourceLoc(..) => {
620                Effect::Pure
621            }
622            Inst::LoadGlobal(..) | Inst::LoadVar(..) => Effect::HeapRead,
623            Inst::AllocVector(..)
624            | Inst::AllocMap(..)
625            | Inst::AllocSet(..)
626            | Inst::AllocList(..)
627            | Inst::AllocCons(..)
628            | Inst::AllocClosure(..) => Effect::Alloc,
629            Inst::CallKnown(_, known, _) => known.effect(),
630            Inst::Call(..) | Inst::CallDirect(..) => Effect::UnknownCall,
631            Inst::Deref(..) => Effect::HeapRead,
632            Inst::DefVar(..) => Effect::HeapWrite,
633            Inst::SetBang(..) => Effect::HeapWrite,
634            Inst::Throw(..) => Effect::UnknownCall, // conservative
635            Inst::Recur(..) => Effect::Pure,
636            Inst::RegionStart(..) | Inst::RegionEnd(..) => Effect::Alloc,
637            Inst::RegionAlloc(..) => Effect::Alloc,
638            Inst::RegionParam(..) => Effect::Pure,
639            Inst::CallWithRegion(..) => Effect::UnknownCall,
640        }
641    }
642
643    /// Return the destination VarId, if this instruction produces one.
644    pub fn dst(&self) -> Option<VarId> {
645        match self {
646            Inst::Const(v, _)
647            | Inst::LoadLocal(v, _)
648            | Inst::LoadGlobal(v, _, _)
649            | Inst::LoadVar(v, _, _)
650            | Inst::AllocVector(v, _)
651            | Inst::AllocMap(v, _)
652            | Inst::AllocSet(v, _)
653            | Inst::AllocList(v, _)
654            | Inst::AllocCons(v, _, _)
655            | Inst::AllocClosure(v, _, _)
656            | Inst::CallKnown(v, _, _)
657            | Inst::Call(v, _, _)
658            | Inst::CallDirect(v, _, _)
659            | Inst::Deref(v, _)
660            | Inst::DefVar(v, _, _, _)
661            | Inst::Phi(v, _)
662            | Inst::RegionStart(v)
663            | Inst::RegionAlloc(v, _, _, _)
664            | Inst::RegionParam(v)
665            | Inst::CallWithRegion(v, _, _) => Some(*v),
666            Inst::SetBang(..)
667            | Inst::Throw(..)
668            | Inst::Recur(..)
669            | Inst::SourceLoc(..)
670            | Inst::RegionEnd(..) => None,
671        }
672    }
673
674    /// Return all VarIds used (read) by this instruction.
675    pub fn uses(&self) -> Vec<VarId> {
676        match self {
677            Inst::Const(..)
678            | Inst::LoadLocal(..)
679            | Inst::LoadGlobal(..)
680            | Inst::LoadVar(..)
681            | Inst::SourceLoc(..) => vec![],
682            Inst::AllocVector(_, elems) | Inst::AllocSet(_, elems) | Inst::AllocList(_, elems) => {
683                elems.clone()
684            }
685            Inst::AllocMap(_, pairs) => pairs.iter().flat_map(|(k, v)| [*k, *v]).collect(),
686            Inst::AllocCons(_, h, t) => vec![*h, *t],
687            Inst::AllocClosure(_, _, captures) => captures.clone(),
688            Inst::CallKnown(_, _, args) => args.clone(),
689            Inst::Call(_, callee, args) => {
690                let mut v = vec![*callee];
691                v.extend(args);
692                v
693            }
694            Inst::CallDirect(_, _, args) => args.clone(),
695            Inst::Deref(_, src) => vec![*src],
696            Inst::DefVar(_, _, _, val) => vec![*val],
697            Inst::SetBang(var, val) => vec![*var, *val],
698            Inst::Throw(val) => vec![*val],
699            Inst::Phi(_, entries) => entries.iter().map(|(_, v)| *v).collect(),
700            Inst::Recur(args) => args.clone(),
701            Inst::RegionStart(..) => vec![],
702            Inst::RegionAlloc(_, region, _, operands) => {
703                let mut v = vec![*region];
704                v.extend(operands);
705                v
706            }
707            Inst::RegionEnd(region) => vec![*region],
708            Inst::RegionParam(..) => vec![],
709            Inst::CallWithRegion(_, _, args) => args.clone(),
710        }
711    }
712}
713
714impl KnownFn {
715    /// Return the effect of calling this known function.
716    pub fn effect(&self) -> Effect {
717        match self {
718            // Pure functions — no side effects, no allocation (result is scalar or reuses input)
719            KnownFn::Get
720            | KnownFn::Nth
721            | KnownFn::Count
722            | KnownFn::Contains
723            | KnownFn::First
724            | KnownFn::Add
725            | KnownFn::Sub
726            | KnownFn::Mul
727            | KnownFn::Div
728            | KnownFn::Rem
729            | KnownFn::Eq
730            | KnownFn::Lt
731            | KnownFn::Gt
732            | KnownFn::Lte
733            | KnownFn::Gte
734            | KnownFn::IsNil
735            | KnownFn::IsSeq
736            | KnownFn::IsVector
737            | KnownFn::IsMap
738            | KnownFn::IsEmpty
739            | KnownFn::Peek
740            | KnownFn::Identical => Effect::Pure,
741
742            // Allocating — return a new persistent collection
743            KnownFn::Vector
744            | KnownFn::HashMap
745            | KnownFn::HashSet
746            | KnownFn::List
747            | KnownFn::Assoc
748            | KnownFn::Dissoc
749            | KnownFn::Conj
750            | KnownFn::Disj
751            | KnownFn::Cons
752            | KnownFn::Rest
753            | KnownFn::Next
754            | KnownFn::Seq
755            | KnownFn::LazySeq
756            | KnownFn::Pop
757            | KnownFn::Vec
758            | KnownFn::Str
759            | KnownFn::Transient
760            | KnownFn::PersistentBang => Effect::Alloc,
761
762            // Transient mutation — heap write on the transient, but doesn't escape
763            KnownFn::AssocBang | KnownFn::ConjBang => Effect::HeapWrite,
764
765            // Deref reads from heap
766            KnownFn::Deref | KnownFn::AtomDeref => Effect::HeapRead,
767
768            // Atom mutation
769            KnownFn::AtomReset | KnownFn::AtomSwap => Effect::HeapWrite,
770
771            // I/O
772            KnownFn::Println | KnownFn::Pr => Effect::IO,
773
774            // Apply calls an unknown function
775            KnownFn::Apply => Effect::UnknownCall,
776
777            // Sequence operations (allocating)
778            KnownFn::Concat
779            | KnownFn::Range1
780            | KnownFn::Range2
781            | KnownFn::Range3
782            | KnownFn::Take
783            | KnownFn::Drop
784            | KnownFn::Reverse => Effect::Alloc,
785
786            // Sort calls comparator (unknown call)
787            KnownFn::Sort | KnownFn::SortBy => Effect::UnknownCall,
788
789            // Collection operations
790            KnownFn::Keys | KnownFn::Vals => Effect::Alloc,
791            KnownFn::Merge | KnownFn::Update | KnownFn::GetIn | KnownFn::AssocIn => Effect::Alloc,
792
793            // Type predicates
794            KnownFn::IsNumber
795            | KnownFn::IsString
796            | KnownFn::IsKeyword
797            | KnownFn::IsSymbol
798            | KnownFn::IsBool
799            | KnownFn::IsInt => Effect::Pure,
800
801            // Additional I/O
802            KnownFn::Prn | KnownFn::Print => Effect::IO,
803
804            // Atom construction
805            KnownFn::Atom => Effect::Alloc,
806
807            // More HOFs call unknown functions
808            KnownFn::GroupBy
809            | KnownFn::Partition2
810            | KnownFn::Partition3
811            | KnownFn::Partition4
812            | KnownFn::Keep
813            | KnownFn::Remove
814            | KnownFn::MapIndexed => Effect::UnknownCall,
815
816            // Function combinators (return new fns, call unknown fns)
817            KnownFn::Juxt | KnownFn::Comp | KnownFn::Partial | KnownFn::Complement => {
818                Effect::UnknownCall
819            }
820
821            // Pure collection ops
822            KnownFn::Frequencies | KnownFn::Zipmap => Effect::Alloc,
823
824            // HOFs call unknown functions
825            KnownFn::Reduce2
826            | KnownFn::Reduce3
827            | KnownFn::Map
828            | KnownFn::Filter
829            | KnownFn::Mapv
830            | KnownFn::Filterv
831            | KnownFn::Mapcat
832            | KnownFn::Some
833            | KnownFn::Every
834            | KnownFn::Into
835            | KnownFn::Into3
836            | KnownFn::Repeatedly => Effect::UnknownCall,
837
838            // Try/catch calls unknown closures
839            KnownFn::TryCatchFinally => Effect::UnknownCall,
840
841            // Dynamic binding
842            KnownFn::SetBangVar => Effect::HeapWrite,
843            KnownFn::WithBindings | KnownFn::WithOutStr => Effect::UnknownCall,
844        }
845    }
846}
847
848// ── Display for debugging ────────────────────────────────────────────────────
849
850impl fmt::Display for IrFunction {
851    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
852        writeln!(
853            f,
854            "{}fn {}({}):",
855            indent(),
856            self.name.as_deref().unwrap_or("<anon>"),
857            self.params
858                .iter()
859                .map(|(name, id)| format!("{name}: {id}"))
860                .collect::<Vec<_>>()
861                .join(", ")
862        )?;
863        indent_inc();
864        for block in &self.blocks {
865            writeln!(f, "{}  {block}", indent())?;
866        }
867        indent_dec();
868        Ok(())
869    }
870}
871
872impl Display for Block {
873    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
874        writeln!(f, "{}{}:", self.id, indent())?;
875        for phi in &self.phis {
876            writeln!(f, "{}    {phi}", indent())?;
877        }
878        for inst in &self.insts {
879            writeln!(f, "{}    {inst}", indent())?;
880        }
881        write!(f, "{}    {}", indent(), self.terminator)
882    }
883}
884
885impl Display for Inst {
886    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
887        match self {
888            Inst::Const(dst, c) => write!(f, "{dst} = const {c:?}"),
889            Inst::LoadLocal(dst, name) => write!(f, "{dst} = load_local {name:?}"),
890            Inst::LoadGlobal(dst, ns, name) => write!(f, "{dst} = load_global {ns}/{name}"),
891            Inst::LoadVar(dst, ns, name) => write!(f, "{dst} = load_var {ns}/{name}"),
892            Inst::AllocVector(dst, elems) => write!(f, "{dst} = alloc_vec {elems:?}"),
893            Inst::AllocMap(dst, pairs) => write!(f, "{dst} = alloc_map {pairs:?}"),
894            Inst::AllocSet(dst, elems) => write!(f, "{dst} = alloc_set {elems:?}"),
895            Inst::AllocList(dst, elems) => write!(f, "{dst} = alloc_list {elems:?}"),
896            Inst::AllocCons(dst, h, t) => write!(f, "{dst} = cons {h} {t}"),
897            Inst::AllocClosure(dst, tmpl, captures) => {
898                write!(f, "{dst} = closure {:?} captures={captures:?}", tmpl.name)
899            }
900            Inst::CallKnown(dst, func, args) => write!(f, "{dst} = call_known {func:?} {args:?}"),
901            Inst::Call(dst, callee, args) => write!(f, "{dst} = call {callee} {args:?}"),
902            Inst::CallDirect(dst, name, args) => write!(f, "{dst} = call_direct {name} {args:?}"),
903            Inst::Deref(dst, src) => write!(f, "{dst} = deref {src}"),
904            Inst::DefVar(dst, ns, name, val) => write!(f, "{dst} = def {ns}/{name} {val}"),
905            Inst::SetBang(var, val) => write!(f, "set! {var} {val}"),
906            Inst::Throw(val) => write!(f, "throw {val}"),
907            Inst::Phi(dst, entries) => write!(f, "{dst} = phi {entries:?}"),
908            Inst::Recur(args) => write!(f, "recur {args:?}"),
909            Inst::SourceLoc(span) => write!(f, "# {}:{}:{}", span.file, span.line, span.col),
910            Inst::RegionStart(dst) => write!(f, "{dst} = region_start"),
911            Inst::RegionAlloc(dst, region, kind, operands) => {
912                write!(f, "{dst} = region_alloc {region} {kind} {operands:?}")
913            }
914            Inst::RegionEnd(region) => write!(f, "region_end {region}"),
915            Inst::RegionParam(dst) => write!(f, "{dst} = region_param"),
916            Inst::CallWithRegion(dst, name, args) => {
917                write!(f, "{dst} = call_with_region {name} {args:?}")
918            }
919        }
920    }
921}
922
923impl fmt::Display for Terminator {
924    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
925        match self {
926            Terminator::Jump(target) => write!(f, "jump {target}"),
927            Terminator::Branch {
928                cond,
929                then_block,
930                else_block,
931            } => write!(f, "branch {cond} then={then_block} else={else_block}"),
932            Terminator::Return(val) => write!(f, "return {val}"),
933            Terminator::RecurJump { target, args } => {
934                write!(f, "recur_jump {target} {args:?}")
935            }
936            Terminator::Unreachable => write!(f, "unreachable"),
937        }
938    }
939}
940
941// ── Embedded Clojure compiler sources ───────────────────────────────────────
942
943/// Clojure source for the IR builder namespace.
944pub const COMPILER_IR_SOURCE: &str = include_str!("cljrs/compiler/ir.cljrs");
945
946/// Clojure source for the known function resolution namespace.
947pub const COMPILER_KNOWN_SOURCE: &str = include_str!("cljrs/compiler/known.cljrs");
948
949/// Clojure source for the ANF lowering namespace.
950pub const COMPILER_ANF_SOURCE: &str = include_str!("cljrs/compiler/anf.cljrs");
951
952/// Clojure source for the escape analysis namespace.
953pub const COMPILER_ESCAPE_SOURCE: &str = include_str!("cljrs/compiler/escape.cljrs");
954
955/// Clojure source for the optimization pass namespace.
956pub const COMPILER_OPTIMIZE_SOURCE: &str = include_str!("cljrs/compiler/optimize.cljrs");
957
958// ── Tests ───────────────────────────────────────────────────────────────────
959
960#[cfg(test)]
961mod tests {
962    use super::*;
963
964    /// Build a simple IR function for testing: one block that returns a constant.
965    fn make_test_fn(name: &str, const_val: i64) -> IrFunction {
966        let mut f = IrFunction::new(Some(Arc::from(name)), None);
967        let dst = f.fresh_var();
968        let block_id = f.fresh_block();
969        f.blocks.push(Block {
970            id: block_id,
971            phis: vec![],
972            insts: vec![Inst::Const(dst, Const::Long(const_val))],
973            terminator: Terminator::Return(dst),
974        });
975        f
976    }
977
978    #[test]
979    fn test_ir_function_serialize_roundtrip() {
980        let f = make_test_fn("identity", 42);
981        let bytes = f.serialize().unwrap();
982        let f2 = IrFunction::deserialize(&bytes).unwrap();
983        assert_eq!(f2.name.as_deref(), Some("identity"));
984        assert_eq!(f2.blocks.len(), 1);
985        assert_eq!(f2.next_var, 1);
986        match &f2.blocks[0].insts[0] {
987            Inst::Const(_, Const::Long(v)) => assert_eq!(*v, 42),
988            other => panic!("expected Const(Long(42)), got {other:?}"),
989        }
990    }
991
992    #[test]
993    fn test_ir_function_with_closure_template() {
994        let mut f = IrFunction::new(Some(Arc::from("outer")), None);
995        let dst = f.fresh_var();
996        let capture = f.fresh_var();
997        let block_id = f.fresh_block();
998        f.blocks.push(Block {
999            id: block_id,
1000            phis: vec![],
1001            insts: vec![
1002                Inst::Const(capture, Const::Str(Arc::from("hello"))),
1003                Inst::AllocClosure(
1004                    dst,
1005                    ClosureTemplate {
1006                        name: Some(Arc::from("inner")),
1007                        arity_fn_names: vec![Arc::from("inner__0")],
1008                        param_counts: vec![1],
1009                        is_variadic: vec![false],
1010                        capture_names: vec![Arc::from("x")],
1011                    },
1012                    vec![capture],
1013                ),
1014            ],
1015            terminator: Terminator::Return(dst),
1016        });
1017
1018        let bytes = f.serialize().unwrap();
1019        let f2 = IrFunction::deserialize(&bytes).unwrap();
1020        match &f2.blocks[0].insts[1] {
1021            Inst::AllocClosure(_, tmpl, captures) => {
1022                assert_eq!(tmpl.name.as_deref(), Some("inner"));
1023                assert_eq!(tmpl.param_counts, vec![1]);
1024                assert_eq!(tmpl.is_variadic, vec![false]);
1025                assert_eq!(captures.len(), 1);
1026            }
1027            other => panic!("expected AllocClosure, got {other:?}"),
1028        }
1029    }
1030
1031    #[test]
1032    fn test_empty_bundle_roundtrip() {
1033        let bundle = IrBundle::new();
1034        assert!(bundle.is_empty());
1035        let bytes = serialize_bundle(&bundle).unwrap();
1036        let bundle2 = deserialize_bundle(&bytes).unwrap();
1037        assert!(bundle2.is_empty());
1038        assert_eq!(bundle2.len(), 0);
1039    }
1040
1041    #[test]
1042    fn test_bundle_single_function() {
1043        let mut bundle = IrBundle::new();
1044        bundle.insert("clojure.core/inc:1".to_string(), make_test_fn("inc", 1));
1045        assert_eq!(bundle.len(), 1);
1046
1047        let bytes = serialize_bundle(&bundle).unwrap();
1048        let bundle2 = deserialize_bundle(&bytes).unwrap();
1049        assert_eq!(bundle2.len(), 1);
1050
1051        let f = bundle2.get("clojure.core/inc:1").unwrap();
1052        assert_eq!(f.name.as_deref(), Some("inc"));
1053    }
1054
1055    #[test]
1056    fn test_bundle_multiple_functions() {
1057        let mut bundle = IrBundle::new();
1058        bundle.insert("clojure.core/inc:1".to_string(), make_test_fn("inc", 1));
1059        bundle.insert("clojure.core/dec:1".to_string(), make_test_fn("dec", -1));
1060        bundle.insert(
1061            "clojure.core/identity:1".to_string(),
1062            make_test_fn("identity", 0),
1063        );
1064        assert_eq!(bundle.len(), 3);
1065
1066        let bytes = serialize_bundle(&bundle).unwrap();
1067        let bundle2 = deserialize_bundle(&bytes).unwrap();
1068        assert_eq!(bundle2.len(), 3);
1069
1070        assert_eq!(
1071            bundle2.get("clojure.core/inc:1").unwrap().name.as_deref(),
1072            Some("inc")
1073        );
1074        assert_eq!(
1075            bundle2.get("clojure.core/dec:1").unwrap().name.as_deref(),
1076            Some("dec")
1077        );
1078        assert_eq!(
1079            bundle2
1080                .get("clojure.core/identity:1")
1081                .unwrap()
1082                .name
1083                .as_deref(),
1084            Some("identity")
1085        );
1086        assert!(bundle2.get("nonexistent").is_none());
1087    }
1088
1089    #[test]
1090    fn test_bundle_with_complex_ir() {
1091        let mut f = IrFunction::new(Some(Arc::from("complex")), None);
1092        let p0 = f.fresh_var();
1093        let p1 = f.fresh_var();
1094        f.params = vec![(Arc::from("x"), p0), (Arc::from("y"), p1)];
1095
1096        // Entry block: branch on x
1097        let entry = f.fresh_block();
1098        let then_bb = f.fresh_block();
1099        let else_bb = f.fresh_block();
1100        let join_bb = f.fresh_block();
1101
1102        let cond_dst = f.fresh_var();
1103        f.blocks.push(Block {
1104            id: entry,
1105            phis: vec![],
1106            insts: vec![Inst::CallKnown(cond_dst, KnownFn::IsNil, vec![p0])],
1107            terminator: Terminator::Branch {
1108                cond: cond_dst,
1109                then_block: then_bb,
1110                else_block: else_bb,
1111            },
1112        });
1113
1114        // Then block: return y
1115        f.blocks.push(Block {
1116            id: then_bb,
1117            phis: vec![],
1118            insts: vec![],
1119            terminator: Terminator::Jump(join_bb),
1120        });
1121
1122        // Else block: return x
1123        f.blocks.push(Block {
1124            id: else_bb,
1125            phis: vec![],
1126            insts: vec![],
1127            terminator: Terminator::Jump(join_bb),
1128        });
1129
1130        // Join block: phi + return
1131        let phi_dst = f.fresh_var();
1132        f.blocks.push(Block {
1133            id: join_bb,
1134            phis: vec![Inst::Phi(phi_dst, vec![(then_bb, p1), (else_bb, p0)])],
1135            insts: vec![],
1136            terminator: Terminator::Return(phi_dst),
1137        });
1138
1139        let mut bundle = IrBundle::new();
1140        bundle.insert("test/complex:2".to_string(), f);
1141
1142        let bytes = serialize_bundle(&bundle).unwrap();
1143        let bundle2 = deserialize_bundle(&bytes).unwrap();
1144
1145        let f2 = bundle2.get("test/complex:2").unwrap();
1146        assert_eq!(f2.params.len(), 2);
1147        assert_eq!(f2.blocks.len(), 4);
1148
1149        // Verify branch terminator survived roundtrip
1150        match &f2.blocks[0].terminator {
1151            Terminator::Branch {
1152                cond,
1153                then_block,
1154                else_block,
1155            } => {
1156                assert_eq!(*cond, cond_dst);
1157                assert_eq!(*then_block, then_bb);
1158                assert_eq!(*else_block, else_bb);
1159            }
1160            other => panic!("expected Branch, got {other:?}"),
1161        }
1162
1163        // Verify phi survived roundtrip
1164        assert_eq!(f2.blocks[3].phis.len(), 1);
1165        match &f2.blocks[3].phis[0] {
1166            Inst::Phi(dst, entries) => {
1167                assert_eq!(*dst, phi_dst);
1168                assert_eq!(entries.len(), 2);
1169            }
1170            other => panic!("expected Phi, got {other:?}"),
1171        }
1172    }
1173
1174    #[test]
1175    fn test_bundle_with_subfunctions() {
1176        let mut outer = make_test_fn("outer", 100);
1177        let inner = make_test_fn("inner", 200);
1178        outer.subfunctions.push(inner);
1179
1180        let mut bundle = IrBundle::new();
1181        bundle.insert("test/outer:0".to_string(), outer);
1182
1183        let bytes = serialize_bundle(&bundle).unwrap();
1184        let bundle2 = deserialize_bundle(&bytes).unwrap();
1185
1186        let f = bundle2.get("test/outer:0").unwrap();
1187        assert_eq!(f.subfunctions.len(), 1);
1188        assert_eq!(f.subfunctions[0].name.as_deref(), Some("inner"));
1189    }
1190
1191    #[test]
1192    fn test_deserialize_invalid_bytes() {
1193        let result = IrFunction::deserialize(&[0xFF, 0xFE, 0xFD]);
1194        assert!(result.is_err());
1195
1196        let result = deserialize_bundle(&[0xFF, 0xFE, 0xFD]);
1197        assert!(result.is_err());
1198    }
1199}