Skip to main content

rlx_flow/
dsl.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fluent builder methods on [`ModelFlow`] — sugar over [`FlowStage`].
17
18use std::path::Path;
19use std::sync::Arc;
20
21use crate::blocks::{
22    AttnMaskStage, BertEncoderLayerSpec, BertEncoderLayerStage, BertQkvStyle,
23    BindDecodeInputsStage, ClsTokenPoolStage, CustomStage, EmbedStage, GatherAddStage,
24    GatherFromInputStage, GatherLastTokenStage, GeluFfnStage, LayerNormStage, LinearStage,
25    LlamaDecodeLayerStage, LlamaDecoderSpec, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage,
26    NomicEncoderLayerSpec, NomicEncoderLayerStage, RepeatStage, ResidualAddStage,
27    ResidualSaveStage, RmsNormStage, RopeTablesStage, SelfAttnPrefillSpec, SelfAttnPrefillStage,
28    SwiGluStage, dinov2_layer_fused, llama_prefill_layer_composed, llama_prefill_layer_fused,
29    nomic_vision_layer_fused,
30};
31use crate::escape::Emit;
32use crate::flow::ModelFlow;
33use crate::layer::LayerStack;
34use crate::profile::CompileProfile;
35use crate::side::SideOutputs;
36use crate::stage::FlowStage;
37use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
38use crate::value::FlowValue;
39
40impl ModelFlow {
41    /// Load tier-1 profile from a `*.rlx.toml` file (falls back to default on error).
42    pub fn profile_file(mut self, path: impl AsRef<Path>, default: fn() -> CompileProfile) -> Self {
43        self.profile = CompileProfile::from_toml_path(path.as_ref()).unwrap_or_else(|_| default());
44        self
45    }
46
47    /// Encoder / embedding model defaults (Direct lowering, no KV fusion).
48    pub fn profile_encoder(mut self) -> Self {
49        self.profile = CompileProfile::encoder();
50        self
51    }
52
53    /// Gather rows from a side input into the primary flow (starts embedding stack).
54    pub fn gather_from_input(
55        mut self,
56        input_name: impl Into<String>,
57        weight_key: impl Into<String>,
58    ) -> Self {
59        self.stages
60            .push(FlowStage::GatherFromInput(GatherFromInputStage::new(
61                input_name, weight_key, 0,
62            )));
63        self
64    }
65
66    /// Add an embedding looked up from a side input.
67    pub fn gather_add(
68        mut self,
69        input_name: impl Into<String>,
70        weight_key: impl Into<String>,
71    ) -> Self {
72        self.stages.push(FlowStage::GatherAdd(GatherAddStage::new(
73            input_name, weight_key, 0,
74        )));
75        self
76    }
77
78    /// LayerNorm with separate gamma/beta weights.
79    pub fn layer_norm(
80        mut self,
81        gamma_key: impl Into<String>,
82        beta_key: impl Into<String>,
83        eps: f32,
84    ) -> Self {
85        self.stages.push(FlowStage::LayerNorm(LayerNormStage::new(
86            gamma_key, beta_key, eps,
87        )));
88        self
89    }
90
91    /// BERT-style GELU FFN under a layer prefix.
92    pub fn gelu_ffn(mut self, layer_prefix: impl Into<String>) -> Self {
93        self.stages
94            .push(FlowStage::GeluFfn(GeluFfnStage::hf_bert(layer_prefix)));
95        self
96    }
97
98    /// Repeat NomicBERT encoder layers.
99    pub fn repeat_nomic_layers(
100        self,
101        count: usize,
102        hidden_size: usize,
103        num_heads: usize,
104        head_dim: usize,
105        eps: f32,
106    ) -> Self {
107        self.repeat_layers(count, move |i| FlowStage::Named {
108            name: format!("layer{i}"),
109            inner: std::sync::Arc::new(FlowStage::NomicEncoderLayer(NomicEncoderLayerStage::new(
110                NomicEncoderLayerSpec::hf(
111                    format!("encoder.layers.{i}"),
112                    hidden_size,
113                    num_heads,
114                    head_dim,
115                    eps,
116                ),
117            ))),
118        })
119    }
120
121    /// BERT-style encoder layer (fused QKV + padding-mask attention + GELU FFN).
122    pub fn bert_encoder_layer(mut self, spec: BertEncoderLayerSpec) -> Self {
123        self.stages
124            .push(FlowStage::BertEncoderLayer(BertEncoderLayerStage::new(
125                spec,
126            )));
127        self
128    }
129
130    /// Repeat BERT encoder layers with auto-named prefixes.
131    pub fn repeat_bert_layers(
132        self,
133        count: usize,
134        prefix: impl Into<String>,
135        qkv_style: BertQkvStyle,
136        hidden_size: usize,
137        num_heads: usize,
138        eps: f32,
139    ) -> Self {
140        let prefix = prefix.into();
141        self.repeat_layers(count, move |i| {
142            let lp = if prefix.is_empty() {
143                format!("encoder.layer.{i}")
144            } else {
145                format!("{prefix}.encoder.layer.{i}")
146            };
147            FlowStage::Named {
148                name: format!("layer{i}"),
149                inner: std::sync::Arc::new(FlowStage::BertEncoderLayer(
150                    BertEncoderLayerStage::new(BertEncoderLayerSpec::hf(
151                        lp,
152                        qkv_style,
153                        hidden_size,
154                        num_heads,
155                        eps,
156                    )),
157                )),
158            }
159        })
160    }
161
162    /// Synthesize an all-ones attention mask for vision encoders (no padding).
163    pub fn attn_mask_ones(mut self, batch: usize, seq: usize) -> Self {
164        self.stages
165            .push(FlowStage::AttnMask(AttnMaskStage::ones(batch, seq)));
166        self
167    }
168
169    /// Repeat DINOv2 ViT encoder blocks.
170    pub fn repeat_dinov2_layers(
171        self,
172        count: usize,
173        hidden_size: usize,
174        num_heads: usize,
175        eps: f32,
176    ) -> Self {
177        self.repeat_layers(count, move |i| {
178            dinov2_layer_fused(i, hidden_size, num_heads, eps)
179        })
180    }
181
182    /// Repeat NomicVision encoder blocks.
183    pub fn repeat_vision_layers(
184        self,
185        count: usize,
186        hidden_size: usize,
187        num_heads: usize,
188        eps: f32,
189    ) -> Self {
190        self.repeat_layers(count, move |i| {
191            nomic_vision_layer_fused(i, hidden_size, num_heads, eps)
192        })
193    }
194
195    /// Compatibility shim: repeat SigLIP-style vision layers.
196    pub fn repeat_siglip_layers(
197        self,
198        count: usize,
199        hidden_size: usize,
200        num_heads: usize,
201        eps: f32,
202    ) -> Self {
203        self.repeat_layers(count, move |i| {
204            nomic_vision_layer_fused(i, hidden_size, num_heads, eps)
205        })
206    }
207
208    /// Pool CLS token: `[batch, seq, hidden]` → `[batch, hidden]`.
209    pub fn cls_token_pool(mut self, batch: usize, hidden: usize) -> Self {
210        self.stages
211            .push(FlowStage::ClsTokenPool(ClsTokenPoolStage::new(
212                batch, hidden,
213            )));
214        self
215    }
216
217    /// Fusion-first prefill defaults.
218    pub fn profile_prefill(mut self) -> Self {
219        self.profile = CompileProfile::llama32_prefill();
220        self
221    }
222
223    /// Decode / KV-cache defaults (`Fusable` lowering).
224    pub fn profile_decode(mut self) -> Self {
225        self.profile = CompileProfile::llama32_decode();
226        self
227    }
228
229    /// Token embedding (`model.embed_tokens.weight` by default).
230    pub fn embed(mut self, weight_key: impl Into<String>) -> Self {
231        self.stages
232            .push(FlowStage::Embed(EmbedStage::token(weight_key)));
233        self
234    }
235
236    /// HuggingFace-style token embedding table.
237    pub fn token_embed(self) -> Self {
238        self.embed("model.embed_tokens.weight")
239    }
240
241    /// Precomputed RoPE sin/cos tables stored as params.
242    pub fn rope_tables(mut self, tables: RopeTablesStage) -> Self {
243        self.stages.push(FlowStage::RopeTables(tables));
244        self
245    }
246
247    /// Rank-1 zero vector for RMSNorm beta slots (LLaMA has no beta).
248    pub fn zero_beta(self, len: usize) -> Self {
249        self.zero_beta_named("zero_beta", len)
250    }
251
252    pub fn zero_beta_named(mut self, name: impl Into<String>, len: usize) -> Self {
253        self.stages.push(FlowStage::ZeroBeta {
254            name: name.into(),
255            len,
256        });
257        self
258    }
259
260    /// Bind decode inputs (call after declaring `rope_cos`, `past_k_*`, …).
261    pub fn bind_decode_inputs(mut self, num_layers: usize, custom_mask: bool) -> Self {
262        self.stages
263            .push(FlowStage::BindDecodeInputs(BindDecodeInputsStage {
264                num_layers,
265                use_custom_mask: custom_mask,
266            }));
267        self
268    }
269
270    /// Repeat a per-layer stage `count` times (layer index passed to closure).
271    pub fn repeat_layers(
272        mut self,
273        count: usize,
274        stage_for_layer: impl Fn(usize) -> FlowStage + Send + Sync + 'static,
275    ) -> Self {
276        self.stages
277            .push(FlowStage::Repeat(RepeatStage::new(count, stage_for_layer)));
278        self
279    }
280
281    /// Named decoder layer (shows up in fusion / inspect dumps).
282    pub fn named_layer(mut self, name: impl Into<String>, inner: FlowStage) -> Self {
283        self.stages.push(FlowStage::Named {
284            name: name.into(),
285            inner: Arc::new(inner),
286        });
287        self
288    }
289
290    /// Build a named layer from a [`LayerStack`] closure.
291    pub fn layer(
292        self,
293        name: impl Into<String>,
294        build: impl FnOnce(LayerStack) -> LayerStack,
295    ) -> Self {
296        self.raw_stage(build(LayerStack::named(name)).build())
297    }
298
299    /// Fused LLaMA prefill layer (default fast path).
300    pub fn llama_prefill_layer(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
301        self.raw_stage(llama_prefill_layer_fused(layer_idx, spec))
302    }
303
304    /// Composed LLaMA prefill layer (small blocks — customize via [`LayerStack`]).
305    pub fn llama_prefill_layer_composed(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
306        self.raw_stage(llama_prefill_layer_composed(layer_idx, spec))
307    }
308
309    pub fn linear(mut self, weight_key: impl Into<String>, transpose: bool) -> Self {
310        self.stages
311            .push(FlowStage::Linear(LinearStage::new(weight_key, transpose)));
312        self
313    }
314
315    pub fn residual_save(mut self) -> Self {
316        self.stages.push(FlowStage::ResidualSave(ResidualSaveStage));
317        self
318    }
319
320    pub fn residual_add(mut self) -> Self {
321        self.stages.push(FlowStage::ResidualAdd(ResidualAddStage));
322        self
323    }
324
325    pub fn swiglu(
326        mut self,
327        gate_key: impl Into<String>,
328        up_key: impl Into<String>,
329        down_key: impl Into<String>,
330    ) -> Self {
331        self.stages.push(FlowStage::SwiGlu(SwiGluStage::new(
332            gate_key, up_key, down_key,
333        )));
334        self
335    }
336
337    pub fn swiglu_hf_mlp(mut self, prefix: impl Into<String>) -> Self {
338        self.stages
339            .push(FlowStage::SwiGlu(SwiGluStage::hf_mlp(prefix)));
340        self
341    }
342
343    pub fn self_attn_prefill(mut self, spec: SelfAttnPrefillSpec) -> Self {
344        self.stages
345            .push(FlowStage::SelfAttnPrefill(SelfAttnPrefillStage::new(spec)));
346        self
347    }
348
349    pub fn gdn_scan(mut self, stage: crate::blocks::GdnScanStage) -> Self {
350        self.stages.push(FlowStage::GdnScan(stage));
351        self
352    }
353
354    pub fn store_stream(mut self, name: impl Into<String>) -> Self {
355        self.stages
356            .push(FlowStage::StoreStream(StoreStreamStage::new(name)));
357        self
358    }
359
360    pub fn load_stream(mut self, name: impl Into<String>) -> Self {
361        self.stages
362            .push(FlowStage::LoadStream(LoadStreamStage::new(name)));
363        self
364    }
365
366    /// Bind declared graph inputs into named streams (multi-input models).
367    ///
368    /// Example: FLUX `.bind_inputs_to_streams(&[("hidden", "img"), ("encoder", "txt")])`.
369    pub fn bind_inputs_to_streams(
370        mut self,
371        pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
372    ) -> Self {
373        let pairs: Vec<(String, String)> = pairs
374            .into_iter()
375            .map(|(input, stream)| (input.into(), stream.into()))
376            .collect();
377        self.stages.push(FlowStage::Custom(CustomStage::named(
378            "bind_inputs_to_streams",
379            move |emit, primary| {
380                let primary = primary.ok_or_else(|| {
381                    anyhow::anyhow!("bind_inputs_to_streams requires primary input")
382                })?;
383                for (input_name, stream_name) in &pairs {
384                    let value = emit.flow_input(input_name)?;
385                    emit.state.streams.insert(stream_name.clone(), value);
386                }
387                Ok(Some(primary))
388            },
389        )));
390        self
391    }
392
393    pub fn dual_stream<F>(
394        mut self,
395        name: impl Into<String>,
396        stream_a: impl Into<String>,
397        stream_b: impl Into<String>,
398        f: F,
399    ) -> Self
400    where
401        F: Fn(&mut Emit<'_>, FlowValue, FlowValue) -> anyhow::Result<(FlowValue, FlowValue)>
402            + Send
403            + Sync
404            + 'static,
405    {
406        self.stages.push(FlowStage::DualStream(DualStreamStage::new(
407            name, stream_a, stream_b, f,
408        )));
409        self
410    }
411
412    pub fn plugin<F>(mut self, f: F) -> Self
413    where
414        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
415            + Send
416            + Sync
417            + 'static,
418    {
419        self.stages.push(crate::plugin::plugin(f));
420        self
421    }
422
423    pub fn plugin_named<F>(mut self, name: impl Into<String>, f: F) -> Self
424    where
425        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
426            + Send
427            + Sync
428            + 'static,
429    {
430        self.stages.push(crate::plugin::plugin_named(name, f));
431        self
432    }
433
434    /// Hidden states output (no LM head).
435    pub fn hidden_states(self) -> Self {
436        self.output("hidden")
437    }
438
439    /// LLaMA prefill decoder block at `layer_idx`.
440    pub fn llama_decoder_layer(
441        self,
442        layer_idx: usize,
443        spec: crate::blocks::LlamaDecoderSpec,
444    ) -> Self {
445        self.named_layer(
446            format!("layer{layer_idx}"),
447            FlowStage::LlamaDecoder(LlamaDecoderStage::layer(layer_idx, spec)),
448        )
449    }
450
451    /// LLaMA decode block with KV-cache concat.
452    pub fn llama_decode_layer(
453        self,
454        layer_idx: usize,
455        spec: crate::blocks::LlamaDecodeLayerSpec,
456        kv_out: SideOutputs,
457    ) -> Self {
458        self.named_layer(
459            format!("layer{layer_idx}"),
460            FlowStage::LlamaDecodeLayer(LlamaDecodeLayerStage::layer(
461                layer_idx,
462                spec,
463                kv_out.inner(),
464            )),
465        )
466    }
467
468    /// Side-effect K/V tap before a prefill layer (exports cache tensors).
469    pub fn llama_kv_tap(
470        mut self,
471        layer_idx: usize,
472        head_dim: usize,
473        eps: f32,
474        sink: &SideOutputs,
475    ) -> Self {
476        self.stages
477            .push(FlowStage::LlamaKvTap(LlamaKvTapStage::layer(
478                layer_idx,
479                head_dim,
480                eps,
481                sink.inner(),
482            )));
483        self
484    }
485
486    /// Final RMSNorm before LM head (`model.norm.weight` by default).
487    pub fn final_norm(self, eps: f32) -> Self {
488        self.rms_norm("model.norm.weight", eps)
489    }
490
491    pub fn rms_norm(mut self, weight_key: impl Into<String>, eps: f32) -> Self {
492        self.stages
493            .push(FlowStage::RmsNorm(RmsNormStage::new(weight_key, eps)));
494        self
495    }
496
497    /// Gather last token (dynamic `last_token_idx` input).
498    pub fn gather_last_token_dynamic(mut self, batch: usize) -> Self {
499        self.stages
500            .push(FlowStage::GatherLastToken(GatherLastTokenStage::dynamic(
501                batch,
502            )));
503        self
504    }
505
506    /// Gather last token at fixed sequence length.
507    pub fn gather_last_token_at(mut self, batch: usize, seq: usize) -> Self {
508        self.stages.push(FlowStage::GatherLastToken(
509            GatherLastTokenStage::static_last(batch, seq),
510        ));
511        self
512    }
513
514    /// Causal LM head — tied or separate weights.
515    pub fn lm_head(
516        mut self,
517        vocab_size: usize,
518        hidden_size: usize,
519        tie_word_embeddings: bool,
520    ) -> Self {
521        let stage = if tie_word_embeddings {
522            LmHeadStage::tied(vocab_size, hidden_size)
523        } else {
524            LmHeadStage::separate("lm_head.weight", vocab_size, hidden_size)
525        };
526        self.stages.push(FlowStage::LmHead(stage));
527        self.output("logits")
528    }
529
530    /// Tier-2 escape hatch — append a raw stage.
531    pub fn raw_stage(mut self, stage: FlowStage) -> Self {
532        self.stages.push(stage);
533        self
534    }
535
536    /// Append multiple raw stages in order.
537    pub fn raw_stages(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
538        self.stages.extend(stages);
539        self
540    }
541
542    /// Run a list of stages as one nested sequence (side-effect stages allowed).
543    pub fn sequence(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
544        self.stages
545            .push(FlowStage::Sequence(stages.into_iter().collect()));
546        self
547    }
548
549    /// Conditionally transform the builder (e.g. optional vision tower).
550    pub fn when(self, cond: bool, f: impl FnOnce(Self) -> Self) -> Self {
551        if cond { f(self) } else { self }
552    }
553
554    /// Tier-2 custom subgraph — prefer promoting repeated patterns to blocks.
555    pub fn custom<F>(mut self, f: F) -> Self
556    where
557        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
558            + Send
559            + Sync
560            + 'static,
561    {
562        self.stages.push(FlowStage::Custom(CustomStage::new(f)));
563        self
564    }
565
566    /// Named custom subgraph (shows up in fusion / inspect dumps).
567    pub fn custom_named<F>(mut self, name: impl Into<String>, f: F) -> Self
568    where
569        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
570            + Send
571            + Sync
572            + 'static,
573    {
574        self.stages
575            .push(FlowStage::Custom(CustomStage::named(name, f)));
576        self
577    }
578
579    /// Patch the builder after preset assembly (arch recipes, Llama32Flow hooks).
580    pub fn patch(self, f: impl FnOnce(Self) -> Self) -> Self {
581        f(self)
582    }
583}