Skip to main content

rlx_ir/
reflect.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Reflection over HIR/MIR/LIR — layout and structure without executing.
17//!
18//! Host code can introspect unspecialized templates (Slang front-end / reflection API)
19//! and specialized layouts independently of backend codegen.
20
21use crate::Shape;
22use crate::binding_manifest::BindingManifest;
23use crate::component::ModelComponent;
24use crate::hir::{HirModule, HirNodeId, HirOp};
25use crate::lir::LirModule;
26use crate::mir::MirModule;
27use crate::shape::DimBinding;
28
29/// Introspection of an unspecialized [`HirModule`] (loadModule analogue).
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct HirReflection {
32    pub name: String,
33    pub node_count: usize,
34    pub fusion_policy: String,
35    pub inputs: Vec<(String, Shape)>,
36    pub params: Vec<(String, Shape)>,
37    pub outputs: Vec<Shape>,
38    pub block_labels: Vec<String>,
39}
40
41impl HirReflection {
42    pub fn from_hir(hir: &HirModule) -> Self {
43        let mut inputs = Vec::new();
44        let mut params = Vec::new();
45        let mut block_labels = Vec::new();
46        for node in hir.nodes().iter() {
47            let label = node
48                .name
49                .clone()
50                .unwrap_or_else(|| format!("{:?}", node.op));
51            match &node.op {
52                HirOp::Input { name } => inputs.push((name.clone(), node.shape.clone())),
53                HirOp::Param { name } => params.push((name.clone(), node.shape.clone())),
54                HirOp::LlamaDecoderBlock { .. }
55                | HirOp::SwiGLU
56                | HirOp::Attention { .. }
57                | HirOp::GatedDeltaNet { .. }
58                | HirOp::Qwen35MtpHead { .. } => block_labels.push(label),
59                _ => {}
60            }
61        }
62        let outputs = hir
63            .outputs
64            .iter()
65            .map(|&id| hir.node(id).shape.clone())
66            .collect();
67        HirReflection {
68            name: hir.name.clone(),
69            node_count: hir.nodes().len(),
70            fusion_policy: format!("{:?}", hir.fusion_policy),
71            inputs,
72            params,
73            outputs,
74            block_labels,
75        }
76    }
77}
78
79/// MIR-level summary after HIR lower (specializeType / graph shape probe).
80#[derive(Debug, Clone, PartialEq, Eq)]
81pub struct MirReflection {
82    pub name: String,
83    pub node_count: usize,
84    pub op_kinds: Vec<(String, usize)>,
85}
86
87impl MirReflection {
88    pub fn from_mir(mir: &MirModule) -> Self {
89        let g = mir.as_graph();
90        let mut counts = std::collections::HashMap::new();
91        for node in g.nodes() {
92            *counts.entry(format!("{:?}", node.op.kind())).or_default() += 1;
93        }
94        let mut op_kinds: Vec<_> = counts.into_iter().collect();
95        op_kinds.sort_by(|a, b| a.0.cmp(&b.0));
96        MirReflection {
97            name: g.name.clone(),
98            node_count: g.nodes().len(),
99            op_kinds,
100        }
101    }
102}
103
104/// Layout reflection from specialized LIR (getTypeLayout / parameter block).
105pub fn layout_from_lir(lir: &LirModule) -> BindingManifest {
106    BindingManifest::from_lir(lir)
107}
108
109/// Layout for a concrete [`ModelComponent`] binding without retaining the graph.
110pub fn layout_for_binding(lir: &LirModule, _component: &ModelComponent) -> BindingManifest {
111    layout_from_lir(lir)
112}
113
114/// Compare template vs specialized manifests (dims / arena may differ).
115#[derive(Debug, Clone, PartialEq, Eq)]
116pub struct ManifestDiff {
117    pub template_arena: usize,
118    pub specialized_arena: usize,
119    pub params_only_in_template: Vec<String>,
120    pub params_only_in_specialized: Vec<String>,
121}
122
123impl ManifestDiff {
124    pub fn compare(template: &BindingManifest, specialized: &BindingManifest) -> Self {
125        let t: std::collections::HashSet<_> = template.param_names().collect();
126        let s: std::collections::HashSet<_> = specialized.param_names().collect();
127        Self {
128            template_arena: template.arena_size,
129            specialized_arena: specialized.arena_size,
130            params_only_in_template: t.difference(&s).map(|x| (*x).to_string()).collect(),
131            params_only_in_specialized: s.difference(&t).map(|x| (*x).to_string()).collect(),
132        }
133    }
134}
135
136/// Block specialization choice (coarse-grained type argument).
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum BlockSpecialization {
139    Default,
140    FusedTransformerLayer,
141    UnfusedPrimitives,
142}
143
144/// Record of a specialization decision for tooling (specializeType analogue).
145#[derive(Debug, Clone, PartialEq, Eq)]
146pub struct SpecializeBlockRecord {
147    pub node: HirNodeId,
148    pub label: String,
149    pub choice: BlockSpecialization,
150}
151
152/// Probe which HIR blocks would take a given specialization (static; no MIR mutate).
153pub fn probe_block_specialization(
154    hir: &HirModule,
155    choice: BlockSpecialization,
156) -> Vec<SpecializeBlockRecord> {
157    hir.nodes()
158        .iter()
159        .filter_map(|node| {
160            let fused = matches!(
161                node.op,
162                HirOp::LlamaDecoderBlock { .. } | HirOp::SwiGLU | HirOp::GatedDeltaNet { .. }
163            );
164            if !fused {
165                return None;
166            }
167            let effective = choice;
168            Some(SpecializeBlockRecord {
169                node: node.id,
170                label: node
171                    .name
172                    .clone()
173                    .unwrap_or_else(|| format!("{:?}", node.op)),
174                choice: effective,
175            })
176        })
177        .collect()
178}
179
180/// Binding-only layout probe when only [`DimBinding`] is known (no full compile).
181pub fn symbolic_layout_hint(binding: &DimBinding) -> String {
182    format!("DimBinding({:?})", binding)
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::hir::HirMut;
189    use crate::{DType, HirModule};
190
191    #[test]
192    fn hir_reflection_lists_inputs() {
193        let mut hir = HirModule::new("t");
194        let mut gb = HirMut::new(&mut hir);
195        let _x = gb.input("x", Shape::new(&[1, 4], DType::F32));
196        let _w = gb.param("w", Shape::new(&[4, 2], DType::F32));
197        let r = HirReflection::from_hir(&hir);
198        assert_eq!(r.inputs.len(), 1);
199        assert_eq!(r.params.len(), 1);
200    }
201}