Skip to main content

rlx_flow/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! ModelFlow builder and built output.
5
6use std::collections::HashMap;
7
8use anyhow::Result;
9use rlx_ir::hir::HirModule;
10use rlx_ir::{Graph, GraphModule, GraphStage, HirNodeId, Shape, hir_to_graph};
11
12use crate::context::{FlowCtx, FlowState};
13use crate::execution::ModelExecutionConfig;
14use crate::extension::FlowExtensionPlan;
15use crate::profile::CompileProfile;
16use crate::stage::FlowStage;
17use crate::value::FlowValue;
18use crate::weight::WeightSource;
19
20/// Block assembly-line builder — tier-0 model author surface.
21#[derive(Debug)]
22pub struct ModelFlow {
23    name: String,
24    pub(crate) profile: CompileProfile,
25    /// Graph inputs declared before block stages (do not participate in tensor flow).
26    inputs: Vec<(String, Shape)>,
27    pub(crate) stages: Vec<FlowStage>,
28    output_names: Vec<String>,
29    extra_outputs: Vec<HirNodeId>,
30    extension_plan: FlowExtensionPlan,
31}
32
33impl ModelFlow {
34    pub fn new(name: impl Into<String>) -> Self {
35        Self {
36            name: name.into(),
37            profile: CompileProfile::default(),
38            inputs: Vec::new(),
39            stages: Vec::new(),
40            output_names: vec!["output".into()],
41            extra_outputs: Vec::new(),
42            extension_plan: FlowExtensionPlan::default(),
43        }
44    }
45
46    /// HIR extensions to apply after assemble, before compile (retroactive plugins).
47    pub fn with_extensions(mut self, plan: FlowExtensionPlan) -> Self {
48        self.extension_plan = plan;
49        self
50    }
51
52    /// Declare a graph input. The first input starts the tensor flow; later
53    /// inputs are side declarations only (e.g. `last_token_idx`).
54    pub fn input(mut self, name: impl Into<String>, shape: Shape) -> Self {
55        self.inputs.push((name.into(), shape));
56        self
57    }
58
59    pub fn with_profile(mut self, profile: CompileProfile) -> Self {
60        self.profile = profile;
61        self
62    }
63
64    pub fn profile(&self) -> &CompileProfile {
65        &self.profile
66    }
67
68    pub fn stage(mut self, stage: FlowStage) -> Self {
69        self.stages.push(stage);
70        self
71    }
72
73    pub fn output(mut self, name: impl Into<String>) -> Self {
74        self.output_names = vec![name.into()];
75        self
76    }
77
78    pub fn outputs(mut self, names: impl IntoIterator<Item = impl Into<String>>) -> Self {
79        self.output_names = names.into_iter().map(Into::into).collect();
80        self
81    }
82
83    /// Append side outputs (e.g. per-layer KV taps) after the primary output.
84    pub fn with_extra_outputs(mut self, ids: Vec<HirNodeId>) -> Self {
85        self.extra_outputs = ids;
86        self
87    }
88
89    /// Build from a composable recipe, then optionally patch before compile.
90    pub fn from_recipe(recipe: &impl crate::recipe::ModelRecipe) -> Self {
91        recipe.assemble()
92    }
93
94    pub fn build(self, weights: &mut dyn WeightSource) -> Result<BuiltModel> {
95        let mut module =
96            GraphModule::hir(&self.name).with_fusion_policy(self.profile.fusion_policy());
97        let mut params = HashMap::new();
98        let mut state = FlowState::default();
99        let mut ctx = FlowCtx {
100            module,
101            params: &mut params,
102            weights,
103            profile: &self.profile,
104            state: &mut state,
105        };
106
107        let mut value: Option<FlowValue> = None;
108        for (i, (name, shape)) in self.inputs.iter().enumerate() {
109            let id = ctx.input(name, shape.clone());
110            ctx.state.inputs.insert(name.clone(), (id, shape.clone()));
111            if i == 0 {
112                value = Some(ctx.wrap(id, shape.clone()));
113            }
114        }
115        for stage in &self.stages {
116            value = stage.emit(&mut ctx, value)?;
117        }
118
119        let primary = value.ok_or_else(|| anyhow::anyhow!("ModelFlow produced no output"))?;
120        let mut outputs = vec![primary.id];
121        outputs.extend(self.extra_outputs);
122
123        ctx.module.set_outputs(outputs);
124        module = ctx.module;
125        if let Some(hir) = module.as_hir_mut() {
126            self.extension_plan.apply(hir);
127        }
128
129        Ok(BuiltModel {
130            module,
131            params,
132            typed_params: Vec::new(),
133            profile: self.profile,
134            output_names: self.output_names,
135            primary_shape: primary.shape,
136        })
137    }
138
139    /// Compatibility shim: older callers passed GGUF packed matmul params.
140    ///
141    /// The current flow builder ignores packed params; packed lowering lives in model crates.
142    pub fn build_with(
143        self,
144        weights: &mut dyn WeightSource,
145        _gguf_packed: Option<&crate::GgufPackedParams>,
146    ) -> Result<BuiltModel> {
147        self.build(weights)
148    }
149}
150
151/// Result of assembling a model flow.
152#[derive(Debug, Clone)]
153pub struct BuiltModel {
154    pub module: GraphModule,
155    pub params: HashMap<String, Vec<f32>>,
156    /// Packed U8 params (GGUF quant blobs) attached after compile via `set_param_typed`.
157    pub typed_params: Vec<(String, Vec<u8>, rlx_ir::DType)>,
158    pub profile: CompileProfile,
159    output_names: Vec<String>,
160    primary_shape: Shape,
161}
162
163impl BuiltModel {
164    /// Attach variant + execution preset (shader-component bundle).
165    pub fn with_execution_config(mut self, config: &ModelExecutionConfig) -> Self {
166        self.profile = config.compile_profile();
167        self
168    }
169
170    /// Wrap a legacy HIR builder product as tier-0 flow output (migration bridge).
171    pub fn from_hir(hir: HirModule, params: HashMap<String, Vec<f32>>) -> anyhow::Result<Self> {
172        let primary = hir
173            .outputs
174            .first()
175            .copied()
176            .ok_or_else(|| anyhow::anyhow!("from_hir: module has no outputs"))?;
177        let primary_shape = hir.node(primary).shape.clone();
178        Ok(Self {
179            module: GraphModule::from_hir(hir),
180            params,
181            typed_params: Vec::new(),
182            profile: CompileProfile::default(),
183            output_names: vec!["output".into()],
184            primary_shape,
185        })
186    }
187
188    /// Wrap a legacy MIR graph builder product as tier-0 flow output (migration bridge).
189    pub fn from_graph(graph: Graph, params: HashMap<String, Vec<f32>>) -> anyhow::Result<Self> {
190        let primary = graph
191            .outputs
192            .first()
193            .copied()
194            .ok_or_else(|| anyhow::anyhow!("from_graph: graph has no outputs"))?;
195        let primary_shape = graph.node(primary).shape.clone();
196        Ok(Self {
197            module: GraphModule::from_graph(graph),
198            params,
199            typed_params: Vec::new(),
200            profile: CompileProfile::default(),
201            output_names: vec!["output".into()],
202            primary_shape,
203        })
204    }
205
206    pub fn profile(&self) -> &CompileProfile {
207        &self.profile
208    }
209
210    pub fn params(&self) -> &HashMap<String, Vec<f32>> {
211        &self.params
212    }
213
214    pub fn primary_shape(&self) -> &Shape {
215        &self.primary_shape
216    }
217
218    pub fn output_names(&self) -> &[String] {
219        &self.output_names
220    }
221
222    /// `(Graph, params)` for legacy compile paths.
223    pub fn into_graph_parts(self) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
224        let params = self.params.clone();
225        let graph = self.into_graph()?;
226        Ok((graph, params))
227    }
228
229    pub fn into_graph_module(self) -> GraphModule {
230        self.module
231    }
232
233    pub fn into_hir(self) -> Option<HirModule> {
234        self.module.into_hir()
235    }
236
237    pub fn into_graph(self) -> Result<Graph> {
238        if self.module.stage() == GraphStage::Hir {
239            let hir = self
240                .module
241                .into_hir()
242                .ok_or_else(|| anyhow::anyhow!("expected HIR stage"))?;
243            hir_to_graph(hir).map_err(Into::into)
244        } else {
245            self.module.into_graph().map_err(Into::into)
246        }
247    }
248
249    pub fn lower(self) -> Result<GraphModule> {
250        self.module.lower().map_err(Into::into)
251    }
252
253    /// Append side outputs after the primary output node.
254    pub fn with_extra_hir_outputs(mut self, extra: impl IntoIterator<Item = HirNodeId>) -> Self {
255        let primary = self.module.as_hir().expect("HIR stage").outputs[0];
256        let mut outputs = vec![primary];
257        outputs.extend(extra);
258        self.module.set_outputs(outputs);
259        self
260    }
261
262    /// Split into HIR module + param map (common compile path).
263    pub fn into_parts(self) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
264        let params = self.params.clone();
265        let hir = self
266            .into_hir()
267            .ok_or_else(|| anyhow::anyhow!("expected HIR stage"))?;
268        Ok((hir, params))
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::layer::LayerStack;
276    use crate::weight::MapWeights;
277    use rlx_ir::{DType, Shape};
278
279    #[test]
280    fn minimal_embed_flow() {
281        let mut w = MapWeights::default();
282        w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
283
284        let flow = ModelFlow::new("smoke")
285            .input("ids", Shape::new(&[1, 2], DType::F32))
286            .embed("embed.weight");
287
288        let built = flow.build(&mut w).unwrap();
289        let hir = built.into_hir().unwrap();
290        assert!(hir.len() >= 3);
291    }
292
293    #[test]
294    fn custom_stage_passthrough() {
295        let mut w = MapWeights::default();
296        w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
297
298        let flow = ModelFlow::new("custom")
299            .input("ids", Shape::new(&[1, 2], DType::F32))
300            .embed("embed.weight")
301            .custom(|_emit, input| Ok(input));
302
303        let built = flow.build(&mut w).unwrap();
304        assert_eq!(built.primary_shape().rank(), 3);
305    }
306
307    #[test]
308    fn layer_stack_builds_sequence() {
309        let mut w = MapWeights::default();
310        w.insert("ln.weight", vec![1.0; 4], vec![4]);
311
312        let stage = LayerStack::named("block")
313            .rms_norm("ln.weight", 1e-5)
314            .build();
315
316        let flow = ModelFlow::new("stack")
317            .input("x", Shape::new(&[1, 2, 4], DType::F32))
318            .zero_beta(4)
319            .raw_stage(stage);
320
321        let built = flow.build(&mut w).unwrap();
322        assert!(built.into_hir().unwrap().len() >= 4);
323    }
324
325    #[test]
326    fn when_conditional_embed() {
327        let mut w = MapWeights::default();
328        w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
329
330        let with_embed = ModelFlow::new("cond")
331            .input("ids", Shape::new(&[1, 2], DType::F32))
332            .when(true, |f| f.embed("embed.weight"))
333            .build(&mut w)
334            .unwrap();
335        assert!(with_embed.into_hir().unwrap().len() >= 3);
336
337        let mut w2 = MapWeights::default();
338        w2.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
339        let skipped = ModelFlow::new("cond")
340            .input("ids", Shape::new(&[1, 2], DType::F32))
341            .when(false, |f| f.embed("embed.weight"))
342            .build(&mut w2)
343            .unwrap();
344        // Skipped embed — graph is input-only passthrough.
345        assert_eq!(skipped.into_hir().unwrap().len(), 1);
346    }
347}