(ns cljrs.compiler.escape)
;; ── Escape analysis on IR data ──────────────────────────────────────────────
;;
;; Operates on plain IR data maps (same structure as ir_convert.rs expects).
;; No Rust types needed — pure Clojure data transformation.
;; ── Effect classification ───────────────────────────────────────────────────
(def alloc-ops
#{:alloc-vector :alloc-map :alloc-set :alloc-list :alloc-cons :alloc-closure})
(defn alloc-inst?
"Is this instruction an allocation?"
[inst]
(contains? alloc-ops (:op inst)))
;; ── Collect allocating instructions ─────────────────────────────────────────
(defn collect-allocs
"Return a map of {alloc-var defining-block-id} for every allocation
instruction in the function.
The defining block is needed by `classify-escape` so that we can reject
allocations whose value flows out of its block via control flow — region
scopes today live for exactly one block (RegionStart at block head,
RegionEnd before the terminator), so any cross-block use would deref
freed memory."
[ir-func]
(reduce (fn [acc block]
(reduce (fn [a inst]
(if (and (alloc-inst? inst) (:dst inst))
(assoc a (:dst inst) (:id block))
a))
acc
(concat (:phis block) (:insts block))))
{}
(:blocks ir-func)))
;; ── Build def-use chains ────────────────────────────────────────────────────
(defn add-use [uses var-id block-id kind]
(update uses var-id (fn [v] (conj (or v []) {:block block-id :kind kind}))))
(defn add-uses-for-inst [uses inst block-id]
(case (:op inst)
:call-known
(reduce (fn [u [i arg]]
(add-use u arg block-id {:type :known-call-arg :func (:func inst) :arg-index i}))
uses
(map-indexed vector (:args inst)))
:call
(let [u (add-use uses (:callee inst) block-id {:type :call-callee})]
(reduce (fn [u2 [i arg]]
(add-use u2 arg block-id {:type :unknown-call-arg :callee (:callee inst) :arg-index i}))
u
(map-indexed vector (:args inst))))
:alloc-closure
(reduce (fn [u cap]
(add-use u cap block-id {:type :closure-capture}))
uses
(:captures inst))
(:alloc-vector :alloc-set :alloc-list)
(reduce (fn [u elem]
(add-use u elem block-id {:type :stored-in-heap}))
uses
(:elems inst))
:alloc-map
(reduce (fn [u [k v]]
(-> u
(add-use k block-id {:type :stored-in-heap})
(add-use v block-id {:type :stored-in-heap})))
uses
(:pairs inst))
:alloc-cons
(-> uses
(add-use (:head inst) block-id {:type :stored-in-heap})
(add-use (:tail inst) block-id {:type :stored-in-heap}))
:def-var
(add-use uses (:value inst) block-id {:type :def-var})
:set!
(-> uses
(add-use (:var inst) block-id {:type :set-bang})
(add-use (:value inst) block-id {:type :set-bang}))
:deref
(add-use uses (:src inst) block-id {:type :deref})
:throw
(add-use uses (:value inst) block-id {:type :throw})
:recur
(reduce (fn [u arg]
(add-use u arg block-id {:type :recur}))
uses
(:args inst))
:phi
(reduce (fn [u [_ var]]
(add-use u var block-id {:type :phi-input}))
uses
(:entries inst))
;; const, load-local, load-global, source-loc — no uses
uses))
(defn add-uses-for-terminator [uses term block-id]
(case (:op term)
:return
(add-use uses (:var term) block-id {:type :return})
:branch
(add-use uses (:cond term) block-id {:type :branch-cond})
:recur-jump
(reduce (fn [u arg]
(add-use u arg block-id {:type :recur}))
uses
(:args term))
;; jump, unreachable — no uses
uses))
(defn build-use-chains
"Build def-use chains: for each VarId, where it is used."
[ir-func]
(reduce (fn [uses block]
(let [block-id (:id block)
u1 (reduce (fn [u inst] (add-uses-for-inst u inst block-id))
uses
(:phis block))
u2 (reduce (fn [u inst] (add-uses-for-inst u inst block-id))
u1
(:insts block))
u3 (add-uses-for-terminator u2 (:terminator block) block-id)]
u3))
{}
(:blocks ir-func)))
;; ── Known function arg escape semantics ─────────────────────────────────────
;; Known functions where no argument escapes.
(def non-escaping-fns
#{:get :nth :first :count :contains :+ :- :* :/ :rem
:= :< :> :<= :>= :nil? :seq? :vector? :map? :identical?
:str :deref :atom-deref :println :pr})
(defn known-fn-arg-escapes?
"Does a known function allow argument at arg-index to escape into its return value?"
[func arg-index]
(if (contains? non-escaping-fns func)
false
(case func
(:dissoc :disj) (= arg-index 0)
(:rest :next :seq) (= arg-index 0)
:transient (= arg-index 0)
(:assoc! :conj!) (= arg-index 0)
:persistent! (= arg-index 0)
;; Default: escapes (vector, hash-map, assoc, conj, cons, etc.)
true)))
;; ── Find call result for a use ──────────────────────────────────────────────
(defn find-call-result [used-var known-fn ir-func block-id]
(let [block (first (filter (fn [b] (= (:id b) block-id)) (:blocks ir-func)))]
(when block
(some (fn [inst]
(when (and (= (:op inst) :call-known)
(= (:func inst) known-fn)
(some (fn [a] (= a used-var)) (:args inst)))
(:dst inst)))
(:insts block)))))
;; ── Classify escape state ───────────────────────────────────────────────────
(defn classify-escape
"Classify whether an allocation escapes the function.
This is a *semantic* check: it asks whether the allocation's value
leaves the function frame (returned, stored in a Var, captured by a
closure, written into a heap-allocated container, etc.). It does
*not* concern itself with control-flow lifetime — whether the value
is consumed in the same block as the allocation, or flows through
phi nodes / known-fn call results into descendant blocks. That
decision lives in `cljrs.compiler.optimize`, which inspects the
dominator tree to choose a region scope wide enough to cover all
uses.
Returns `:no-escape`, `:arg-escape`, or `:escapes`."
[var uses ir-func]
(loop [worklist [var]
visited #{}
result :no-escape]
(if (empty? worklist)
result
(let [current (first worklist)
rest-wl (rest worklist)]
(if (contains? visited current)
(recur rest-wl visited result)
(let [visited (conj visited current)
use-list (get uses current)]
(if (nil? use-list)
(recur rest-wl visited result)
;; Check all uses of current
(let [check-result
(reduce
(fn [acc use-info]
(if (= (:state acc) :escapes)
(reduced acc)
(let [kind (:kind use-info)
kt (:type kind)]
(case kt
;; Always causes escape
(:return :def-var :set-bang :closure-capture :throw :stored-in-heap :recur)
(reduced (assoc acc :state :escapes))
:unknown-call-arg
(if (= (:state acc) :no-escape)
(assoc acc :state :arg-escape
:callee (:callee kind)
:arg-index (:arg-index kind))
(reduced (assoc acc :state :escapes)))
:known-call-arg
(if (known-fn-arg-escapes? (:func kind) (:arg-index kind))
(let [call-result (find-call-result current (:func kind) ir-func (:block use-info))]
(if call-result
(update acc :new-worklist conj call-result)
(reduced (assoc acc :state :escapes))))
acc)
:phi-input
;; The phi result inherits the alloc's
;; semantic-escape disposition: if the phi
;; result later escapes, so does the alloc.
;; Propagate to the phi's dst.
(let [block (first (filter (fn [b] (= (:id b) (:block use-info))) (:blocks ir-func)))]
(if block
(reduce (fn [a phi]
(if (and (= (:op phi) :phi)
(some (fn [[_ v]] (= v current)) (:entries phi)))
(update a :new-worklist conj (:dst phi))
a))
acc
(:phis block))
acc))
;; branch-cond, deref, call-callee — don't cause escape
acc))))
{:state result :new-worklist [] :callee nil :arg-index nil}
use-list)]
(if (= (:state check-result) :escapes)
:escapes
(recur (into (vec rest-wl) (:new-worklist check-result))
visited
(:state check-result)))))))))))
;; ── Public API ──────────────────────────────────────────────────────────────
(defn analyze
"Run escape analysis on an IR function (data map).
Returns `{:states {var-id state} :uses {var-id [use-info]}
:alloc-blocks {alloc-var defining-block-id}}`
where `state` is `:no-escape`, `:arg-escape`, or `:escapes`.
The `:alloc-blocks` map exposes where each allocation lives so that
downstream passes (notably `cljrs.compiler.optimize`) can compute
region scopes via the dominator tree."
[ir-func]
(let [alloc-blocks (collect-allocs ir-func)
uses (build-use-chains ir-func)
states (reduce (fn [acc alloc-var]
(assoc acc alloc-var
(classify-escape alloc-var uses ir-func)))
{}
(keys alloc-blocks))]
{:states states :uses uses :alloc-blocks alloc-blocks}))
;; ── Collection chain detection ──────────────────────────────────────────────
(def chainable-ops #{:assoc :conj :dissoc :disj})
(defn detect-collection-chains
"Detect assoc/conj chains where intermediate collections don't escape.
Returns a vector of chain maps: {:root var-id :ops [...] :result var-id}."
[ir-func escape-analysis]
(let [uses (:uses escape-analysis)]
(reduce
(fn [chains block]
(let [result
(reduce
(fn [{:keys [current-chain found-chains]} inst]
(if (and (= (:op inst) :call-known)
(contains? chainable-ops (:func inst))
(seq (:args inst)))
(let [collection-arg (first (:args inst))
other-args (vec (rest (:args inst)))]
(if (and current-chain (= (:result current-chain) collection-arg))
;; Try to extend chain
(let [intermediate-uses (get uses (:result current-chain))
single-use? (or (nil? intermediate-uses) (= 1 (count intermediate-uses)))]
(if single-use?
{:current-chain (-> current-chain
(update :ops conj {:func (:func inst)
:result (:dst inst)
:args other-args})
(assoc :result (:dst inst)))
:found-chains found-chains}
;; Can't extend — flush and start new
{:current-chain {:root collection-arg
:ops [{:func (:func inst)
:result (:dst inst)
:args other-args}]
:result (:dst inst)}
:found-chains (if (>= (count (:ops current-chain)) 2)
(conj found-chains current-chain)
found-chains)}))
;; Start new chain
(let [new-chains (if (and current-chain (>= (count (:ops current-chain)) 2))
(conj found-chains current-chain)
found-chains)]
{:current-chain {:root collection-arg
:ops [{:func (:func inst)
:result (:dst inst)
:args other-args}]
:result (:dst inst)}
:found-chains new-chains})))
;; Non-chainable instruction — don't flush
{:current-chain current-chain :found-chains found-chains}))
{:current-chain nil :found-chains []}
(:insts block))
;; Flush at block end
final-chains (if (and (:current-chain result)
(>= (count (:ops (:current-chain result))) 2))
(conj (:found-chains result) (:current-chain result))
(:found-chains result))]
(into chains final-chains)))
[]
(:blocks ir-func))))