use rlx_ir::hir::HirModule;
use rlx_ir::{
BindingManifest, HirReflection, ManifestDiff, MirReflection, ModelComponent,
apply_hir_extensions, layout_from_lir,
};
use rlx_opt::CompileResult;
use crate::Device;
use crate::model_pipeline::ModelCompilePipeline;
use crate::options::CompileOptions;
use crate::stages;
pub struct ModelReflection {
pub hir: HirReflection,
template: Option<CompileResult>,
}
impl ModelReflection {
pub fn from_hir(hir: &HirModule) -> Self {
Self {
hir: HirReflection::from_hir(hir),
template: None,
}
}
pub fn load_hir_template(
device: Device,
hir: HirModule,
options: &CompileOptions,
) -> Result<Self, rlx_ir::hir::LowerError> {
let mut opts = options.clone();
opts.dim_binding = None;
let hir_ref = HirReflection::from_hir(&hir);
let pipe = stages::pipeline_for(device, &opts);
let template = pipe.compile_hir(hir)?;
Ok(Self {
hir: hir_ref,
template: Some(template),
})
}
pub fn has_template(&self) -> bool {
self.template.is_some()
}
pub fn mir_summary(&self) -> Option<MirReflection> {
self.template
.as_ref()
.map(|t| MirReflection::from_mir(&t.lir.mir))
}
pub fn template_layout(&self) -> Option<BindingManifest> {
self.template.as_ref().map(|t| layout_from_lir(&t.lir))
}
pub fn layout_for_component(
&self,
component: &ModelComponent,
device: Device,
options: &CompileOptions,
) -> Option<BindingManifest> {
let template = self.template.as_ref()?;
let mut opts = options.clone();
opts.dim_binding = None;
let pipe = stages::pipeline_for(device, &opts);
let specialized = template.specialize(&pipe, &component.dim_binding());
Some(layout_from_lir(&specialized.lir))
}
pub fn manifest_diff_for_component(
&self,
component: &ModelComponent,
device: Device,
options: &CompileOptions,
) -> Option<ManifestDiff> {
let t = self.template_layout()?;
let s = self.layout_for_component(component, device, options)?;
Some(ManifestDiff::compare(&t, &s))
}
}
pub fn specialize_entry<'a>(
pipeline: &'a mut ModelCompilePipeline,
component: &ModelComponent,
build_hir: impl FnOnce() -> HirModule,
options: &CompileOptions,
) -> Result<&'a mut crate::CompiledGraph, rlx_ir::hir::LowerError> {
let key = component.cache_key();
let binding = component.dim_binding();
pipeline.get_or_compile(key, &binding, build_hir, options)
}
pub fn load_hir_template_with_extensions(
device: Device,
mut hir: HirModule,
options: &CompileOptions,
) -> Result<ModelReflection, rlx_ir::hir::LowerError> {
apply_hir_extensions(&mut hir);
ModelReflection::load_hir_template(device, hir, options)
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::hir::HirMut;
use rlx_ir::{DType, HirModule, ModelVariant, Shape};
#[test]
fn reflection_loads_template_on_cpu() {
let device = Device::Cpu;
let hir = || {
let mut hir = HirModule::new("refl");
let mut gb = HirMut::new(&mut hir);
let x = gb.input("x", Shape::new(&[1, 4], DType::F32));
let w = gb.param("w", Shape::new(&[4, 2], DType::F32));
let y = hir.linear(x, w, None, None, Shape::new(&[1, 2], DType::F32));
hir.set_outputs(vec![y]);
hir
};
let refl =
ModelReflection::load_hir_template(device, hir(), &CompileOptions::new()).unwrap();
assert!(refl.has_template());
let layout = refl.template_layout().unwrap();
assert_eq!(layout.params[0].name, "w");
let comp = ModelComponent::new(ModelVariant::prefill(1, 4));
let spec_layout = refl
.layout_for_component(&comp, device, &CompileOptions::new())
.unwrap();
assert_eq!(spec_layout.params[0].name, "w");
}
}