cljrs-ir 0.1.28

Intermediate representation types for clojurust compiler and interpreter
Documentation
(ns cljrs.compiler.escape)

;; ── Escape analysis on IR data ──────────────────────────────────────────────
;;
;; Operates on plain IR data maps (same structure as ir_convert.rs expects).
;; No Rust types needed — pure Clojure data transformation.

;; ── Effect classification ───────────────────────────────────────────────────

(def alloc-ops
  #{:alloc-vector :alloc-map :alloc-set :alloc-list :alloc-cons :alloc-closure})

(defn alloc-inst?
  "Is this instruction an allocation?"
  [inst]
  (contains? alloc-ops (:op inst)))

;; ── Collect allocating instructions ─────────────────────────────────────────

(defn collect-allocs
  "Return a map of {alloc-var defining-block-id} for every allocation
   instruction in the function.

   The defining block is needed by `classify-escape` so that we can reject
   allocations whose value flows out of its block via control flow — region
   scopes today live for exactly one block (RegionStart at block head,
   RegionEnd before the terminator), so any cross-block use would deref
   freed memory."
  [ir-func]
  (reduce (fn [acc block]
            (reduce (fn [a inst]
                      (if (and (alloc-inst? inst) (:dst inst))
                        (assoc a (:dst inst) (:id block))
                        a))
                    acc
                    (concat (:phis block) (:insts block))))
          {}
          (:blocks ir-func)))

;; ── Build def-use chains ────────────────────────────────────────────────────

(defn add-use [uses var-id block-id kind]
  (update uses var-id (fn [v] (conj (or v []) {:block block-id :kind kind}))))

(defn add-uses-for-inst [uses inst block-id]
  (case (:op inst)
    :call-known
    (reduce (fn [u [i arg]]
              (add-use u arg block-id {:type :known-call-arg :func (:func inst) :arg-index i}))
            uses
            (map-indexed vector (:args inst)))

    :call
    (let [u (add-use uses (:callee inst) block-id {:type :call-callee})]
      (reduce (fn [u2 [i arg]]
                (add-use u2 arg block-id {:type :unknown-call-arg :callee (:callee inst) :arg-index i}))
              u
              (map-indexed vector (:args inst))))

    :alloc-closure
    (reduce (fn [u cap]
              (add-use u cap block-id {:type :closure-capture}))
            uses
            (:captures inst))

    (:alloc-vector :alloc-set :alloc-list)
    (reduce (fn [u elem]
              (add-use u elem block-id {:type :stored-in-heap}))
            uses
            (:elems inst))

    :alloc-map
    (reduce (fn [u [k v]]
              (-> u
                  (add-use k block-id {:type :stored-in-heap})
                  (add-use v block-id {:type :stored-in-heap})))
            uses
            (:pairs inst))

    :alloc-cons
    (-> uses
        (add-use (:head inst) block-id {:type :stored-in-heap})
        (add-use (:tail inst) block-id {:type :stored-in-heap}))

    :def-var
    (add-use uses (:value inst) block-id {:type :def-var})

    :set!
    (-> uses
        (add-use (:var inst) block-id {:type :set-bang})
        (add-use (:value inst) block-id {:type :set-bang}))

    :deref
    (add-use uses (:src inst) block-id {:type :deref})

    :throw
    (add-use uses (:value inst) block-id {:type :throw})

    :recur
    (reduce (fn [u arg]
              (add-use u arg block-id {:type :recur}))
            uses
            (:args inst))

    :phi
    (reduce (fn [u [_ var]]
              (add-use u var block-id {:type :phi-input}))
            uses
            (:entries inst))

    ;; const, load-local, load-global, source-loc — no uses
    uses))

(defn add-uses-for-terminator [uses term block-id]
  (case (:op term)
    :return
    (add-use uses (:var term) block-id {:type :return})

    :branch
    (add-use uses (:cond term) block-id {:type :branch-cond})

    :recur-jump
    (reduce (fn [u arg]
              (add-use u arg block-id {:type :recur}))
            uses
            (:args term))

    ;; jump, unreachable — no uses
    uses))

(defn build-use-chains
  "Build def-use chains: for each VarId, where it is used."
  [ir-func]
  (reduce (fn [uses block]
            (let [block-id (:id block)
                  u1 (reduce (fn [u inst] (add-uses-for-inst u inst block-id))
                             uses
                             (:phis block))
                  u2 (reduce (fn [u inst] (add-uses-for-inst u inst block-id))
                             u1
                             (:insts block))
                  u3 (add-uses-for-terminator u2 (:terminator block) block-id)]
              u3))
          {}
          (:blocks ir-func)))

;; ── Known function arg escape semantics ─────────────────────────────────────

;; Known functions where no argument escapes.
(def non-escaping-fns
  #{:get :nth :first :count :contains :+ :- :* :/ :rem
    := :< :> :<= :>= :nil? :seq? :vector? :map? :identical?
    :str :deref :atom-deref :println :pr})

(defn known-fn-arg-escapes?
  "Does a known function allow argument at arg-index to escape into its return value?"
  [func arg-index]
  (if (contains? non-escaping-fns func)
    false
    (case func
      (:dissoc :disj) (= arg-index 0)
      (:rest :next :seq) (= arg-index 0)
      :transient (= arg-index 0)
      (:assoc! :conj!) (= arg-index 0)
      :persistent! (= arg-index 0)
      ;; Default: escapes (vector, hash-map, assoc, conj, cons, etc.)
      true)))

;; ── Find call result for a use ──────────────────────────────────────────────

(defn find-call-result [used-var known-fn ir-func block-id]
  (let [block (first (filter (fn [b] (= (:id b) block-id)) (:blocks ir-func)))]
    (when block
      (some (fn [inst]
              (when (and (= (:op inst) :call-known)
                         (= (:func inst) known-fn)
                         (some (fn [a] (= a used-var)) (:args inst)))
                (:dst inst)))
            (:insts block)))))

;; ── Classify escape state ───────────────────────────────────────────────────

(defn classify-escape
  "Classify whether an allocation escapes the function.

   This is a *semantic* check: it asks whether the allocation's value
   leaves the function frame (returned, stored in a Var, captured by a
   closure, written into a heap-allocated container, etc.).  It does
   *not* concern itself with control-flow lifetime — whether the value
   is consumed in the same block as the allocation, or flows through
   phi nodes / known-fn call results into descendant blocks.  That
   decision lives in `cljrs.compiler.optimize`, which inspects the
   dominator tree to choose a region scope wide enough to cover all
   uses.

   Returns `:no-escape`, `:arg-escape`, or `:escapes`."
  [var uses ir-func]
  (loop [worklist [var]
         visited #{}
         result :no-escape]
    (if (empty? worklist)
      result
      (let [current (first worklist)
            rest-wl (rest worklist)]
        (if (contains? visited current)
          (recur rest-wl visited result)
          (let [visited (conj visited current)
                use-list (get uses current)]
            (if (nil? use-list)
              (recur rest-wl visited result)
              ;; Check all uses of current
              (let [check-result
                    (reduce
                     (fn [acc use-info]
                       (if (= (:state acc) :escapes)
                         (reduced acc)
                         (let [kind (:kind use-info)
                               kt (:type kind)]
                           (case kt
                             ;; Always causes escape
                             (:return :def-var :set-bang :closure-capture :throw :stored-in-heap :recur)
                             (reduced (assoc acc :state :escapes))

                             :unknown-call-arg
                             (if (= (:state acc) :no-escape)
                               (assoc acc :state :arg-escape
                                      :callee (:callee kind)
                                      :arg-index (:arg-index kind))
                               (reduced (assoc acc :state :escapes)))

                             :known-call-arg
                             (if (known-fn-arg-escapes? (:func kind) (:arg-index kind))
                               (let [call-result (find-call-result current (:func kind) ir-func (:block use-info))]
                                 (if call-result
                                   (update acc :new-worklist conj call-result)
                                   (reduced (assoc acc :state :escapes))))
                               acc)

                             :phi-input
                             ;; The phi result inherits the alloc's
                             ;; semantic-escape disposition: if the phi
                             ;; result later escapes, so does the alloc.
                             ;; Propagate to the phi's dst.
                             (let [block (first (filter (fn [b] (= (:id b) (:block use-info))) (:blocks ir-func)))]
                               (if block
                                 (reduce (fn [a phi]
                                           (if (and (= (:op phi) :phi)
                                                    (some (fn [[_ v]] (= v current)) (:entries phi)))
                                             (update a :new-worklist conj (:dst phi))
                                             a))
                                         acc
                                         (:phis block))
                                 acc))

                             ;; branch-cond, deref, call-callee — don't cause escape
                             acc))))
                     {:state result :new-worklist [] :callee nil :arg-index nil}
                     use-list)]
                (if (= (:state check-result) :escapes)
                  :escapes
                  (recur (into (vec rest-wl) (:new-worklist check-result))
                         visited
                         (:state check-result)))))))))))

;; ── Public API ──────────────────────────────────────────────────────────────

(defn analyze
  "Run escape analysis on an IR function (data map).

   Returns `{:states {var-id state} :uses {var-id [use-info]}
              :alloc-blocks {alloc-var defining-block-id}}`
   where `state` is `:no-escape`, `:arg-escape`, or `:escapes`.

   The `:alloc-blocks` map exposes where each allocation lives so that
   downstream passes (notably `cljrs.compiler.optimize`) can compute
   region scopes via the dominator tree."
  [ir-func]
  (let [alloc-blocks (collect-allocs ir-func)
        uses (build-use-chains ir-func)
        states (reduce (fn [acc alloc-var]
                         (assoc acc alloc-var
                                (classify-escape alloc-var uses ir-func)))
                       {}
                       (keys alloc-blocks))]
    {:states states :uses uses :alloc-blocks alloc-blocks}))

;; ── Collection chain detection ──────────────────────────────────────────────

(def chainable-ops #{:assoc :conj :dissoc :disj})

(defn detect-collection-chains
  "Detect assoc/conj chains where intermediate collections don't escape.
   Returns a vector of chain maps: {:root var-id :ops [...] :result var-id}."
  [ir-func escape-analysis]
  (let [uses (:uses escape-analysis)]
    (reduce
     (fn [chains block]
       (let [result
             (reduce
              (fn [{:keys [current-chain found-chains]} inst]
                (if (and (= (:op inst) :call-known)
                         (contains? chainable-ops (:func inst))
                         (seq (:args inst)))
                  (let [collection-arg (first (:args inst))
                        other-args (vec (rest (:args inst)))]
                    (if (and current-chain (= (:result current-chain) collection-arg))
                      ;; Try to extend chain
                      (let [intermediate-uses (get uses (:result current-chain))
                            single-use? (or (nil? intermediate-uses) (= 1 (count intermediate-uses)))]
                        (if single-use?
                          {:current-chain (-> current-chain
                                             (update :ops conj {:func (:func inst)
                                                                :result (:dst inst)
                                                                :args other-args})
                                             (assoc :result (:dst inst)))
                           :found-chains found-chains}
                          ;; Can't extend — flush and start new
                          {:current-chain {:root collection-arg
                                          :ops [{:func (:func inst)
                                                 :result (:dst inst)
                                                 :args other-args}]
                                          :result (:dst inst)}
                           :found-chains (if (>= (count (:ops current-chain)) 2)
                                           (conj found-chains current-chain)
                                           found-chains)}))
                      ;; Start new chain
                      (let [new-chains (if (and current-chain (>= (count (:ops current-chain)) 2))
                                         (conj found-chains current-chain)
                                         found-chains)]
                        {:current-chain {:root collection-arg
                                        :ops [{:func (:func inst)
                                               :result (:dst inst)
                                               :args other-args}]
                                        :result (:dst inst)}
                         :found-chains new-chains})))
                  ;; Non-chainable instruction — don't flush
                  {:current-chain current-chain :found-chains found-chains}))
              {:current-chain nil :found-chains []}
              (:insts block))
             ;; Flush at block end
             final-chains (if (and (:current-chain result)
                                   (>= (count (:ops (:current-chain result))) 2))
                            (conj (:found-chains result) (:current-chain result))
                            (:found-chains result))]
         (into chains final-chains)))
     []
     (:blocks ir-func))))