Skip to main content

rlx_runtime/
model_pipeline.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Three-step model compile pipeline (template → specialize → backend).
17//!
18//! Host code loads symbolic HIR once, specializes per [`ModelVariant`] /
19//! [`DimBinding`], then lowers to a device executable. Pair with
20//! [`BindingManifest`] for parameter-block style binding.
21
22use std::collections::VecDeque;
23
24use rlx_ir::hir::HirModule;
25use rlx_ir::{BindingManifest, DimBinding, ModelComponent};
26use rlx_opt::CompileResult;
27
28use crate::stages;
29use crate::{CompileOptions, CompiledGraph, Device};
30
31/// Compile-once / specialize-per-variant pipeline with optional FIFO cache.
32pub struct ModelCompilePipeline {
33    device: Device,
34    capacity: usize,
35    template: Option<CompileResult>,
36    entries: Vec<(u64, CompiledGraph)>,
37    order: VecDeque<u64>,
38}
39
40impl ModelCompilePipeline {
41    pub fn new(device: Device) -> Self {
42        Self::with_capacity(device, 8)
43    }
44
45    pub fn with_capacity(device: Device, capacity: usize) -> Self {
46        assert!(capacity > 0, "ModelCompilePipeline capacity must be ≥ 1");
47        Self {
48            device,
49            capacity,
50            template: None,
51            entries: Vec::new(),
52            order: VecDeque::new(),
53        }
54    }
55
56    pub fn device(&self) -> Device {
57        self.device
58    }
59
60    pub fn has_template(&self) -> bool {
61        self.template.is_some()
62    }
63
64    /// **Step 1** — run fusion pipeline on symbolic HIR (dynamic dims allowed).
65    pub fn build_template<F>(
66        &mut self,
67        build_hir: F,
68        options: &CompileOptions,
69    ) -> Result<&CompileResult, rlx_ir::hir::LowerError>
70    where
71        F: FnOnce() -> HirModule,
72    {
73        if self.template.is_none() {
74            let pipe = stages::pipeline_for(self.device, options);
75            self.template = Some(pipe.compile_hir(build_hir())?);
76        }
77        Ok(self.template.as_ref().expect("template set"))
78    }
79
80    pub fn template_binding_manifest(&self) -> BindingManifest {
81        let template = self.template.as_ref().expect("call build_template first");
82        BindingManifest::from_lir(&template.lir)
83    }
84
85    /// **Step 2** — bind symbolic dims and replan buffers.
86    pub fn specialize_template(
87        &self,
88        binding: &DimBinding,
89        options: &CompileOptions,
90    ) -> CompileResult {
91        let template = self
92            .template
93            .as_ref()
94            .expect("call build_template before specialize_template");
95        let pipe = stages::pipeline_for(self.device, options);
96        template.specialize(&pipe, binding)
97    }
98
99    /// **Step 3** — backend executable from specialized LIR.
100    pub fn compile_lir(
101        &self,
102        specialized: CompileResult,
103        options: &CompileOptions,
104    ) -> CompiledGraph {
105        let backend = crate::registry::backend_for(self.device).expect("backend registered");
106        let executable = backend.compile_lir(specialized.lir, options);
107        CompiledGraph::new(executable, self.device)
108    }
109
110    /// Full pipeline: template (once) → specialize → compile; cached by `key`.
111    pub fn get_or_compile<F>(
112        &mut self,
113        key: u64,
114        binding: &DimBinding,
115        build_hir: F,
116        options: &CompileOptions,
117    ) -> Result<&mut CompiledGraph, rlx_ir::hir::LowerError>
118    where
119        F: FnOnce() -> HirModule,
120    {
121        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
122            return Ok(&mut self.entries[idx].1);
123        }
124        let mut template_opts = options.clone();
125        template_opts.dim_binding = None;
126        self.build_template(build_hir, &template_opts)?;
127        let specialized = self.specialize_template(binding, &template_opts);
128        let mut compile_opts = options.clone();
129        compile_opts.dim_binding = None;
130        let compiled = self.compile_lir(specialized, &compile_opts);
131
132        if self.entries.len() >= self.capacity
133            && let Some(evict) = self.order.pop_front()
134        {
135            self.entries.retain(|(k, _)| *k != evict);
136        }
137        self.entries.push((key, compiled));
138        self.order.push_back(key);
139        Ok(&mut self.entries.last_mut().unwrap().1)
140    }
141
142    /// Manifest for a variant without storing specialized LIR in the cache.
143    pub fn binding_manifest_for_binding(
144        &self,
145        binding: &DimBinding,
146        options: &CompileOptions,
147    ) -> BindingManifest {
148        let specialized = self.specialize_template(binding, options);
149        BindingManifest::from_lir(&specialized.lir)
150    }
151
152    /// Layout for a full [`ModelComponent`] (specialized parameter block).
153    pub fn binding_manifest_for_component(
154        &self,
155        component: &ModelComponent,
156        options: &CompileOptions,
157    ) -> BindingManifest {
158        self.binding_manifest_for_binding(&component.dim_binding(), options)
159    }
160
161    /// Template → specialize → compile; keyed by [`ModelComponent::cache_key`].
162    pub fn get_or_compile_component<F>(
163        &mut self,
164        component: &ModelComponent,
165        build_hir: F,
166        options: &CompileOptions,
167    ) -> Result<(&mut CompiledGraph, BindingManifest), rlx_ir::hir::LowerError>
168    where
169        F: FnOnce() -> HirModule,
170    {
171        let key = component.cache_key();
172        let binding = component.dim_binding();
173        let manifest = self.binding_manifest_for_component(component, options);
174        let compiled = self.get_or_compile(key, &binding, build_hir, options)?;
175        Ok((compiled, manifest))
176    }
177
178    pub fn contains(&self, key: u64) -> bool {
179        self.entries.iter().any(|(k, _)| *k == key)
180    }
181
182    pub fn len(&self) -> usize {
183        self.entries.len()
184    }
185
186    pub fn is_empty(&self) -> bool {
187        self.entries.is_empty()
188    }
189
190    /// Symbolic template from [`Self::build_template`] / [`Self::get_or_compile`].
191    pub fn template_result(&self) -> Option<&CompileResult> {
192        self.template.as_ref()
193    }
194
195    /// Build the symbolic template once (no specialization).
196    pub fn ensure_template<F: FnOnce() -> HirModule>(
197        &mut self,
198        build_hir: F,
199        options: &CompileOptions,
200    ) -> Result<&CompileResult, rlx_ir::hir::LowerError> {
201        self.build_template(build_hir, options)
202    }
203
204    /// Disk-backed specialize ([`CompilationMode::Aot`]); caches by `key`.
205    pub fn get_or_specialize_aot<F: FnOnce() -> HirModule>(
206        &mut self,
207        aot: &crate::AotCache,
208        disk_base: &str,
209        key: u64,
210        binding: &DimBinding,
211        build_hir: F,
212        options: &CompileOptions,
213    ) -> Result<&mut CompiledGraph, crate::AotCacheError> {
214        if let Some(idx) = self.entries.iter().position(|(k, _)| *k == key) {
215            return Ok(&mut self.entries[idx].1);
216        }
217        let device = self.device;
218        let template = self.ensure_template(build_hir, options)?;
219        let compiled = aot.specialize_cached(disk_base, binding, device, template, options)?;
220        if self.entries.len() >= self.capacity
221            && let Some(evict_key) = self.order.pop_front()
222        {
223            self.entries.retain(|(k, _)| *k != evict_key);
224        }
225        self.entries.push((key, compiled));
226        self.order.push_back(key);
227        Ok(&mut self.entries.last_mut().unwrap().1)
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use rlx_ir::hir::HirMut;
235    use rlx_ir::{DType, HirModule, Shape};
236
237    #[test]
238    fn template_specialize_compile_smoke() {
239        let device = Device::Cpu;
240        let mut pipe = ModelCompilePipeline::new(device);
241        let opts = CompileOptions::new();
242
243        let build = || {
244            let mut hir = HirModule::new("dyn");
245            let mut gb = HirMut::new(&mut hir);
246            let x = gb.input("x", Shape::new(&[1, 8, 4], DType::F32));
247            let w = gb.param("w", Shape::new(&[4, 2], DType::F32));
248            let y = hir.linear(x, w, None, None, Shape::new(&[1, 8, 2], DType::F32));
249            hir.set_outputs(vec![y]);
250            hir
251        };
252
253        pipe.build_template(build, &opts).unwrap();
254        let binding = rlx_ir::DimBinding::new();
255        let spec = pipe.specialize_template(&binding, &opts);
256        let manifest = BindingManifest::from_lir(&spec.lir);
257        assert_eq!(manifest.params[0].name, "w");
258        let mut compiled = pipe.compile_lir(spec, &opts);
259        compiled.set_param("w", &[0.0f32; 8]);
260        let outs = compiled.run(&[("x", &[0.0f32; 32])]);
261        assert_eq!(outs.len(), 1);
262    }
263}