1use rlx_ir::hir::HirModule;
22use rlx_ir::{
23 BindingManifest, HirReflection, ManifestDiff, MirReflection, ModelComponent,
24 apply_hir_extensions, layout_from_lir,
25};
26use rlx_opt::CompileResult;
27
28use crate::Device;
29use crate::model_pipeline::ModelCompilePipeline;
30use crate::options::CompileOptions;
31use crate::stages;
32
33pub struct ModelReflection {
35 pub hir: HirReflection,
36 template: Option<CompileResult>,
37}
38
39impl ModelReflection {
40 pub fn from_hir(hir: &HirModule) -> Self {
42 Self {
43 hir: HirReflection::from_hir(hir),
44 template: None,
45 }
46 }
47
48 pub fn load_hir_template(
50 device: Device,
51 hir: HirModule,
52 options: &CompileOptions,
53 ) -> Result<Self, rlx_ir::hir::LowerError> {
54 let mut opts = options.clone();
55 opts.dim_binding = None;
56 let hir_ref = HirReflection::from_hir(&hir);
57 let pipe = stages::pipeline_for(device, &opts);
58 let template = pipe.compile_hir(hir)?;
59 Ok(Self {
60 hir: hir_ref,
61 template: Some(template),
62 })
63 }
64
65 pub fn has_template(&self) -> bool {
66 self.template.is_some()
67 }
68
69 pub fn mir_summary(&self) -> Option<MirReflection> {
70 self.template
71 .as_ref()
72 .map(|t| MirReflection::from_mir(&t.lir.mir))
73 }
74
75 pub fn template_layout(&self) -> Option<BindingManifest> {
77 self.template.as_ref().map(|t| layout_from_lir(&t.lir))
78 }
79
80 pub fn layout_for_component(
82 &self,
83 component: &ModelComponent,
84 device: Device,
85 options: &CompileOptions,
86 ) -> Option<BindingManifest> {
87 let template = self.template.as_ref()?;
88 let mut opts = options.clone();
89 opts.dim_binding = None;
90 let pipe = stages::pipeline_for(device, &opts);
91 let specialized = template.specialize(&pipe, &component.dim_binding());
92 Some(layout_from_lir(&specialized.lir))
93 }
94
95 pub fn manifest_diff_for_component(
96 &self,
97 component: &ModelComponent,
98 device: Device,
99 options: &CompileOptions,
100 ) -> Option<ManifestDiff> {
101 let t = self.template_layout()?;
102 let s = self.layout_for_component(component, device, options)?;
103 Some(ManifestDiff::compare(&t, &s))
104 }
105}
106
107pub fn specialize_entry<'a>(
109 pipeline: &'a mut ModelCompilePipeline,
110 component: &ModelComponent,
111 build_hir: impl FnOnce() -> HirModule,
112 options: &CompileOptions,
113) -> Result<&'a mut crate::CompiledGraph, rlx_ir::hir::LowerError> {
114 let key = component.cache_key();
115 let binding = component.dim_binding();
116 pipeline.get_or_compile(key, &binding, build_hir, options)
117}
118
119pub fn load_hir_template_with_extensions(
121 device: Device,
122 mut hir: HirModule,
123 options: &CompileOptions,
124) -> Result<ModelReflection, rlx_ir::hir::LowerError> {
125 apply_hir_extensions(&mut hir);
126 ModelReflection::load_hir_template(device, hir, options)
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use rlx_ir::hir::HirMut;
133 use rlx_ir::{DType, HirModule, ModelVariant, Shape};
134
135 #[test]
136 fn reflection_loads_template_on_cpu() {
137 let device = Device::Cpu;
138 let hir = || {
139 let mut hir = HirModule::new("refl");
140 let mut gb = HirMut::new(&mut hir);
141 let x = gb.input("x", Shape::new(&[1, 4], DType::F32));
142 let w = gb.param("w", Shape::new(&[4, 2], DType::F32));
143 let y = hir.linear(x, w, None, None, Shape::new(&[1, 2], DType::F32));
144 hir.set_outputs(vec![y]);
145 hir
146 };
147 let refl =
148 ModelReflection::load_hir_template(device, hir(), &CompileOptions::new()).unwrap();
149 assert!(refl.has_template());
150 let layout = refl.template_layout().unwrap();
151 assert_eq!(layout.params[0].name, "w");
152 let comp = ModelComponent::new(ModelVariant::prefill(1, 4));
153 let spec_layout = refl
154 .layout_for_component(&comp, device, &CompileOptions::new())
155 .unwrap();
156 assert_eq!(spec_layout.params[0].name, "w");
157 }
158}