Skip to main content

rlx_flux2/text_encoder/
prompt.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! FLUX.2 text position ids and prompt encoding helpers.
17
18use super::forward::{Flux2PromptOutput, encode_prompt_embeds};
19use super::tokenizer::encode_prompt_padded;
20use super::weights::Flux2TextEncoderWeights;
21use anyhow::Result;
22use rlx_qwen3::Qwen3Config;
23use std::path::Path;
24
25/// Default hidden-state indices for FLUX.2 Klein (matches mflux).
26pub const DEFAULT_TEXT_ENCODER_LAYERS: &[usize] = &[9, 18, 27];
27
28/// Tiny config for basic tests (2 layers → use layers `[1, 2]`).
29pub const TINY_TEXT_ENCODER_LAYERS: &[usize] = &[1, 2];
30
31/// Build FLUX.2-style text position ids `[batch, seq, 4]` flattened as `[seq*4]`.
32pub fn prepare_text_ids(batch: usize, seq: usize) -> Vec<f32> {
33    let mut ids = vec![0.0f32; batch * seq * 4];
34    for b in 0..batch {
35        for t in 0..seq {
36            let base = (b * seq + t) * 4;
37            ids[base + 3] = t as f32;
38        }
39    }
40    ids
41}
42
43/// `Qwen3Config` sized for [`super::weights::synthetic_text_encoder_weights`] tests.
44pub fn tiny_text_encoder_config() -> Qwen3Config {
45    Qwen3Config {
46        vocab_size: 32,
47        hidden_size: 8,
48        intermediate_size: 32,
49        num_hidden_layers: 2,
50        num_attention_heads: 2,
51        num_key_value_heads: 1,
52        head_dim: 4,
53        max_position_embeddings: 128,
54        rms_norm_eps: 1e-6,
55        rope_theta: 1_000_000.0,
56        hidden_act: "silu".into(),
57        tie_word_embeddings: true,
58        attention_bias: true,
59        qk_norm: true,
60        sliding_window: None,
61        max_window_layers: usize::MAX,
62        use_sliding_window: false,
63        num_experts: 0,
64        num_experts_used: 0,
65        expert_ffn_size: 0,
66        shared_expert_ffn_size: 0,
67        expert_weights_scale: 1.0,
68    }
69}
70
71/// End-to-end: tokenize (optional) + text encoder → embeddings + text ids.
72pub fn encode_flux2_prompt(
73    te_weights: &Flux2TextEncoderWeights,
74    te_cfg: &Qwen3Config,
75    input_ids: &[u32],
76    batch: usize,
77    seq: usize,
78    hidden_state_layers: &[usize],
79) -> Result<(Flux2PromptOutput, Vec<f32>)> {
80    let out = encode_prompt_embeds(
81        te_weights,
82        te_cfg,
83        input_ids,
84        batch,
85        seq,
86        hidden_state_layers,
87    )?;
88    let txt_ids = prepare_text_ids(batch, seq);
89    Ok((out, txt_ids))
90}
91
92/// Resolve `text_encoder/` next to a transformer weights file or model root.
93pub fn resolve_text_encoder_dir(model_path: &Path) -> Option<std::path::PathBuf> {
94    crate::paths::find_component_dir(model_path, "text_encoder")
95}
96
97/// Load tokenizer + encode with right padding to `seq_len`.
98pub fn tokenize_flux2_prompt(
99    tokenizer_path: &Path,
100    prompt: &str,
101    seq_len: usize,
102) -> Result<Vec<u32>> {
103    encode_prompt_padded(tokenizer_path, prompt, seq_len)
104}