Skip to main content

rlx_runtime/
model_pipeline.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Three-step model compile pipeline (template → specialize → backend).
5//!
6//! Host code loads symbolic HIR once, specializes per [`ModelVariant`] /
7//! [`DimBinding`], then lowers to a device executable. Pair with
8//! [`BindingManifest`] for parameter-block style binding.
9
10use std::collections::VecDeque;
11
12use rlx_ir::hir::HirModule;
13use rlx_ir::{BindingManifest, DimBinding, ModelComponent};
14use rlx_opt::CompileResult;
15
16use crate::stages;
17use crate::{CompileOptions, CompiledGraph, Device};
18
19/// Compile-once / specialize-per-variant pipeline with optional FIFO cache.
20pub struct ModelCompilePipeline {
21    device: Device,
22    capacity: usize,
23    template: Option<CompileResult>,
24    entries: Vec<(u64, CompiledGraph)>,
25    order: VecDeque<u64>,
26}
27
28impl ModelCompilePipeline {
29    pub fn new(device: Device) -> Self {
30        Self::with_capacity(device, 8)
31    }
32
33    pub fn with_capacity(device: Device, capacity: usize) -> Self {
34        assert!(capacity > 0, "ModelCompilePipeline capacity must be ≥ 1");
35        Self {
36            device,
37            capacity,
38            template: None,
39            entries: Vec::new(),
40            order: VecDeque::new(),
41        }
42    }
43
44    pub fn device(&self) -> Device {
45        self.device
46    }
47
48    pub fn has_template(&self) -> bool {
49        self.template.is_some()
50    }
51
52    /// **Step 1** — run fusion pipeline on symbolic HIR (dynamic dims allowed).
53    pub fn build_template<F>(
54        &mut self,
55        build_hir: F,
56        options: &CompileOptions,
57    ) -> Result<&CompileResult, rlx_ir::hir::LowerError>
58    where
59        F: FnOnce() -> HirModule,
60    {
61        if self.template.is_none() {
62            let pipe = stages::pipeline_for(self.device, options);
63            self.template = Some(pipe.compile_hir(build_hir())?);
64        }
65        Ok(self.template.as_ref().expect("template set"))
66    }
67
68    pub fn template_binding_manifest(&self) -> BindingManifest {
69        let template = self.template.as_ref().expect("call build_template first");
70        BindingManifest::from_lir(&template.lir)
71    }
72
73    /// **Step 2** — bind symbolic dims and replan buffers.
74    pub fn specialize_template(
75        &self,
76        binding: &DimBinding,
77        options: &CompileOptions,
78    ) -> CompileResult {
79        let template = self
80            .template
81            .as_ref()
82            .expect("call build_template before specialize_template");
83        let pipe = stages::pipeline_for(self.device, options);
84        template.specialize(&pipe, binding)
85    }
86
87    /// **Step 3** — backend executable from specialized LIR.
88    pub fn compile_lir(
89        &self,
90        specialized: CompileResult,
91        options: &CompileOptions,
92    ) -> CompiledGraph {
93        let backend = crate::registry::backend_for(self.device).expect("backend registered");
94        let executable = backend.compile_lir(specialized.lir, options);
95        CompiledGraph::new(executable, self.device)
96    }
97
98    /// Full pipeline: template (once) → specialize → compile; cached by `key`.
99    pub fn get_or_compile<F>(
100        &mut self,
101        key: u64,
102        binding: &DimBinding,
103        build_hir: F,
104        options: &CompileOptions,
105    ) -> Result<&mut CompiledGraph, rlx_ir::hir::LowerError>
106    where
107        F: FnOnce() -> HirModule,
108    {
109        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
110            return Ok(&mut self.entries[idx].1);
111        }
112        let mut template_opts = options.clone();
113        template_opts.dim_binding = None;
114        self.build_template(build_hir, &template_opts)?;
115        let specialized = self.specialize_template(binding, &template_opts);
116        let mut compile_opts = options.clone();
117        compile_opts.dim_binding = None;
118        let compiled = self.compile_lir(specialized, &compile_opts);
119
120        if self.entries.len() >= self.capacity
121            && let Some(evict) = self.order.pop_front()
122        {
123            self.entries.retain(|(k, _)| *k != evict);
124        }
125        self.entries.push((key, compiled));
126        self.order.push_back(key);
127        Ok(&mut self.entries.last_mut().unwrap().1)
128    }
129
130    /// Manifest for a variant without storing specialized LIR in the cache.
131    pub fn binding_manifest_for_binding(
132        &self,
133        binding: &DimBinding,
134        options: &CompileOptions,
135    ) -> BindingManifest {
136        let specialized = self.specialize_template(binding, options);
137        BindingManifest::from_lir(&specialized.lir)
138    }
139
140    /// Layout for a full [`ModelComponent`] (specialized parameter block).
141    pub fn binding_manifest_for_component(
142        &self,
143        component: &ModelComponent,
144        options: &CompileOptions,
145    ) -> BindingManifest {
146        self.binding_manifest_for_binding(&component.dim_binding(), options)
147    }
148
149    /// Template → specialize → compile; keyed by [`ModelComponent::cache_key`].
150    pub fn get_or_compile_component<F>(
151        &mut self,
152        component: &ModelComponent,
153        build_hir: F,
154        options: &CompileOptions,
155    ) -> Result<(&mut CompiledGraph, BindingManifest), rlx_ir::hir::LowerError>
156    where
157        F: FnOnce() -> HirModule,
158    {
159        let key = component.cache_key();
160        let binding = component.dim_binding();
161        let manifest = self.binding_manifest_for_component(component, options);
162        let compiled = self.get_or_compile(key, &binding, build_hir, options)?;
163        Ok((compiled, manifest))
164    }
165
166    pub fn contains(&self, key: u64) -> bool {
167        self.entries.iter().any(|(k, _)| *k == key)
168    }
169
170    pub fn len(&self) -> usize {
171        self.entries.len()
172    }
173
174    pub fn is_empty(&self) -> bool {
175        self.entries.is_empty()
176    }
177
178    /// Symbolic template from [`Self::build_template`] / [`Self::get_or_compile`].
179    pub fn template_result(&self) -> Option<&CompileResult> {
180        self.template.as_ref()
181    }
182
183    /// Build the symbolic template once (no specialization).
184    pub fn ensure_template<F: FnOnce() -> HirModule>(
185        &mut self,
186        build_hir: F,
187        options: &CompileOptions,
188    ) -> Result<&CompileResult, rlx_ir::hir::LowerError> {
189        self.build_template(build_hir, options)
190    }
191
192    /// Disk-backed specialize ([`CompilationMode::Aot`]); caches by `key`.
193    pub fn get_or_specialize_aot<F: FnOnce() -> HirModule>(
194        &mut self,
195        aot: &crate::AotCache,
196        disk_base: &str,
197        key: u64,
198        binding: &DimBinding,
199        build_hir: F,
200        options: &CompileOptions,
201    ) -> Result<&mut CompiledGraph, crate::AotCacheError> {
202        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
203            return Ok(&mut self.entries[idx].1);
204        }
205        let device = self.device;
206        let template = self.ensure_template(build_hir, options)?;
207        let compiled = aot.specialize_cached(disk_base, binding, device, template, options)?;
208        if self.entries.len() >= self.capacity
209            && let Some(evict_key) = self.order.pop_front()
210        {
211            self.entries.retain(|(k, _)| *k != evict_key);
212        }
213        self.entries.push((key, compiled));
214        self.order.push_back(key);
215        Ok(&mut self.entries.last_mut().unwrap().1)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use rlx_ir::hir::HirMut;
223    use rlx_ir::{DType, HirModule, Shape};
224
225    #[test]
226    fn template_specialize_compile_smoke() {
227        let device = Device::Cpu;
228        let mut pipe = ModelCompilePipeline::new(device);
229        let opts = CompileOptions::new();
230
231        let build = || {
232            let mut hir = HirModule::new("dyn");
233            let mut gb = HirMut::new(&mut hir);
234            let x = gb.input("x", Shape::new(&[1, 8, 4], DType::F32));
235            let w = gb.param("w", Shape::new(&[4, 2], DType::F32));
236            let y = hir.linear(x, w, None, None, Shape::new(&[1, 8, 2], DType::F32));
237            hir.set_outputs(vec![y]);
238            hir
239        };
240
241        pipe.build_template(build, &opts).unwrap();
242        let binding = rlx_ir::DimBinding::new();
243        let spec = pipe.specialize_template(&binding, &opts);
244        let manifest = BindingManifest::from_lir(&spec.lir);
245        assert_eq!(manifest.params[0].name, "w");
246        let mut compiled = pipe.compile_lir(spec, &opts);
247        compiled.set_param("w", &[0.0f32; 8]);
248        let outs = compiled.run(&[("x", &[0.0f32; 32])]);
249        assert_eq!(outs.len(), 1);
250    }
251}