Skip to main content

rlx_runtime/
reflect.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Model reflection services (Slang compiler/runtime API §5).
5//!
6//! Introspect unspecialized templates and specialized layouts across eager/lazy/AOT
7//! while preserving the HIR → MIR → LIR pipeline.
8
9use rlx_ir::hir::HirModule;
10use rlx_ir::{
11    BindingManifest, HirReflection, ManifestDiff, MirReflection, ModelComponent,
12    apply_hir_extensions, layout_from_lir,
13};
14use rlx_opt::CompileResult;
15
16use crate::Device;
17use crate::model_pipeline::ModelCompilePipeline;
18use crate::options::CompileOptions;
19use crate::stages;
20
21/// Loaded template + HIR reflection (front-end load).
22pub struct ModelReflection {
23    pub hir: HirReflection,
24    template: Option<CompileResult>,
25}
26
27impl ModelReflection {
28    /// Build HIR reflection only (no compile).
29    pub fn from_hir(hir: &HirModule) -> Self {
30        Self {
31            hir: HirReflection::from_hir(hir),
32            template: None,
33        }
34    }
35
36    /// Compile symbolic template on `device` and retain for specialize/layout.
37    pub fn load_hir_template(
38        device: Device,
39        hir: HirModule,
40        options: &CompileOptions,
41    ) -> Result<Self, rlx_ir::hir::LowerError> {
42        let mut opts = options.clone();
43        opts.dim_binding = None;
44        let hir_ref = HirReflection::from_hir(&hir);
45        let pipe = stages::pipeline_for(device, &opts);
46        let template = pipe.compile_hir(hir)?;
47        Ok(Self {
48            hir: hir_ref,
49            template: Some(template),
50        })
51    }
52
53    pub fn has_template(&self) -> bool {
54        self.template.is_some()
55    }
56
57    pub fn mir_summary(&self) -> Option<MirReflection> {
58        self.template
59            .as_ref()
60            .map(|t| MirReflection::from_mir(&t.lir.mir))
61    }
62
63    /// Template layout (symbolic dims may be unresolved in arena sizes).
64    pub fn template_layout(&self) -> Option<BindingManifest> {
65        self.template.as_ref().map(|t| layout_from_lir(&t.lir))
66    }
67
68    /// Specialized layout for a [`ModelComponent`] (getTypeLayout after specialize).
69    pub fn layout_for_component(
70        &self,
71        component: &ModelComponent,
72        device: Device,
73        options: &CompileOptions,
74    ) -> Option<BindingManifest> {
75        let template = self.template.as_ref()?;
76        let mut opts = options.clone();
77        opts.dim_binding = None;
78        let pipe = stages::pipeline_for(device, &opts);
79        let specialized = template.specialize(&pipe, &component.dim_binding());
80        Some(layout_from_lir(&specialized.lir))
81    }
82
83    pub fn manifest_diff_for_component(
84        &self,
85        component: &ModelComponent,
86        device: Device,
87        options: &CompileOptions,
88    ) -> Option<ManifestDiff> {
89        let t = self.template_layout()?;
90        let s = self.layout_for_component(component, device, options)?;
91        Some(ManifestDiff::compare(&t, &s))
92    }
93}
94
95/// Full specialize + compile entry (specializeEntryPoint analogue).
96pub fn specialize_entry<'a>(
97    pipeline: &'a mut ModelCompilePipeline,
98    component: &ModelComponent,
99    build_hir: impl FnOnce() -> HirModule,
100    options: &CompileOptions,
101) -> Result<&'a mut crate::CompiledGraph, rlx_ir::hir::LowerError> {
102    let key = component.cache_key();
103    let binding = component.dim_binding();
104    pipeline.get_or_compile(key, &binding, build_hir, options)
105}
106
107/// Apply HIR extensions then load template.
108pub fn load_hir_template_with_extensions(
109    device: Device,
110    mut hir: HirModule,
111    options: &CompileOptions,
112) -> Result<ModelReflection, rlx_ir::hir::LowerError> {
113    apply_hir_extensions(&mut hir);
114    ModelReflection::load_hir_template(device, hir, options)
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use rlx_ir::hir::HirMut;
121    use rlx_ir::{DType, HirModule, ModelVariant, Shape};
122
123    #[test]
124    fn reflection_loads_template_on_cpu() {
125        let device = Device::Cpu;
126        let hir = || {
127            let mut hir = HirModule::new("refl");
128            let mut gb = HirMut::new(&mut hir);
129            let x = gb.input("x", Shape::new(&[1, 4], DType::F32));
130            let w = gb.param("w", Shape::new(&[4, 2], DType::F32));
131            let y = hir.linear(x, w, None, None, Shape::new(&[1, 2], DType::F32));
132            hir.set_outputs(vec![y]);
133            hir
134        };
135        let refl =
136            ModelReflection::load_hir_template(device, hir(), &CompileOptions::new()).unwrap();
137        assert!(refl.has_template());
138        let layout = refl.template_layout().unwrap();
139        assert_eq!(layout.params[0].name, "w");
140        let comp = ModelComponent::new(ModelVariant::prefill(1, 4));
141        let spec_layout = refl
142            .layout_for_component(&comp, device, &CompileOptions::new())
143            .unwrap();
144        assert_eq!(spec_layout.params[0].name, "w");
145    }
146}