Skip to main content

rlx_flux2/text_encoder/
hir_builder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Qwen3-shaped FLUX.2 text encoder HIR (causal LM trunk → multi-layer prompt embeds).
17
18use super::super::builder::Flux2GraphParams;
19use super::weights::{
20    Flux2TextEncoderAttnWeights, Flux2TextEncoderLayerWeights, Flux2TextEncoderMlpWeights,
21    Flux2TextEncoderWeights,
22};
23use crate::weights::{LinearWeights, RmsNormWeight};
24use anyhow::{Result, ensure};
25use rlx_ir::hir::{FusionPolicy, HirModule, HirNodeId};
26use rlx_ir::op::{Activation, BinaryOp, MaskKind};
27use rlx_ir::{DType, Op, Shape};
28use rlx_qwen3::Qwen3Config;
29use rlx_runtime::Device;
30
31pub struct Flux2TextEncoderGraph {
32    pub hir: HirModule,
33    pub params: Flux2GraphParams,
34    pub joint_dim: usize,
35}
36
37pub fn build_flux2_text_encoder_hir(
38    cfg: &Qwen3Config,
39    weights: &Flux2TextEncoderWeights,
40    batch: usize,
41    seq: usize,
42    hidden_state_layers: &[usize],
43) -> Result<Flux2TextEncoderGraph> {
44    ensure!(
45        cfg.num_attention_heads
46            .is_multiple_of(cfg.num_key_value_heads),
47        "num_attention_heads must divide num_key_value_heads"
48    );
49    let joint_dim = cfg.hidden_size * hidden_state_layers.len();
50    let f = DType::F32;
51    let mut hir = HirModule::new("flux2_text_encoder").with_fusion_policy(FusionPolicy::Direct);
52    let mut params = Flux2GraphParams::new();
53    let ids = hir.input("input_ids", Shape::new(&[batch, seq], f));
54    let mut b =
55        TextEncoderHirBuilder::from_emit_parts(&mut hir, &mut params, cfg, weights, batch, seq);
56    let mut hidden = b.emit_embed(ids)?;
57    let mut checkpoints = vec![hidden];
58    let (cos, sin) = b.rope_tables()?;
59    for (li, layer) in weights.layers.iter().enumerate() {
60        hidden = b.layer_forward(layer, li, hidden, cos, sin)?;
61        checkpoints.push(hidden);
62    }
63    let out = b.emit_joint_output(&checkpoints, hidden_state_layers, joint_dim)?;
64    hir.outputs = vec![out];
65    Ok(Flux2TextEncoderGraph {
66        hir,
67        params,
68        joint_dim,
69    })
70}
71
72pub fn compile_flux2_text_encoder_hir(
73    cfg: &Qwen3Config,
74    weights: &Flux2TextEncoderWeights,
75    batch: usize,
76    seq: usize,
77    hidden_state_layers: &[usize],
78    device: Device,
79    aot: Option<&rlx_runtime::AotCache>,
80) -> Result<(rlx_runtime::CompiledGraph, Flux2GraphParams)> {
81    use crate::compile_util::{compile_hir_cached, flux2_text_encoder_aot_key};
82
83    crate::device::assert_flux2_device_available(device)?;
84    let g = build_flux2_text_encoder_hir(cfg, weights, batch, seq, hidden_state_layers)?;
85    let key = flux2_text_encoder_aot_key(device, batch, seq);
86    let mut compiled = compile_hir_cached(
87        device,
88        aot,
89        &key,
90        g.hir,
91        &crate::compile_util::flux2_compile_profile(),
92    )?;
93    for (name, data) in &g.params {
94        compiled.set_param(name, data);
95    }
96    Ok((compiled, g.params))
97}
98
99pub(crate) struct TextEncoderHirBuilder<'a> {
100    hir: &'a mut HirModule,
101    params: &'a mut Flux2GraphParams,
102    cfg: &'a Qwen3Config,
103    weights: &'a Flux2TextEncoderWeights,
104    batch: usize,
105    seq: usize,
106    f: DType,
107    eps: f32,
108}
109
110impl<'a> TextEncoderHirBuilder<'a> {
111    pub(crate) fn from_emit_parts(
112        hir: &'a mut HirModule,
113        params: &'a mut Flux2GraphParams,
114        cfg: &'a Qwen3Config,
115        weights: &'a Flux2TextEncoderWeights,
116        batch: usize,
117        seq: usize,
118    ) -> Self {
119        Self {
120            hir,
121            params,
122            cfg,
123            weights,
124            batch,
125            seq,
126            f: DType::F32,
127            eps: cfg.rms_norm_eps as f32,
128        }
129    }
130
131    pub(crate) fn emit_embed(&mut self, ids: HirNodeId) -> Result<HirNodeId> {
132        let h = self.cfg.hidden_size;
133        let (embed_data, vocab, _) = &self.weights.embed_tokens;
134        let embed = self.register_param(
135            "embed_tokens.weight",
136            embed_data.clone(),
137            Shape::new(&[*vocab, h], self.f),
138        );
139        Ok(self
140            .hir
141            .mir(Op::Gather { axis: 0 }, vec![embed, ids], self.bsh()))
142    }
143
144    pub(crate) fn emit_joint_output(
145        &mut self,
146        checkpoints: &[HirNodeId],
147        hidden_state_layers: &[usize],
148        joint_dim: usize,
149    ) -> Result<HirNodeId> {
150        let h = self.cfg.hidden_size;
151        let mut out_pieces: Vec<HirNodeId> = Vec::with_capacity(hidden_state_layers.len());
152        for (i, &layer_idx) in hidden_state_layers.iter().enumerate() {
153            ensure!(
154                layer_idx < checkpoints.len(),
155                "hidden_state_layers[{i}]={layer_idx} out of range (len={})",
156                checkpoints.len()
157            );
158            out_pieces.push(checkpoints[layer_idx]);
159        }
160        let rows = self.batch * self.seq;
161        let mut flat_parts: Vec<HirNodeId> = Vec::with_capacity(out_pieces.len());
162        for p in &out_pieces {
163            flat_parts.push(self.reshape(*p, vec![rows as i64, h as i64]));
164        }
165        let flat = if flat_parts.len() == 1 {
166            flat_parts[0]
167        } else {
168            self.concat(flat_parts, 1, Shape::new(&[rows, joint_dim], self.f))
169        };
170        Ok(self.reshape(
171            flat,
172            vec![self.batch as i64, self.seq as i64, joint_dim as i64],
173        ))
174    }
175
176    fn bsh(&self) -> Shape {
177        Shape::new(&[self.batch, self.seq, self.cfg.hidden_size], self.f)
178    }
179
180    fn bsh_heads(&self, heads: usize) -> Shape {
181        Shape::new(&[self.batch, self.seq, heads * self.cfg.head_dim], self.f)
182    }
183
184    fn register_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
185        let id = self.hir.param(name, shape);
186        self.params.insert(name.to_string(), data);
187        id
188    }
189
190    fn linear(
191        &mut self,
192        x: HirNodeId,
193        lw: &LinearWeights,
194        name: &str,
195        out_shape: Shape,
196    ) -> Result<HirNodeId> {
197        let w = self.register_param(
198            &format!("{name}.weight"),
199            lw.w_t.clone(),
200            Shape::new(&[lw.in_dim, lw.out_dim], self.f),
201        );
202        let bias = if lw.bias.iter().all(|&v| v == 0.0) {
203            None
204        } else {
205            let b = self.register_param(
206                &format!("{name}.bias"),
207                lw.bias.clone(),
208                Shape::new(&[lw.out_dim], self.f),
209            );
210            Some(b)
211        };
212        Ok(self.hir.linear(x, w, bias, None, out_shape))
213    }
214
215    fn rms_norm(
216        &mut self,
217        x: HirNodeId,
218        gamma: &RmsNormWeight,
219        name: &str,
220        shape: Shape,
221    ) -> HirNodeId {
222        let g = self.register_param(
223            &format!("{name}.weight"),
224            gamma.scale.clone(),
225            Shape::new(&[gamma.scale.len()], self.f),
226        );
227        let beta = self.register_param(
228            &format!("{name}.beta"),
229            vec![0.0f32; gamma.scale.len()],
230            Shape::new(&[gamma.scale.len()], self.f),
231        );
232        self.hir.mir(
233            Op::RmsNorm {
234                axis: -1,
235                eps: self.eps,
236            },
237            vec![x, g, beta],
238            shape,
239        )
240    }
241
242    fn per_head_rms(
243        &mut self,
244        x: HirNodeId,
245        gamma: &RmsNormWeight,
246        name: &str,
247        heads: usize,
248    ) -> HirNodeId {
249        let hd = self.cfg.head_dim;
250        let flat = self.reshape(x, vec![(self.batch * self.seq * heads) as i64, hd as i64]);
251        let n = self.rms_norm(
252            flat,
253            gamma,
254            name,
255            Shape::new(&[self.batch * self.seq * heads, hd], self.f),
256        );
257        self.reshape(
258            n,
259            vec![self.batch as i64, self.seq as i64, (heads * hd) as i64],
260        )
261    }
262
263    pub(crate) fn layer_forward(
264        &mut self,
265        layer: &Flux2TextEncoderLayerWeights,
266        li: usize,
267        x: HirNodeId,
268        cos: HirNodeId,
269        sin: HirNodeId,
270    ) -> Result<HirNodeId> {
271        let lp = format!("layers.{li}");
272        let shape = self.bsh();
273        let normed = self.rms_norm(
274            x,
275            &layer.input_layernorm,
276            &format!("{lp}.in_ln"),
277            shape.clone(),
278        );
279        let attn_out = self.attn_forward(&layer.attn, &format!("{lp}.attn"), normed, cos, sin)?;
280        let post_attn = self.add(x, attn_out, shape.clone());
281        let mlp_out = self.mlp_forward(
282            &layer.mlp,
283            &layer.post_attention_layernorm,
284            &format!("{lp}.mlp"),
285            post_attn,
286        )?;
287        Ok(self.add(post_attn, mlp_out, shape))
288    }
289
290    fn attn_forward(
291        &mut self,
292        attn: &Flux2TextEncoderAttnWeights,
293        tag: &str,
294        x: HirNodeId,
295        cos: HirNodeId,
296        sin: HirNodeId,
297    ) -> Result<HirNodeId> {
298        let nh = self.cfg.num_attention_heads;
299        let nkv = self.cfg.num_key_value_heads;
300        let hd = self.cfg.head_dim;
301        let group = nh / nkv;
302        let shape = self.bsh();
303
304        let q = self.linear(x, &attn.q, &format!("{tag}.q"), self.bsh_heads(nh))?;
305        let k = self.linear(x, &attn.k, &format!("{tag}.k"), self.bsh_heads(nkv))?;
306        let v = self.linear(x, &attn.v, &format!("{tag}.v"), self.bsh_heads(nkv))?;
307
308        let q = self.per_head_rms(q, &attn.q_norm, &format!("{tag}.nq"), nh);
309        let k = self.per_head_rms(k, &attn.k_norm, &format!("{tag}.nk"), nkv);
310
311        let qh = self.bsh_heads(nh);
312        let q = self.rope(q, cos, sin, qh.clone());
313        let k = self.rope(k, cos, sin, self.bsh_heads(nkv));
314        let k_rep = self.repeat_kv(k, nkv, hd, group);
315        let v_rep = self.repeat_kv(v, nkv, hd, group);
316
317        let attn_out =
318            self.hir
319                .attention(q, k_rep, v_rep, None, nh, hd, MaskKind::Causal, qh.clone());
320        self.linear(attn_out, &attn.o, &format!("{tag}.o"), shape)
321    }
322
323    fn mlp_forward(
324        &mut self,
325        mlp: &Flux2TextEncoderMlpWeights,
326        post_ln: &RmsNormWeight,
327        tag: &str,
328        x: HirNodeId,
329    ) -> Result<HirNodeId> {
330        let rows = self.batch * self.seq;
331        let h = self.cfg.hidden_size;
332        let ff = self.cfg.intermediate_size;
333        let flat = self.reshape(x, vec![rows as i64, h as i64]);
334        let flat = self.rms_norm(
335            flat,
336            post_ln,
337            &format!("{tag}.post_ln"),
338            Shape::new(&[rows, h], self.f),
339        );
340        let gate = self.linear(
341            flat,
342            &mlp.gate,
343            &format!("{tag}.gate"),
344            Shape::new(&[rows, ff], self.f),
345        )?;
346        let up = self.linear(
347            flat,
348            &mlp.up,
349            &format!("{tag}.up"),
350            Shape::new(&[rows, ff], self.f),
351        )?;
352        let gate3 = self.reshape(gate, vec![self.batch as i64, self.seq as i64, ff as i64]);
353        let up3 = self.reshape(up, vec![self.batch as i64, self.seq as i64, ff as i64]);
354        let silu = self.hir.mir(
355            Op::Activation(Activation::Silu),
356            vec![gate3],
357            Shape::new(&[self.batch, self.seq, ff], self.f),
358        );
359        let prod = self.mul(silu, up3, Shape::new(&[self.batch, self.seq, ff], self.f));
360        let prod_flat = self.reshape(prod, vec![rows as i64, ff as i64]);
361        self.linear(
362            prod_flat,
363            &mlp.down,
364            &format!("{tag}.down"),
365            Shape::new(&[rows, h], self.f),
366        )
367        .map(|o| self.reshape(o, vec![self.batch as i64, self.seq as i64, h as i64]))
368    }
369
370    fn mul(&mut self, a: HirNodeId, b: HirNodeId, shape: Shape) -> HirNodeId {
371        self.hir.mir(Op::Binary(BinaryOp::Mul), vec![a, b], shape)
372    }
373
374    fn repeat_kv(&mut self, x: HirNodeId, nkv: usize, hd: usize, group: usize) -> HirNodeId {
375        if group == 1 {
376            return x;
377        }
378        let last = 2;
379        let slice_shape = Shape::new(&[self.batch, self.seq, hd], self.f);
380        let out_shape = Shape::new(&[self.batch, self.seq, nkv * group * hd], self.f);
381        let mut pieces = Vec::with_capacity(nkv * group);
382        for h in 0..nkv {
383            let slice = self.narrow(x, last, h * hd, hd, slice_shape.clone());
384            for _ in 0..group {
385                pieces.push(slice);
386            }
387        }
388        self.concat(pieces, last, out_shape)
389    }
390
391    pub(crate) fn rope_tables(&mut self) -> Result<(HirNodeId, HirNodeId)> {
392        let dh = self.cfg.head_dim;
393        let half = dh / 2;
394        let max_pos = self.cfg.max_position_embeddings;
395        let mut cos_data = vec![0f32; max_pos * dh];
396        let mut sin_data = vec![0f32; max_pos * dh];
397        for pos in 0..max_pos {
398            for i in 0..half {
399                let freq = 1.0 / self.cfg.rope_theta.powf((2 * i) as f64 / dh as f64);
400                let angle = pos as f64 * freq;
401                let (s, c) = angle.sin_cos();
402                cos_data[pos * dh + 2 * i] = c as f32;
403                cos_data[pos * dh + 2 * i + 1] = c as f32;
404                sin_data[pos * dh + 2 * i] = s as f32;
405                sin_data[pos * dh + 2 * i + 1] = s as f32;
406            }
407        }
408        let cos = self.register_param("rope.cos", cos_data, Shape::new(&[max_pos, dh], self.f));
409        let sin = self.register_param("rope.sin", sin_data, Shape::new(&[max_pos, dh], self.f));
410        Ok((cos, sin))
411    }
412
413    fn rope(&mut self, x: HirNodeId, cos: HirNodeId, sin: HirNodeId, shape: Shape) -> HirNodeId {
414        self.hir.mir(
415            Op::Rope {
416                head_dim: self.cfg.head_dim,
417                n_rot: self.cfg.head_dim,
418            },
419            vec![x, cos, sin],
420            shape,
421        )
422    }
423
424    fn reshape(&mut self, x: HirNodeId, new_shape: Vec<i64>) -> HirNodeId {
425        let shape = Shape::new(
426            &new_shape.iter().map(|&d| d as usize).collect::<Vec<_>>(),
427            self.f,
428        );
429        self.hir.mir(Op::Reshape { new_shape }, vec![x], shape)
430    }
431
432    fn narrow(
433        &mut self,
434        x: HirNodeId,
435        axis: usize,
436        start: usize,
437        len: usize,
438        shape: Shape,
439    ) -> HirNodeId {
440        self.hir
441            .mir(Op::Narrow { axis, start, len }, vec![x], shape)
442    }
443
444    fn concat(&mut self, inputs: Vec<HirNodeId>, axis: usize, shape: Shape) -> HirNodeId {
445        self.hir.mir(Op::Concat { axis }, inputs, shape)
446    }
447
448    fn add(&mut self, a: HirNodeId, b: HirNodeId, shape: Shape) -> HirNodeId {
449        self.hir.mir(Op::Binary(BinaryOp::Add), vec![a, b], shape)
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use crate::text_encoder::{
457        TINY_TEXT_ENCODER_LAYERS, encode_prompt_embeds, synthetic_text_encoder_weights,
458        tiny_text_encoder_config,
459    };
460    use rlx_runtime::Device;
461
462    #[test]
463    fn text_encoder_hir_lowers() {
464        let cfg = tiny_text_encoder_config();
465        let w = synthetic_text_encoder_weights(&cfg);
466        let g = build_flux2_text_encoder_hir(&cfg, &w, 1, 4, TINY_TEXT_ENCODER_LAYERS).unwrap();
467        g.hir.lower_to_mir().expect("lower");
468    }
469
470    #[test]
471    fn compiled_single_layer_hidden_matches_native() {
472        let cfg = tiny_text_encoder_config();
473        let w = synthetic_text_encoder_weights(&cfg);
474        let layers = [1usize];
475        let batch = 1usize;
476        let seq = 4usize;
477        let ids: Vec<u32> = (0..seq as u32).collect();
478        let ids_f32: Vec<f32> = ids.iter().map(|&x| x as f32).collect();
479        let native = encode_prompt_embeds(&w, &cfg, &ids, batch, seq, &layers).unwrap();
480        let (mut compiled, _) =
481            compile_flux2_text_encoder_hir(&cfg, &w, batch, seq, &layers, Device::Cpu, None)
482                .unwrap();
483        let out = compiled.run(&[("input_ids", ids_f32.as_slice())]).remove(0);
484        assert_eq!(out.len(), native.prompt_embeds.len());
485        let max = out
486            .iter()
487            .zip(&native.prompt_embeds)
488            .map(|(a, b)| (a - b).abs())
489            .fold(0.0f32, f32::max);
490        assert!(max < 2e-2, "single layer max_abs_diff={max}");
491    }
492
493    #[test]
494    fn compiled_text_encoder_matches_native() {
495        let cfg = tiny_text_encoder_config();
496        let w = synthetic_text_encoder_weights(&cfg);
497        let batch = 1usize;
498        let seq = 4usize;
499        let ids: Vec<u32> = (0..seq as u32).collect();
500        let ids_f32: Vec<f32> = ids.iter().map(|&x| x as f32).collect();
501        let layers = TINY_TEXT_ENCODER_LAYERS;
502
503        let native = encode_prompt_embeds(&w, &cfg, &ids, batch, seq, layers).unwrap();
504
505        let (mut compiled, _) =
506            compile_flux2_text_encoder_hir(&cfg, &w, batch, seq, layers, Device::Cpu, None)
507                .unwrap();
508        let out = compiled.run(&[("input_ids", ids_f32.as_slice())]).remove(0);
509
510        assert_eq!(out.len(), native.prompt_embeds.len());
511        let max = out
512            .iter()
513            .zip(&native.prompt_embeds)
514            .map(|(a, b)| (a - b).abs())
515            .fold(0.0f32, f32::max);
516        assert!(max < 2e-2, "HIR vs native max_abs_diff={max}");
517    }
518}