rlx_flux2/text_encoder/
prompt.rs1use super::forward::{Flux2PromptOutput, encode_prompt_embeds};
19use super::tokenizer::encode_prompt_padded;
20use super::weights::Flux2TextEncoderWeights;
21use anyhow::Result;
22use rlx_qwen3::Qwen3Config;
23use std::path::Path;
24
25pub const DEFAULT_TEXT_ENCODER_LAYERS: &[usize] = &[9, 18, 27];
27
28pub const TINY_TEXT_ENCODER_LAYERS: &[usize] = &[1, 2];
30
31pub fn prepare_text_ids(batch: usize, seq: usize) -> Vec<f32> {
33 let mut ids = vec![0.0f32; batch * seq * 4];
34 for b in 0..batch {
35 for t in 0..seq {
36 let base = (b * seq + t) * 4;
37 ids[base + 3] = t as f32;
38 }
39 }
40 ids
41}
42
43pub fn tiny_text_encoder_config() -> Qwen3Config {
45 Qwen3Config {
46 vocab_size: 32,
47 hidden_size: 8,
48 intermediate_size: 32,
49 num_hidden_layers: 2,
50 num_attention_heads: 2,
51 num_key_value_heads: 1,
52 head_dim: 4,
53 max_position_embeddings: 128,
54 rms_norm_eps: 1e-6,
55 rope_theta: 1_000_000.0,
56 hidden_act: "silu".into(),
57 tie_word_embeddings: true,
58 attention_bias: true,
59 qk_norm: true,
60 sliding_window: None,
61 max_window_layers: usize::MAX,
62 use_sliding_window: false,
63 num_experts: 0,
64 num_experts_used: 0,
65 expert_ffn_size: 0,
66 shared_expert_ffn_size: 0,
67 expert_weights_scale: 1.0,
68 }
69}
70
71pub fn encode_flux2_prompt(
73 te_weights: &Flux2TextEncoderWeights,
74 te_cfg: &Qwen3Config,
75 input_ids: &[u32],
76 batch: usize,
77 seq: usize,
78 hidden_state_layers: &[usize],
79) -> Result<(Flux2PromptOutput, Vec<f32>)> {
80 let out = encode_prompt_embeds(
81 te_weights,
82 te_cfg,
83 input_ids,
84 batch,
85 seq,
86 hidden_state_layers,
87 )?;
88 let txt_ids = prepare_text_ids(batch, seq);
89 Ok((out, txt_ids))
90}
91
92pub fn resolve_text_encoder_dir(model_path: &Path) -> Option<std::path::PathBuf> {
94 crate::paths::find_component_dir(model_path, "text_encoder")
95}
96
97pub fn tokenize_flux2_prompt(
99 tokenizer_path: &Path,
100 prompt: &str,
101 seq_len: usize,
102) -> Result<Vec<u32>> {
103 encode_prompt_padded(tokenizer_path, prompt, seq_len)
104}