(ns cljrs.compiler.escape)
;; ── Escape analysis on IR data ──────────────────────────────────────────────
;;
;; Operates on plain IR data maps (same structure as ir_convert.rs expects).
;; No Rust types needed — pure Clojure data transformation.
;; ── Effect classification ───────────────────────────────────────────────────
(def alloc-ops
#{:alloc-vector :alloc-map :alloc-set :alloc-list :alloc-cons :alloc-closure})
(defn alloc-inst?
"Is this instruction an allocation?"
[inst]
(contains? alloc-ops (:op inst)))
;; ── Collect allocating instructions ─────────────────────────────────────────
(defn collect-allocs
"Return a map of {alloc-var defining-block-id} for every allocation
instruction in the function.
The defining block is needed by `classify-escape` so that we can reject
allocations whose value flows out of its block via control flow — region
scopes today live for exactly one block (RegionStart at block head,
RegionEnd before the terminator), so any cross-block use would deref
freed memory."
[ir-func]
(reduce (fn [acc block]
(reduce (fn [a inst]
(if (and (alloc-inst? inst) (:dst inst))
(assoc a (:dst inst) (:id block))
a))
acc
(concat (:phis block) (:insts block))))
{}
(:blocks ir-func)))
;; ── Build def-use chains ────────────────────────────────────────────────────
(defn add-use [uses var-id block-id kind]
(update uses var-id (fn [v] (conj (or v []) {:block block-id :kind kind}))))
(defn add-uses-for-inst [uses inst block-id]
(case (:op inst)
:call-known
(reduce (fn [u [i arg]]
(add-use u arg block-id {:type :known-call-arg :func (:func inst) :arg-index i}))
uses
(map-indexed vector (:args inst)))
:call
(let [u (add-use uses (:callee inst) block-id {:type :call-callee})]
(reduce (fn [u2 [i arg]]
(add-use u2 arg block-id {:type :unknown-call-arg :callee (:callee inst) :arg-index i}))
u
(map-indexed vector (:args inst))))
:alloc-closure
(reduce (fn [u cap]
(add-use u cap block-id {:type :closure-capture}))
uses
(:captures inst))
(:alloc-vector :alloc-set :alloc-list)
(reduce (fn [u elem]
(add-use u elem block-id {:type :stored-in-heap}))
uses
(:elems inst))
:alloc-map
(reduce (fn [u [k v]]
(-> u
(add-use k block-id {:type :stored-in-heap})
(add-use v block-id {:type :stored-in-heap})))
uses
(:pairs inst))
:alloc-cons
(-> uses
(add-use (:head inst) block-id {:type :stored-in-heap})
(add-use (:tail inst) block-id {:type :stored-in-heap}))
:def-var
(add-use uses (:value inst) block-id {:type :def-var})
:set!
(-> uses
(add-use (:var inst) block-id {:type :set-bang})
(add-use (:value inst) block-id {:type :set-bang}))
:deref
(add-use uses (:src inst) block-id {:type :deref})
:throw
(add-use uses (:value inst) block-id {:type :throw})
:recur
(reduce (fn [u arg]
(add-use u arg block-id {:type :recur}))
uses
(:args inst))
:phi
(reduce (fn [u [_ var]]
(add-use u var block-id {:type :phi-input}))
uses
(:entries inst))
;; const, load-local, load-global, source-loc — no uses
uses))
(defn add-uses-for-terminator [uses term block-id]
(case (:op term)
:return
(add-use uses (:var term) block-id {:type :return})
:branch
(add-use uses (:cond term) block-id {:type :branch-cond})
:recur-jump
(reduce (fn [u arg]
(add-use u arg block-id {:type :recur}))
uses
(:args term))
;; jump, unreachable — no uses
uses))
(defn build-use-chains
"Build def-use chains: for each VarId, where it is used."
[ir-func]
(reduce (fn [uses block]
(let [block-id (:id block)
u1 (reduce (fn [u inst] (add-uses-for-inst u inst block-id))
uses
(:phis block))
u2 (reduce (fn [u inst] (add-uses-for-inst u inst block-id))
u1
(:insts block))
u3 (add-uses-for-terminator u2 (:terminator block) block-id)]
u3))
{}
(:blocks ir-func)))
;; ── Known function arg escape semantics ─────────────────────────────────────
;; Known functions where no argument escapes — i.e. the function reads
;; its arguments but does not retain a reference to them (or to anything
;; reachable from them) past the call. An alloc passed solely to such
;; functions stays `:no-escape` and remains eligible for region
;; allocation.
;;
;; When extending this set, only add functions whose return value cannot
;; alias the argument or its contents. Predicates and arithmetic are
;; safe because they return a fresh primitive. Pretty-printers
;; (`println`/`pr`/`prn`/`print`) are safe because they format and
;; discard. Lookups like `:get`/`:nth`/`:first` are safe because the
;; element they return is itself an independently-allocated `GcPtr`,
;; not a pointer into the argument's storage — freeing the container
;; does not invalidate elements pulled out of it.
(def non-escaping-fns
#{;; Lookups (return an element, which is independently allocated)
:get :nth :first :count :contains
;; Arithmetic (return fresh primitive)
:+ :- :* :/ :rem
;; Comparison (return bool)
:= :< :> :<= :>=
;; Type predicates (return bool)
:nil? :seq? :vector? :map? :identical?
:number? :string? :keyword? :symbol? :boolean? :int?
;; String formatting (consume, discard)
:str
;; Reads
:deref :atom-deref
;; I/O (format and write to stdout, retain nothing)
:println :pr :prn :print})
(defn known-fn-arg-escapes?
"Does a known function allow argument at arg-index to escape into its return value?"
[func arg-index]
(if (contains? non-escaping-fns func)
false
(case func
(:dissoc :disj) (= arg-index 0)
(:rest :next :seq) (= arg-index 0)
:transient (= arg-index 0)
(:assoc! :conj!) (= arg-index 0)
:persistent! (= arg-index 0)
;; Default: escapes (vector, hash-map, assoc, conj, cons, etc.)
true)))
;; ── Find call result for a use ──────────────────────────────────────────────
(defn find-call-result [used-var known-fn ir-func block-id]
(let [block (first (filter (fn [b] (= (:id b) block-id)) (:blocks ir-func)))]
(when block
(some (fn [inst]
(when (and (= (:op inst) :call-known)
(= (:func inst) known-fn)
(some (fn [a] (= a used-var)) (:args inst)))
(:dst inst)))
(:insts block)))))
(defn find-unknown-call-with-arg
"Locate an `:call` instruction in `block-id` whose `:callee` is
`callee-var` and whose `:args` contain `arg-var`. Returns the
instruction map (so callers can read both `:dst` and the arg list)."
[ir-func callee-var arg-var block-id]
(let [block (first (filter (fn [b] (= (:id b) block-id)) (:blocks ir-func)))]
(when block
(some (fn [inst]
(when (and (= (:op inst) :call)
(= (:callee inst) callee-var)
(some (fn [a] (= a arg-var)) (:args inst)))
inst))
(:insts block)))))
;; ── Inter-procedural summary support ────────────────────────────────────────
;;
;; A function summary maps each parameter index to one of:
;;
;; :no-escape — the parameter is read but no value derived from it
;; leaks past the function body
;; :returns — the parameter (or a value derived from it) flows into
;; the function's return value; the caller can treat the
;; call result as the parameter's downstream owner
;; :escapes — the parameter (or a derived value) is stored in a
;; static container, captured by a long-lived closure,
;; or thrown
;;
;; Summaries let alloc-site analysis treat calls to user-defined
;; functions in the same compilation unit precisely instead of bailing
;; with the conservative `:arg-escape`.
(defn walk-functions
"Flatten an IR function tree into a seq of every function reached
through `:subfunctions` (root first, then depth-first)."
[root]
(cons root (mapcat walk-functions (or (:subfunctions root) []))))
(defn build-defn-map
"Walk the IR tree looking for `(def-var ns name v)` whose `v` was
produced by an `(alloc-closure ...)` in the same block. Returns
`{[ns name] {:arity-fn-names ... :param-counts ... :is-variadic ...}}`
so that callers loading a global Var can resolve the call to a
specific arity body."
[root]
(reduce
(fn [acc f]
(reduce
(fn [a block]
(let [insts (vec (:insts block))
alloc-info (reduce
(fn [m inst]
(if (= (:op inst) :alloc-closure)
(assoc m (:dst inst)
{:arity-fn-names (:arity-fn-names inst)
:param-counts (:param-counts inst)
:is-variadic (:is-variadic inst)})
m))
{}
insts)]
(reduce
(fn [a2 inst]
(if (and (= (:op inst) :def-var)
(contains? alloc-info (:value inst)))
(assoc a2 [(:ns inst) (:name inst)]
(get alloc-info (:value inst)))
a2))
a
insts)))
acc
(:blocks f)))
{}
(walk-functions root)))
(defn build-fn-registry
"Index every function in the IR tree by its `:name` (the
arity-fn-name like `__cljrs_fn_user_foo_42_arity1`)."
[root]
(reduce (fn [acc f]
(if-let [n (:name f)]
(assoc acc n f)
acc))
{}
(walk-functions root)))
(defn build-var-defs
"Map each variable ID in `ir-func` to the instruction that defines
it (alloc, deref, load-global, etc.). Used by call-target
resolution to chase callee var chains."
[ir-func]
(reduce (fn [acc block]
(reduce (fn [a inst]
(if (:dst inst)
(assoc a (:dst inst) inst)
a))
acc
(concat (:phis block) (:insts block))))
{}
(:blocks ir-func)))
(defn pick-arity
"Pick the fixed arity from `info` whose param count matches
`arg-count`. Variadic arities are skipped — packing the rest list
would require knowing where to split, which we don't track."
[info arg-count]
(when info
(let [names (:arity-fn-names info)
counts (:param-counts info)
variadic (:is-variadic info)]
(some (fn [i]
(when (and (= (nth counts i) arg-count)
(not (nth variadic i)))
(nth names i)))
(range (count names))))))
(defn resolve-call-target
"Try to resolve the callee of a `:call` instruction to a concrete
arity-fn-name. Three patterns are recognised:
1. Callee var was produced by an `:alloc-closure` directly
(let-bound function).
2. Callee var was produced by `:load-global` (the IR's
`:load-global` op is already a value-load, not a Var-load —
the typical lowering of `(my-fn ...)` against a top-level
`defn`).
3. Callee var was produced by `:deref` of `:load-var` (the
explicit Var-deref shape, kept for completeness).
In cases 2 and 3 we look up the global's name in `defn-map` to find
the alloc-closure that defined it."
[callee-var arg-count var-defs defn-map]
(let [def-inst (get var-defs callee-var)]
(when def-inst
(case (:op def-inst)
:alloc-closure
(pick-arity {:arity-fn-names (:arity-fn-names def-inst)
:param-counts (:param-counts def-inst)
:is-variadic (:is-variadic def-inst)}
arg-count)
:load-global
(pick-arity (get defn-map [(:ns def-inst) (:name def-inst)])
arg-count)
:deref
(let [src (:src def-inst)
src-def (get var-defs src)]
(when (and src-def
(or (= (:op src-def) :load-global)
(= (:op src-def) :load-var)))
(pick-arity (get defn-map [(:ns src-def) (:name src-def)])
arg-count)))
nil))))
;; The mutual recursion between `classify-escape-with-ctx` (defined
;; below) and per-function summary computation is broken by routing
;; the summary call through `(:summary-fn ctx)`. `make-context`
;; installs `compute-fn-summary` as the summary-fn after both are
;; defined. This avoids a forward declaration that the runtime
;; doesn't currently support.
(defn- max-state
"Lattice join over escape states: :no-escape ⊑ :arg-escape ⊑
:returns ⊑ :escapes. (`:arg-escape` is unused in `:param` mode —
summaries promote unresolvable calls straight to `:escapes`.)"
[a b]
(cond
(or (= a :escapes) (= b :escapes)) :escapes
(or (= a :returns) (= b :returns)) :returns
(or (= a :arg-escape) (= b :arg-escape)) :arg-escape
:else :no-escape))
;; ── Classify escape state ───────────────────────────────────────────────────
(defn classify-escape-with-ctx
"Generalised escape classification.
`mode` is `:alloc` (alloc-site analysis — `:return` is escape) or
`:param` (per-parameter summary computation — `:return` records
`:returns` and analysis continues, since `:returns` is strictly
better than `:escapes` and other uses may still force `:escapes`).
`ctx` may carry inter-procedural lookup tables:
:var-defs — `{var-id defining-instruction}` for `ir-func`
:registry — `{arity-fn-name ir-func}` across the IR tree
:defn-map — `{[ns name] {:arity-fn-names ...}}`
:cache — atom holding `{arity-fn-name summary-vector}`
:computing — atom holding the set of summaries currently in flight
When the tables are present, an `:unknown-call-arg` use whose
callee resolves to a known function consults that function's
summary; otherwise the conservative behaviour of the old single-arg
form applies."
[var uses ir-func ctx mode]
(loop [worklist [var]
visited #{}
result :no-escape]
(if (empty? worklist)
result
(let [current (first worklist)
rest-wl (rest worklist)]
(if (contains? visited current)
(recur rest-wl visited result)
(let [visited (conj visited current)
use-list (get uses current)]
(if (nil? use-list)
(recur rest-wl visited result)
(let [check-result
(reduce
(fn [acc use-info]
(if (= (:state acc) :escapes)
(reduced acc)
(let [kind (:kind use-info)
kt (:type kind)]
(case kt
:return
(if (= mode :param)
(assoc acc :state (max-state (:state acc) :returns))
(reduced (assoc acc :state :escapes)))
(:def-var :set-bang :closure-capture :throw :stored-in-heap :recur)
(reduced (assoc acc :state :escapes))
:unknown-call-arg
(let [callee (:callee kind)
arg-idx (:arg-index kind)
call-inst (find-unknown-call-with-arg
ir-func callee current (:block use-info))
target-name (when (and call-inst (:var-defs ctx))
(resolve-call-target
callee (count (:args call-inst))
(:var-defs ctx) (:defn-map ctx)))
target-fn (when target-name
(get (:registry ctx) target-name))
summary-fn (:summary-fn ctx)]
(if (and target-fn summary-fn)
(let [summary (summary-fn target-fn ctx)
param-state (nth summary arg-idx :escapes)]
(case param-state
:no-escape acc
:returns (update acc :new-worklist conj (:dst call-inst))
:escapes (reduced (assoc acc :state :escapes))))
(if (= mode :param)
(reduced (assoc acc :state :escapes))
(if (= (:state acc) :no-escape)
(assoc acc :state :arg-escape
:callee callee :arg-index arg-idx)
(reduced (assoc acc :state :escapes))))))
:known-call-arg
(if (known-fn-arg-escapes? (:func kind) (:arg-index kind))
(let [call-result (find-call-result current (:func kind) ir-func (:block use-info))]
(if call-result
(update acc :new-worklist conj call-result)
(reduced (assoc acc :state :escapes))))
acc)
:phi-input
(let [block (first (filter (fn [b] (= (:id b) (:block use-info))) (:blocks ir-func)))]
(if block
(reduce (fn [a phi]
(if (and (= (:op phi) :phi)
(some (fn [[_ v]] (= v current)) (:entries phi)))
(update a :new-worklist conj (:dst phi))
a))
acc
(:phis block))
acc))
;; branch-cond, deref, call-callee — don't cause escape
acc))))
{:state result :new-worklist [] :callee nil :arg-index nil}
use-list)]
(if (= (:state check-result) :escapes)
:escapes
(recur (into (vec rest-wl) (:new-worklist check-result))
visited
(:state check-result)))))))))))
(defn classify-escape
"Classify whether an allocation escapes the function. Single-arg
form runs without inter-procedural context — every call to a
user-defined function conservatively yields `:arg-escape`. See
`classify-escape-with-ctx` for the context-aware form."
[var uses ir-func]
(classify-escape-with-ctx var uses ir-func {} :alloc))
(defn compute-fn-summary
"Compute the per-parameter escape summary for `ir-func`. Returns a
vector — one entry per parameter — of `:no-escape`, `:returns`, or
`:escapes`.
Recursion handling: if `(:name ir-func)` is already in the
`:computing` atom, return an all-`:escapes` summary (the
conservative answer, sound for any caller). This loses precision
on directly- or mutually-recursive functions but avoids
non-termination."
[ir-func ctx]
(let [fn-name (:name ir-func)
cache (:cache ctx)
computing (:computing ctx)]
(cond
(and fn-name cache (contains? @cache fn-name))
(get @cache fn-name)
(and fn-name computing (contains? @computing fn-name))
(mapv (fn [_] :escapes) (:params ir-func))
:else
(do
(when (and fn-name computing) (swap! computing conj fn-name))
(let [uses (build-use-chains ir-func)
var-defs (build-var-defs ir-func)
local-ctx (assoc ctx :var-defs var-defs)
param-vars (mapv second (:params ir-func))
summary (mapv (fn [pv]
(classify-escape-with-ctx pv uses ir-func local-ctx :param))
param-vars)]
(when (and fn-name cache) (swap! cache assoc fn-name summary))
(when (and fn-name computing) (swap! computing disj fn-name))
summary)))))
;; ── Public API ──────────────────────────────────────────────────────────────
(defn make-context
"Build an inter-procedural analysis context rooted at `root`. The
resulting map can be threaded through `analyze` and
`classify-escape-with-ctx` to share defn/registry lookups and
summary memoisation across an entire IR tree."
[root]
{:registry (build-fn-registry root)
:defn-map (build-defn-map root)
:cache (atom {})
:computing (atom #{})
:summary-fn compute-fn-summary})
(defn analyze
"Run escape analysis on an IR function (data map).
Returns `{:states {var-id state} :uses {var-id [use-info]}
:alloc-blocks {alloc-var defining-block-id}}`
where `state` is `:no-escape`, `:arg-escape`, or `:escapes`.
The `:alloc-blocks` map exposes where each allocation lives so that
downstream passes (notably `cljrs.compiler.optimize`) can compute
region scopes via the dominator tree.
The two-arg form takes an inter-procedural context (see
`make-context`). The single-arg form analyses `ir-func` in
isolation: every call to a user-defined function conservatively
yields `:arg-escape`."
([ir-func]
(analyze ir-func {}))
([ir-func ctx]
(let [alloc-blocks (collect-allocs ir-func)
uses (build-use-chains ir-func)
var-defs (build-var-defs ir-func)
local-ctx (assoc ctx :var-defs var-defs)
states (reduce (fn [acc alloc-var]
(assoc acc alloc-var
(classify-escape-with-ctx alloc-var uses ir-func local-ctx :alloc)))
{}
(keys alloc-blocks))]
{:states states :uses uses :alloc-blocks alloc-blocks})))
;; ── Collection chain detection ──────────────────────────────────────────────
(def chainable-ops #{:assoc :conj :dissoc :disj})
(defn detect-collection-chains
"Detect assoc/conj chains where intermediate collections don't escape.
Returns a vector of chain maps: {:root var-id :ops [...] :result var-id}."
[ir-func escape-analysis]
(let [uses (:uses escape-analysis)]
(reduce
(fn [chains block]
(let [result
(reduce
(fn [{:keys [current-chain found-chains]} inst]
(if (and (= (:op inst) :call-known)
(contains? chainable-ops (:func inst))
(seq (:args inst)))
(let [collection-arg (first (:args inst))
other-args (vec (rest (:args inst)))]
(if (and current-chain (= (:result current-chain) collection-arg))
;; Try to extend chain
(let [intermediate-uses (get uses (:result current-chain))
single-use? (or (nil? intermediate-uses) (= 1 (count intermediate-uses)))]
(if single-use?
{:current-chain (-> current-chain
(update :ops conj {:func (:func inst)
:result (:dst inst)
:args other-args})
(assoc :result (:dst inst)))
:found-chains found-chains}
;; Can't extend — flush and start new
{:current-chain {:root collection-arg
:ops [{:func (:func inst)
:result (:dst inst)
:args other-args}]
:result (:dst inst)}
:found-chains (if (>= (count (:ops current-chain)) 2)
(conj found-chains current-chain)
found-chains)}))
;; Start new chain
(let [new-chains (if (and current-chain (>= (count (:ops current-chain)) 2))
(conj found-chains current-chain)
found-chains)]
{:current-chain {:root collection-arg
:ops [{:func (:func inst)
:result (:dst inst)
:args other-args}]
:result (:dst inst)}
:found-chains new-chains})))
;; Non-chainable instruction — don't flush
{:current-chain current-chain :found-chains found-chains}))
{:current-chain nil :found-chains []}
(:insts block))
;; Flush at block end
final-chains (if (and (:current-chain result)
(>= (count (:ops (:current-chain result))) 2))
(conj (:found-chains result) (:current-chain result))
(:found-chains result))]
(into chains final-chains)))
[]
(:blocks ir-func))))