Skip to main content

rlx_runtime/
reflect.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Model reflection services (Slang compiler/runtime API §5).
17//!
18//! Introspect unspecialized templates and specialized layouts across eager/lazy/AOT
19//! while preserving the HIR → MIR → LIR pipeline.
20
21use rlx_ir::hir::HirModule;
22use rlx_ir::{
23    BindingManifest, HirReflection, ManifestDiff, MirReflection, ModelComponent,
24    apply_hir_extensions, layout_from_lir,
25};
26use rlx_opt::CompileResult;
27
28use crate::Device;
29use crate::model_pipeline::ModelCompilePipeline;
30use crate::options::CompileOptions;
31use crate::stages;
32
33/// Loaded template + HIR reflection (front-end load).
34pub struct ModelReflection {
35    pub hir: HirReflection,
36    template: Option<CompileResult>,
37}
38
39impl ModelReflection {
40    /// Build HIR reflection only (no compile).
41    pub fn from_hir(hir: &HirModule) -> Self {
42        Self {
43            hir: HirReflection::from_hir(hir),
44            template: None,
45        }
46    }
47
48    /// Compile symbolic template on `device` and retain for specialize/layout.
49    pub fn load_hir_template(
50        device: Device,
51        hir: HirModule,
52        options: &CompileOptions,
53    ) -> Result<Self, rlx_ir::hir::LowerError> {
54        let mut opts = options.clone();
55        opts.dim_binding = None;
56        let hir_ref = HirReflection::from_hir(&hir);
57        let pipe = stages::pipeline_for(device, &opts);
58        let template = pipe.compile_hir(hir)?;
59        Ok(Self {
60            hir: hir_ref,
61            template: Some(template),
62        })
63    }
64
65    pub fn has_template(&self) -> bool {
66        self.template.is_some()
67    }
68
69    pub fn mir_summary(&self) -> Option<MirReflection> {
70        self.template
71            .as_ref()
72            .map(|t| MirReflection::from_mir(&t.lir.mir))
73    }
74
75    /// Template layout (symbolic dims may be unresolved in arena sizes).
76    pub fn template_layout(&self) -> Option<BindingManifest> {
77        self.template.as_ref().map(|t| layout_from_lir(&t.lir))
78    }
79
80    /// Specialized layout for a [`ModelComponent`] (getTypeLayout after specialize).
81    pub fn layout_for_component(
82        &self,
83        component: &ModelComponent,
84        device: Device,
85        options: &CompileOptions,
86    ) -> Option<BindingManifest> {
87        let template = self.template.as_ref()?;
88        let mut opts = options.clone();
89        opts.dim_binding = None;
90        let pipe = stages::pipeline_for(device, &opts);
91        let specialized = template.specialize(&pipe, &component.dim_binding());
92        Some(layout_from_lir(&specialized.lir))
93    }
94
95    pub fn manifest_diff_for_component(
96        &self,
97        component: &ModelComponent,
98        device: Device,
99        options: &CompileOptions,
100    ) -> Option<ManifestDiff> {
101        let t = self.template_layout()?;
102        let s = self.layout_for_component(component, device, options)?;
103        Some(ManifestDiff::compare(&t, &s))
104    }
105}
106
107/// Full specialize + compile entry (specializeEntryPoint analogue).
108pub fn specialize_entry<'a>(
109    pipeline: &'a mut ModelCompilePipeline,
110    component: &ModelComponent,
111    build_hir: impl FnOnce() -> HirModule,
112    options: &CompileOptions,
113) -> Result<&'a mut crate::CompiledGraph, rlx_ir::hir::LowerError> {
114    let key = component.cache_key();
115    let binding = component.dim_binding();
116    pipeline.get_or_compile(key, &binding, build_hir, options)
117}
118
119/// Apply HIR extensions then load template.
120pub fn load_hir_template_with_extensions(
121    device: Device,
122    mut hir: HirModule,
123    options: &CompileOptions,
124) -> Result<ModelReflection, rlx_ir::hir::LowerError> {
125    apply_hir_extensions(&mut hir);
126    ModelReflection::load_hir_template(device, hir, options)
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use rlx_ir::hir::HirMut;
133    use rlx_ir::{DType, HirModule, ModelVariant, Shape};
134
135    #[test]
136    fn reflection_loads_template_on_cpu() {
137        let device = Device::Cpu;
138        let hir = || {
139            let mut hir = HirModule::new("refl");
140            let mut gb = HirMut::new(&mut hir);
141            let x = gb.input("x", Shape::new(&[1, 4], DType::F32));
142            let w = gb.param("w", Shape::new(&[4, 2], DType::F32));
143            let y = hir.linear(x, w, None, None, Shape::new(&[1, 2], DType::F32));
144            hir.set_outputs(vec![y]);
145            hir
146        };
147        let refl =
148            ModelReflection::load_hir_template(device, hir(), &CompileOptions::new()).unwrap();
149        assert!(refl.has_template());
150        let layout = refl.template_layout().unwrap();
151        assert_eq!(layout.params[0].name, "w");
152        let comp = ModelComponent::new(ModelVariant::prefill(1, 4));
153        let spec_layout = refl
154            .layout_for_component(&comp, device, &CompileOptions::new())
155            .unwrap();
156        assert_eq!(spec_layout.params[0].name, "w");
157    }
158}