Skip to main content

rlx_flux2/text_encoder/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native FLUX.2 text encoder flow — Qwen3-shaped causal trunk → joint prompt embeds.
17
18use std::sync::{Arc, Mutex};
19
20use anyhow::{Result, ensure};
21use rlx_flow::{BuiltModel, MapWeights, ModelFlow, plugin_named};
22use rlx_ir::{DType, HirNodeId, Shape};
23
24use super::hir_builder::TextEncoderHirBuilder;
25use super::weights::Flux2TextEncoderWeights;
26use rlx_qwen3::Qwen3Config;
27
28const ROPE_COS: &str = "flux2_te.rope_cos";
29const ROPE_SIN: &str = "flux2_te.rope_sin";
30
31/// Tier-0 FLUX.2 text encoder flow (Qwen3-shaped causal LM trunk).
32#[derive(Clone)]
33pub struct Flux2TextEncoderFlow<'a> {
34    cfg: &'a Qwen3Config,
35    weights: &'a Flux2TextEncoderWeights,
36    batch: usize,
37    seq: usize,
38    hidden_state_layers: Vec<usize>,
39}
40
41impl<'a> Flux2TextEncoderFlow<'a> {
42    pub fn new(
43        cfg: &'a Qwen3Config,
44        weights: &'a Flux2TextEncoderWeights,
45        batch: usize,
46        seq: usize,
47        hidden_state_layers: &[usize],
48    ) -> Self {
49        Self {
50            cfg,
51            weights,
52            batch,
53            seq,
54            hidden_state_layers: hidden_state_layers.to_vec(),
55        }
56    }
57
58    pub fn build(self) -> Result<Flux2TextEncoderBuilt> {
59        build_flux2_text_encoder_built(
60            self.cfg,
61            self.weights,
62            self.batch,
63            self.seq,
64            &self.hidden_state_layers,
65        )
66    }
67}
68
69pub struct Flux2TextEncoderBuilt {
70    pub model: BuiltModel,
71    pub joint_dim: usize,
72}
73
74pub fn build_flux2_text_encoder_built(
75    cfg: &Qwen3Config,
76    weights: &Flux2TextEncoderWeights,
77    batch: usize,
78    seq: usize,
79    hidden_state_layers: &[usize],
80) -> Result<Flux2TextEncoderBuilt> {
81    ensure!(
82        cfg.num_attention_heads
83            .is_multiple_of(cfg.num_key_value_heads),
84        "num_attention_heads must divide num_key_value_heads"
85    );
86    let joint_dim = cfg.hidden_size * hidden_state_layers.len();
87    let h = cfg.hidden_size;
88    let f = DType::F32;
89    let hidden_shape = Shape::new(&[batch, seq, h], f);
90    let out_shape = Shape::new(&[batch, seq, joint_dim], f);
91
92    let cfg = cfg.clone();
93    let weights = weights.clone();
94    let hidden_state_layers = hidden_state_layers.to_vec();
95    let checkpoints: Arc<Mutex<Vec<HirNodeId>>> = Arc::new(Mutex::new(Vec::new()));
96    let embed_hidden_shape = hidden_shape.clone();
97    let layer_hidden_shape = hidden_shape.clone();
98
99    let mut flow = ModelFlow::new("flux2_text_encoder")
100        .input("input_ids", Shape::new(&[batch, seq], f))
101        .plugin_named("flux2_te.embed", {
102            let cfg = cfg.clone();
103            let weights = weights.clone();
104            let checkpoints = checkpoints.clone();
105            move |emit, _| {
106                let ids = emit.flow_input("input_ids")?.hir_id();
107                let (hir, params) = emit.hir_and_params();
108                let mut b =
109                    TextEncoderHirBuilder::from_emit_parts(hir, params, &cfg, &weights, batch, seq);
110                let hidden = b.emit_embed(ids)?;
111                checkpoints.lock().unwrap().push(hidden);
112                Ok(Some(emit.wrap(hidden, embed_hidden_shape.clone())))
113            }
114        })
115        .plugin_named("flux2_te.rope", {
116            let cfg = cfg.clone();
117            let weights = weights.clone();
118            move |emit, primary| {
119                let (hir, params) = emit.hir_and_params();
120                let mut b =
121                    TextEncoderHirBuilder::from_emit_parts(hir, params, &cfg, &weights, batch, seq);
122                let (cos, sin) = b.rope_tables()?;
123                emit.set_named(ROPE_COS, cos);
124                emit.set_named(ROPE_SIN, sin);
125                Ok(primary)
126            }
127        });
128
129    for (li, layer) in weights.layers.iter().enumerate() {
130        let layer = layer.clone();
131        let cfg = cfg.clone();
132        let weights = weights.clone();
133        let checkpoints = checkpoints.clone();
134        let layer_shape = layer_hidden_shape.clone();
135        flow = flow.raw_stage(plugin_named(
136            format!("flux2_te.layer{li}"),
137            move |emit, input| {
138                let hidden =
139                    input.ok_or_else(|| anyhow::anyhow!("text encoder layer requires hidden"))?;
140                let cos = emit.named(ROPE_COS)?;
141                let sin = emit.named(ROPE_SIN)?;
142                let (hir, params) = emit.hir_and_params();
143                let mut b =
144                    TextEncoderHirBuilder::from_emit_parts(hir, params, &cfg, &weights, batch, seq);
145                let out = b.layer_forward(&layer, li, hidden.hir_id(), cos, sin)?;
146                checkpoints.lock().unwrap().push(out);
147                Ok(Some(emit.wrap(out, layer_shape.clone())))
148            },
149        ));
150    }
151
152    let built = flow
153        .plugin_named("flux2_te.joint", {
154            let cfg = cfg.clone();
155            let weights = weights.clone();
156            let checkpoints = checkpoints.clone();
157            let hidden_state_layers = hidden_state_layers.clone();
158            move |emit, primary| {
159                let hidden = primary
160                    .ok_or_else(|| anyhow::anyhow!("joint output requires hidden"))?
161                    .hir_id();
162                let ckpts = checkpoints.lock().unwrap().clone();
163                let (hir, params) = emit.hir_and_params();
164                let mut b =
165                    TextEncoderHirBuilder::from_emit_parts(hir, params, &cfg, &weights, batch, seq);
166                let out = b.emit_joint_output(&ckpts, &hidden_state_layers, joint_dim)?;
167                let _ = hidden;
168                Ok(Some(emit.wrap(out, out_shape.clone())))
169            }
170        })
171        .output("prompt_embeds")
172        .build(&mut MapWeights::default())?;
173
174    Ok(Flux2TextEncoderBuilt {
175        model: built,
176        joint_dim,
177    })
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::text_encoder::{
184        TINY_TEXT_ENCODER_LAYERS, build_flux2_text_encoder_hir, synthetic_text_encoder_weights,
185        tiny_text_encoder_config,
186    };
187
188    #[test]
189    fn text_encoder_flow_matches_hir_node_count() {
190        let cfg = tiny_text_encoder_config();
191        let w = synthetic_text_encoder_weights(&cfg);
192        let batch = 1;
193        let seq = 4;
194        let layers = TINY_TEXT_ENCODER_LAYERS;
195
196        let ref_hir = build_flux2_text_encoder_hir(&cfg, &w, batch, seq, layers)
197            .unwrap()
198            .hir;
199        let built = Flux2TextEncoderFlow::new(&cfg, &w, batch, seq, layers)
200            .build()
201            .unwrap();
202        let flow_hir = built.model.into_hir().unwrap();
203
204        assert_eq!(
205            flow_hir.len(),
206            ref_hir.len(),
207            "text encoder flow should match hir_builder node count (flow={}, builder={})",
208            flow_hir.len(),
209            ref_hir.len()
210        );
211    }
212}