Skip to main content

rlx_ir/
reflect.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Reflection over HIR/MIR/LIR — layout and structure without executing.
5//!
6//! Host code can introspect unspecialized templates (Slang front-end / reflection API)
7//! and specialized layouts independently of backend codegen.
8
9use crate::Shape;
10use crate::binding_manifest::BindingManifest;
11use crate::component::ModelComponent;
12use crate::hir::{HirModule, HirNodeId, HirOp};
13use crate::lir::LirModule;
14use crate::mir::MirModule;
15use crate::shape::DimBinding;
16
17/// Introspection of an unspecialized [`HirModule`] (loadModule analogue).
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct HirReflection {
20    pub name: String,
21    pub node_count: usize,
22    pub fusion_policy: String,
23    pub inputs: Vec<(String, Shape)>,
24    pub params: Vec<(String, Shape)>,
25    pub outputs: Vec<Shape>,
26    pub block_labels: Vec<String>,
27}
28
29impl HirReflection {
30    pub fn from_hir(hir: &HirModule) -> Self {
31        let mut inputs = Vec::new();
32        let mut params = Vec::new();
33        let mut block_labels = Vec::new();
34        for node in hir.nodes().iter() {
35            let label = node
36                .name
37                .clone()
38                .unwrap_or_else(|| format!("{:?}", node.op));
39            match &node.op {
40                HirOp::Input { name } => inputs.push((name.clone(), node.shape.clone())),
41                HirOp::Param { name } => params.push((name.clone(), node.shape.clone())),
42                HirOp::LlamaDecoderBlock { .. }
43                | HirOp::SwiGLU
44                | HirOp::Attention { .. }
45                | HirOp::GatedDeltaNet { .. }
46                | HirOp::Qwen35MtpHead { .. } => block_labels.push(label),
47                _ => {}
48            }
49        }
50        let outputs = hir
51            .outputs
52            .iter()
53            .map(|&id| hir.node(id).shape.clone())
54            .collect();
55        HirReflection {
56            name: hir.name.clone(),
57            node_count: hir.nodes().len(),
58            fusion_policy: format!("{:?}", hir.fusion_policy),
59            inputs,
60            params,
61            outputs,
62            block_labels,
63        }
64    }
65}
66
67/// MIR-level summary after HIR lower (specializeType / graph shape probe).
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct MirReflection {
70    pub name: String,
71    pub node_count: usize,
72    pub op_kinds: Vec<(String, usize)>,
73}
74
75impl MirReflection {
76    pub fn from_mir(mir: &MirModule) -> Self {
77        let g = mir.as_graph();
78        let mut counts = std::collections::HashMap::new();
79        for node in g.nodes() {
80            *counts.entry(format!("{:?}", node.op.kind())).or_default() += 1;
81        }
82        let mut op_kinds: Vec<_> = counts.into_iter().collect();
83        op_kinds.sort_by(|a, b| a.0.cmp(&b.0));
84        MirReflection {
85            name: g.name.clone(),
86            node_count: g.nodes().len(),
87            op_kinds,
88        }
89    }
90}
91
92/// Layout reflection from specialized LIR (getTypeLayout / parameter block).
93pub fn layout_from_lir(lir: &LirModule) -> BindingManifest {
94    BindingManifest::from_lir(lir)
95}
96
97/// Layout for a concrete [`ModelComponent`] binding without retaining the graph.
98pub fn layout_for_binding(lir: &LirModule, _component: &ModelComponent) -> BindingManifest {
99    layout_from_lir(lir)
100}
101
102/// Compare template vs specialized manifests (dims / arena may differ).
103#[derive(Debug, Clone, PartialEq, Eq)]
104pub struct ManifestDiff {
105    pub template_arena: usize,
106    pub specialized_arena: usize,
107    pub params_only_in_template: Vec<String>,
108    pub params_only_in_specialized: Vec<String>,
109}
110
111impl ManifestDiff {
112    pub fn compare(template: &BindingManifest, specialized: &BindingManifest) -> Self {
113        let t: std::collections::HashSet<_> = template.param_names().collect();
114        let s: std::collections::HashSet<_> = specialized.param_names().collect();
115        Self {
116            template_arena: template.arena_size,
117            specialized_arena: specialized.arena_size,
118            params_only_in_template: t.difference(&s).map(|x| (*x).to_string()).collect(),
119            params_only_in_specialized: s.difference(&t).map(|x| (*x).to_string()).collect(),
120        }
121    }
122}
123
124/// Block specialization choice (coarse-grained type argument).
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
126pub enum BlockSpecialization {
127    Default,
128    FusedTransformerLayer,
129    UnfusedPrimitives,
130}
131
132/// Record of a specialization decision for tooling (specializeType analogue).
133#[derive(Debug, Clone, PartialEq, Eq)]
134pub struct SpecializeBlockRecord {
135    pub node: HirNodeId,
136    pub label: String,
137    pub choice: BlockSpecialization,
138}
139
140/// Probe which HIR blocks would take a given specialization (static; no MIR mutate).
141pub fn probe_block_specialization(
142    hir: &HirModule,
143    choice: BlockSpecialization,
144) -> Vec<SpecializeBlockRecord> {
145    hir.nodes()
146        .iter()
147        .filter_map(|node| {
148            let fused = matches!(
149                node.op,
150                HirOp::LlamaDecoderBlock { .. } | HirOp::SwiGLU | HirOp::GatedDeltaNet { .. }
151            );
152            if !fused {
153                return None;
154            }
155            let effective = choice;
156            Some(SpecializeBlockRecord {
157                node: node.id,
158                label: node
159                    .name
160                    .clone()
161                    .unwrap_or_else(|| format!("{:?}", node.op)),
162                choice: effective,
163            })
164        })
165        .collect()
166}
167
168/// Binding-only layout probe when only [`DimBinding`] is known (no full compile).
169pub fn symbolic_layout_hint(binding: &DimBinding) -> String {
170    format!("DimBinding({:?})", binding)
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use crate::hir::HirMut;
177    use crate::{DType, HirModule};
178
179    #[test]
180    fn hir_reflection_lists_inputs() {
181        let mut hir = HirModule::new("t");
182        let mut gb = HirMut::new(&mut hir);
183        let _x = gb.input("x", Shape::new(&[1, 4], DType::F32));
184        let _w = gb.param("w", Shape::new(&[4, 2], DType::F32));
185        let r = HirReflection::from_hir(&hir);
186        assert_eq!(r.inputs.len(), 1);
187        assert_eq!(r.params.len(), 1);
188    }
189}