(ns cljrs.compiler.optimize
(:require [cljrs.compiler.escape :as escape]
[cljrs.compiler.ir :as ir]))
;; ── Region allocation optimization ──────────────────────────────────────────
;;
;; Rewrites non-escaping allocations (alloc-vector, alloc-map, etc.) into
;; region-backed allocations (region-start, region-alloc, region-end).
;;
;; Strategy:
;; - Run escape analysis on each IR function
;; - For each block, identify non-escaping allocations
;; - Group them into a region scope per block
;; - Rewrite alloc instructions to region-alloc instructions
;; - Insert region-start at block entry and region-end before terminator
;; Map from allocation op to RegionAllocKind keyword.
(def alloc-op-to-region-kind
{:alloc-vector :vector
:alloc-map :map
:alloc-set :set
:alloc-list :list
:alloc-cons :cons})
(defn alloc-operands
"Extract the operand VarIds from an allocation instruction."
[inst]
(case (:op inst)
(:alloc-vector :alloc-set :alloc-list) (:elems inst)
:alloc-map (reduce (fn [acc [k v]] (conj (conj acc k) v)) [] (:pairs inst))
:alloc-cons [(:head inst) (:tail inst)]
[]))
(defn rewrite-block
"Rewrite a single block, replacing non-escaping allocs with region allocs.
Returns the rewritten block (or the original if no non-escaping allocs exist)."
[block escape-states next-var-atom]
(let [;; Find non-escaping allocation instructions in this block
non-escaping-insts
(filterv (fn [inst]
(and (contains? alloc-op-to-region-kind (:op inst))
(let [state (get escape-states (:dst inst))]
(= state :no-escape))))
(:insts block))]
(if (empty? non-escaping-insts)
;; No non-escaping allocs — return block unchanged
block
;; Rewrite: wrap non-escaping allocs in a region
(let [region-var (let [v @next-var-atom] (swap! next-var-atom inc) v)
region-start (ir/inst-region-start region-var)
;; Rewrite each instruction
new-insts
(reduce
(fn [acc inst]
(if (and (contains? alloc-op-to-region-kind (:op inst))
(= :no-escape (get escape-states (:dst inst))))
;; Replace with region-alloc
(let [kind (get alloc-op-to-region-kind (:op inst))
operands (alloc-operands inst)]
(conj acc (ir/inst-region-alloc (:dst inst) region-var kind operands)))
;; Keep as-is
(conj acc inst)))
[]
(:insts block))
;; Prepend region-start, append region-end
final-insts (into [region-start] new-insts)
region-end (ir/inst-region-end region-var)
final-insts (conj final-insts region-end)]
(assoc block :insts final-insts)))))
(defn optimize-regions
"Run escape analysis and rewrite non-escaping allocations to use regions.
Takes an IR function data map, returns the optimized IR function data map."
[ir-func]
(let [analysis (escape/analyze ir-func)
states (:states analysis)
;; Count non-escaping allocations
non-escaping-count (count (filterv (fn [[_ s]] (= s :no-escape)) states))]
(if (zero? non-escaping-count)
;; Nothing to optimize
ir-func
;; Rewrite blocks
(let [next-var-atom (atom (:next-var ir-func))
new-blocks (mapv (fn [block]
(rewrite-block block states next-var-atom))
(:blocks ir-func))]
(assoc ir-func
:blocks new-blocks
:next-var @next-var-atom)))))
(defn optimize
"Run all optimization passes on an IR function.
Currently only runs region allocation optimization.
Recursively optimizes subfunctions."
[ir-func]
(let [optimized (optimize-regions ir-func)
new-subs (mapv optimize (or (:subfunctions optimized) []))]
(assoc optimized :subfunctions new-subs)))