1use super::super::builder::Flux2GraphParams;
19use super::weights::{
20 Flux2TextEncoderAttnWeights, Flux2TextEncoderLayerWeights, Flux2TextEncoderMlpWeights,
21 Flux2TextEncoderWeights,
22};
23use crate::weights::{LinearWeights, RmsNormWeight};
24use anyhow::{Result, ensure};
25use rlx_ir::hir::{FusionPolicy, HirModule, HirNodeId};
26use rlx_ir::op::{Activation, BinaryOp, MaskKind};
27use rlx_ir::{DType, Op, Shape};
28use rlx_qwen3::Qwen3Config;
29use rlx_runtime::Device;
30
31pub struct Flux2TextEncoderGraph {
32 pub hir: HirModule,
33 pub params: Flux2GraphParams,
34 pub joint_dim: usize,
35}
36
37pub fn build_flux2_text_encoder_hir(
38 cfg: &Qwen3Config,
39 weights: &Flux2TextEncoderWeights,
40 batch: usize,
41 seq: usize,
42 hidden_state_layers: &[usize],
43) -> Result<Flux2TextEncoderGraph> {
44 ensure!(
45 cfg.num_attention_heads
46 .is_multiple_of(cfg.num_key_value_heads),
47 "num_attention_heads must divide num_key_value_heads"
48 );
49 let joint_dim = cfg.hidden_size * hidden_state_layers.len();
50 let f = DType::F32;
51 let mut hir = HirModule::new("flux2_text_encoder").with_fusion_policy(FusionPolicy::Direct);
52 let mut params = Flux2GraphParams::new();
53 let ids = hir.input("input_ids", Shape::new(&[batch, seq], f));
54 let mut b =
55 TextEncoderHirBuilder::from_emit_parts(&mut hir, &mut params, cfg, weights, batch, seq);
56 let mut hidden = b.emit_embed(ids)?;
57 let mut checkpoints = vec![hidden];
58 let (cos, sin) = b.rope_tables()?;
59 for (li, layer) in weights.layers.iter().enumerate() {
60 hidden = b.layer_forward(layer, li, hidden, cos, sin)?;
61 checkpoints.push(hidden);
62 }
63 let out = b.emit_joint_output(&checkpoints, hidden_state_layers, joint_dim)?;
64 hir.outputs = vec![out];
65 Ok(Flux2TextEncoderGraph {
66 hir,
67 params,
68 joint_dim,
69 })
70}
71
72pub fn compile_flux2_text_encoder_hir(
73 cfg: &Qwen3Config,
74 weights: &Flux2TextEncoderWeights,
75 batch: usize,
76 seq: usize,
77 hidden_state_layers: &[usize],
78 device: Device,
79 aot: Option<&rlx_runtime::AotCache>,
80) -> Result<(rlx_runtime::CompiledGraph, Flux2GraphParams)> {
81 use crate::compile_util::{compile_hir_cached, flux2_text_encoder_aot_key};
82
83 crate::device::assert_flux2_device_available(device)?;
84 let g = build_flux2_text_encoder_hir(cfg, weights, batch, seq, hidden_state_layers)?;
85 let key = flux2_text_encoder_aot_key(device, batch, seq);
86 let mut compiled = compile_hir_cached(
87 device,
88 aot,
89 &key,
90 g.hir,
91 &crate::compile_util::flux2_compile_profile(),
92 )?;
93 for (name, data) in &g.params {
94 compiled.set_param(name, data);
95 }
96 Ok((compiled, g.params))
97}
98
99pub(crate) struct TextEncoderHirBuilder<'a> {
100 hir: &'a mut HirModule,
101 params: &'a mut Flux2GraphParams,
102 cfg: &'a Qwen3Config,
103 weights: &'a Flux2TextEncoderWeights,
104 batch: usize,
105 seq: usize,
106 f: DType,
107 eps: f32,
108}
109
110impl<'a> TextEncoderHirBuilder<'a> {
111 pub(crate) fn from_emit_parts(
112 hir: &'a mut HirModule,
113 params: &'a mut Flux2GraphParams,
114 cfg: &'a Qwen3Config,
115 weights: &'a Flux2TextEncoderWeights,
116 batch: usize,
117 seq: usize,
118 ) -> Self {
119 Self {
120 hir,
121 params,
122 cfg,
123 weights,
124 batch,
125 seq,
126 f: DType::F32,
127 eps: cfg.rms_norm_eps as f32,
128 }
129 }
130
131 pub(crate) fn emit_embed(&mut self, ids: HirNodeId) -> Result<HirNodeId> {
132 let h = self.cfg.hidden_size;
133 let (embed_data, vocab, _) = &self.weights.embed_tokens;
134 let embed = self.register_param(
135 "embed_tokens.weight",
136 embed_data.clone(),
137 Shape::new(&[*vocab, h], self.f),
138 );
139 Ok(self
140 .hir
141 .mir(Op::Gather { axis: 0 }, vec![embed, ids], self.bsh()))
142 }
143
144 pub(crate) fn emit_joint_output(
145 &mut self,
146 checkpoints: &[HirNodeId],
147 hidden_state_layers: &[usize],
148 joint_dim: usize,
149 ) -> Result<HirNodeId> {
150 let h = self.cfg.hidden_size;
151 let mut out_pieces: Vec<HirNodeId> = Vec::with_capacity(hidden_state_layers.len());
152 for (i, &layer_idx) in hidden_state_layers.iter().enumerate() {
153 ensure!(
154 layer_idx < checkpoints.len(),
155 "hidden_state_layers[{i}]={layer_idx} out of range (len={})",
156 checkpoints.len()
157 );
158 out_pieces.push(checkpoints[layer_idx]);
159 }
160 let rows = self.batch * self.seq;
161 let mut flat_parts: Vec<HirNodeId> = Vec::with_capacity(out_pieces.len());
162 for p in &out_pieces {
163 flat_parts.push(self.reshape(*p, vec![rows as i64, h as i64]));
164 }
165 let flat = if flat_parts.len() == 1 {
166 flat_parts[0]
167 } else {
168 self.concat(flat_parts, 1, Shape::new(&[rows, joint_dim], self.f))
169 };
170 Ok(self.reshape(
171 flat,
172 vec![self.batch as i64, self.seq as i64, joint_dim as i64],
173 ))
174 }
175
176 fn bsh(&self) -> Shape {
177 Shape::new(&[self.batch, self.seq, self.cfg.hidden_size], self.f)
178 }
179
180 fn bsh_heads(&self, heads: usize) -> Shape {
181 Shape::new(&[self.batch, self.seq, heads * self.cfg.head_dim], self.f)
182 }
183
184 fn register_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
185 let id = self.hir.param(name, shape);
186 self.params.insert(name.to_string(), data);
187 id
188 }
189
190 fn linear(
191 &mut self,
192 x: HirNodeId,
193 lw: &LinearWeights,
194 name: &str,
195 out_shape: Shape,
196 ) -> Result<HirNodeId> {
197 let w = self.register_param(
198 &format!("{name}.weight"),
199 lw.w_t.clone(),
200 Shape::new(&[lw.in_dim, lw.out_dim], self.f),
201 );
202 let bias = if lw.bias.iter().all(|&v| v == 0.0) {
203 None
204 } else {
205 let b = self.register_param(
206 &format!("{name}.bias"),
207 lw.bias.clone(),
208 Shape::new(&[lw.out_dim], self.f),
209 );
210 Some(b)
211 };
212 Ok(self.hir.linear(x, w, bias, None, out_shape))
213 }
214
215 fn rms_norm(
216 &mut self,
217 x: HirNodeId,
218 gamma: &RmsNormWeight,
219 name: &str,
220 shape: Shape,
221 ) -> HirNodeId {
222 let g = self.register_param(
223 &format!("{name}.weight"),
224 gamma.scale.clone(),
225 Shape::new(&[gamma.scale.len()], self.f),
226 );
227 let beta = self.register_param(
228 &format!("{name}.beta"),
229 vec![0.0f32; gamma.scale.len()],
230 Shape::new(&[gamma.scale.len()], self.f),
231 );
232 self.hir.mir(
233 Op::RmsNorm {
234 axis: -1,
235 eps: self.eps,
236 },
237 vec![x, g, beta],
238 shape,
239 )
240 }
241
242 fn per_head_rms(
243 &mut self,
244 x: HirNodeId,
245 gamma: &RmsNormWeight,
246 name: &str,
247 heads: usize,
248 ) -> HirNodeId {
249 let hd = self.cfg.head_dim;
250 let flat = self.reshape(x, vec![(self.batch * self.seq * heads) as i64, hd as i64]);
251 let n = self.rms_norm(
252 flat,
253 gamma,
254 name,
255 Shape::new(&[self.batch * self.seq * heads, hd], self.f),
256 );
257 self.reshape(
258 n,
259 vec![self.batch as i64, self.seq as i64, (heads * hd) as i64],
260 )
261 }
262
263 pub(crate) fn layer_forward(
264 &mut self,
265 layer: &Flux2TextEncoderLayerWeights,
266 li: usize,
267 x: HirNodeId,
268 cos: HirNodeId,
269 sin: HirNodeId,
270 ) -> Result<HirNodeId> {
271 let lp = format!("layers.{li}");
272 let shape = self.bsh();
273 let normed = self.rms_norm(
274 x,
275 &layer.input_layernorm,
276 &format!("{lp}.in_ln"),
277 shape.clone(),
278 );
279 let attn_out = self.attn_forward(&layer.attn, &format!("{lp}.attn"), normed, cos, sin)?;
280 let post_attn = self.add(x, attn_out, shape.clone());
281 let mlp_out = self.mlp_forward(
282 &layer.mlp,
283 &layer.post_attention_layernorm,
284 &format!("{lp}.mlp"),
285 post_attn,
286 )?;
287 Ok(self.add(post_attn, mlp_out, shape))
288 }
289
290 fn attn_forward(
291 &mut self,
292 attn: &Flux2TextEncoderAttnWeights,
293 tag: &str,
294 x: HirNodeId,
295 cos: HirNodeId,
296 sin: HirNodeId,
297 ) -> Result<HirNodeId> {
298 let nh = self.cfg.num_attention_heads;
299 let nkv = self.cfg.num_key_value_heads;
300 let hd = self.cfg.head_dim;
301 let group = nh / nkv;
302 let shape = self.bsh();
303
304 let q = self.linear(x, &attn.q, &format!("{tag}.q"), self.bsh_heads(nh))?;
305 let k = self.linear(x, &attn.k, &format!("{tag}.k"), self.bsh_heads(nkv))?;
306 let v = self.linear(x, &attn.v, &format!("{tag}.v"), self.bsh_heads(nkv))?;
307
308 let q = self.per_head_rms(q, &attn.q_norm, &format!("{tag}.nq"), nh);
309 let k = self.per_head_rms(k, &attn.k_norm, &format!("{tag}.nk"), nkv);
310
311 let qh = self.bsh_heads(nh);
312 let q = self.rope(q, cos, sin, qh.clone());
313 let k = self.rope(k, cos, sin, self.bsh_heads(nkv));
314 let k_rep = self.repeat_kv(k, nkv, hd, group);
315 let v_rep = self.repeat_kv(v, nkv, hd, group);
316
317 let attn_out =
318 self.hir
319 .attention(q, k_rep, v_rep, None, nh, hd, MaskKind::Causal, qh.clone());
320 self.linear(attn_out, &attn.o, &format!("{tag}.o"), shape)
321 }
322
323 fn mlp_forward(
324 &mut self,
325 mlp: &Flux2TextEncoderMlpWeights,
326 post_ln: &RmsNormWeight,
327 tag: &str,
328 x: HirNodeId,
329 ) -> Result<HirNodeId> {
330 let rows = self.batch * self.seq;
331 let h = self.cfg.hidden_size;
332 let ff = self.cfg.intermediate_size;
333 let flat = self.reshape(x, vec![rows as i64, h as i64]);
334 let flat = self.rms_norm(
335 flat,
336 post_ln,
337 &format!("{tag}.post_ln"),
338 Shape::new(&[rows, h], self.f),
339 );
340 let gate = self.linear(
341 flat,
342 &mlp.gate,
343 &format!("{tag}.gate"),
344 Shape::new(&[rows, ff], self.f),
345 )?;
346 let up = self.linear(
347 flat,
348 &mlp.up,
349 &format!("{tag}.up"),
350 Shape::new(&[rows, ff], self.f),
351 )?;
352 let gate3 = self.reshape(gate, vec![self.batch as i64, self.seq as i64, ff as i64]);
353 let up3 = self.reshape(up, vec![self.batch as i64, self.seq as i64, ff as i64]);
354 let silu = self.hir.mir(
355 Op::Activation(Activation::Silu),
356 vec![gate3],
357 Shape::new(&[self.batch, self.seq, ff], self.f),
358 );
359 let prod = self.mul(silu, up3, Shape::new(&[self.batch, self.seq, ff], self.f));
360 let prod_flat = self.reshape(prod, vec![rows as i64, ff as i64]);
361 self.linear(
362 prod_flat,
363 &mlp.down,
364 &format!("{tag}.down"),
365 Shape::new(&[rows, h], self.f),
366 )
367 .map(|o| self.reshape(o, vec![self.batch as i64, self.seq as i64, h as i64]))
368 }
369
370 fn mul(&mut self, a: HirNodeId, b: HirNodeId, shape: Shape) -> HirNodeId {
371 self.hir.mir(Op::Binary(BinaryOp::Mul), vec![a, b], shape)
372 }
373
374 fn repeat_kv(&mut self, x: HirNodeId, nkv: usize, hd: usize, group: usize) -> HirNodeId {
375 if group == 1 {
376 return x;
377 }
378 let last = 2;
379 let slice_shape = Shape::new(&[self.batch, self.seq, hd], self.f);
380 let out_shape = Shape::new(&[self.batch, self.seq, nkv * group * hd], self.f);
381 let mut pieces = Vec::with_capacity(nkv * group);
382 for h in 0..nkv {
383 let slice = self.narrow(x, last, h * hd, hd, slice_shape.clone());
384 for _ in 0..group {
385 pieces.push(slice);
386 }
387 }
388 self.concat(pieces, last, out_shape)
389 }
390
391 pub(crate) fn rope_tables(&mut self) -> Result<(HirNodeId, HirNodeId)> {
392 let dh = self.cfg.head_dim;
393 let half = dh / 2;
394 let max_pos = self.cfg.max_position_embeddings;
395 let mut cos_data = vec![0f32; max_pos * dh];
396 let mut sin_data = vec![0f32; max_pos * dh];
397 for pos in 0..max_pos {
398 for i in 0..half {
399 let freq = 1.0 / self.cfg.rope_theta.powf((2 * i) as f64 / dh as f64);
400 let angle = pos as f64 * freq;
401 let (s, c) = angle.sin_cos();
402 cos_data[pos * dh + 2 * i] = c as f32;
403 cos_data[pos * dh + 2 * i + 1] = c as f32;
404 sin_data[pos * dh + 2 * i] = s as f32;
405 sin_data[pos * dh + 2 * i + 1] = s as f32;
406 }
407 }
408 let cos = self.register_param("rope.cos", cos_data, Shape::new(&[max_pos, dh], self.f));
409 let sin = self.register_param("rope.sin", sin_data, Shape::new(&[max_pos, dh], self.f));
410 Ok((cos, sin))
411 }
412
413 fn rope(&mut self, x: HirNodeId, cos: HirNodeId, sin: HirNodeId, shape: Shape) -> HirNodeId {
414 self.hir.mir(
415 Op::Rope {
416 head_dim: self.cfg.head_dim,
417 n_rot: self.cfg.head_dim,
418 },
419 vec![x, cos, sin],
420 shape,
421 )
422 }
423
424 fn reshape(&mut self, x: HirNodeId, new_shape: Vec<i64>) -> HirNodeId {
425 let shape = Shape::new(
426 &new_shape.iter().map(|&d| d as usize).collect::<Vec<_>>(),
427 self.f,
428 );
429 self.hir.mir(Op::Reshape { new_shape }, vec![x], shape)
430 }
431
432 fn narrow(
433 &mut self,
434 x: HirNodeId,
435 axis: usize,
436 start: usize,
437 len: usize,
438 shape: Shape,
439 ) -> HirNodeId {
440 self.hir
441 .mir(Op::Narrow { axis, start, len }, vec![x], shape)
442 }
443
444 fn concat(&mut self, inputs: Vec<HirNodeId>, axis: usize, shape: Shape) -> HirNodeId {
445 self.hir.mir(Op::Concat { axis }, inputs, shape)
446 }
447
448 fn add(&mut self, a: HirNodeId, b: HirNodeId, shape: Shape) -> HirNodeId {
449 self.hir.mir(Op::Binary(BinaryOp::Add), vec![a, b], shape)
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use crate::text_encoder::{
457 TINY_TEXT_ENCODER_LAYERS, encode_prompt_embeds, synthetic_text_encoder_weights,
458 tiny_text_encoder_config,
459 };
460 use rlx_runtime::Device;
461
462 #[test]
463 fn text_encoder_hir_lowers() {
464 let cfg = tiny_text_encoder_config();
465 let w = synthetic_text_encoder_weights(&cfg);
466 let g = build_flux2_text_encoder_hir(&cfg, &w, 1, 4, TINY_TEXT_ENCODER_LAYERS).unwrap();
467 g.hir.lower_to_mir().expect("lower");
468 }
469
470 #[test]
471 fn compiled_single_layer_hidden_matches_native() {
472 let cfg = tiny_text_encoder_config();
473 let w = synthetic_text_encoder_weights(&cfg);
474 let layers = [1usize];
475 let batch = 1usize;
476 let seq = 4usize;
477 let ids: Vec<u32> = (0..seq as u32).collect();
478 let ids_f32: Vec<f32> = ids.iter().map(|&x| x as f32).collect();
479 let native = encode_prompt_embeds(&w, &cfg, &ids, batch, seq, &layers).unwrap();
480 let (mut compiled, _) =
481 compile_flux2_text_encoder_hir(&cfg, &w, batch, seq, &layers, Device::Cpu, None)
482 .unwrap();
483 let out = compiled.run(&[("input_ids", ids_f32.as_slice())]).remove(0);
484 assert_eq!(out.len(), native.prompt_embeds.len());
485 let max = out
486 .iter()
487 .zip(&native.prompt_embeds)
488 .map(|(a, b)| (a - b).abs())
489 .fold(0.0f32, f32::max);
490 assert!(max < 2e-2, "single layer max_abs_diff={max}");
491 }
492
493 #[test]
494 fn compiled_text_encoder_matches_native() {
495 let cfg = tiny_text_encoder_config();
496 let w = synthetic_text_encoder_weights(&cfg);
497 let batch = 1usize;
498 let seq = 4usize;
499 let ids: Vec<u32> = (0..seq as u32).collect();
500 let ids_f32: Vec<f32> = ids.iter().map(|&x| x as f32).collect();
501 let layers = TINY_TEXT_ENCODER_LAYERS;
502
503 let native = encode_prompt_embeds(&w, &cfg, &ids, batch, seq, layers).unwrap();
504
505 let (mut compiled, _) =
506 compile_flux2_text_encoder_hir(&cfg, &w, batch, seq, layers, Device::Cpu, None)
507 .unwrap();
508 let out = compiled.run(&[("input_ids", ids_f32.as_slice())]).remove(0);
509
510 assert_eq!(out.len(), native.prompt_embeds.len());
511 let max = out
512 .iter()
513 .zip(&native.prompt_embeds)
514 .map(|(a, b)| (a - b).abs())
515 .fold(0.0f32, f32::max);
516 assert!(max < 2e-2, "HIR vs native max_abs_diff={max}");
517 }
518}