(ns cljrs.compiler.optimize
(:require [cljrs.compiler.escape :as escape]
[cljrs.compiler.ir :as ir]))
;; ── Region allocation optimization ──────────────────────────────────────────
;;
;; Rewrites non-escaping allocations into region-backed allocations. An
;; allocation can be region-allocated only if there exists a contiguous
;; subgraph of the CFG within which the allocation is born, used, and
;; never reachable from blocks that may unwind out of it. We implement
;; that by computing, for each `:no-escape` allocation:
;;
;; start = lca-of-doms({def-block} ∪ use-blocks)
;; end = lca-of-postdoms({def-block} ∪ use-blocks)
;;
;; If the path `start..end` is acyclic (no back-edges) and contains no
;; throw-style terminators, we wrap it in a region: `RegionStart` at the
;; head of `start`, `RegionEnd` at the head of `end`, and the alloc
;; instruction is rewritten in place to `RegionAlloc`.
;;
;; Limitations of this first pass (deliberate; revisit later):
;;
;; - one region per allocation (no grouping of allocations sharing a
;; start/end pair)
;; - allocations whose region would span a back-edge are skipped
;; (loops not optimized)
;; - allocations whose region would contain a `:throw` or
;; `:unreachable` terminator are skipped (no unwind-aware cleanup)
;; - if `start` or `end` cannot be determined, the allocation is left
;; as a regular `Alloc*`
(def alloc-op-to-region-kind
{:alloc-vector :vector
:alloc-map :map
:alloc-set :set
:alloc-list :list
:alloc-cons :cons})
(defn alloc-operands
"Extract the operand VarIds from an allocation instruction."
[inst]
(case (:op inst)
(:alloc-vector :alloc-set :alloc-list) (:elems inst)
:alloc-map (reduce (fn [acc [k v]] (conj (conj acc k) v)) [] (:pairs inst))
:alloc-cons [(:head inst) (:tail inst)]
[]))
;; ── CFG analysis ────────────────────────────────────────────────────────────
(defn block-successors
"Return the successor block IDs for a block's terminator."
[block]
(let [t (:terminator block)]
(case (:op t)
:jump [(:target t)]
:branch [(:then-block t) (:else-block t)]
:recur-jump [(:target t)]
;; :return / :throw / :unreachable / nil have no successors
[])))
(defn block-by-id-map
[ir-func]
(reduce (fn [m b] (assoc m (:id b) b)) {} (:blocks ir-func)))
(defn predecessor-map
"Return {block-id #{pred-block-ids}} for the IR function."
[ir-func]
(reduce (fn [acc block]
(reduce (fn [a succ]
(update a succ (fn [v] (conj (or v #{}) (:id block)))))
acc
(block-successors block)))
{}
(:blocks ir-func)))
(defn reverse-postorder
"DFS from block 0; return blocks in reverse-postorder. Unreachable
blocks are not included."
[ir-func]
(let [by-id (block-by-id-map ir-func)]
(loop [stack [[0 false]]
visited #{}
postorder []]
(if (empty? stack)
(vec (reverse postorder))
(let [[bid done?] (peek stack)
rest-stack (pop stack)]
(cond
done? (recur rest-stack visited (conj postorder bid))
(contains? visited bid) (recur rest-stack visited postorder)
:else
(let [block (get by-id bid)
succs (if block (block-successors block) [])
new-stack (reduce (fn [s succ] (conj s [succ false]))
(conj rest-stack [bid true])
succs)]
(recur new-stack (conj visited bid) postorder))))))))
;; ── Dominator analysis (iterative fixed-point) ──────────────────────────────
(defn- intersect-sets
"Intersect a non-empty seq of sets without requiring `clojure.set`."
[sets]
(reduce (fn [acc s]
(if (nil? acc)
s
(reduce (fn [a x] (if (contains? s x) a (disj a x))) acc acc)))
nil
sets))
(defn- dom-iterate
"Generic iterative dominator computation.
`roots` — set of block IDs initialised to `#{root}` (their only
dominator); for forward dominators this is `#{0}`.
`block-ids` — all block IDs to consider (typically reverse-postorder).
`pred-fn` — `block-id → #{pred-block-ids}`.
Returns `{block-id #{dominators}}`."
[roots block-ids pred-fn]
(let [universe (set block-ids)
init (reduce (fn [m b]
(assoc m b (if (contains? roots b) #{b} universe)))
{}
block-ids)]
(loop [doms init]
(let [next-doms
(reduce
(fn [d b]
(if (contains? roots b)
d
(let [computed-preds (filter (fn [p] (contains? d p)) (pred-fn b))]
(if (empty? computed-preds)
d
(let [intersected (intersect-sets (map (fn [p] (get d p)) computed-preds))
new-set (conj (or intersected #{}) b)]
(assoc d b new-set))))))
doms
block-ids)]
(if (= next-doms doms)
doms
(recur next-doms))))))
(defn dominators
"Return `{block-id #{dominator-block-ids}}` for the function. Block 0
dominates only itself; every other reachable block has at least block
0 in its dominator set."
[ir-func]
(let [rpo (reverse-postorder ir-func)
preds (predecessor-map ir-func)]
(dom-iterate #{0} rpo (fn [b] (get preds b #{})))))
(defn- collect-exits
"Return the set of terminating block IDs (`:return` / `:throw` /
`:unreachable`)."
[ir-func]
(reduce (fn [acc b]
(let [op (-> b :terminator :op)]
(if (contains? #{:return :throw :unreachable} op)
(conj acc (:id b))
acc)))
#{}
(:blocks ir-func)))
(defn post-dominators
"Return `{block-id #{post-dominator-block-ids}}`. Post-dominator
analysis runs on the reversed CFG. All exit blocks are roots — they
only post-dominate themselves; every reachable block reaches at least
one exit, so its post-dominator set is non-empty."
[ir-func]
(let [rpo (reverse-postorder ir-func)
by-id (block-by-id-map ir-func)
succ-fn (fn [b] (set (block-successors (get by-id b))))
exits (collect-exits ir-func)]
(dom-iterate exits rpo succ-fn)))
;; ── Lowest-common-ancestor in a dominator-tree-shaped relation ──────────────
;;
;; `dom-of` is `{block #{dominators}}`. The LCA of two blocks A and B is
;; the dominator of both that is itself dominated by every other common
;; dominator — equivalently, the largest element of `(∩ dom-of[A]
;; dom-of[B])` under the dominator partial order.
(defn lca-of
"Lowest common ancestor of `a` and `b` in the dominator-tree induced
by `dom-of`. Returns nil if no common ancestor exists (e.g. one of
the blocks is unreachable)."
[dom-of a b]
(let [da (get dom-of a #{})
db (get dom-of b #{})
common (reduce (fn [acc x] (if (contains? db x) (conj acc x) acc)) #{} da)]
(when (seq common)
(reduce (fn [best d]
(cond
(nil? best) d
(contains? (get dom-of d) best) d
:else best))
nil
common))))
(defn lca-of-many
"Fold `lca-of` over a non-empty seq of blocks."
[dom-of blocks]
(let [bs (seq blocks)]
(when bs
(reduce (fn [acc b]
(if (nil? acc)
nil
(lca-of dom-of acc b)))
(first bs)
(rest bs)))))
;; ── Reachability between two blocks (avoiding the end block) ───────────────
(defn- blocks-on-path
"Return the set of block IDs reachable from `start` whose paths
terminate at `end`. Includes `start` and `end`. We stop expanding
past `end`, so its successors aren't part of the region. Used to
scan for back-edges and throw blocks within the proposed region."
[ir-func start end]
(let [by-id (block-by-id-map ir-func)]
(loop [stack [start] seen #{}]
(if (empty? stack)
seen
(let [b (peek stack)]
(if (contains? seen b)
(recur (pop stack) seen)
(let [block (get by-id b)
succs (if (or (nil? block) (= b end))
[]
(block-successors block))]
(recur (into (pop stack) succs)
(conj seen b)))))))))
(defn- has-back-edge?
"True if any edge whose head is in `region-blocks` targets a block
that already dominates the source — i.e. the region contains a
loop."
[ir-func region-blocks doms]
(let [by-id (block-by-id-map ir-func)]
(boolean
(some (fn [b]
(some (fn [succ]
(and (contains? region-blocks succ)
(contains? (get doms b #{}) succ)))
(block-successors (get by-id b))))
region-blocks))))
(defn- region-contains-throw?
"True if any block in `region-blocks` could unwind out of the region
without running our `RegionEnd`. We refuse to optimize anything
reachable from a `:throw` instruction or a `:throw` /
`:unreachable` terminator within the region — region cleanup today
has no exception-handling integration, so an unwind would leak the
bump-allocated chunk."
[ir-func region-blocks]
(let [by-id (block-by-id-map ir-func)]
(boolean
(some (fn [b]
(let [block (get by-id b)
term-op (-> block :terminator :op)]
(or (contains? #{:throw :unreachable} term-op)
(some (fn [inst] (= :throw (:op inst))) (:insts block)))))
region-blocks))))
;; ── Use-block collection ───────────────────────────────────────────────────
;;
;; For an allocation, walk the same propagation chain as `classify-escape`
;; — through phi nodes and through the results of escape-style known
;; calls (e.g. `:assoc`) — and collect every block in which the
;; allocation, or any value derived from it, has a use. These blocks
;; bound the region's required lifetime.
(defn- collect-use-blocks
[alloc-var uses ir-func]
(loop [worklist [alloc-var]
visited #{}
use-blocks #{}]
(if (empty? worklist)
use-blocks
(let [current (first worklist)
rest-wl (rest worklist)]
(if (contains? visited current)
(recur rest-wl visited use-blocks)
(let [visited (conj visited current)
use-list (or (get uses current) [])
{:keys [extra-vars new-blocks]}
(reduce
(fn [acc use-info]
(let [kind (:kind use-info)
kt (:type kind)
acc (update acc :new-blocks conj (:block use-info))]
(case kt
:known-call-arg
(if (escape/known-fn-arg-escapes? (:func kind) (:arg-index kind))
(let [call-result (escape/find-call-result current (:func kind) ir-func (:block use-info))]
(if call-result
(update acc :extra-vars conj call-result)
acc))
acc)
:phi-input
(let [block (first (filter (fn [b] (= (:id b) (:block use-info)))
(:blocks ir-func)))]
(if block
(reduce (fn [a phi]
(if (and (= (:op phi) :phi)
(some (fn [[_ v]] (= v current)) (:entries phi)))
(update a :extra-vars conj (:dst phi))
a))
acc
(:phis block))
acc))
acc)))
{:extra-vars [] :new-blocks []}
use-list)]
(recur (into (vec rest-wl) extra-vars)
visited
(into use-blocks new-blocks))))))))
;; ── The pass ───────────────────────────────────────────────────────────────
(defn- next-fresh-var!
[next-var-atom]
(let [v @next-var-atom]
(swap! next-var-atom inc)
v))
(defn- insert-region-start
"Return `block` with `RegionStart` prepended to its instruction list."
[block region-var]
(assoc block :insts (into [(ir/inst-region-start region-var)] (:insts block))))
(defn- insert-region-end
"Return `block` with `RegionEnd` appended to its instruction list.
The end block is the post-dominator of all uses, so any uses living
*inside* the end block (e.g. the `:count` call in min-key's join
block) need to run before cleanup. Appending — i.e. placing
`RegionEnd` immediately before the terminator — is the only safe
placement."
[block region-var]
(assoc block :insts (conj (vec (:insts block)) (ir/inst-region-end region-var))))
(defn- rewrite-alloc-in-block
"Replace the alloc instruction with `:dst alloc-var` in `block` with a
region-allocate counterpart targeting `region-var`."
[block alloc-var region-var]
(assoc block :insts
(mapv (fn [inst]
(if (and (contains? alloc-op-to-region-kind (:op inst))
(= (:dst inst) alloc-var))
(let [kind (get alloc-op-to-region-kind (:op inst))
operands (alloc-operands inst)]
(ir/inst-region-alloc (:dst inst) region-var kind operands))
inst))
(:insts block))))
(defn- update-block-by-id
[ir-func block-id update-fn]
(assoc ir-func :blocks
(mapv (fn [b] (if (= (:id b) block-id) (update-fn b) b))
(:blocks ir-func))))
(defn- emit-region-for-alloc
"Return an updated `ir-func` with one allocation rewritten into a
region. No-op if the safety checks fail or the start/end can't be
determined."
[ir-func alloc-var alloc-block use-blocks doms postdoms next-var-atom]
(let [relevant (conj (set use-blocks) alloc-block)
start (lca-of-many doms relevant)
end (lca-of-many postdoms relevant)]
(cond
(or (nil? start) (nil? end))
ir-func
;; Defining block must be dominated by `start` (start ≤ alloc-block
;; in the dominator tree) — otherwise inserting RegionStart there
;; doesn't actually precede the alloc.
(not (contains? (get doms alloc-block #{}) start))
ir-func
:else
(let [region (blocks-on-path ir-func start end)]
(cond
(has-back-edge? ir-func region doms)
ir-func
(region-contains-throw? ir-func region)
ir-func
:else
(let [region-var (next-fresh-var! next-var-atom)]
(-> ir-func
(update-block-by-id alloc-block
(fn [b] (rewrite-alloc-in-block b alloc-var region-var)))
(update-block-by-id start
(fn [b] (insert-region-start b region-var)))
(update-block-by-id end
(fn [b] (insert-region-end b region-var))))))))))
(defn optimize-regions
"Walk all `:no-escape` allocations and rewrite each into a
region-allocate scoped over the dominator subgraph that covers the
allocation and all its (transitive) uses. Allocations that don't
meet the safety constraints are left as regular allocations.
`ctx` is an inter-procedural analysis context (see
`escape/make-context`) — supplying one lets escape analysis
refine `:unknown-call-arg` uses against per-function summaries
instead of bailing with `:arg-escape`."
([ir-func]
(optimize-regions ir-func {}))
([ir-func ctx]
(let [analysis (escape/analyze ir-func ctx)
states (:states analysis)
uses (:uses analysis)
alloc-blocks (:alloc-blocks analysis)
no-escape-allocs (filterv (fn [[v _]] (= :no-escape (get states v)))
alloc-blocks)]
(if (empty? no-escape-allocs)
ir-func
(let [doms (dominators ir-func)
postdoms (post-dominators ir-func)
next-var-atom (atom (:next-var ir-func))
new-func (reduce (fn [acc-func [alloc-var alloc-block]]
(let [use-blocks (collect-use-blocks alloc-var uses ir-func)]
(emit-region-for-alloc acc-func alloc-var alloc-block
use-blocks doms postdoms
next-var-atom)))
ir-func
no-escape-allocs)]
(assoc new-func :next-var @next-var-atom))))))
(defn- optimize-tree
"Optimize `ir-func` and its subfunctions (recursively), threading
`ctx` through so summary lookups share state across the tree."
[ir-func ctx]
(let [optimized (optimize-regions ir-func ctx)
new-subs (mapv (fn [sub] (optimize-tree sub ctx))
(or (:subfunctions optimized) []))]
(assoc optimized :subfunctions new-subs)))
(defn optimize
"Run all optimization passes on an IR function. Currently only runs
region allocation optimization, recursing into subfunctions and
sharing inter-procedural escape summaries across the whole tree."
[ir-func]
(let [ctx (escape/make-context ir-func)]
(optimize-tree ir-func ctx)))