Skip to main content

rlx_flow/
dsl.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fluent builder methods on [`ModelFlow`] — sugar over [`FlowStage`].
17
18use std::path::Path;
19use std::sync::Arc;
20
21use crate::blocks::{
22    AttnMaskStage, BertEncoderLayerSpec, BertEncoderLayerStage, BertQkvStyle,
23    BindDecodeInputsStage, ClsTokenPoolStage, CustomStage, EmbedStage, GatherAddStage,
24    GatherDecodeRopeStage, GatherFromInputStage, GatherLastTokenStage, GeluFfnStage,
25    LayerNormStage, LinearStage, LlamaDecodeLayerStage, LlamaDecoderSpec, LlamaDecoderStage,
26    LlamaKvTapStage, LmHeadStage, NomicEncoderLayerSpec, NomicEncoderLayerStage, RepeatStage,
27    ResidualAddStage, ResidualSaveStage, RmsNormStage, RopeTablesStage, SelfAttnPrefillSpec,
28    SelfAttnPrefillStage, SwiGluStage, dinov2_layer_fused, llama_prefill_layer_composed,
29    llama_prefill_layer_fused, nomic_vision_layer_fused,
30};
31use crate::escape::Emit;
32use crate::flow::ModelFlow;
33use crate::layer::LayerStack;
34use crate::profile::CompileProfile;
35use crate::side::SideOutputs;
36use crate::stage::FlowStage;
37use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
38use crate::value::FlowValue;
39
40impl ModelFlow {
41    /// Load tier-1 profile from a `*.rlx.toml` file (falls back to default on error).
42    pub fn profile_file(mut self, path: impl AsRef<Path>, default: fn() -> CompileProfile) -> Self {
43        self.profile = CompileProfile::from_toml_path(path.as_ref()).unwrap_or_else(|_| default());
44        self
45    }
46
47    /// Encoder / embedding model defaults (Direct lowering, no KV fusion).
48    pub fn profile_encoder(mut self) -> Self {
49        self.profile = CompileProfile::encoder();
50        self
51    }
52
53    /// Gather rows from a side input into the primary flow (starts embedding stack).
54    pub fn gather_from_input(
55        mut self,
56        input_name: impl Into<String>,
57        weight_key: impl Into<String>,
58    ) -> Self {
59        self.stages
60            .push(FlowStage::GatherFromInput(GatherFromInputStage::new(
61                input_name, weight_key, 0,
62            )));
63        self
64    }
65
66    /// Add an embedding looked up from a side input.
67    pub fn gather_add(
68        mut self,
69        input_name: impl Into<String>,
70        weight_key: impl Into<String>,
71    ) -> Self {
72        self.stages.push(FlowStage::GatherAdd(GatherAddStage::new(
73            input_name, weight_key, 0,
74        )));
75        self
76    }
77
78    /// LayerNorm with separate gamma/beta weights.
79    pub fn layer_norm(
80        mut self,
81        gamma_key: impl Into<String>,
82        beta_key: impl Into<String>,
83        eps: f32,
84    ) -> Self {
85        self.stages.push(FlowStage::LayerNorm(LayerNormStage::new(
86            gamma_key, beta_key, eps,
87        )));
88        self
89    }
90
91    /// BERT-style GELU FFN under a layer prefix.
92    pub fn gelu_ffn(mut self, layer_prefix: impl Into<String>) -> Self {
93        self.stages
94            .push(FlowStage::GeluFfn(GeluFfnStage::hf_bert(layer_prefix)));
95        self
96    }
97
98    /// Repeat NomicBERT encoder layers.
99    pub fn repeat_nomic_layers(
100        self,
101        count: usize,
102        hidden_size: usize,
103        num_heads: usize,
104        head_dim: usize,
105        eps: f32,
106    ) -> Self {
107        self.repeat_layers(count, move |i| FlowStage::Named {
108            name: format!("layer{i}"),
109            inner: std::sync::Arc::new(FlowStage::NomicEncoderLayer(NomicEncoderLayerStage::new(
110                NomicEncoderLayerSpec::hf(
111                    format!("encoder.layers.{i}"),
112                    hidden_size,
113                    num_heads,
114                    head_dim,
115                    eps,
116                ),
117            ))),
118        })
119    }
120
121    /// BERT-style encoder layer (fused QKV + padding-mask attention + GELU FFN).
122    pub fn bert_encoder_layer(mut self, spec: BertEncoderLayerSpec) -> Self {
123        self.stages
124            .push(FlowStage::BertEncoderLayer(BertEncoderLayerStage::new(
125                spec,
126            )));
127        self
128    }
129
130    /// Repeat BERT encoder layers with auto-named prefixes.
131    pub fn repeat_bert_layers(
132        self,
133        count: usize,
134        prefix: impl Into<String>,
135        qkv_style: BertQkvStyle,
136        hidden_size: usize,
137        num_heads: usize,
138        eps: f32,
139    ) -> Self {
140        let prefix = prefix.into();
141        self.repeat_layers(count, move |i| {
142            let lp = if prefix.is_empty() {
143                format!("encoder.layer.{i}")
144            } else {
145                format!("{prefix}.encoder.layer.{i}")
146            };
147            FlowStage::Named {
148                name: format!("layer{i}"),
149                inner: std::sync::Arc::new(FlowStage::BertEncoderLayer(
150                    BertEncoderLayerStage::new(BertEncoderLayerSpec::hf(
151                        lp,
152                        qkv_style,
153                        hidden_size,
154                        num_heads,
155                        eps,
156                    )),
157                )),
158            }
159        })
160    }
161
162    /// Synthesize an all-ones attention mask for vision encoders (no padding).
163    pub fn attn_mask_ones(mut self, batch: usize, seq: usize) -> Self {
164        self.stages
165            .push(FlowStage::AttnMask(AttnMaskStage::ones(batch, seq)));
166        self
167    }
168
169    /// Repeat DINOv2 ViT encoder blocks.
170    pub fn repeat_dinov2_layers(
171        self,
172        count: usize,
173        hidden_size: usize,
174        num_heads: usize,
175        eps: f32,
176    ) -> Self {
177        self.repeat_layers(count, move |i| {
178            dinov2_layer_fused(i, hidden_size, num_heads, eps)
179        })
180    }
181
182    /// Repeat NomicVision encoder blocks.
183    pub fn repeat_vision_layers(
184        self,
185        count: usize,
186        hidden_size: usize,
187        num_heads: usize,
188        eps: f32,
189    ) -> Self {
190        self.repeat_layers(count, move |i| {
191            nomic_vision_layer_fused(i, hidden_size, num_heads, eps)
192        })
193    }
194
195    /// Compatibility shim: repeat SigLIP-style vision layers.
196    pub fn repeat_siglip_layers(
197        self,
198        count: usize,
199        hidden_size: usize,
200        num_heads: usize,
201        eps: f32,
202    ) -> Self {
203        self.repeat_layers(count, move |i| {
204            nomic_vision_layer_fused(i, hidden_size, num_heads, eps)
205        })
206    }
207
208    /// Pool CLS token: `[batch, seq, hidden]` → `[batch, hidden]`.
209    pub fn cls_token_pool(mut self, batch: usize, hidden: usize) -> Self {
210        self.stages
211            .push(FlowStage::ClsTokenPool(ClsTokenPoolStage::new(
212                batch, hidden,
213            )));
214        self
215    }
216
217    /// Fusion-first prefill defaults.
218    pub fn profile_prefill(mut self) -> Self {
219        self.profile = CompileProfile::llama32_prefill();
220        self
221    }
222
223    /// Decode / KV-cache defaults (`Fusable` lowering).
224    pub fn profile_decode(mut self) -> Self {
225        self.profile = CompileProfile::llama32_decode();
226        self
227    }
228
229    /// Token embedding (`model.embed_tokens.weight` by default).
230    pub fn embed(mut self, weight_key: impl Into<String>) -> Self {
231        self.stages
232            .push(FlowStage::Embed(EmbedStage::token(weight_key)));
233        self
234    }
235
236    /// HuggingFace-style token embedding table.
237    pub fn token_embed(self) -> Self {
238        self.embed("model.embed_tokens.weight")
239    }
240
241    /// Precomputed RoPE sin/cos tables stored as params.
242    pub fn rope_tables(mut self, tables: RopeTablesStage) -> Self {
243        self.stages.push(FlowStage::RopeTables(tables));
244        self
245    }
246
247    /// Gather one decode RoPE row from [`Self::rope_tables`] using `position` input.
248    pub fn gather_decode_rope(mut self, half_dim: usize) -> Self {
249        self.stages
250            .push(FlowStage::GatherDecodeRope(GatherDecodeRopeStage::new(
251                half_dim,
252            )));
253        self
254    }
255
256    /// Rank-1 zero vector for RMSNorm beta slots (LLaMA has no beta).
257    pub fn zero_beta(self, len: usize) -> Self {
258        self.zero_beta_named("zero_beta", len)
259    }
260
261    pub fn zero_beta_named(mut self, name: impl Into<String>, len: usize) -> Self {
262        self.stages.push(FlowStage::ZeroBeta {
263            name: name.into(),
264            len,
265        });
266        self
267    }
268
269    /// Bind decode inputs (call after declaring `rope_cos`, `past_k_*`, …).
270    pub fn bind_decode_inputs(
271        mut self,
272        num_layers: usize,
273        custom_mask: bool,
274        need_past_kv: bool,
275    ) -> Self {
276        self.stages
277            .push(FlowStage::BindDecodeInputs(BindDecodeInputsStage {
278                num_layers,
279                use_custom_mask: custom_mask,
280                need_past_kv,
281            }));
282        self
283    }
284
285    /// Repeat a per-layer stage `count` times (layer index passed to closure).
286    pub fn repeat_layers(
287        mut self,
288        count: usize,
289        stage_for_layer: impl Fn(usize) -> FlowStage + Send + Sync + 'static,
290    ) -> Self {
291        self.stages
292            .push(FlowStage::Repeat(RepeatStage::new(count, stage_for_layer)));
293        self
294    }
295
296    /// Named decoder layer (shows up in fusion / inspect dumps).
297    pub fn named_layer(mut self, name: impl Into<String>, inner: FlowStage) -> Self {
298        self.stages.push(FlowStage::Named {
299            name: name.into(),
300            inner: Arc::new(inner),
301        });
302        self
303    }
304
305    /// Build a named layer from a [`LayerStack`] closure.
306    pub fn layer(
307        self,
308        name: impl Into<String>,
309        build: impl FnOnce(LayerStack) -> LayerStack,
310    ) -> Self {
311        self.raw_stage(build(LayerStack::named(name)).build())
312    }
313
314    /// Fused LLaMA prefill layer (default fast path).
315    pub fn llama_prefill_layer(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
316        self.raw_stage(llama_prefill_layer_fused(layer_idx, spec))
317    }
318
319    /// Composed LLaMA prefill layer (small blocks — customize via [`LayerStack`]).
320    pub fn llama_prefill_layer_composed(self, layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
321        self.raw_stage(llama_prefill_layer_composed(layer_idx, spec))
322    }
323
324    pub fn linear(mut self, weight_key: impl Into<String>, transpose: bool) -> Self {
325        self.stages
326            .push(FlowStage::Linear(LinearStage::new(weight_key, transpose)));
327        self
328    }
329
330    pub fn residual_save(mut self) -> Self {
331        self.stages.push(FlowStage::ResidualSave(ResidualSaveStage));
332        self
333    }
334
335    pub fn residual_add(mut self) -> Self {
336        self.stages.push(FlowStage::ResidualAdd(ResidualAddStage));
337        self
338    }
339
340    pub fn swiglu(
341        mut self,
342        gate_key: impl Into<String>,
343        up_key: impl Into<String>,
344        down_key: impl Into<String>,
345    ) -> Self {
346        self.stages.push(FlowStage::SwiGlu(SwiGluStage::new(
347            gate_key, up_key, down_key,
348        )));
349        self
350    }
351
352    pub fn swiglu_hf_mlp(mut self, prefix: impl Into<String>) -> Self {
353        self.stages
354            .push(FlowStage::SwiGlu(SwiGluStage::hf_mlp(prefix)));
355        self
356    }
357
358    pub fn self_attn_prefill(mut self, spec: SelfAttnPrefillSpec) -> Self {
359        self.stages
360            .push(FlowStage::SelfAttnPrefill(SelfAttnPrefillStage::new(spec)));
361        self
362    }
363
364    pub fn gdn_scan(mut self, stage: crate::blocks::GdnScanStage) -> Self {
365        self.stages.push(FlowStage::GdnScan(stage));
366        self
367    }
368
369    pub fn store_stream(mut self, name: impl Into<String>) -> Self {
370        self.stages
371            .push(FlowStage::StoreStream(StoreStreamStage::new(name)));
372        self
373    }
374
375    pub fn load_stream(mut self, name: impl Into<String>) -> Self {
376        self.stages
377            .push(FlowStage::LoadStream(LoadStreamStage::new(name)));
378        self
379    }
380
381    /// Bind declared graph inputs into named streams (multi-input models).
382    ///
383    /// Example: FLUX `.bind_inputs_to_streams(&[("hidden", "img"), ("encoder", "txt")])`.
384    pub fn bind_inputs_to_streams(
385        mut self,
386        pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
387    ) -> Self {
388        let pairs: Vec<(String, String)> = pairs
389            .into_iter()
390            .map(|(input, stream)| (input.into(), stream.into()))
391            .collect();
392        self.stages.push(FlowStage::Custom(CustomStage::named(
393            "bind_inputs_to_streams",
394            move |emit, primary| {
395                let primary = primary.ok_or_else(|| {
396                    anyhow::anyhow!("bind_inputs_to_streams requires primary input")
397                })?;
398                for (input_name, stream_name) in &pairs {
399                    let value = emit.flow_input(input_name)?;
400                    emit.state.streams.insert(stream_name.clone(), value);
401                }
402                Ok(Some(primary))
403            },
404        )));
405        self
406    }
407
408    pub fn dual_stream<F>(
409        mut self,
410        name: impl Into<String>,
411        stream_a: impl Into<String>,
412        stream_b: impl Into<String>,
413        f: F,
414    ) -> Self
415    where
416        F: Fn(&mut Emit<'_>, FlowValue, FlowValue) -> anyhow::Result<(FlowValue, FlowValue)>
417            + Send
418            + Sync
419            + 'static,
420    {
421        self.stages.push(FlowStage::DualStream(DualStreamStage::new(
422            name, stream_a, stream_b, f,
423        )));
424        self
425    }
426
427    pub fn plugin<F>(mut self, f: F) -> Self
428    where
429        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
430            + Send
431            + Sync
432            + 'static,
433    {
434        self.stages.push(crate::plugin::plugin(f));
435        self
436    }
437
438    pub fn plugin_named<F>(mut self, name: impl Into<String>, f: F) -> Self
439    where
440        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
441            + Send
442            + Sync
443            + 'static,
444    {
445        self.stages.push(crate::plugin::plugin_named(name, f));
446        self
447    }
448
449    /// Hidden states output (no LM head).
450    pub fn hidden_states(self) -> Self {
451        self.output("hidden")
452    }
453
454    /// LLaMA prefill decoder block at `layer_idx`.
455    pub fn llama_decoder_layer(
456        self,
457        layer_idx: usize,
458        spec: crate::blocks::LlamaDecoderSpec,
459    ) -> Self {
460        self.named_layer(
461            format!("layer{layer_idx}"),
462            FlowStage::LlamaDecoder(LlamaDecoderStage::layer(layer_idx, spec)),
463        )
464    }
465
466    /// LLaMA decode block with KV-cache concat.
467    pub fn llama_decode_layer(
468        self,
469        layer_idx: usize,
470        spec: crate::blocks::LlamaDecodeLayerSpec,
471        kv_out: SideOutputs,
472    ) -> Self {
473        self.named_layer(
474            format!("layer{layer_idx}"),
475            FlowStage::LlamaDecodeLayer(LlamaDecodeLayerStage::layer(
476                layer_idx,
477                spec,
478                kv_out.inner(),
479            )),
480        )
481    }
482
483    /// Side-effect K/V tap before a prefill layer (exports cache tensors).
484    pub fn llama_kv_tap(
485        mut self,
486        layer_idx: usize,
487        head_dim: usize,
488        eps: f32,
489        sink: &SideOutputs,
490    ) -> Self {
491        self.stages
492            .push(FlowStage::LlamaKvTap(LlamaKvTapStage::layer(
493                layer_idx,
494                head_dim,
495                eps,
496                sink.inner(),
497            )));
498        self
499    }
500
501    /// Final RMSNorm before LM head (`model.norm.weight` by default).
502    pub fn final_norm(self, eps: f32) -> Self {
503        self.rms_norm("model.norm.weight", eps)
504    }
505
506    pub fn rms_norm(mut self, weight_key: impl Into<String>, eps: f32) -> Self {
507        self.stages
508            .push(FlowStage::RmsNorm(RmsNormStage::new(weight_key, eps)));
509        self
510    }
511
512    /// Gather last token (dynamic `last_token_idx` input).
513    pub fn gather_last_token_dynamic(mut self, batch: usize) -> Self {
514        self.stages
515            .push(FlowStage::GatherLastToken(GatherLastTokenStage::dynamic(
516                batch,
517            )));
518        self
519    }
520
521    /// Gather last token at fixed sequence length.
522    pub fn gather_last_token_at(mut self, batch: usize, seq: usize) -> Self {
523        self.stages.push(FlowStage::GatherLastToken(
524            GatherLastTokenStage::static_last(batch, seq),
525        ));
526        self
527    }
528
529    /// Causal LM head — tied or separate weights.
530    pub fn lm_head(
531        mut self,
532        vocab_size: usize,
533        hidden_size: usize,
534        tie_word_embeddings: bool,
535    ) -> Self {
536        let stage = if tie_word_embeddings {
537            LmHeadStage::tied(vocab_size, hidden_size)
538        } else {
539            LmHeadStage::separate("lm_head.weight", vocab_size, hidden_size)
540        };
541        self.stages.push(FlowStage::LmHead(stage));
542        self.output("logits")
543    }
544
545    /// Tier-2 escape hatch — append a raw stage.
546    pub fn raw_stage(mut self, stage: FlowStage) -> Self {
547        self.stages.push(stage);
548        self
549    }
550
551    /// Append multiple raw stages in order.
552    pub fn raw_stages(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
553        self.stages.extend(stages);
554        self
555    }
556
557    /// Run a list of stages as one nested sequence (side-effect stages allowed).
558    pub fn sequence(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
559        self.stages
560            .push(FlowStage::Sequence(stages.into_iter().collect()));
561        self
562    }
563
564    /// Conditionally transform the builder (e.g. optional vision tower).
565    pub fn when(self, cond: bool, f: impl FnOnce(Self) -> Self) -> Self {
566        if cond { f(self) } else { self }
567    }
568
569    /// Tier-2 custom subgraph — prefer promoting repeated patterns to blocks.
570    pub fn custom<F>(mut self, f: F) -> Self
571    where
572        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
573            + Send
574            + Sync
575            + 'static,
576    {
577        self.stages.push(FlowStage::Custom(CustomStage::new(f)));
578        self
579    }
580
581    /// Named custom subgraph (shows up in fusion / inspect dumps).
582    pub fn custom_named<F>(mut self, name: impl Into<String>, f: F) -> Self
583    where
584        F: Fn(&mut Emit<'_>, Option<FlowValue>) -> anyhow::Result<Option<FlowValue>>
585            + Send
586            + Sync
587            + 'static,
588    {
589        self.stages
590            .push(FlowStage::Custom(CustomStage::named(name, f)));
591        self
592    }
593
594    /// Patch the builder after preset assembly (arch recipes, Llama32Flow hooks).
595    pub fn patch(self, f: impl FnOnce(Self) -> Self) -> Self {
596        f(self)
597    }
598}