Skip to main content

rlx_flow/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! ModelFlow builder and built output.
5
6use std::collections::HashMap;
7
8use anyhow::Result;
9use rlx_ir::hir::HirModule;
10use rlx_ir::{Graph, GraphModule, GraphStage, HirNodeId, Shape, hir_to_graph};
11
12use crate::context::{FlowCtx, FlowState};
13use crate::execution::ModelExecutionConfig;
14use crate::extension::FlowExtensionPlan;
15use crate::profile::CompileProfile;
16use crate::stage::FlowStage;
17use crate::value::FlowValue;
18use crate::weight::WeightSource;
19
20/// Block assembly-line builder — tier-0 model author surface.
21#[derive(Debug)]
22pub struct ModelFlow {
23    name: String,
24    pub(crate) profile: CompileProfile,
25    /// Graph inputs declared before block stages (do not participate in tensor flow).
26    inputs: Vec<(String, Shape)>,
27    pub(crate) stages: Vec<FlowStage>,
28    output_names: Vec<String>,
29    extra_outputs: Vec<HirNodeId>,
30    extension_plan: FlowExtensionPlan,
31}
32
33impl ModelFlow {
34    pub fn new(name: impl Into<String>) -> Self {
35        Self {
36            name: name.into(),
37            profile: CompileProfile::default(),
38            inputs: Vec::new(),
39            stages: Vec::new(),
40            output_names: vec!["output".into()],
41            extra_outputs: Vec::new(),
42            extension_plan: FlowExtensionPlan::default(),
43        }
44    }
45
46    /// HIR extensions to apply after assemble, before compile (retroactive plugins).
47    pub fn with_extensions(mut self, plan: FlowExtensionPlan) -> Self {
48        self.extension_plan = plan;
49        self
50    }
51
52    /// Declare a graph input. The first input starts the tensor flow; later
53    /// inputs are side declarations only (e.g. `last_token_idx`).
54    pub fn input(mut self, name: impl Into<String>, shape: Shape) -> Self {
55        self.inputs.push((name.into(), shape));
56        self
57    }
58
59    pub fn with_profile(mut self, profile: CompileProfile) -> Self {
60        self.profile = profile;
61        self
62    }
63
64    pub fn profile(&self) -> &CompileProfile {
65        &self.profile
66    }
67
68    pub fn stage(mut self, stage: FlowStage) -> Self {
69        self.stages.push(stage);
70        self
71    }
72
73    pub fn output(mut self, name: impl Into<String>) -> Self {
74        self.output_names = vec![name.into()];
75        self
76    }
77
78    pub fn outputs(mut self, names: impl IntoIterator<Item = impl Into<String>>) -> Self {
79        self.output_names = names.into_iter().map(Into::into).collect();
80        self
81    }
82
83    /// Append side outputs (e.g. per-layer KV taps) after the primary output.
84    pub fn with_extra_outputs(mut self, ids: Vec<HirNodeId>) -> Self {
85        self.extra_outputs = ids;
86        self
87    }
88
89    /// Build from a composable recipe, then optionally patch before compile.
90    pub fn from_recipe(recipe: &impl crate::recipe::ModelRecipe) -> Self {
91        recipe.assemble()
92    }
93
94    pub fn build(self, weights: &mut dyn WeightSource) -> Result<BuiltModel> {
95        let mut module =
96            GraphModule::hir(&self.name).with_fusion_policy(self.profile.fusion_policy());
97        let mut params = HashMap::new();
98        let mut state = FlowState::default();
99        let mut ctx = FlowCtx {
100            module,
101            params: &mut params,
102            weights,
103            profile: &self.profile,
104            state: &mut state,
105        };
106
107        let mut value: Option<FlowValue> = None;
108        for (i, (name, shape)) in self.inputs.iter().enumerate() {
109            let id = ctx.input(name, shape.clone());
110            ctx.state.inputs.insert(name.clone(), (id, shape.clone()));
111            if i == 0 {
112                value = Some(ctx.wrap(id, shape.clone()));
113            }
114        }
115        for stage in &self.stages {
116            value = stage.emit(&mut ctx, value)?;
117        }
118
119        let primary = value.ok_or_else(|| anyhow::anyhow!("ModelFlow produced no output"))?;
120        let mut outputs = vec![primary.id];
121        outputs.extend(self.extra_outputs);
122
123        ctx.module.set_outputs(outputs);
124        module = ctx.module;
125        if let Some(hir) = module.as_hir_mut() {
126            self.extension_plan.apply(hir);
127        }
128
129        Ok(BuiltModel {
130            module,
131            params,
132            profile: self.profile,
133            output_names: self.output_names,
134            primary_shape: primary.shape,
135        })
136    }
137}
138
139/// Result of assembling a model flow.
140#[derive(Debug, Clone)]
141pub struct BuiltModel {
142    pub module: GraphModule,
143    pub params: HashMap<String, Vec<f32>>,
144    pub profile: CompileProfile,
145    output_names: Vec<String>,
146    primary_shape: Shape,
147}
148
149impl BuiltModel {
150    /// Attach variant + execution preset (shader-component bundle).
151    pub fn with_execution_config(mut self, config: &ModelExecutionConfig) -> Self {
152        self.profile = config.compile_profile();
153        self
154    }
155
156    /// Wrap a legacy HIR builder product as tier-0 flow output (migration bridge).
157    pub fn from_hir(hir: HirModule, params: HashMap<String, Vec<f32>>) -> anyhow::Result<Self> {
158        let primary = hir
159            .outputs
160            .first()
161            .copied()
162            .ok_or_else(|| anyhow::anyhow!("from_hir: module has no outputs"))?;
163        let primary_shape = hir.node(primary).shape.clone();
164        Ok(Self {
165            module: GraphModule::from_hir(hir),
166            params,
167            profile: CompileProfile::default(),
168            output_names: vec!["output".into()],
169            primary_shape,
170        })
171    }
172
173    /// Wrap a legacy MIR graph builder product as tier-0 flow output (migration bridge).
174    pub fn from_graph(graph: Graph, params: HashMap<String, Vec<f32>>) -> anyhow::Result<Self> {
175        let primary = graph
176            .outputs
177            .first()
178            .copied()
179            .ok_or_else(|| anyhow::anyhow!("from_graph: graph has no outputs"))?;
180        let primary_shape = graph.node(primary).shape.clone();
181        Ok(Self {
182            module: GraphModule::from_graph(graph),
183            params,
184            profile: CompileProfile::default(),
185            output_names: vec!["output".into()],
186            primary_shape,
187        })
188    }
189
190    pub fn profile(&self) -> &CompileProfile {
191        &self.profile
192    }
193
194    pub fn params(&self) -> &HashMap<String, Vec<f32>> {
195        &self.params
196    }
197
198    pub fn primary_shape(&self) -> &Shape {
199        &self.primary_shape
200    }
201
202    pub fn output_names(&self) -> &[String] {
203        &self.output_names
204    }
205
206    /// `(Graph, params)` for legacy compile paths.
207    pub fn into_graph_parts(self) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
208        let params = self.params.clone();
209        let graph = self.into_graph()?;
210        Ok((graph, params))
211    }
212
213    pub fn into_graph_module(self) -> GraphModule {
214        self.module
215    }
216
217    pub fn into_hir(self) -> Option<HirModule> {
218        self.module.into_hir()
219    }
220
221    pub fn into_graph(self) -> Result<Graph> {
222        if self.module.stage() == GraphStage::Hir {
223            let hir = self
224                .module
225                .into_hir()
226                .ok_or_else(|| anyhow::anyhow!("expected HIR stage"))?;
227            hir_to_graph(hir).map_err(Into::into)
228        } else {
229            self.module.into_graph().map_err(Into::into)
230        }
231    }
232
233    pub fn lower(self) -> Result<GraphModule> {
234        self.module.lower().map_err(Into::into)
235    }
236
237    /// Append side outputs after the primary output node.
238    pub fn with_extra_hir_outputs(mut self, extra: impl IntoIterator<Item = HirNodeId>) -> Self {
239        let primary = self.module.as_hir().expect("HIR stage").outputs[0];
240        let mut outputs = vec![primary];
241        outputs.extend(extra);
242        self.module.set_outputs(outputs);
243        self
244    }
245
246    /// Split into HIR module + param map (common compile path).
247    pub fn into_parts(self) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
248        let params = self.params.clone();
249        let hir = self
250            .into_hir()
251            .ok_or_else(|| anyhow::anyhow!("expected HIR stage"))?;
252        Ok((hir, params))
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::layer::LayerStack;
260    use crate::weight::MapWeights;
261    use rlx_ir::{DType, Shape};
262
263    #[test]
264    fn minimal_embed_flow() {
265        let mut w = MapWeights::default();
266        w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
267
268        let flow = ModelFlow::new("smoke")
269            .input("ids", Shape::new(&[1, 2], DType::F32))
270            .embed("embed.weight");
271
272        let built = flow.build(&mut w).unwrap();
273        let hir = built.into_hir().unwrap();
274        assert!(hir.len() >= 3);
275    }
276
277    #[test]
278    fn custom_stage_passthrough() {
279        let mut w = MapWeights::default();
280        w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
281
282        let flow = ModelFlow::new("custom")
283            .input("ids", Shape::new(&[1, 2], DType::F32))
284            .embed("embed.weight")
285            .custom(|_emit, input| Ok(input));
286
287        let built = flow.build(&mut w).unwrap();
288        assert_eq!(built.primary_shape().rank(), 3);
289    }
290
291    #[test]
292    fn layer_stack_builds_sequence() {
293        let mut w = MapWeights::default();
294        w.insert("ln.weight", vec![1.0; 4], vec![4]);
295
296        let stage = LayerStack::named("block")
297            .rms_norm("ln.weight", 1e-5)
298            .build();
299
300        let flow = ModelFlow::new("stack")
301            .input("x", Shape::new(&[1, 2, 4], DType::F32))
302            .zero_beta(4)
303            .raw_stage(stage);
304
305        let built = flow.build(&mut w).unwrap();
306        assert!(built.into_hir().unwrap().len() >= 4);
307    }
308
309    #[test]
310    fn when_conditional_embed() {
311        let mut w = MapWeights::default();
312        w.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
313
314        let with_embed = ModelFlow::new("cond")
315            .input("ids", Shape::new(&[1, 2], DType::F32))
316            .when(true, |f| f.embed("embed.weight"))
317            .build(&mut w)
318            .unwrap();
319        assert!(with_embed.into_hir().unwrap().len() >= 3);
320
321        let mut w2 = MapWeights::default();
322        w2.insert("embed.weight", vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]);
323        let skipped = ModelFlow::new("cond")
324            .input("ids", Shape::new(&[1, 2], DType::F32))
325            .when(false, |f| f.embed("embed.weight"))
326            .build(&mut w2)
327            .unwrap();
328        // Skipped embed — graph is input-only passthrough.
329        assert_eq!(skipped.into_hir().unwrap().len(), 1);
330    }
331}