1use std::sync::{Arc, Mutex};
19
20use anyhow::{Result, ensure};
21use rlx_flow::{BuiltModel, MapWeights, ModelFlow, plugin_named};
22use rlx_ir::{DType, HirNodeId, Shape};
23
24use super::hir_builder::TextEncoderHirBuilder;
25use super::weights::Flux2TextEncoderWeights;
26use rlx_qwen3::Qwen3Config;
27
28const ROPE_COS: &str = "flux2_te.rope_cos";
29const ROPE_SIN: &str = "flux2_te.rope_sin";
30
31#[derive(Clone)]
33pub struct Flux2TextEncoderFlow<'a> {
34 cfg: &'a Qwen3Config,
35 weights: &'a Flux2TextEncoderWeights,
36 batch: usize,
37 seq: usize,
38 hidden_state_layers: Vec<usize>,
39}
40
41impl<'a> Flux2TextEncoderFlow<'a> {
42 pub fn new(
43 cfg: &'a Qwen3Config,
44 weights: &'a Flux2TextEncoderWeights,
45 batch: usize,
46 seq: usize,
47 hidden_state_layers: &[usize],
48 ) -> Self {
49 Self {
50 cfg,
51 weights,
52 batch,
53 seq,
54 hidden_state_layers: hidden_state_layers.to_vec(),
55 }
56 }
57
58 pub fn build(self) -> Result<Flux2TextEncoderBuilt> {
59 build_flux2_text_encoder_built(
60 self.cfg,
61 self.weights,
62 self.batch,
63 self.seq,
64 &self.hidden_state_layers,
65 )
66 }
67}
68
69pub struct Flux2TextEncoderBuilt {
70 pub model: BuiltModel,
71 pub joint_dim: usize,
72}
73
74pub fn build_flux2_text_encoder_built(
75 cfg: &Qwen3Config,
76 weights: &Flux2TextEncoderWeights,
77 batch: usize,
78 seq: usize,
79 hidden_state_layers: &[usize],
80) -> Result<Flux2TextEncoderBuilt> {
81 ensure!(
82 cfg.num_attention_heads
83 .is_multiple_of(cfg.num_key_value_heads),
84 "num_attention_heads must divide num_key_value_heads"
85 );
86 let joint_dim = cfg.hidden_size * hidden_state_layers.len();
87 let h = cfg.hidden_size;
88 let f = DType::F32;
89 let hidden_shape = Shape::new(&[batch, seq, h], f);
90 let out_shape = Shape::new(&[batch, seq, joint_dim], f);
91
92 let cfg = cfg.clone();
93 let weights = weights.clone();
94 let hidden_state_layers = hidden_state_layers.to_vec();
95 let checkpoints: Arc<Mutex<Vec<HirNodeId>>> = Arc::new(Mutex::new(Vec::new()));
96 let embed_hidden_shape = hidden_shape.clone();
97 let layer_hidden_shape = hidden_shape.clone();
98
99 let mut flow = ModelFlow::new("flux2_text_encoder")
100 .input("input_ids", Shape::new(&[batch, seq], f))
101 .plugin_named("flux2_te.embed", {
102 let cfg = cfg.clone();
103 let weights = weights.clone();
104 let checkpoints = checkpoints.clone();
105 move |emit, _| {
106 let ids = emit.flow_input("input_ids")?.hir_id();
107 let (hir, params) = emit.hir_and_params();
108 let mut b =
109 TextEncoderHirBuilder::from_emit_parts(hir, params, &cfg, &weights, batch, seq);
110 let hidden = b.emit_embed(ids)?;
111 checkpoints.lock().unwrap().push(hidden);
112 Ok(Some(emit.wrap(hidden, embed_hidden_shape.clone())))
113 }
114 })
115 .plugin_named("flux2_te.rope", {
116 let cfg = cfg.clone();
117 let weights = weights.clone();
118 move |emit, primary| {
119 let (hir, params) = emit.hir_and_params();
120 let mut b =
121 TextEncoderHirBuilder::from_emit_parts(hir, params, &cfg, &weights, batch, seq);
122 let (cos, sin) = b.rope_tables()?;
123 emit.set_named(ROPE_COS, cos);
124 emit.set_named(ROPE_SIN, sin);
125 Ok(primary)
126 }
127 });
128
129 for (li, layer) in weights.layers.iter().enumerate() {
130 let layer = layer.clone();
131 let cfg = cfg.clone();
132 let weights = weights.clone();
133 let checkpoints = checkpoints.clone();
134 let layer_shape = layer_hidden_shape.clone();
135 flow = flow.raw_stage(plugin_named(
136 format!("flux2_te.layer{li}"),
137 move |emit, input| {
138 let hidden =
139 input.ok_or_else(|| anyhow::anyhow!("text encoder layer requires hidden"))?;
140 let cos = emit.named(ROPE_COS)?;
141 let sin = emit.named(ROPE_SIN)?;
142 let (hir, params) = emit.hir_and_params();
143 let mut b =
144 TextEncoderHirBuilder::from_emit_parts(hir, params, &cfg, &weights, batch, seq);
145 let out = b.layer_forward(&layer, li, hidden.hir_id(), cos, sin)?;
146 checkpoints.lock().unwrap().push(out);
147 Ok(Some(emit.wrap(out, layer_shape.clone())))
148 },
149 ));
150 }
151
152 let built = flow
153 .plugin_named("flux2_te.joint", {
154 let cfg = cfg.clone();
155 let weights = weights.clone();
156 let checkpoints = checkpoints.clone();
157 let hidden_state_layers = hidden_state_layers.clone();
158 move |emit, primary| {
159 let hidden = primary
160 .ok_or_else(|| anyhow::anyhow!("joint output requires hidden"))?
161 .hir_id();
162 let ckpts = checkpoints.lock().unwrap().clone();
163 let (hir, params) = emit.hir_and_params();
164 let mut b =
165 TextEncoderHirBuilder::from_emit_parts(hir, params, &cfg, &weights, batch, seq);
166 let out = b.emit_joint_output(&ckpts, &hidden_state_layers, joint_dim)?;
167 let _ = hidden;
168 Ok(Some(emit.wrap(out, out_shape.clone())))
169 }
170 })
171 .output("prompt_embeds")
172 .build(&mut MapWeights::default())?;
173
174 Ok(Flux2TextEncoderBuilt {
175 model: built,
176 joint_dim,
177 })
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::text_encoder::{
184 TINY_TEXT_ENCODER_LAYERS, build_flux2_text_encoder_hir, synthetic_text_encoder_weights,
185 tiny_text_encoder_config,
186 };
187
188 #[test]
189 fn text_encoder_flow_matches_hir_node_count() {
190 let cfg = tiny_text_encoder_config();
191 let w = synthetic_text_encoder_weights(&cfg);
192 let batch = 1;
193 let seq = 4;
194 let layers = TINY_TEXT_ENCODER_LAYERS;
195
196 let ref_hir = build_flux2_text_encoder_hir(&cfg, &w, batch, seq, layers)
197 .unwrap()
198 .hir;
199 let built = Flux2TextEncoderFlow::new(&cfg, &w, batch, seq, layers)
200 .build()
201 .unwrap();
202 let flow_hir = built.model.into_hir().unwrap();
203
204 assert_eq!(
205 flow_hir.len(),
206 ref_hir.len(),
207 "text encoder flow should match hir_builder node count (flow={}, builder={})",
208 flow_hir.len(),
209 ref_hir.len()
210 );
211 }
212}