Skip to main content

rlx_flow/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! ModelFlow builder and built output.
17
18use std::collections::HashMap;
19
20use anyhow::Result;
21use rlx_ir::hir::HirModule;
22use rlx_ir::{Graph, GraphModule, GraphStage, HirNodeId, Shape, hir_to_graph};
23
24use crate::context::{FlowCtx, FlowState};
25use crate::execution::ModelExecutionConfig;
26use crate::extension::FlowExtensionPlan;
27use crate::profile::CompileProfile;
28use crate::stage::FlowStage;
29use crate::value::FlowValue;
30use crate::weight::WeightSource;
31
32/// Block assembly-line builder — tier-0 model author surface.
33#[derive(Debug)]
34pub struct ModelFlow {
35    name: String,
36    pub(crate) profile: CompileProfile,
37    /// Graph inputs declared before block stages (do not participate in tensor flow).
38    inputs: Vec<(String, Shape)>,
39    pub(crate) stages: Vec<FlowStage>,
40    output_names: Vec<String>,
41    extra_outputs: Vec<HirNodeId>,
42    extension_plan: FlowExtensionPlan,
43}
44
45impl ModelFlow {
46    pub fn new(name: impl Into<String>) -> Self {
47        Self {
48            name: name.into(),
49            profile: CompileProfile::default(),
50            inputs: Vec::new(),
51            stages: Vec::new(),
52            output_names: vec!["output".into()],
53            extra_outputs: Vec::new(),
54            extension_plan: FlowExtensionPlan::default(),
55        }
56    }
57
58    /// HIR extensions to apply after assemble, before compile (retroactive plugins).
59    pub fn with_extensions(mut self, plan: FlowExtensionPlan) -> Self {
60        self.extension_plan = plan;
61        self
62    }
63
64    /// Declare a graph input. The first input starts the tensor flow; later
65    /// inputs are side declarations only (e.g. `last_token_idx`).
66    pub fn input(mut self, name: impl Into<String>, shape: Shape) -> Self {
67        self.inputs.push((name.into(), shape));
68        self
69    }
70
71    pub fn with_profile(mut self, profile: CompileProfile) -> Self {
72        self.profile = profile;
73        self
74    }
75
76    pub fn profile(&self) -> &CompileProfile {
77        &self.profile
78    }
79
80    pub fn stage(mut self, stage: FlowStage) -> Self {
81        self.stages.push(stage);
82        self
83    }
84
85    pub fn output(mut self, name: impl Into<String>) -> Self {
86        self.output_names = vec![name.into()];
87        self
88    }
89
90    pub fn outputs(mut self, names: impl IntoIterator<Item = impl Into<String>>) -> Self {
91        self.output_names = names.into_iter().map(Into::into).collect();
92        self
93    }
94
95    /// Append side outputs (e.g. per-layer KV taps) after the primary output.
96    pub fn with_extra_outputs(mut self, ids: Vec<HirNodeId>) -> Self {
97        self.extra_outputs = ids;
98        self
99    }
100
101    /// Build from a composable recipe, then optionally patch before compile.
102    pub fn from_recipe(recipe: &impl crate::recipe::ModelRecipe) -> Self {
103        recipe.assemble()
104    }
105
106    pub fn build(self, weights: &mut dyn WeightSource) -> Result<BuiltModel> {
107        let mut module =
108            GraphModule::hir(&self.name).with_fusion_policy(self.profile.fusion_policy());
109        let mut params = HashMap::new();
110        let mut state = FlowState::default();
111        let mut ctx = FlowCtx {
112            module,
113            params: &mut params,
114            weights,
115            profile: &self.profile,
116            state: &mut state,
117        };
118
119        let mut value: Option<FlowValue> = None;
120        for (i, (name, shape)) in self.inputs.iter().enumerate() {
121            let id = ctx.input(name, shape.clone());
122            ctx.state.inputs.insert(name.clone(), (id, shape.clone()));
123            if i == 0 {
124                value = Some(ctx.wrap(id, shape.clone()));
125            }
126        }
127        for stage in &self.stages {
128            value = stage.emit(&mut ctx, value)?;
129        }
130
131        let primary = value.ok_or_else(|| anyhow::anyhow!("ModelFlow produced no output"))?;
132        let mut outputs = vec![primary.id];
133        outputs.extend(self.extra_outputs);
134
135        ctx.module.set_outputs(outputs);
136        module = ctx.module;
137        if let Some(hir) = module.as_hir_mut() {
138            self.extension_plan.apply(hir);
139        }
140
141        Ok(BuiltModel {
142            module,
143            params,
144            typed_params: Vec::new(),
145            profile: self.profile,
146            output_names: self.output_names,
147            primary_shape: primary.shape,
148        })
149    }
150
151    /// Compatibility shim: older callers passed GGUF packed matmul params.
152    ///
153    /// The current flow builder ignores packed params; packed lowering lives in model crates.
154    pub fn build_with(
155        self,
156        weights: &mut dyn WeightSource,
157        _gguf_packed: Option<&crate::GgufPackedParams>,
158    ) -> Result<BuiltModel> {
159        self.build(weights)
160    }
161}
162
163/// Result of assembling a model flow.
164#[derive(Debug, Clone)]
165pub struct BuiltModel {
166    pub module: GraphModule,
167    pub params: HashMap<String, Vec<f32>>,
168    /// Packed U8 params (GGUF quant blobs) attached after compile via `set_param_typed`.
169    pub typed_params: Vec<(String, Vec<u8>, rlx_ir::DType)>,
170    pub profile: CompileProfile,
171    output_names: Vec<String>,
172    primary_shape: Shape,
173}
174
175impl BuiltModel {
176    /// Attach variant + execution preset (shader-component bundle).
177    pub fn with_execution_config(mut self, config: &ModelExecutionConfig) -> Self {
178        self.profile = config.compile_profile();
179        self
180    }
181
182    /// Wrap a legacy HIR builder product as tier-0 flow output (migration bridge).
183    pub fn from_hir(hir: HirModule, params: HashMap<String, Vec<f32>>) -> anyhow::Result<Self> {
184        let primary = hir
185            .outputs
186            .first()
187            .copied()
188            .ok_or_else(|| anyhow::anyhow!("from_hir: module has no outputs"))?;
189        let primary_shape = hir.node(primary).shape.clone();
190        Ok(Self {
191            module: GraphModule::from_hir(hir),
192            params,
193            typed_params: Vec::new(),
194            profile: CompileProfile::default(),
195            output_names: vec!["output".into()],
196            primary_shape,
197        })
198    }
199
200    /// Wrap a legacy MIR graph builder product as tier-0 flow output (migration bridge).
201    pub fn from_graph(graph: Graph, params: HashMap<String, Vec<f32>>) -> anyhow::Result<Self> {
202        let primary = graph
203            .outputs
204            .first()
205            .copied()
206            .ok_or_else(|| anyhow::anyhow!("from_graph: graph has no outputs"))?;
207        let primary_shape = graph.node(primary).shape.clone();
208        Ok(Self {
209            module: GraphModule::from_graph(graph),
210            params,
211            typed_params: Vec::new(),
212            profile: CompileProfile::default(),
213            output_names: vec!["output".into()],
214            primary_shape,
215        })
216    }
217
218    pub fn profile(&self) -> &CompileProfile {
219        &self.profile
220    }
221
222    pub fn params(&self) -> &HashMap<String, Vec<f32>> {
223        &self.params
224    }
225
226    pub fn primary_shape(&self) -> &Shape {
227        &self.primary_shape
228    }
229
230    pub fn output_names(&self) -> &[String] {
231        &self.output_names
232    }
233
234    /// `(Graph, params)` for legacy compile paths. Moves `params` (no clone).
235    pub fn into_graph_parts(mut self) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
236        let params = std::mem::take(&mut self.params);
237        let graph = self.into_graph()?;
238        Ok((graph, params))
239    }
240
241    pub fn into_graph_module(self) -> GraphModule {
242        self.module
243    }
244
245    pub fn into_hir(self) -> Option<HirModule> {
246        self.module.into_hir()
247    }
248
249    pub fn into_graph(self) -> Result<Graph> {
250        if self.module.stage() == GraphStage::Hir {
251            let hir = self
252                .module
253                .into_hir()
254                .ok_or_else(|| anyhow::anyhow!("expected HIR stage"))?;
255            hir_to_graph(hir).map_err(Into::into)
256        } else {
257            self.module.into_graph().map_err(Into::into)
258        }
259    }
260
261    pub fn lower(self) -> Result<GraphModule> {
262        self.module.lower().map_err(Into::into)
263    }
264
265    /// Append side outputs after the primary output node.
266    pub fn with_extra_hir_outputs(mut self, extra: impl IntoIterator<Item = HirNodeId>) -> Self {
267        let primary = self.module.as_hir().expect("HIR stage").outputs[0];
268        let mut outputs = vec![primary];
269        outputs.extend(extra);
270        self.module.set_outputs(outputs);
271        self
272    }
273
274    /// Split into HIR module + param map (common compile path).
275    pub fn into_parts(self) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
276        let params = self.params.clone();
277        let hir = self
278            .into_hir()
279            .ok_or_else(|| anyhow::anyhow!("expected HIR stage"))?;
280        Ok((hir, params))
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use crate::layer::LayerStack;
288    use crate::weight::MapWeights;
289    use rlx_ir::{DType, Shape};
290
291    #[test]
292    fn minimal_embed_flow() {
293        let mut w = MapWeights::default();
294        w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
295
296        let flow = ModelFlow::new("smoke")
297            .input("ids", Shape::new(&[1, 2], DType::F32))
298            .embed("embed.weight");
299
300        let built = flow.build(&mut w).unwrap();
301        let hir = built.into_hir().unwrap();
302        assert!(hir.len() >= 3);
303    }
304
305    #[test]
306    fn custom_stage_passthrough() {
307        let mut w = MapWeights::default();
308        w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
309
310        let flow = ModelFlow::new("custom")
311            .input("ids", Shape::new(&[1, 2], DType::F32))
312            .embed("embed.weight")
313            .custom(|_emit, input| Ok(input));
314
315        let built = flow.build(&mut w).unwrap();
316        assert_eq!(built.primary_shape().rank(), 3);
317    }
318
319    #[test]
320    fn layer_stack_builds_sequence() {
321        let mut w = MapWeights::default();
322        w.insert("ln.weight", vec![1.0; 4], vec![4]);
323
324        let stage = LayerStack::named("block")
325            .rms_norm("ln.weight", 1e-5)
326            .build();
327
328        let flow = ModelFlow::new("stack")
329            .input("x", Shape::new(&[1, 2, 4], DType::F32))
330            .zero_beta(4)
331            .raw_stage(stage);
332
333        let built = flow.build(&mut w).unwrap();
334        assert!(built.into_hir().unwrap().len() >= 4);
335    }
336
337    #[test]
338    fn when_conditional_embed() {
339        let mut w = MapWeights::default();
340        w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
341
342        let with_embed = ModelFlow::new("cond")
343            .input("ids", Shape::new(&[1, 2], DType::F32))
344            .when(true, |f| f.embed("embed.weight"))
345            .build(&mut w)
346            .unwrap();
347        assert!(with_embed.into_hir().unwrap().len() >= 3);
348
349        let mut w2 = MapWeights::default();
350        w2.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
351        let skipped = ModelFlow::new("cond")
352            .input("ids", Shape::new(&[1, 2], DType::F32))
353            .when(false, |f| f.embed("embed.weight"))
354            .build(&mut w2)
355            .unwrap();
356        // Skipped embed — graph is input-only passthrough.
357        assert_eq!(skipped.into_hir().unwrap().len(), 1);
358    }
359}