1use rlx_ir::hir::HirModule;
10use rlx_ir::{
11 BindingManifest, HirReflection, ManifestDiff, MirReflection, ModelComponent,
12 apply_hir_extensions, layout_from_lir,
13};
14use rlx_opt::CompileResult;
15
16use crate::Device;
17use crate::model_pipeline::ModelCompilePipeline;
18use crate::options::CompileOptions;
19use crate::stages;
20
21pub struct ModelReflection {
23 pub hir: HirReflection,
24 template: Option<CompileResult>,
25}
26
27impl ModelReflection {
28 pub fn from_hir(hir: &HirModule) -> Self {
30 Self {
31 hir: HirReflection::from_hir(hir),
32 template: None,
33 }
34 }
35
36 pub fn load_hir_template(
38 device: Device,
39 hir: HirModule,
40 options: &CompileOptions,
41 ) -> Result<Self, rlx_ir::hir::LowerError> {
42 let mut opts = options.clone();
43 opts.dim_binding = None;
44 let hir_ref = HirReflection::from_hir(&hir);
45 let pipe = stages::pipeline_for(device, &opts);
46 let template = pipe.compile_hir(hir)?;
47 Ok(Self {
48 hir: hir_ref,
49 template: Some(template),
50 })
51 }
52
53 pub fn has_template(&self) -> bool {
54 self.template.is_some()
55 }
56
57 pub fn mir_summary(&self) -> Option<MirReflection> {
58 self.template
59 .as_ref()
60 .map(|t| MirReflection::from_mir(&t.lir.mir))
61 }
62
63 pub fn template_layout(&self) -> Option<BindingManifest> {
65 self.template.as_ref().map(|t| layout_from_lir(&t.lir))
66 }
67
68 pub fn layout_for_component(
70 &self,
71 component: &ModelComponent,
72 device: Device,
73 options: &CompileOptions,
74 ) -> Option<BindingManifest> {
75 let template = self.template.as_ref()?;
76 let mut opts = options.clone();
77 opts.dim_binding = None;
78 let pipe = stages::pipeline_for(device, &opts);
79 let specialized = template.specialize(&pipe, &component.dim_binding());
80 Some(layout_from_lir(&specialized.lir))
81 }
82
83 pub fn manifest_diff_for_component(
84 &self,
85 component: &ModelComponent,
86 device: Device,
87 options: &CompileOptions,
88 ) -> Option<ManifestDiff> {
89 let t = self.template_layout()?;
90 let s = self.layout_for_component(component, device, options)?;
91 Some(ManifestDiff::compare(&t, &s))
92 }
93}
94
95pub fn specialize_entry<'a>(
97 pipeline: &'a mut ModelCompilePipeline,
98 component: &ModelComponent,
99 build_hir: impl FnOnce() -> HirModule,
100 options: &CompileOptions,
101) -> Result<&'a mut crate::CompiledGraph, rlx_ir::hir::LowerError> {
102 let key = component.cache_key();
103 let binding = component.dim_binding();
104 pipeline.get_or_compile(key, &binding, build_hir, options)
105}
106
107pub fn load_hir_template_with_extensions(
109 device: Device,
110 mut hir: HirModule,
111 options: &CompileOptions,
112) -> Result<ModelReflection, rlx_ir::hir::LowerError> {
113 apply_hir_extensions(&mut hir);
114 ModelReflection::load_hir_template(device, hir, options)
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use rlx_ir::hir::HirMut;
121 use rlx_ir::{DType, HirModule, ModelVariant, Shape};
122
123 #[test]
124 fn reflection_loads_template_on_cpu() {
125 let device = Device::Cpu;
126 let hir = || {
127 let mut hir = HirModule::new("refl");
128 let mut gb = HirMut::new(&mut hir);
129 let x = gb.input("x", Shape::new(&[1, 4], DType::F32));
130 let w = gb.param("w", Shape::new(&[4, 2], DType::F32));
131 let y = hir.linear(x, w, None, None, Shape::new(&[1, 2], DType::F32));
132 hir.set_outputs(vec![y]);
133 hir
134 };
135 let refl =
136 ModelReflection::load_hir_template(device, hir(), &CompileOptions::new()).unwrap();
137 assert!(refl.has_template());
138 let layout = refl.template_layout().unwrap();
139 assert_eq!(layout.params[0].name, "w");
140 let comp = ModelComponent::new(ModelVariant::prefill(1, 4));
141 let spec_layout = refl
142 .layout_for_component(&comp, device, &CompileOptions::new())
143 .unwrap();
144 assert_eq!(spec_layout.params[0].name, "w");
145 }
146}