use std::collections::HashMap;
use std::path::Path;
use mlx_rs::module::{Module, Param};
use mlx_rs::nn;
use mlx_rs::ops;
use mlx_rs::ops::indexing::{take_axis, IndexOp};
use mlx_rs::Array;
use tracing::info;
use super::mlx::{
build_qembedding, build_qlinear, load_all_tensors, QEmbedding, QLinear, QuantConfig,
};
use crate::tasks::generate_image::{GenerateImageRequest, GenerateImageResult};
use crate::InferenceError;
#[derive(Debug, Clone)]
pub struct FluxConfig {
pub hidden_dim: usize,
pub head_dim: usize,
pub num_heads: usize,
pub mlp_dim: usize,
pub num_double_blocks: usize,
pub num_single_blocks: usize,
pub clip_hidden: usize,
pub clip_layers: usize,
pub clip_vocab: usize,
pub clip_max_seq: usize,
pub t5_hidden: usize,
pub t5_layers: usize,
pub t5_vocab: usize,
pub t5_mlp_dim: usize,
pub patch_dim: usize,
pub vae_latent_channels: usize,
pub quant: Option<QuantConfig>,
}
impl Default for FluxConfig {
fn default() -> Self {
Self {
hidden_dim: 3072,
head_dim: 128,
num_heads: 24,
mlp_dim: 12288,
num_double_blocks: 8,
num_single_blocks: 38,
clip_hidden: 768,
clip_layers: 12,
clip_vocab: 49408,
clip_max_seq: 77,
t5_hidden: 4096,
t5_layers: 24,
t5_vocab: 32128,
t5_mlp_dim: 10240,
patch_dim: 64,
vae_latent_channels: 16,
quant: Some(QuantConfig {
group_size: 64,
bits: 4,
}),
}
}
}
fn get_tensor(tensors: &HashMap<String, Array>, key: &str) -> Result<Array, InferenceError> {
tensors
.get(key)
.cloned()
.ok_or_else(|| InferenceError::InferenceFailed(format!("missing tensor: {key}")))
}
fn build_dense_linear(
tensors: &HashMap<String, Array>,
prefix: &str,
) -> Result<nn::Linear, InferenceError> {
let weight = get_tensor(tensors, &format!("{prefix}.weight"))?;
let bias = tensors.get(&format!("{prefix}.bias")).cloned();
Ok(nn::Linear {
weight: Param::new(weight),
bias: Param::new(bias),
})
}
fn build_layer_norm(
tensors: &HashMap<String, Array>,
prefix: &str,
eps: f32,
) -> Result<LayerNorm, InferenceError> {
let weight = get_tensor(tensors, &format!("{prefix}.weight"))?;
let bias = tensors.get(&format!("{prefix}.bias")).cloned();
Ok(LayerNorm { weight, bias, eps })
}
fn build_group_norm(
tensors: &HashMap<String, Array>,
prefix: &str,
num_groups: usize,
eps: f32,
) -> Result<GroupNorm, InferenceError> {
let weight = get_tensor(tensors, &format!("{prefix}.weight"))?;
let bias = tensors.get(&format!("{prefix}.bias")).cloned();
Ok(GroupNorm {
weight,
bias,
num_groups,
eps,
})
}
fn dump_flux_stage_first_call(name: &str, t: &Array) {
use std::sync::Mutex;
static SEEN: Mutex<Option<std::collections::HashSet<String>>> = Mutex::new(None);
let mut g = SEEN.lock().unwrap();
let set = g.get_or_insert_with(std::collections::HashSet::new);
if !set.insert(name.to_string()) {
return;
}
drop(g);
dump_flux_stage(name, t);
}
fn dump_flux_stage(name: &str, t: &Array) {
use std::sync::OnceLock;
static DIR: OnceLock<Option<String>> = OnceLock::new();
let dir = DIR.get_or_init(|| std::env::var("CAR_DUMP_FLUX_STAGE").ok());
let Some(dir) = dir else {
return;
};
let _ = std::fs::create_dir_all(&dir);
let Ok(t_f32) = t.as_dtype(mlx_rs::Dtype::Float32) else {
return;
};
let _ = mlx_rs::transforms::eval([&t_f32]);
let shape = t_f32.shape().to_vec();
let data: &[f32] = t_f32.as_slice();
let bin_path = format!("{dir}/{name}.bin");
let meta_path = format!("{dir}/{name}.meta");
let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
let _ = std::fs::write(&bin_path, &bytes);
let _ = std::fs::write(&meta_path, format!("{shape:?}\n"));
}
const FLUX_ROPE_AXES_DIM: [usize; 3] = [16, 56, 56];
const FLUX_ROPE_THETA: f32 = 10000.0;
#[derive(Clone)]
struct FluxRope {
cos: Array,
sin: Array,
}
fn flux_rope_axis(
positions: &[f32],
axis_dim: usize,
theta: f32,
) -> Result<(Array, Array), mlx_rs::error::Exception> {
let half = axis_dim / 2;
let omega: Vec<f32> = (0..half)
.map(|i| 1.0 / theta.powf(2.0 * i as f32 / axis_dim as f32))
.collect();
let n = positions.len();
let mut angles = vec![0.0f32; n * half];
for (p_idx, &p) in positions.iter().enumerate() {
for (k, &om) in omega.iter().enumerate() {
angles[p_idx * half + k] = p * om;
}
}
let angles = Array::from_slice(&angles, &[1, n as i32, half as i32]);
let cos = ops::cos(&angles)?;
let sin = ops::sin(&angles)?;
Ok((cos, sin))
}
fn flux_rope_build(
text_seq_len: usize,
h_patches: i32,
w_patches: i32,
) -> Result<FluxRope, mlx_rs::error::Exception> {
let img_seq_len = (h_patches * w_patches) as usize;
let total_seq = text_seq_len + img_seq_len;
let mut pos_axis0 = vec![0.0f32; total_seq];
let mut pos_axis1 = vec![0.0f32; total_seq];
let mut pos_axis2 = vec![0.0f32; total_seq];
for hi in 0..h_patches {
for wi in 0..w_patches {
let token = text_seq_len + (hi * w_patches + wi) as usize;
pos_axis1[token] = hi as f32;
pos_axis2[token] = wi as f32;
}
}
let (cos0, sin0) = flux_rope_axis(&pos_axis0, FLUX_ROPE_AXES_DIM[0], FLUX_ROPE_THETA)?;
let (cos1, sin1) = flux_rope_axis(&pos_axis1, FLUX_ROPE_AXES_DIM[1], FLUX_ROPE_THETA)?;
let (cos2, sin2) = flux_rope_axis(&pos_axis2, FLUX_ROPE_AXES_DIM[2], FLUX_ROPE_THETA)?;
let cos = ops::concatenate_axis(&[&cos0, &cos1, &cos2], -1)?;
let sin = ops::concatenate_axis(&[&sin0, &sin1, &sin2], -1)?;
let cos = ops::expand_dims(&cos, 1)?;
let sin = ops::expand_dims(&sin, 1)?;
Ok(FluxRope { cos, sin })
}
fn flux_apply_rope(x: &Array, rope: &FluxRope) -> Result<Array, mlx_rs::error::Exception> {
let x_dtype = x.dtype();
let x_f32 = if x_dtype == mlx_rs::Dtype::Float32 {
x.clone()
} else {
x.as_dtype(mlx_rs::Dtype::Float32)?
};
let shape = x_f32.shape();
let last = shape[shape.len() - 1];
let half = last / 2;
let mut paired_shape: Vec<i32> = shape.to_vec();
let ln = paired_shape.len();
paired_shape[ln - 1] = half;
paired_shape.push(2);
let x_pairs = ops::reshape(&x_f32, &paired_shape)?;
let x0 = x_pairs.index((.., .., .., .., 0..1));
let x1 = x_pairs.index((.., .., .., .., 1..2));
let squeeze_last = |t: Array| -> Result<Array, mlx_rs::error::Exception> {
let s = t.shape();
let ln = s.len();
let new_shape: Vec<i32> = s[..ln - 1].to_vec();
ops::reshape(&t, &new_shape)
};
let x0 = squeeze_last(x0)?; let x1 = squeeze_last(x1)?;
let cos = &rope.cos; let sin = &rope.sin;
let out0 = ops::subtract(&ops::multiply(&x0, cos)?, &ops::multiply(&x1, sin)?)?; let out1 = ops::add(&ops::multiply(&x0, sin)?, &ops::multiply(&x1, cos)?)?;
let out0_u = ops::expand_dims(&out0, -1)?; let out1_u = ops::expand_dims(&out1, -1)?;
let paired = ops::concatenate_axis(&[&out0_u, &out1_u], -1)?; let out = ops::reshape(&paired, shape)?;
let _ = x_dtype;
Ok(out)
}
fn flux_layer_norm_parameterless(x: &Array, eps: f32) -> Result<Array, mlx_rs::error::Exception> {
let x_dtype = x.dtype();
let x_f32 = if x_dtype == mlx_rs::Dtype::Float32 {
x.clone()
} else {
x.as_dtype(mlx_rs::Dtype::Float32)?
};
let mean = x_f32.mean_axes(&[-1], true)?;
let centered = ops::subtract(&x_f32, &mean)?;
let var = ops::multiply(¢ered, ¢ered)?.mean_axes(&[-1], true)?;
let eps_a = Array::from_f32(eps);
let inv = ops::rsqrt(&ops::add(&var, &eps_a)?)?;
let normed = ops::multiply(¢ered, &inv)?;
if x_dtype == mlx_rs::Dtype::Float32 {
Ok(normed)
} else {
normed.as_dtype(x_dtype)
}
}
struct LayerNorm {
weight: Array,
bias: Option<Array>,
eps: f32,
}
impl LayerNorm {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let x_dtype = x.dtype();
let x_f32 = if x_dtype == mlx_rs::Dtype::Float32 {
x.clone()
} else {
x.as_dtype(mlx_rs::Dtype::Float32)?
};
let mean = x_f32.mean_axes(&[-1], true)?;
let centered = ops::subtract(&x_f32, &mean)?;
let var = centered.multiply(¢ered)?.mean_axes(&[-1], true)?;
let eps = Array::from_f32(self.eps);
let inv_std = ops::rsqrt(&ops::add(&var, &eps)?)?;
let normed = ops::multiply(¢ered, &inv_std)?;
let normed = if x_dtype == mlx_rs::Dtype::Float32 {
normed
} else {
normed.as_dtype(x_dtype)?
};
let scaled = ops::multiply(&normed, &self.weight)?;
if let Some(ref bias) = self.bias {
ops::add(&scaled, bias)
} else {
Ok(scaled)
}
}
}
struct GroupNorm {
weight: Array,
bias: Option<Array>,
num_groups: usize,
eps: f32,
}
impl GroupNorm {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape();
let ndim = shape.len();
let num_channels = shape[ndim - 1] as usize;
let channels_per_group = num_channels / self.num_groups;
let mut group_shape: Vec<i32> = shape[..ndim - 1].to_vec();
group_shape.push(self.num_groups as i32);
group_shape.push(channels_per_group as i32);
let x_grouped = ops::reshape(x, &group_shape)?;
let mean = x_grouped.mean_axes(&[-1], true)?;
let centered = ops::subtract(&x_grouped, &mean)?;
let var = centered.multiply(¢ered)?.mean_axes(&[-1], true)?;
let eps = Array::from_f32(self.eps);
let inv_std = ops::rsqrt(&ops::add(&var, &eps)?)?;
let normed = ops::multiply(¢ered, &inv_std)?;
let out = ops::reshape(&normed, shape)?;
let scaled = ops::multiply(&out, &self.weight)?;
if let Some(ref bias) = self.bias {
ops::add(&scaled, bias)
} else {
Ok(scaled)
}
}
}
struct RmsNormPerHead {
weight: Array,
}
impl RmsNormPerHead {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let x_dtype = x.dtype();
let x_f32 = if x_dtype == mlx_rs::Dtype::Float32 {
x.clone()
} else {
x.as_dtype(mlx_rs::Dtype::Float32)?
};
let x_sq = ops::multiply(&x_f32, &x_f32)?;
let mean = x_sq.mean_axes(&[-1], true)?;
let eps = Array::from_f32(1e-6);
let norm = ops::rsqrt(&ops::add(&mean, &eps)?)?;
let normed_f32 = ops::multiply(&x_f32, &norm)?;
let normed = if x_dtype == mlx_rs::Dtype::Float32 {
normed_f32
} else {
normed_f32.as_dtype(x_dtype)?
};
ops::multiply(&normed, &self.weight)
}
}
struct ClipAttention {
q_proj: QLinear,
k_proj: QLinear,
v_proj: QLinear,
out_proj: QLinear,
num_heads: usize,
head_dim: usize,
}
impl ClipAttention {
fn forward(
&mut self,
x: &Array,
mask: Option<&Array>,
) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape();
let (batch, seq_len, _hidden) = (shape[0] as usize, shape[1] as usize, shape[2] as usize);
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
dump_flux_stage_first_call("clip_l0_q", &q);
dump_flux_stage_first_call("clip_l0_k", &k);
dump_flux_stage_first_call("clip_l0_v", &v);
let reshape_head = |t: Array| -> Result<Array, mlx_rs::error::Exception> {
let r = ops::reshape(
&t,
&[
batch as i32,
seq_len as i32,
self.num_heads as i32,
self.head_dim as i32,
],
)?;
ops::transpose_axes(&r, &[0, 2, 1, 3])
};
let q = reshape_head(q)?;
let k = reshape_head(k)?;
let v = reshape_head(v)?;
let out_dtype = q.dtype();
let to_f32 = |t: Array| -> Result<Array, mlx_rs::error::Exception> {
if t.dtype() == mlx_rs::Dtype::Float32 {
Ok(t)
} else {
t.as_dtype(mlx_rs::Dtype::Float32)
}
};
let q32 = to_f32(q)?;
let k32 = to_f32(k)?;
let v32 = to_f32(v)?;
let scale = Array::from_f32(1.0 / (self.head_dim as f32).sqrt());
let scores = ops::multiply(
&ops::matmul(&q32, &ops::transpose_axes(&k32, &[0, 1, 3, 2])?)?,
&scale,
)?;
let scores = if let Some(m) = mask {
let m32 = if m.dtype() == mlx_rs::Dtype::Float32 {
m.clone()
} else {
m.as_dtype(mlx_rs::Dtype::Float32)?
};
ops::add(&scores, &m32)?
} else {
scores
};
let attn = ops::softmax_axis(&scores, -1, None)?;
let out_f32 = ops::matmul(&attn, &v32)?;
let out = if out_dtype == mlx_rs::Dtype::Float32 {
out_f32
} else {
out_f32.as_dtype(out_dtype)?
};
let out = ops::transpose_axes(&out, &[0, 2, 1, 3])?;
let out = ops::reshape(
&out,
&[
batch as i32,
seq_len as i32,
(self.num_heads * self.head_dim) as i32,
],
)?;
dump_flux_stage_first_call("clip_l0_attn_pre_out", &out);
self.out_proj.forward(&out)
}
}
struct ClipMlp {
fc1: QLinear,
fc2: QLinear,
}
impl ClipMlp {
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = self.fc1.forward(x)?;
let coeff = Array::from_f32(1.702);
let scaled = ops::multiply(&h, &coeff)?;
let sig = ops::sigmoid(&scaled)?;
let activated = ops::multiply(&h, &sig)?;
self.fc2.forward(&activated)
}
}
struct ClipEncoderLayer {
self_attn: ClipAttention,
layer_norm1: LayerNorm,
mlp: ClipMlp,
layer_norm2: LayerNorm,
}
impl ClipEncoderLayer {
fn forward(
&mut self,
x: &Array,
mask: Option<&Array>,
) -> Result<Array, mlx_rs::error::Exception> {
let residual = x.clone();
let n1 = self.layer_norm1.forward(x)?;
let attn_out = self.self_attn.forward(&n1, mask)?;
let x_post_attn = ops::add(&residual, &attn_out)?;
let residual = x_post_attn.clone();
let n2 = self.layer_norm2.forward(&x_post_attn)?;
let mlp_out = self.mlp.forward(&n2)?;
let out = ops::add(&residual, &mlp_out)?;
dump_flux_stage_first_call("clip_l0_norm1", &n1);
dump_flux_stage_first_call("clip_l0_attn", &attn_out);
dump_flux_stage_first_call("clip_l0_post_attn", &x_post_attn);
dump_flux_stage_first_call("clip_l0_norm2", &n2);
dump_flux_stage_first_call("clip_l0_mlp", &mlp_out);
Ok(out)
}
}
fn build_clip_tokenizer_json(
vocab_path: &std::path::Path,
merges_path: &std::path::Path,
out_path: &std::path::Path,
) -> Result<(), String> {
let vocab_raw = std::fs::read_to_string(vocab_path)
.map_err(|e| format!("read {}: {e}", vocab_path.display()))?;
let vocab: serde_json::Value = serde_json::from_str(&vocab_raw)
.map_err(|e| format!("parse {}: {e}", vocab_path.display()))?;
let merges_raw = std::fs::read_to_string(merges_path)
.map_err(|e| format!("read {}: {e}", merges_path.display()))?;
let merges: Vec<[String; 2]> = merges_raw
.lines()
.filter(|l| !l.is_empty() && !l.starts_with('#'))
.filter_map(|l| {
let mut parts = l.splitn(2, ' ');
match (parts.next(), parts.next()) {
(Some(a), Some(b)) => Some([a.to_string(), b.to_string()]),
_ => None,
}
})
.collect();
let tokenizer_json = serde_json::json!({
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{
"id": 49406,
"content": "<|startoftext|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": true,
"special": true,
},
{
"id": 49407,
"content": "<|endoftext|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true,
},
],
"normalizer": {
"type": "Sequence",
"normalizers": [
{ "type": "NFC" },
{ "type": "Replace", "pattern": { "Regex": "\\s+" }, "content": " " },
{ "type": "Lowercase" },
],
},
"pre_tokenizer": {
"type": "Sequence",
"pretokenizers": [
{
"type": "Split",
"pattern": {
"Regex": "<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+"
},
"behavior": "Removed",
"invert": true,
},
{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": true,
"use_regex": true,
},
],
},
"post_processor": {
"type": "RobertaProcessing",
"sep": ["<|endoftext|>", 49407],
"cls": ["<|startoftext|>", 49406],
"trim_offsets": false,
"add_prefix_space": false,
},
"decoder": {
"type": "ByteLevel",
"add_prefix_space": true,
"trim_offsets": true,
"use_regex": true,
},
"model": {
"type": "BPE",
"dropout": null,
"unk_token": "<|endoftext|>",
"continuing_subword_prefix": "",
"end_of_word_suffix": "</w>",
"fuse_unk": false,
"byte_fallback": false,
"ignore_merges": false,
"vocab": vocab,
"merges": merges,
},
});
let pretty = serde_json::to_string_pretty(&tokenizer_json)
.map_err(|e| format!("serialize tokenizer.json: {e}"))?;
std::fs::write(out_path, pretty).map_err(|e| format!("write {}: {e}", out_path.display()))?;
Ok(())
}
struct ClipTextEncoder {
token_embedding: QEmbedding,
position_embedding: QEmbedding,
layers: Vec<ClipEncoderLayer>,
final_layer_norm: LayerNorm,
max_seq_len: usize,
}
impl ClipTextEncoder {
fn load(tensors: &HashMap<String, Array>, config: &FluxConfig) -> Result<Self, InferenceError> {
let quant = config.quant.as_ref();
let pfx = "text_encoders.clip.transformer";
let token_embedding = build_qembedding(
tensors,
&format!("{pfx}.text_model.embeddings.token_embedding"),
quant,
)?;
let position_embedding = build_qembedding(
tensors,
&format!("{pfx}.text_model.embeddings.position_embedding"),
quant,
)?;
let clip_heads = config.clip_hidden / 64; let clip_head_dim = 64;
let mut layers = Vec::with_capacity(config.clip_layers);
for i in 0..config.clip_layers {
let lpfx = format!("{pfx}.text_model.encoder.layers.{i}");
let layer = ClipEncoderLayer {
self_attn: ClipAttention {
q_proj: build_qlinear(tensors, &format!("{lpfx}.self_attn.q_proj"), quant)?,
k_proj: build_qlinear(tensors, &format!("{lpfx}.self_attn.k_proj"), quant)?,
v_proj: build_qlinear(tensors, &format!("{lpfx}.self_attn.v_proj"), quant)?,
out_proj: build_qlinear(tensors, &format!("{lpfx}.self_attn.out_proj"), quant)?,
num_heads: clip_heads,
head_dim: clip_head_dim,
},
layer_norm1: build_layer_norm(tensors, &format!("{lpfx}.layer_norm1"), 1e-5)?,
mlp: ClipMlp {
fc1: build_qlinear(tensors, &format!("{lpfx}.mlp.fc1"), quant)?,
fc2: build_qlinear(tensors, &format!("{lpfx}.mlp.fc2"), quant)?,
},
layer_norm2: build_layer_norm(tensors, &format!("{lpfx}.layer_norm2"), 1e-5)?,
};
layers.push(layer);
}
let final_layer_norm =
build_layer_norm(tensors, &format!("{pfx}.text_model.final_layer_norm"), 1e-5)?;
Ok(Self {
token_embedding,
position_embedding,
layers,
final_layer_norm,
max_seq_len: config.clip_max_seq,
})
}
fn forward(&mut self, token_ids: &Array) -> Result<Array, mlx_rs::error::Exception> {
let seq_len = token_ids.shape()[1] as usize;
let clamped_len = seq_len.min(self.max_seq_len);
let tok_emb = self.token_embedding.forward(token_ids)?;
let pos_ids = Array::from_slice(
&(0..clamped_len as i32).collect::<Vec<_>>(),
&[1, clamped_len as i32],
);
let pos_emb = self.position_embedding.forward(&pos_ids)?;
let mut h = ops::add(&tok_emb, &pos_emb)?;
dump_flux_stage("clip_embed", &h);
let mask = Self::causal_mask(clamped_len)?;
for (i, layer) in self.layers.iter_mut().enumerate() {
h = layer.forward(&h, Some(&mask))?;
if i < 3 || i == 11 {
dump_flux_stage(&format!("clip_layer{i:02}"), &h);
}
}
let out = self.final_layer_norm.forward(&h)?;
dump_flux_stage("clip_final_norm", &out);
Ok(out)
}
fn causal_mask(seq_len: usize) -> Result<Array, mlx_rs::error::Exception> {
let mut data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
if j > i {
data[i * seq_len + j] = f32::NEG_INFINITY;
}
}
}
let mask = Array::from_slice(&data, &[seq_len as i32, seq_len as i32]);
ops::reshape(&mask, &[1, 1, seq_len as i32, seq_len as i32])
}
}
fn t5_relative_position_bucket(seq_len: usize) -> &'static Array {
use std::cell::RefCell;
use std::collections::HashMap;
thread_local! {
static CACHE: RefCell<HashMap<usize, &'static Array>> =
RefCell::new(HashMap::new());
}
CACHE.with(|cell| {
if let Some(existing) = cell.borrow().get(&seq_len) {
return *existing;
}
const NUM_BUCKETS: i32 = 32;
const MAX_DISTANCE: i32 = 128;
let half_buckets = NUM_BUCKETS / 2; let max_exact = half_buckets / 2; let log_denom = ((MAX_DISTANCE as f32) / (max_exact as f32)).ln();
let mut data = vec![0i32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
let rel_pos = (j as i32) - (i as i32);
let mut bucket = if rel_pos > 0 { half_buckets } else { 0 };
let abs_pos = rel_pos.unsigned_abs() as i32;
if abs_pos < max_exact {
bucket += abs_pos;
} else {
let log_ratio = ((abs_pos as f32) / (max_exact as f32)).ln() / log_denom;
let b = max_exact + (log_ratio * (half_buckets - max_exact) as f32) as i32;
bucket += b.min(half_buckets - 1);
}
data[i * seq_len + j] = bucket;
}
}
let arr = Array::from_slice(&data, &[seq_len as i32 * seq_len as i32]);
let leaked: &'static Array = Box::leak(Box::new(arr));
cell.borrow_mut().insert(seq_len, leaked);
leaked
})
}
pub(crate) struct T5Attention {
q_proj: QLinear,
k_proj: QLinear,
v_proj: QLinear,
o_proj: QLinear,
num_heads: usize,
head_dim: usize,
relative_attention_bias: Option<QEmbedding>,
}
impl T5Attention {
fn forward(
&mut self,
x: &Array,
_position_bias: Option<&Array>,
) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape();
let (batch, seq_len, _) = (shape[0] as usize, shape[1] as usize, shape[2] as usize);
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let reshape_head = |t: Array| -> Result<Array, mlx_rs::error::Exception> {
let r = ops::reshape(
&t,
&[
batch as i32,
seq_len as i32,
self.num_heads as i32,
self.head_dim as i32,
],
)?;
ops::transpose_axes(&r, &[0, 2, 1, 3])
};
let q = reshape_head(q)?;
let k = reshape_head(k)?;
let v = reshape_head(v)?;
let scores = ops::matmul(&q, &ops::transpose_axes(&k, &[0, 1, 3, 2])?)?;
let scores = if let Some(bias_emb) = self.relative_attention_bias.as_mut() {
let bucket_indices = t5_relative_position_bucket(seq_len);
let bias_values = bias_emb.forward(bucket_indices)?;
let bias_2d = ops::reshape(
&bias_values,
&[seq_len as i32, seq_len as i32, self.num_heads as i32],
)?;
let bias_perm = ops::transpose_axes(&bias_2d, &[2, 0, 1])?;
let bias_4d = ops::reshape(
&bias_perm,
&[1, self.num_heads as i32, seq_len as i32, seq_len as i32],
)?;
ops::add(&scores, &bias_4d)?
} else if let Some(pos_bias) = _position_bias {
ops::add(&scores, pos_bias)?
} else {
scores
};
let attn = ops::softmax_axis(&scores, -1, None)?;
let out = ops::matmul(&attn, &v)?;
let out = ops::transpose_axes(&out, &[0, 2, 1, 3])?;
let out = ops::reshape(
&out,
&[
batch as i32,
seq_len as i32,
(self.num_heads * self.head_dim) as i32,
],
)?;
self.o_proj.forward(&out)
}
}
pub(crate) struct T5RmsNorm {
weight: Array,
eps: f32,
}
impl T5RmsNorm {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let x_sq_f32 = {
let x_f32 = if x.dtype() == mlx_rs::Dtype::Float32 {
x.clone()
} else {
x.as_dtype(mlx_rs::Dtype::Float32)?
};
ops::multiply(&x_f32, &x_f32)?
};
let mean = x_sq_f32.mean_axes(&[-1], true)?;
let eps = Array::from_f32(self.eps);
let norm = ops::rsqrt(&ops::add(&mean, &eps)?)?;
let normed = ops::multiply(x, &norm)?;
ops::multiply(&normed, &self.weight)
}
}
pub(crate) fn build_t5_rms_norm(
tensors: &HashMap<String, Array>,
prefix: &str,
eps: f32,
) -> Result<T5RmsNorm, InferenceError> {
let weight = get_tensor(tensors, &format!("{prefix}.weight"))?;
Ok(T5RmsNorm { weight, eps })
}
pub(crate) struct T5FeedForward {
wi_0: QLinear, wi_1: QLinear, wo: QLinear, }
impl T5FeedForward {
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let profile = std::env::var("CAR_T5_FFN_PROFILE").is_ok();
let mut t = std::time::Instant::now();
let mark = |label: &str,
t: &mut std::time::Instant,
arr: &Array|
-> Result<(), mlx_rs::error::Exception> {
mlx_rs::transforms::eval([arr])?;
tracing::info!(
label,
elapsed_ms = t.elapsed().as_millis() as u64,
"t5 ffn sub"
);
*t = std::time::Instant::now();
Ok(())
};
let wi0 = self.wi_0.forward(x)?;
if profile {
mark("wi_0", &mut t, &wi0)?;
}
let gate = nn::gelu_approximate(&wi0)?;
if profile {
mark("gelu", &mut t, &gate)?;
}
let up = self.wi_1.forward(x)?;
if profile {
mark("wi_1", &mut t, &up)?;
}
let activated = ops::multiply(&gate, &up)?;
if profile {
mark("mult", &mut t, &activated)?;
}
let out = self.wo.forward(&activated)?;
if profile {
mark("wo", &mut t, &out)?;
}
Ok(out)
}
}
pub(crate) struct T5Block {
self_attn: T5Attention,
norm1: T5RmsNorm,
ffn: T5FeedForward,
norm2: T5RmsNorm,
}
impl T5Block {
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let profile = std::env::var("CAR_T5_SUBPROFILE").is_ok();
let mut t = std::time::Instant::now();
let mark = |label: &str,
t: &mut std::time::Instant,
arr: &Array|
-> Result<(), mlx_rs::error::Exception> {
mlx_rs::transforms::eval([arr])?;
tracing::info!(label, elapsed_ms = t.elapsed().as_millis() as u64, "t5 sub");
*t = std::time::Instant::now();
Ok(())
};
let residual = x.clone();
let h = self.norm1.forward(x)?;
if profile {
mark("norm1", &mut t, &h)?;
}
let h = self.self_attn.forward(&h, None)?;
if profile {
mark("self_attn", &mut t, &h)?;
}
let x = ops::add(&residual, &h)?;
if profile {
mark("residual1", &mut t, &x)?;
}
let residual = x.clone();
let h = self.norm2.forward(&x)?;
if profile {
mark("norm2", &mut t, &h)?;
}
let h = self.ffn.forward(&h)?;
if profile {
mark("ffn", &mut t, &h)?;
}
let out = ops::add(&residual, &h)?;
if profile {
mark("residual2", &mut t, &out)?;
}
Ok(out)
}
}
pub(crate) struct T5TextEncoder {
shared_embedding: QEmbedding,
blocks: Vec<T5Block>,
final_norm: T5RmsNorm,
}
impl T5TextEncoder {
pub(crate) fn load(
tensors: &HashMap<String, Array>,
config: &FluxConfig,
) -> Result<Self, InferenceError> {
let quant = config.quant.as_ref();
let pfx = "text_encoders.t5.transformer";
let shared_embedding = build_qembedding(tensors, &format!("{pfx}.shared"), quant)?;
let t5_head_dim: usize = 64;
let t5_heads = config.t5_hidden / t5_head_dim;
let mut blocks = Vec::with_capacity(config.t5_layers);
for i in 0..config.t5_layers {
let bpfx = format!("{pfx}.t5_blocks.{i}");
let has_rel_bias =
tensors.contains_key(&format!("{bpfx}.self_attn.relative_attention_bias.weight"));
let rel_bias = if has_rel_bias {
Some(build_qembedding(
tensors,
&format!("{bpfx}.self_attn.relative_attention_bias"),
quant,
)?)
} else {
None
};
let block = T5Block {
self_attn: T5Attention {
q_proj: build_qlinear(tensors, &format!("{bpfx}.self_attn.q"), quant)?,
k_proj: build_qlinear(tensors, &format!("{bpfx}.self_attn.k"), quant)?,
v_proj: build_qlinear(tensors, &format!("{bpfx}.self_attn.v"), quant)?,
o_proj: build_qlinear(tensors, &format!("{bpfx}.self_attn.o"), quant)?,
num_heads: t5_heads,
head_dim: t5_head_dim,
relative_attention_bias: rel_bias,
},
norm1: build_t5_rms_norm(tensors, &format!("{bpfx}.norm1"), 1e-6)?,
ffn: T5FeedForward {
wi_0: build_qlinear(tensors, &format!("{bpfx}.ff.wi_0"), quant)?,
wi_1: build_qlinear(tensors, &format!("{bpfx}.ff.wi_1"), quant)?,
wo: build_qlinear(tensors, &format!("{bpfx}.ff.wo"), quant)?,
},
norm2: build_t5_rms_norm(tensors, &format!("{bpfx}.norm2"), 1e-6)?,
};
blocks.push(block);
}
let final_norm = build_t5_rms_norm(tensors, &format!("{pfx}.final_layer_norm"), 1e-6)?;
Ok(Self {
shared_embedding,
blocks,
final_norm,
})
}
pub(crate) fn forward(&mut self, token_ids: &Array) -> Result<Array, mlx_rs::error::Exception> {
let profile = std::env::var("CAR_FLUX_PROFILE").is_ok();
let mut h = self.shared_embedding.forward(token_ids)?;
dump_flux_stage("t5_embed", &h);
for (i, block) in self.blocks.iter_mut().enumerate() {
let t0 = std::time::Instant::now();
h = block.forward(&h)?;
if profile {
mlx_rs::transforms::eval([&h])?;
tracing::info!(
block = i,
elapsed_ms = t0.elapsed().as_millis() as u64,
"t5 block timing"
);
}
if i < 4 || i == 11 || i == 23 {
dump_flux_stage(&format!("t5_block{i:02}"), &h);
}
}
let out = self.final_norm.forward(&h)?;
dump_flux_stage("t5_final_norm", &out);
Ok(out)
}
}
struct TimestepEmbedder {
linear1: QLinear,
linear2: QLinear,
}
impl TimestepEmbedder {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
) -> Result<Self, InferenceError> {
Ok(Self {
linear1: build_qlinear(tensors, &format!("{prefix}.0"), quant)?,
linear2: build_qlinear(tensors, &format!("{prefix}.2"), quant)?,
})
}
fn forward(&mut self, t: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = self.linear1.forward(t)?;
let h = nn::silu(&h)?;
self.linear2.forward(&h)
}
}
struct TextEmbedder {
linear1: QLinear,
linear2: QLinear,
}
impl TextEmbedder {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
) -> Result<Self, InferenceError> {
Ok(Self {
linear1: build_qlinear(tensors, &format!("{prefix}.0"), quant)?,
linear2: build_qlinear(tensors, &format!("{prefix}.2"), quant)?,
})
}
fn forward(&mut self, t: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = self.linear1.forward(t)?;
let h = nn::silu(&h)?;
self.linear2.forward(&h)
}
}
struct TimeTextEmbed {
timestep_embedder: TimestepEmbedder,
text_embedder: TextEmbedder,
guidance_embedder: TimestepEmbedder,
}
impl TimeTextEmbed {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
) -> Result<Self, InferenceError> {
Ok(Self {
timestep_embedder: TimestepEmbedder::load(
tensors,
&format!("{prefix}.timestep_embedder"),
quant,
)?,
text_embedder: TextEmbedder::load(tensors, &format!("{prefix}.text_embedder"), quant)?,
guidance_embedder: TimestepEmbedder::load(
tensors,
&format!("{prefix}.guidance_embedder"),
quant,
)?,
})
}
fn forward(
&mut self,
timestep: &Array,
pooled_clip: &Array,
guidance: &Array,
) -> Result<Array, mlx_rs::error::Exception> {
let t_emb = self.timestep_embedder.forward(timestep)?;
let txt_emb = self.text_embedder.forward(pooled_clip)?;
let g_emb = self.guidance_embedder.forward(guidance)?;
let combined = ops::add(&t_emb, &txt_emb)?;
ops::add(&combined, &g_emb)
}
}
struct AdaLNModulation {
linear: QLinear,
num_outputs: usize,
}
impl AdaLNModulation {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
num_outputs: usize,
) -> Result<Self, InferenceError> {
Ok(Self {
linear: build_qlinear(tensors, &format!("{prefix}.linear"), quant)?,
num_outputs,
})
}
fn forward(&mut self, conditioning: &Array) -> Result<Vec<Array>, mlx_rs::error::Exception> {
let h = nn::silu(conditioning)?;
let out = self.linear.forward(&h)?;
let total_dim = out.shape().last().copied().unwrap_or(0) as usize;
let chunk_dim = total_dim / self.num_outputs;
let mut chunks = Vec::with_capacity(self.num_outputs);
for i in 0..self.num_outputs {
let start = (i * chunk_dim) as i32;
let end = ((i + 1) * chunk_dim) as i32;
let chunk = out.index((.., start..end));
chunks.push(chunk);
}
Ok(chunks)
}
}
struct JointAttention {
to_q: QLinear,
to_k: QLinear,
to_v: QLinear,
to_out_0: QLinear,
add_q_proj: QLinear,
add_k_proj: QLinear,
add_v_proj: QLinear,
to_add_out: QLinear,
norm_q: RmsNormPerHead,
norm_k: RmsNormPerHead,
norm_added_q: RmsNormPerHead,
norm_added_k: RmsNormPerHead,
num_heads: usize,
head_dim: usize,
}
impl JointAttention {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
config: &FluxConfig,
) -> Result<Self, InferenceError> {
let quant = config.quant.as_ref();
let apfx = &format!("{prefix}.attn");
Ok(Self {
to_q: build_qlinear(tensors, &format!("{apfx}.to_q"), quant)?,
to_k: build_qlinear(tensors, &format!("{apfx}.to_k"), quant)?,
to_v: build_qlinear(tensors, &format!("{apfx}.to_v"), quant)?,
to_out_0: build_qlinear(tensors, &format!("{apfx}.to_out.0"), quant)?,
add_q_proj: build_qlinear(tensors, &format!("{apfx}.add_q_proj"), quant)?,
add_k_proj: build_qlinear(tensors, &format!("{apfx}.add_k_proj"), quant)?,
add_v_proj: build_qlinear(tensors, &format!("{apfx}.add_v_proj"), quant)?,
to_add_out: build_qlinear(tensors, &format!("{apfx}.to_add_out"), quant)?,
norm_q: RmsNormPerHead {
weight: get_tensor(tensors, &format!("{apfx}.norm_q.weight"))?,
},
norm_k: RmsNormPerHead {
weight: get_tensor(tensors, &format!("{apfx}.norm_k.weight"))?,
},
norm_added_q: RmsNormPerHead {
weight: get_tensor(tensors, &format!("{apfx}.norm_added_q.weight"))?,
},
norm_added_k: RmsNormPerHead {
weight: get_tensor(tensors, &format!("{apfx}.norm_added_k.weight"))?,
},
num_heads: config.num_heads,
head_dim: config.head_dim,
})
}
fn forward(
&mut self,
x: &Array,
context: &Array,
rope: &FluxRope,
) -> Result<(Array, Array), mlx_rs::error::Exception> {
let x_shape = x.shape();
let ctx_shape = context.shape();
let batch = x_shape[0] as usize;
let x_seq = x_shape[1] as usize;
let ctx_seq = ctx_shape[1] as usize;
let q = self.to_q.forward(x)?;
let k = self.to_k.forward(x)?;
let v = self.to_v.forward(x)?;
dump_flux_stage_first_call("block0_attn_img_q_raw", &q);
dump_flux_stage_first_call("block0_attn_img_k_raw", &k);
let cq = self.add_q_proj.forward(context)?;
let ck = self.add_k_proj.forward(context)?;
let cv = self.add_v_proj.forward(context)?;
dump_flux_stage_first_call("block0_attn_ctx_q_raw", &cq);
dump_flux_stage_first_call("block0_attn_ctx_k_raw", &ck);
let nh = self.num_heads as i32;
let hd = self.head_dim as i32;
let reshape_head = |t: Array, seq: usize| -> Result<Array, mlx_rs::error::Exception> {
let r = ops::reshape(&t, &[batch as i32, seq as i32, nh, hd])?;
ops::transpose_axes(&r, &[0, 2, 1, 3])
};
let q = reshape_head(q, x_seq)?;
let k = reshape_head(k, x_seq)?;
let v = reshape_head(v, x_seq)?;
let cq = reshape_head(cq, ctx_seq)?;
let ck = reshape_head(ck, ctx_seq)?;
let cv = reshape_head(cv, ctx_seq)?;
let q = self.norm_q.forward(&q)?;
let k = self.norm_k.forward(&k)?;
let cq = self.norm_added_q.forward(&cq)?;
let ck = self.norm_added_k.forward(&ck)?;
let q_joint = ops::concatenate_axis(&[&cq, &q], 2)?;
let k_joint = ops::concatenate_axis(&[&ck, &k], 2)?;
let v_joint = ops::concatenate_axis(&[&cv, &v], 2)?;
let q_joint = flux_apply_rope(&q_joint, rope)?;
let k_joint = flux_apply_rope(&k_joint, rope)?;
let q = q_joint;
let k_full = k_joint;
let v_full = v_joint;
let scale = 1.0_f32 / (self.head_dim as f32).sqrt();
let out_joint =
mlx_rs::fast::scaled_dot_product_attention(&q, &k_full, &v_full, scale, None)?;
let out_ctx = out_joint.index((.., .., ..ctx_seq as i32, ..));
let out_img = out_joint.index((.., .., ctx_seq as i32.., ..));
let hidden = (self.num_heads * self.head_dim) as i32;
let out_img = ops::transpose_axes(&out_img, &[0, 2, 1, 3])?;
let out_img = ops::reshape(&out_img, &[batch as i32, x_seq as i32, hidden])?;
let out_ctx = ops::transpose_axes(&out_ctx, &[0, 2, 1, 3])?;
let out_ctx = ops::reshape(&out_ctx, &[batch as i32, ctx_seq as i32, hidden])?;
dump_flux_stage_first_call("block0_attn_ctx_pre_out", &out_ctx);
dump_flux_stage_first_call("block0_attn_img_pre_out", &out_img);
let img_out = self.to_out_0.forward(&out_img)?;
let ctx_out = self.to_add_out.forward(&out_ctx)?;
dump_flux_stage_first_call("block0_attn_ctx_post_out", &ctx_out);
dump_flux_stage_first_call("block0_attn_img_post_out", &img_out);
Ok((img_out, ctx_out))
}
}
struct FluxFfn {
linear1: QLinear,
linear2: QLinear,
activation: FluxGelu,
}
#[derive(Clone, Copy)]
enum FluxGelu {
Precise,
Approx,
}
impl FluxFfn {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
activation: FluxGelu,
) -> Result<Self, InferenceError> {
Ok(Self {
linear1: build_qlinear(tensors, &format!("{prefix}.linear1"), quant)?,
linear2: build_qlinear(tensors, &format!("{prefix}.linear2"), quant)?,
activation,
})
}
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = self.linear1.forward(x)?;
let h = match self.activation {
FluxGelu::Precise => nn::gelu(&h)?,
FluxGelu::Approx => nn::gelu_approximate(&h)?,
};
self.linear2.forward(&h)
}
}
struct DoubleTransformerBlock {
attn: JointAttention,
ff: FluxFfn,
ff_context: FluxFfn,
norm1: AdaLNModulation,
norm1_context: AdaLNModulation,
}
impl DoubleTransformerBlock {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
config: &FluxConfig,
) -> Result<Self, InferenceError> {
let quant = config.quant.as_ref();
Ok(Self {
attn: JointAttention::load(tensors, prefix, config)?,
ff: FluxFfn::load(tensors, &format!("{prefix}.ff"), quant, FluxGelu::Precise)?,
ff_context: FluxFfn::load(
tensors,
&format!("{prefix}.ff_context"),
quant,
FluxGelu::Approx,
)?,
norm1: AdaLNModulation::load(tensors, &format!("{prefix}.norm1"), quant, 6)?,
norm1_context: AdaLNModulation::load(
tensors,
&format!("{prefix}.norm1_context"),
quant,
6,
)?,
})
}
fn forward(
&mut self,
x: &Array,
context: &Array,
conditioning: &Array,
rope: &FluxRope,
) -> Result<(Array, Array), mlx_rs::error::Exception> {
let mods_x = self.norm1.forward(conditioning)?;
let (shift1_x, scale1_x, gate1_x) = (&mods_x[0], &mods_x[1], &mods_x[2]);
let (shift2_x, scale2_x, gate2_x) = (&mods_x[3], &mods_x[4], &mods_x[5]);
let mods_ctx = self.norm1_context.forward(conditioning)?;
let (shift1_c, scale1_c, gate1_c) = (&mods_ctx[0], &mods_ctx[1], &mods_ctx[2]);
let (shift2_c, scale2_c, gate2_c) = (&mods_ctx[3], &mods_ctx[4], &mods_ctx[5]);
let one = Array::from_f32(1.0);
let x_norm = flux_layer_norm_parameterless(x, 1e-6)?;
let ctx_norm = flux_layer_norm_parameterless(context, 1e-6)?;
let x_mod = ops::add(
&ops::multiply(&x_norm, &ops::add(&one, scale1_x)?)?,
shift1_x,
)?;
let ctx_mod = ops::add(
&ops::multiply(&ctx_norm, &ops::add(&Array::from_f32(1.0), scale1_c)?)?,
shift1_c,
)?;
dump_flux_stage_first_call("block0_x_mod", &x_mod);
dump_flux_stage_first_call("block0_ctx_mod", &ctx_mod);
let (attn_x, attn_ctx) = self.attn.forward(&x_mod, &ctx_mod, rope)?;
let x = ops::add(x, &ops::multiply(&attn_x, gate1_x)?)?;
let context = ops::add(context, &ops::multiply(&attn_ctx, gate1_c)?)?;
let x_norm2 = flux_layer_norm_parameterless(&x, 1e-6)?;
let ctx_norm2 = flux_layer_norm_parameterless(&context, 1e-6)?;
let x_ff_mod = ops::add(
&ops::multiply(&x_norm2, &ops::add(&Array::from_f32(1.0), scale2_x)?)?,
shift2_x,
)?;
let ctx_ff_mod = ops::add(
&ops::multiply(&ctx_norm2, &ops::add(&Array::from_f32(1.0), scale2_c)?)?,
shift2_c,
)?;
let x_ff = self.ff.forward(&x_ff_mod)?;
let ctx_ff = self.ff_context.forward(&ctx_ff_mod)?;
let x = ops::add(&x, &ops::multiply(&x_ff, gate2_x)?)?;
let context = ops::add(&context, &ops::multiply(&ctx_ff, gate2_c)?)?;
Ok((x, context))
}
}
struct SingleTransformerBlock {
attn_to_q: QLinear,
attn_to_k: QLinear,
attn_to_v: QLinear,
attn_norm_q: RmsNormPerHead,
attn_norm_k: RmsNormPerHead,
proj_mlp: QLinear,
proj_out: QLinear,
norm: AdaLNModulation,
num_heads: usize,
head_dim: usize,
}
impl SingleTransformerBlock {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
config: &FluxConfig,
) -> Result<Self, InferenceError> {
let quant = config.quant.as_ref();
Ok(Self {
attn_to_q: build_qlinear(tensors, &format!("{prefix}.attn.to_q"), quant)?,
attn_to_k: build_qlinear(tensors, &format!("{prefix}.attn.to_k"), quant)?,
attn_to_v: build_qlinear(tensors, &format!("{prefix}.attn.to_v"), quant)?,
attn_norm_q: RmsNormPerHead {
weight: get_tensor(tensors, &format!("{prefix}.attn.norm_q.weight"))?,
},
attn_norm_k: RmsNormPerHead {
weight: get_tensor(tensors, &format!("{prefix}.attn.norm_k.weight"))?,
},
proj_mlp: build_qlinear(tensors, &format!("{prefix}.proj_mlp"), quant)?,
proj_out: build_qlinear(tensors, &format!("{prefix}.proj_out"), quant)?,
norm: AdaLNModulation::load(tensors, &format!("{prefix}.norm"), quant, 3)?,
num_heads: config.num_heads,
head_dim: config.head_dim,
})
}
fn forward(
&mut self,
x: &Array,
conditioning: &Array,
rope: &FluxRope,
) -> Result<Array, mlx_rs::error::Exception> {
let mods = self.norm.forward(conditioning)?;
let (shift, scale, gate) = (&mods[0], &mods[1], &mods[2]);
let one = Array::from_f32(1.0);
let x_norm = flux_layer_norm_parameterless(x, 1e-6)?;
let x_mod = ops::add(&ops::multiply(&x_norm, &ops::add(&one, scale)?)?, shift)?;
let x_shape = x_mod.shape();
let batch = x_shape[0] as usize;
let seq_len = x_shape[1] as usize;
let nh = self.num_heads as i32;
let hd = self.head_dim as i32;
let q = self.attn_to_q.forward(&x_mod)?;
let k = self.attn_to_k.forward(&x_mod)?;
let v = self.attn_to_v.forward(&x_mod)?;
let reshape_head = |t: Array| -> Result<Array, mlx_rs::error::Exception> {
let r = ops::reshape(&t, &[batch as i32, seq_len as i32, nh, hd])?;
ops::transpose_axes(&r, &[0, 2, 1, 3])
};
let q = reshape_head(q)?;
let k = reshape_head(k)?;
let v = reshape_head(v)?;
let q = self.attn_norm_q.forward(&q)?;
let k = self.attn_norm_k.forward(&k)?;
let q = flux_apply_rope(&q, rope)?;
let k = flux_apply_rope(&k, rope)?;
let attn_scale = 1.0_f32 / (self.head_dim as f32).sqrt();
let attn_out = mlx_rs::fast::scaled_dot_product_attention(&q, &k, &v, attn_scale, None)?;
let hidden = (self.num_heads * self.head_dim) as i32;
let attn_out = ops::transpose_axes(&attn_out, &[0, 2, 1, 3])?;
let attn_out = ops::reshape(&attn_out, &[batch as i32, seq_len as i32, hidden])?;
let mlp_out = nn::gelu(&self.proj_mlp.forward(&x_mod)?)?;
let combined = ops::concatenate_axis(&[&attn_out, &mlp_out], -1)?;
let out = self.proj_out.forward(&combined)?;
ops::add(x, &ops::multiply(&out, gate)?)
}
}
struct FluxTransformer {
x_embedder: QLinear,
context_embedder: QLinear,
time_text_embed: TimeTextEmbed,
double_blocks: Vec<DoubleTransformerBlock>,
single_blocks: Vec<SingleTransformerBlock>,
norm_out: AdaLNModulation,
proj_out: QLinear,
}
impl FluxTransformer {
fn load(tensors: &HashMap<String, Array>, config: &FluxConfig) -> Result<Self, InferenceError> {
let quant = config.quant.as_ref();
let pfx = "transformer";
let x_embedder = build_qlinear(tensors, &format!("{pfx}.x_embedder"), quant)?;
let context_embedder = build_qlinear(tensors, &format!("{pfx}.context_embedder"), quant)?;
let time_text_embed =
TimeTextEmbed::load(tensors, &format!("{pfx}.time_text_embed"), quant)?;
let mut double_blocks = Vec::with_capacity(config.num_double_blocks);
for i in 0..config.num_double_blocks {
double_blocks.push(DoubleTransformerBlock::load(
tensors,
&format!("{pfx}.transformer_blocks.{i}"),
config,
)?);
}
let mut single_blocks = Vec::with_capacity(config.num_single_blocks);
for i in 0..config.num_single_blocks {
single_blocks.push(SingleTransformerBlock::load(
tensors,
&format!("{pfx}.single_transformer_blocks.{i}"),
config,
)?);
}
let norm_out = AdaLNModulation::load(tensors, &format!("{pfx}.norm_out"), quant, 2)?;
let proj_out = build_qlinear(tensors, &format!("{pfx}.proj_out"), quant)?;
Ok(Self {
x_embedder,
context_embedder,
time_text_embed,
double_blocks,
single_blocks,
norm_out,
proj_out,
})
}
fn forward(
&mut self,
latents: &Array,
t5_hidden: &Array,
clip_pooled: &Array,
timestep: &Array,
guidance: &Array,
rope: &FluxRope,
) -> Result<Array, mlx_rs::error::Exception> {
let mut x = self.x_embedder.forward(latents)?;
let mut context = self.context_embedder.forward(t5_hidden)?;
let cond = self
.time_text_embed
.forward(timestep, clip_pooled, guidance)?;
dump_flux_stage("x_embed", &x);
dump_flux_stage("context_embed", &context);
dump_flux_stage("text_emb", &cond);
let profile = std::env::var("CAR_FLUX_PROFILE").is_ok();
for (index, block) in self.double_blocks.iter_mut().enumerate() {
let t0 = std::time::Instant::now();
let (x_new, ctx_new) = block.forward(&x, &context, &cond, rope)?;
x = x_new;
context = ctx_new;
if profile {
mlx_rs::transforms::eval([&x, &context])?;
tracing::info!(
block = index,
elapsed_ms = t0.elapsed().as_millis() as u64,
"flux double block timing"
);
}
if index == 0 {
dump_flux_stage("block0_hidden", &x);
dump_flux_stage("block0_ctx", &context);
}
if index == 2 || index == 4 || index == 7 {
dump_flux_stage(&format!("block{index:02}_hidden"), &x);
dump_flux_stage(&format!("block{index:02}_ctx"), &context);
}
}
let mut h = ops::concatenate_axis(&[&context, &x], 1)?;
for (index, block) in self.single_blocks.iter_mut().enumerate() {
let t0 = std::time::Instant::now();
h = block.forward(&h, &cond, rope)?;
if profile {
mlx_rs::transforms::eval([&h])?;
tracing::info!(
block = index,
elapsed_ms = t0.elapsed().as_millis() as u64,
"flux single block timing"
);
}
if index == 0 {
dump_flux_stage("single0_hidden", &h);
}
if index == 9 || index == 18 || index == 27 || index == 37 {
dump_flux_stage(&format!("single{index:02}_hidden"), &h);
}
}
let context_len = t5_hidden.shape()[1];
let h = h.index((.., context_len.., ..));
let mods = self.norm_out.forward(&cond)?;
let (scale, shift) = (&mods[0], &mods[1]);
let one = Array::from_f32(1.0);
let h_norm = flux_layer_norm_parameterless(&h, 1e-6)?;
let scale_b = ops::expand_dims(scale, 1)?;
let shift_b = ops::expand_dims(shift, 1)?;
let h = ops::add(
&ops::multiply(&h_norm, &ops::add(&one, &scale_b)?)?,
&shift_b,
)?;
self.proj_out.forward(&h)
}
}
fn conv2d_forward(
input: &Array,
weight: &Array,
bias: Option<&Array>,
stride: (i32, i32),
padding: (i32, i32),
) -> Result<Array, mlx_rs::error::Exception> {
let mut y = ops::conv2d(
input,
weight,
stride,
padding,
None::<(i32, i32)>,
None::<i32>,
)?;
if let Some(b) = bias {
y = ops::add(&y, b)?;
}
Ok(y)
}
fn upsample_2x(x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape();
let (b, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
let expanded_h = ops::reshape(x, &[b, h, 1, w, c])?;
let tiled_h = ops::concatenate_axis(&[&expanded_h, &expanded_h], 2)?;
let merged_h = ops::reshape(&tiled_h, &[b, h * 2, w, c])?;
let expanded_w = ops::reshape(&merged_h, &[b, h * 2, w, 1, c])?;
let tiled_w = ops::concatenate_axis(&[&expanded_w, &expanded_w], 3)?;
ops::reshape(&tiled_w, &[b, h * 2, w * 2, c])
}
struct VaeResnetBlock {
norm1: GroupNorm,
conv1_weight: Array,
conv1_bias: Option<Array>,
norm2: GroupNorm,
conv2_weight: Array,
conv2_bias: Option<Array>,
skip_weight: Option<Array>,
skip_bias: Option<Array>,
}
impl VaeResnetBlock {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
num_groups: usize,
) -> Result<Self, InferenceError> {
let norm1 = build_group_norm(tensors, &format!("{prefix}.norm1"), num_groups, 1e-6)?;
let conv1_weight = get_tensor(tensors, &format!("{prefix}.conv1.weight"))?;
let conv1_bias = tensors.get(&format!("{prefix}.conv1.bias")).cloned();
let norm2 = build_group_norm(tensors, &format!("{prefix}.norm2"), num_groups, 1e-6)?;
let conv2_weight = get_tensor(tensors, &format!("{prefix}.conv2.weight"))?;
let conv2_bias = tensors.get(&format!("{prefix}.conv2.bias")).cloned();
let skip_weight = tensors
.get(&format!("{prefix}.conv_shortcut.weight"))
.cloned();
let skip_bias = tensors
.get(&format!("{prefix}.conv_shortcut.bias"))
.cloned();
Ok(Self {
norm1,
conv1_weight,
conv1_bias,
norm2,
conv2_weight,
conv2_bias,
skip_weight,
skip_bias,
})
}
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = self.norm1.forward(x)?;
let h = nn::silu(&h)?;
let h = conv2d_forward(
&h,
&self.conv1_weight,
self.conv1_bias.as_ref(),
(1, 1),
(1, 1),
)?;
let h = self.norm2.forward(&h)?;
let h = nn::silu(&h)?;
let h = conv2d_forward(
&h,
&self.conv2_weight,
self.conv2_bias.as_ref(),
(1, 1),
(1, 1),
)?;
let skip = if let Some(ref sw) = self.skip_weight {
conv2d_forward(x, sw, self.skip_bias.as_ref(), (1, 1), (0, 0))?
} else {
x.clone()
};
ops::add(&skip, &h)
}
}
struct VaeMidAttention {
norm: GroupNorm,
q_proj: QLinear,
k_proj: QLinear,
v_proj: QLinear,
out_proj: QLinear,
}
impl VaeMidAttention {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
num_groups: usize,
) -> Result<Self, InferenceError> {
let norm = build_group_norm(tensors, &format!("{prefix}.group_norm"), num_groups, 1e-6)?;
let qc = QuantConfig {
group_size: 64,
bits: 4,
};
let q_proj = build_qlinear(tensors, &format!("{prefix}.to_q"), Some(&qc))?;
let k_proj = build_qlinear(tensors, &format!("{prefix}.to_k"), Some(&qc))?;
let v_proj = build_qlinear(tensors, &format!("{prefix}.to_v"), Some(&qc))?;
let out_proj = build_qlinear(tensors, &format!("{prefix}.to_out.0"), Some(&qc))?;
Ok(Self {
norm,
q_proj,
k_proj,
v_proj,
out_proj,
})
}
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape();
let (b, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
let seq_len = h * w;
let normed = self.norm.forward(x)?;
let flat = ops::reshape(&normed, &[b, seq_len, c])?;
let q = self.q_proj.forward(&flat)?;
let k = self.k_proj.forward(&flat)?;
let v = self.v_proj.forward(&flat)?;
let scale = Array::from_f32(1.0 / (c as f32).sqrt());
let scores = ops::multiply(
&ops::matmul(&q, &ops::transpose_axes(&k, &[0, 2, 1])?)?,
&scale,
)?;
let attn = ops::softmax_axis(&scores, -1, None)?;
let out = ops::matmul(&attn, &v)?;
let out = self.out_proj.forward(&out)?;
let out = ops::reshape(&out, &[b, h, w, c])?;
ops::add(x, &out)
}
}
struct VaeUpBlock {
resnets: Vec<VaeResnetBlock>,
upsample_conv_weight: Option<Array>,
upsample_conv_bias: Option<Array>,
}
impl VaeUpBlock {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let mut h = x.clone();
for resnet in &self.resnets {
h = resnet.forward(&h)?;
}
if let Some(ref w) = self.upsample_conv_weight {
h = upsample_2x(&h)?;
h = conv2d_forward(&h, w, self.upsample_conv_bias.as_ref(), (1, 1), (1, 1))?;
}
Ok(h)
}
}
struct VaeDecoder {
conv_in_weight: Array,
conv_in_bias: Option<Array>,
mid_resnet0: VaeResnetBlock,
mid_attn: VaeMidAttention,
mid_resnet1: VaeResnetBlock,
up_blocks: Vec<VaeUpBlock>,
norm_out: GroupNorm,
conv_out_weight: Array,
conv_out_bias: Option<Array>,
}
impl VaeDecoder {
fn load(
tensors: &HashMap<String, Array>,
_config: &FluxConfig,
) -> Result<Self, InferenceError> {
let pfx = "vae.decoder";
let conv_in_weight = get_tensor(tensors, &format!("{pfx}.conv_in.weight"))?;
let conv_in_bias = tensors.get(&format!("{pfx}.conv_in.bias")).cloned();
let conv_out_weight = get_tensor(tensors, &format!("{pfx}.conv_out.weight"))?;
let conv_out_bias = tensors.get(&format!("{pfx}.conv_out.bias")).cloned();
let norm_out = build_group_norm(tensors, &format!("{pfx}.conv_norm_out"), 32, 1e-6)?;
let mid_resnet0 = VaeResnetBlock::load(tensors, &format!("{pfx}.mid_block.resnets.0"), 32)?;
let mid_attn =
VaeMidAttention::load(tensors, &format!("{pfx}.mid_block.attentions.0"), 32)?;
let mid_resnet1 = VaeResnetBlock::load(tensors, &format!("{pfx}.mid_block.resnets.1"), 32)?;
let mut num_up_blocks = 0usize;
for key in tensors.keys() {
if let Some(rest) = key.strip_prefix(&format!("{pfx}.up_blocks.")) {
if let Some(idx_str) = rest.split('.').next() {
if let Ok(idx) = idx_str.parse::<usize>() {
num_up_blocks = num_up_blocks.max(idx + 1);
}
}
}
}
let mut up_blocks = Vec::with_capacity(num_up_blocks);
for i in 0..num_up_blocks {
let bpfx = format!("{pfx}.up_blocks.{i}");
let mut num_resnets = 0usize;
for key in tensors.keys() {
if let Some(rest) = key.strip_prefix(&format!("{bpfx}.resnets.")) {
if let Some(idx_str) = rest.split('.').next() {
if let Ok(idx) = idx_str.parse::<usize>() {
num_resnets = num_resnets.max(idx + 1);
}
}
}
}
let mut resnets = Vec::with_capacity(num_resnets);
for r in 0..num_resnets {
resnets.push(VaeResnetBlock::load(
tensors,
&format!("{bpfx}.resnets.{r}"),
32,
)?);
}
let upsample_conv_weight = tensors
.get(&format!("{bpfx}.upsamplers.0.conv.weight"))
.cloned();
let upsample_conv_bias = tensors
.get(&format!("{bpfx}.upsamplers.0.conv.bias"))
.cloned();
up_blocks.push(VaeUpBlock {
resnets,
upsample_conv_weight,
upsample_conv_bias,
});
}
Ok(Self {
conv_in_weight,
conv_in_bias,
mid_resnet0,
mid_attn,
mid_resnet1,
up_blocks,
norm_out,
conv_out_weight,
conv_out_bias,
})
}
fn decode(&mut self, latents: &Array) -> Result<Array, mlx_rs::error::Exception> {
let x = ops::transpose_axes(latents, &[0, 2, 3, 1])?;
let mut h = conv2d_forward(
&x,
&self.conv_in_weight,
self.conv_in_bias.as_ref(),
(1, 1),
(1, 1),
)?;
h = self.mid_resnet0.forward(&h)?;
h = self.mid_attn.forward(&h)?;
h = self.mid_resnet1.forward(&h)?;
for block in &self.up_blocks {
h = block.forward(&h)?;
}
h = self.norm_out.forward(&h)?;
h = nn::silu(&h)?;
h = conv2d_forward(
&h,
&self.conv_out_weight,
self.conv_out_bias.as_ref(),
(1, 1),
(1, 1),
)?;
let half = Array::from_f32(0.5);
let h = ops::add(&ops::multiply(&h, &half)?, &half)?;
let zero = Array::from_f32(0.0);
let one = Array::from_f32(1.0);
ops::clip(&h, (&zero, &one))
}
}
pub struct EulerDiscreteScheduler {
pub num_inference_steps: usize,
pub sigmas: Vec<f32>,
}
impl EulerDiscreteScheduler {
pub fn new(num_inference_steps: usize) -> Self {
const NUM_TRAIN_TIMESTEPS: f32 = 1000.0;
const SHIFT_TERMINAL: f32 = 0.02;
const MU: f32 = 1.0;
let n = num_inference_steps;
let sigma_min = 1.0 / NUM_TRAIN_TIMESTEPS;
let sigma_max = 1.0_f32;
let sigmas_linear: Vec<f32> = (0..n)
.map(|i| {
if n == 1 {
sigma_max
} else {
sigma_max - i as f32 * (sigma_max - sigma_min) / (n - 1) as f32
}
})
.collect();
let exp_mu = MU.exp();
let sigmas_shifted: Vec<f32> = sigmas_linear
.iter()
.map(|&s| exp_mu / (exp_mu + (1.0 / s - 1.0)))
.collect();
let one_minus_last = 1.0 - sigmas_shifted[sigmas_shifted.len() - 1];
let scale_factor = if one_minus_last > 0.0 {
one_minus_last / (1.0 - SHIFT_TERMINAL)
} else {
1.0
};
let mut sigmas: Vec<f32> = sigmas_shifted
.iter()
.map(|&s| 1.0 - (1.0 - s) / scale_factor)
.collect();
sigmas.push(0.0);
Self {
num_inference_steps,
sigmas,
}
}
pub fn step(
&self,
model_output: &Array,
step_index: usize,
sample: &Array,
) -> Result<Array, mlx_rs::error::Exception> {
let sigma = self.sigmas[step_index];
let sigma_next = self.sigmas[step_index + 1];
let dt = Array::from_f32(sigma_next - sigma);
ops::add(sample, &ops::multiply(model_output, &dt)?)
}
pub fn init_noise(&self, shape: &[i32], seed: u64) -> Result<Array, mlx_rs::error::Exception> {
let key = mlx_rs::random::key(seed)?;
mlx_rs::random::normal::<f32>(shape, None, None, Some(&key))
}
}
fn timestep_embedding(timestep: f32, dim: usize) -> Result<Array, mlx_rs::error::Exception> {
let half = dim / 2;
let ln_max_period = 10_000_f32.ln();
let mut emb = vec![0.0f32; dim];
for i in 0..half {
let freq = (-(i as f32) / half as f32 * ln_max_period).exp();
emb[i] = (timestep * freq).cos();
emb[i + half] = (timestep * freq).sin();
}
Ok(Array::from_slice(&emb, &[1, dim as i32]))
}
fn patchify(latents: &Array) -> Result<Array, mlx_rs::error::Exception> {
let shape = latents.shape();
let (_b, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
let ph = h / 2;
let pw = w / 2;
let reshaped = ops::reshape(latents, &[1, c, ph, 2, pw, 2])?;
let permuted = ops::transpose_axes(&reshaped, &[0, 2, 4, 1, 3, 5])?;
ops::reshape(&permuted, &[1, ph * pw, c * 4])
}
fn unpatchify(
patches: &Array,
channels: i32,
h_patches: i32,
w_patches: i32,
) -> Result<Array, mlx_rs::error::Exception> {
let reshaped = ops::reshape(patches, &[1, h_patches, w_patches, channels, 2, 2])?;
let permuted = ops::transpose_axes(&reshaped, &[0, 3, 1, 4, 2, 5])?;
ops::reshape(&permuted, &[1, channels, h_patches * 2, w_patches * 2])
}
pub struct FluxBackend {
clip: ClipTextEncoder,
t5: T5TextEncoder,
transformer: FluxTransformer,
vae: VaeDecoder,
config: FluxConfig,
clip_tokenizer: tokenizers::Tokenizer,
t5_tokenizer: tokenizers::Tokenizer,
}
unsafe impl Send for FluxBackend {}
unsafe impl Sync for FluxBackend {}
impl FluxBackend {
pub fn load(model_dir: &Path) -> Result<Self, InferenceError> {
let config = FluxConfig::default();
info!(
hidden = config.hidden_dim,
double_blocks = config.num_double_blocks,
single_blocks = config.num_single_blocks,
"loading Flux model via MLX"
);
mlx_rs::Device::set_default(&mlx_rs::Device::cpu());
info!("loading safetensors weights for Flux");
let tensors = load_all_tensors(model_dir)?;
info!(tensors = tensors.len(), "Flux tensors loaded");
#[cfg(feature = "mlx-metal")]
let default_device = mlx_rs::Device::gpu();
#[cfg(not(feature = "mlx-metal"))]
let default_device = mlx_rs::Device::cpu();
match std::env::var("CAR_MLX_DEVICE").ok().as_deref() {
Some("cpu") => mlx_rs::Device::set_default(&mlx_rs::Device::cpu()),
#[cfg(feature = "mlx-metal")]
Some("gpu") => mlx_rs::Device::set_default(&mlx_rs::Device::gpu()),
_ => mlx_rs::Device::set_default(&default_device),
}
let clip = ClipTextEncoder::load(&tensors, &config)?;
info!("CLIP text encoder loaded");
let t5 = T5TextEncoder::load(&tensors, &config)?;
info!("T5 text encoder loaded");
let transformer = FluxTransformer::load(&tensors, &config)?;
info!("Flux transformer loaded");
let vae = VaeDecoder::load(&tensors, &config)?;
info!("VAE decoder loaded");
let clip_tok_dir = model_dir.join("tokenizer");
let clip_tok_path = clip_tok_dir.join("tokenizer.json");
if !clip_tok_path.exists() {
let vocab_path = clip_tok_dir.join("vocab.json");
let merges_path = clip_tok_dir.join("merges.txt");
if !vocab_path.exists() || !merges_path.exists() {
return Err(InferenceError::InferenceFailed(format!(
"CLIP tokenizer missing — neither {} nor (vocab.json + merges.txt) present",
clip_tok_path.display()
)));
}
info!(
dir = %clip_tok_dir.display(),
"CLIP tokenizer.json missing — synthesizing from vocab.json + merges.txt"
);
build_clip_tokenizer_json(&vocab_path, &merges_path, &clip_tok_path).map_err(|e| {
InferenceError::InferenceFailed(format!("CLIP tokenizer bootstrap failed: {e}"))
})?;
if !clip_tok_path.exists() {
return Err(InferenceError::InferenceFailed(format!(
"CLIP tokenizer bootstrap ran but {} was not produced",
clip_tok_path.display()
)));
}
}
let clip_tokenizer = tokenizers::Tokenizer::from_file(&clip_tok_path)
.map_err(|e| InferenceError::InferenceFailed(format!("CLIP tokenizer: {e}")))?;
info!("CLIP tokenizer loaded");
let t5_tok_path = model_dir.join("tokenizer_2/tokenizer.json");
let t5_tokenizer = if t5_tok_path.exists() {
tokenizers::Tokenizer::from_file(&t5_tok_path)
.map_err(|e| InferenceError::InferenceFailed(format!("T5 tokenizer: {e}")))?
} else {
let sp_path = model_dir.join("tokenizer_2/spiece.model");
if sp_path.exists() {
return Err(InferenceError::InferenceFailed(
"T5 tokenizer: only tokenizer.json format is supported; spiece.model requires conversion to tokenizer.json first".into(),
));
} else {
return Err(InferenceError::InferenceFailed(
"T5 tokenizer not found: expected tokenizer_2/tokenizer.json".into(),
));
}
};
info!("T5 tokenizer loaded");
info!("Flux model loaded successfully");
Ok(Self {
clip,
t5,
transformer,
vae,
config,
clip_tokenizer,
t5_tokenizer,
})
}
pub fn generate(
&mut self,
req: &GenerateImageRequest,
) -> Result<GenerateImageResult, InferenceError> {
let width = req.width.unwrap_or(512);
let height = req.height.unwrap_or(512);
let steps = req.steps.unwrap_or(20) as usize;
let guidance_scale = req.guidance.unwrap_or(3.5);
let seed = req.seed.unwrap_or(42);
let map_err = |e: mlx_rs::error::Exception| InferenceError::InferenceFailed(e.to_string());
info!(
prompt = %req.prompt,
width,
height,
steps,
guidance = guidance_scale,
seed,
"generating image with Flux"
);
let clip_encoding = self
.clip_tokenizer
.encode(req.prompt.as_str(), true)
.map_err(|e| InferenceError::InferenceFailed(format!("CLIP tokenize: {e}")))?;
let mut clip_ids: Vec<i32> = clip_encoding
.get_ids()
.iter()
.map(|&id| id as i32)
.collect();
const CLIP_EOS: i32 = 49407;
clip_ids.truncate(77);
while clip_ids.len() < 77 {
clip_ids.push(CLIP_EOS);
}
let clip_tokens = Array::from_slice(&clip_ids, &[1, 77]);
let t5_encoding = self
.t5_tokenizer
.encode(req.prompt.as_str(), true)
.map_err(|e| InferenceError::InferenceFailed(format!("T5 tokenize: {e}")))?;
let mut t5_ids: Vec<i32> = t5_encoding.get_ids().iter().map(|&id| id as i32).collect();
let t5_max_len: usize = 512;
t5_ids.truncate(t5_max_len);
while t5_ids.len() < t5_max_len {
t5_ids.push(0); }
let t5_seq_len = t5_ids.len() as i32;
let t5_tokens = Array::from_slice(&t5_ids, &[1, t5_seq_len]);
info!(
clip_token_count = clip_ids.len(),
t5_token_count = t5_ids.len(),
"tokenization complete"
);
if std::env::var("CAR_DUMP_FLUX_STAGE").is_ok() {
let clip_head: Vec<_> = clip_ids.iter().take(15).copied().collect();
let t5_head: Vec<_> = t5_ids.iter().take(15).copied().collect();
tracing::warn!(?clip_head, ?t5_head, "CLIP/T5 first-15 token IDs");
}
let perf_start = std::time::Instant::now();
info!("starting CLIP text encode");
let clip_out = self.clip.forward(&clip_tokens).map_err(map_err)?;
let eos_pos = {
let max_val = clip_ids.iter().copied().max().unwrap_or(0);
let pos = clip_ids.iter().position(|&v| v == max_val).unwrap_or(0);
pos as i32
};
let clip_pooled = clip_out.index((.., eos_pos, ..));
let clip_pooled =
ops::reshape(&clip_pooled, &[1, self.config.clip_hidden as i32]).map_err(map_err)?;
mlx_rs::transforms::eval([&clip_pooled]).map_err(map_err)?;
info!(
elapsed_ms = perf_start.elapsed().as_millis() as u64,
"CLIP text encode complete"
);
let t_start = std::time::Instant::now();
info!("starting T5 text encode");
let t5_hidden = self.t5.forward(&t5_tokens).map_err(map_err)?;
mlx_rs::transforms::eval([&t5_hidden]).map_err(map_err)?;
info!(
elapsed_ms = t_start.elapsed().as_millis() as u64,
"T5 text encode complete"
);
dump_flux_stage("clip_pooled", &clip_pooled);
dump_flux_stage("t5_hidden", &t5_hidden);
let latent_h = (height / 8) as i32;
let latent_w = (width / 8) as i32;
let latent_channels = self.config.vae_latent_channels as i32;
let scheduler = EulerDiscreteScheduler::new(steps);
let h_patches = latent_h / 2;
let w_patches = latent_w / 2;
let packed_dim = (latent_channels * 4) as i32; let mut latents = scheduler
.init_noise(&[1, h_patches * w_patches, packed_dim], seed)
.map_err(map_err)?;
info!(latents_shape = ?latents.shape(), "noise latents initialized (packed)");
let text_seq_len = t5_hidden.shape()[1] as usize;
let rope = flux_rope_build(text_seq_len, h_patches, w_patches).map_err(map_err)?;
info!(
text_seq_len,
h_patches,
w_patches,
cos_shape = ?rope.cos.shape(),
"built Flux RoPE"
);
dump_flux_stage("rope_cos", &rope.cos);
dump_flux_stage("rope_sin", &rope.sin);
dump_flux_stage("latents_init", &latents);
for step in 0..steps {
let step_start = std::time::Instant::now();
let sigma = scheduler.sigmas[step];
info!(
step = step + 1,
total_steps = steps,
sigma,
"starting denoising step"
);
let t_emb = timestep_embedding(sigma * 1000.0, 256).map_err(map_err)?;
let g_emb = timestep_embedding(guidance_scale * 1000.0, 256).map_err(map_err)?;
if step == 0 {
dump_flux_stage("patches_step0", &latents);
dump_flux_stage("t_emb_step0", &t_emb);
dump_flux_stage("g_emb_step0", &g_emb);
}
let noise_pred = self
.transformer
.forward(&latents, &t5_hidden, &clip_pooled, &t_emb, &g_emb, &rope)
.map_err(map_err)?;
if step == 0 {
dump_flux_stage("noise_pred_step0", &noise_pred);
}
latents = scheduler
.step(&noise_pred, step, &latents)
.map_err(map_err)?;
mlx_rs::transforms::eval([&latents]).map_err(map_err)?;
info!(
elapsed_ms = step_start.elapsed().as_millis() as u64,
step = step + 1,
"denoising step complete"
);
}
let vae_start = std::time::Instant::now();
let unpacked =
unpatchify(&latents, latent_channels, h_patches, w_patches).map_err(map_err)?;
let pixels = self.vae.decode(&unpacked).map_err(map_err)?;
mlx_rs::transforms::eval([&pixels]).map_err(map_err)?;
info!(
elapsed_ms = vae_start.elapsed().as_millis() as u64,
"VAE decode complete"
);
let output_path = req
.output_path
.clone()
.unwrap_or_else(|| "output.png".to_string());
let pix_shape = pixels.shape();
let img_h = pix_shape[1] as u32;
let img_w = pix_shape[2] as u32;
let scale_255 = Array::from_f32(255.0);
info!("converting decoded pixels to u8");
let pixels_u8 = ops::multiply(&pixels, &scale_255).map_err(map_err)?;
mlx_rs::transforms::eval([&pixels_u8]).map_err(map_err)?;
let pixel_data: Vec<f32> = pixels_u8.as_slice::<f32>().to_vec();
info!(pixel_count = pixel_data.len(), "pixel buffer materialized");
let mut img_buf = image::RgbImage::new(img_w, img_h);
for y in 0..img_h {
for x_px in 0..img_w {
let base = ((y * img_w + x_px) * 3) as usize;
let r = pixel_data
.get(base)
.copied()
.unwrap_or(0.0)
.clamp(0.0, 255.0) as u8;
let g = pixel_data
.get(base + 1)
.copied()
.unwrap_or(0.0)
.clamp(0.0, 255.0) as u8;
let b_val = pixel_data
.get(base + 2)
.copied()
.unwrap_or(0.0)
.clamp(0.0, 255.0) as u8;
img_buf.put_pixel(x_px, y, image::Rgb([r, g, b_val]));
}
}
info!(path = %output_path, width = img_w, height = img_h, "saving PNG");
img_buf
.save(&output_path)
.map_err(|e| InferenceError::InferenceFailed(format!("save PNG: {e}")))?;
info!(path = %output_path, "image saved");
Ok(GenerateImageResult {
image_path: output_path,
media_type: "image/png".to_string(),
model_used: Some("mlx-community/Flux-1.lite-8B-MLX-Q4".to_string()),
})
}
}