use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::SystemTime;
use mlx_rs::module::{Module, Param};
use mlx_rs::nn;
use mlx_rs::ops;
use mlx_rs::ops::indexing::IndexOp;
use mlx_rs::Array;
use tracing::{info, warn};
use super::image_io::load_rgb_image;
use super::mlx::{build_qlinear, QLinear, QuantConfig};
use crate::tasks::generate_video::{GenerateVideoRequest, GenerateVideoResult, VideoMode};
use crate::InferenceError;
#[derive(Debug, Clone)]
pub struct LtxConfig {
pub hidden_dim: usize,
pub num_heads: usize,
pub head_dim: usize,
pub num_layers: usize,
pub cross_attention_dim: usize,
pub in_channels: usize,
pub out_channels: usize,
pub norm_eps: f32,
pub quant: Option<QuantConfig>,
pub audio_hidden_dim: usize,
pub audio_heads: usize,
pub audio_head_dim: usize,
}
impl Default for LtxConfig {
fn default() -> Self {
Self {
hidden_dim: 4096,
num_heads: 32,
head_dim: 128,
num_layers: 48,
cross_attention_dim: 4096,
in_channels: 128,
out_channels: 128,
norm_eps: 1e-6,
quant: Some(QuantConfig {
group_size: 64,
bits: 4,
}),
audio_hidden_dim: 2048,
audio_heads: 32,
audio_head_dim: 64,
}
}
}
fn dump_ltx_stage(name: &str, t: &Array) {
let Ok(dir) = std::env::var("CAR_DUMP_LTX_STAGE") else {
return;
};
let _ = std::fs::create_dir_all(&dir);
let Ok(t_f32) = t.as_dtype(mlx_rs::Dtype::Float32) else {
return;
};
let _ = mlx_rs::transforms::eval([&t_f32]);
let shape = t_f32.shape().to_vec();
let data: &[f32] = t_f32.as_slice();
let bin_path = format!("{dir}/{name}.bin");
let meta_path = format!("{dir}/{name}.meta");
let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
let _ = std::fs::write(&bin_path, &bytes);
let _ = std::fs::write(&meta_path, format!("{shape:?}\n"));
}
fn dump_ltx_stage_first_call(name: &str, t: &Array) {
use std::sync::Mutex;
static SEEN: Mutex<Option<std::collections::HashSet<String>>> = Mutex::new(None);
let mut g = SEEN.lock().unwrap();
let set = g.get_or_insert_with(std::collections::HashSet::new);
if !set.insert(name.to_string()) {
return;
}
drop(g);
dump_ltx_stage(name, t);
}
fn get_tensor(tensors: &HashMap<String, Array>, key: &str) -> Result<Array, InferenceError> {
tensors
.get(key)
.cloned()
.ok_or_else(|| InferenceError::InferenceFailed(format!("missing tensor: {key}")))
}
fn build_dense_linear(
tensors: &HashMap<String, Array>,
prefix: &str,
) -> Result<nn::Linear, InferenceError> {
let weight = get_tensor(tensors, &format!("{prefix}.weight"))?;
let bias = tensors.get(&format!("{prefix}.bias")).cloned();
Ok(nn::Linear {
weight: Param::new(weight),
bias: Param::new(bias),
})
}
fn load_ltx_tensors(model_dir: &Path) -> Result<HashMap<String, Array>, InferenceError> {
let required_files = [
"transformer-distilled.safetensors",
"connector.safetensors",
"vae_decoder.safetensors",
"audio_vae.safetensors",
"vocoder.safetensors",
];
let optional_files = ["vae_encoder.safetensors"];
let mut all_tensors = HashMap::new();
for filename in &required_files {
let path = model_dir.join(filename);
if !path.exists() {
return Err(InferenceError::InferenceFailed(format!(
"missing weight file: {}",
path.display()
)));
}
let tensors = Array::load_safetensors(&path)
.map_err(|e| InferenceError::InferenceFailed(format!("load {filename}: {e}")))?;
for (name, array) in tensors {
all_tensors.insert(name, array);
}
}
for filename in &optional_files {
let path = model_dir.join(filename);
if !path.exists() {
continue;
}
let tensors = Array::load_safetensors(&path)
.map_err(|e| InferenceError::InferenceFailed(format!("load {filename}: {e}")))?;
for (name, array) in tensors {
all_tensors.insert(name, array);
}
}
Ok(all_tensors)
}
struct RmsNorm {
weight: Array,
eps: f32,
}
impl RmsNorm {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let x_sq = ops::multiply(x, x)?;
let mean = x_sq.mean_axes(&[-1], true)?;
let eps = Array::from_f32(self.eps);
let norm = ops::rsqrt(&ops::add(&mean, &eps)?)?;
let normed = ops::multiply(x, &norm)?;
ops::multiply(&normed, &self.weight)
}
}
fn build_rms_norm(
tensors: &HashMap<String, Array>,
prefix: &str,
eps: f32,
) -> Result<RmsNorm, InferenceError> {
let weight = get_tensor(tensors, &format!("{prefix}.weight"))?;
Ok(RmsNorm { weight, eps })
}
struct GatedSelfAttention {
to_q: QLinear,
to_k: QLinear,
to_v: QLinear,
to_out: QLinear,
norm_q: RmsNorm,
norm_k: RmsNorm,
to_gate_logits: QLinear,
num_heads: usize,
head_dim: usize,
}
impl GatedSelfAttention {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
num_heads: usize,
head_dim: usize,
eps: f32,
) -> Result<Self, InferenceError> {
Ok(Self {
to_q: build_qlinear(tensors, &format!("{prefix}.to_q"), quant)?,
to_k: build_qlinear(tensors, &format!("{prefix}.to_k"), quant)?,
to_v: build_qlinear(tensors, &format!("{prefix}.to_v"), quant)?,
to_out: build_qlinear(tensors, &format!("{prefix}.to_out"), quant)?,
norm_q: build_rms_norm(tensors, &format!("{prefix}.q_norm"), eps)?,
norm_k: build_rms_norm(tensors, &format!("{prefix}.k_norm"), eps)?,
to_gate_logits: build_qlinear(tensors, &format!("{prefix}.to_gate_logits"), quant)?,
num_heads,
head_dim,
})
}
fn forward(
&mut self,
x: &Array,
rope_freqs: Option<(&Array, &Array)>,
) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape();
let (batch, seq_len) = (shape[0] as usize, shape[1] as usize);
let nh = self.num_heads as i32;
let hd = self.head_dim as i32;
let hidden = (self.num_heads * self.head_dim) as i32;
let q = self.to_q.forward(x)?;
let k = self.to_k.forward(x)?;
let v = self.to_v.forward(x)?;
let q = self.norm_q.forward(&q)?;
let k = self.norm_k.forward(&k)?;
let reshape_head = |t: Array| -> Result<Array, mlx_rs::error::Exception> {
let r = ops::reshape(&t, &[batch as i32, seq_len as i32, nh, hd])?;
ops::transpose_axes(&r, &[0, 2, 1, 3])
};
let mut q = reshape_head(q)?;
let mut k = reshape_head(k)?;
let v = reshape_head(v)?;
if let Some((cos_f, sin_f)) = rope_freqs {
q = apply_split_rope(&q, cos_f, sin_f)?;
k = apply_split_rope(&k, cos_f, sin_f)?;
}
let scale = 1.0_f32 / (self.head_dim as f32).sqrt();
let attn_out = mlx_rs::fast::scaled_dot_product_attention(&q, &k, &v, scale, None)?;
let attn_out = ops::transpose_axes(&attn_out, &[0, 2, 1, 3])?;
let gate_logits = self.to_gate_logits.forward(x)?;
let two = Array::from_f32(2.0);
let gate = ops::multiply(&ops::sigmoid(&gate_logits)?, &two)?;
let gate = ops::reshape(&gate, &[batch as i32, seq_len as i32, nh, 1])?;
let gated = ops::multiply(&attn_out, &gate)?;
let gated = ops::reshape(&gated, &[batch as i32, seq_len as i32, hidden])?;
self.to_out.forward(&gated)
}
}
struct GatedCrossAttention {
to_q: QLinear,
to_k: QLinear,
to_v: QLinear,
to_out: QLinear,
norm_q: RmsNorm,
norm_k: RmsNorm,
to_gate_logits: QLinear,
num_heads: usize,
head_dim: usize,
}
impl GatedCrossAttention {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
num_heads: usize,
head_dim: usize,
eps: f32,
) -> Result<Self, InferenceError> {
Ok(Self {
to_q: build_qlinear(tensors, &format!("{prefix}.to_q"), quant)?,
to_k: build_qlinear(tensors, &format!("{prefix}.to_k"), quant)?,
to_v: build_qlinear(tensors, &format!("{prefix}.to_v"), quant)?,
to_out: build_qlinear(tensors, &format!("{prefix}.to_out"), quant)?,
norm_q: build_rms_norm(tensors, &format!("{prefix}.q_norm"), eps)?,
norm_k: build_rms_norm(tensors, &format!("{prefix}.k_norm"), eps)?,
to_gate_logits: build_qlinear(tensors, &format!("{prefix}.to_gate_logits"), quant)?,
num_heads,
head_dim,
})
}
fn forward(
&mut self,
x: &Array,
context: &Array,
rope_q: Option<(&Array, &Array)>,
rope_k: Option<(&Array, &Array)>,
) -> Result<Array, mlx_rs::error::Exception> {
let x_shape = x.shape();
let (batch, x_seq) = (x_shape[0] as usize, x_shape[1] as usize);
let ctx_seq = context.shape()[1] as usize;
let nh = self.num_heads as i32;
let hd = self.head_dim as i32;
let hidden = (self.num_heads * self.head_dim) as i32;
let q = self.to_q.forward(x)?;
let k = self.to_k.forward(context)?;
let v = self.to_v.forward(context)?;
let q = self.norm_q.forward(&q)?;
let k = self.norm_k.forward(&k)?;
let mut q = ops::transpose_axes(
&ops::reshape(&q, &[batch as i32, x_seq as i32, nh, hd])?,
&[0, 2, 1, 3],
)?;
let mut k = ops::transpose_axes(
&ops::reshape(&k, &[batch as i32, ctx_seq as i32, nh, hd])?,
&[0, 2, 1, 3],
)?;
let v = ops::transpose_axes(
&ops::reshape(&v, &[batch as i32, ctx_seq as i32, nh, hd])?,
&[0, 2, 1, 3],
)?;
if let Some((cos_f, sin_f)) = rope_q {
q = apply_split_rope(&q, cos_f, sin_f)?;
}
if let Some((cos_f, sin_f)) = rope_k {
k = apply_split_rope(&k, cos_f, sin_f)?;
}
let scale = 1.0_f32 / (self.head_dim as f32).sqrt();
let attn_out = mlx_rs::fast::scaled_dot_product_attention(&q, &k, &v, scale, None)?;
let attn_out = ops::transpose_axes(&attn_out, &[0, 2, 1, 3])?;
let gate_logits = self.to_gate_logits.forward(x)?;
let two = Array::from_f32(2.0);
let gate = ops::multiply(&ops::sigmoid(&gate_logits)?, &two)?;
let gate = ops::reshape(&gate, &[batch as i32, x_seq as i32, nh, 1])?;
let gated = ops::multiply(&attn_out, &gate)?;
let gated = ops::reshape(&gated, &[batch as i32, x_seq as i32, hidden])?;
self.to_out.forward(&gated)
}
}
struct GeGluFfn {
proj_in: QLinear,
proj_out: QLinear,
}
impl GeGluFfn {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
) -> Result<Self, InferenceError> {
Ok(Self {
proj_in: build_qlinear(tensors, &format!("{prefix}.proj_in"), quant)?,
proj_out: build_qlinear(tensors, &format!("{prefix}.proj_out"), quant)?,
})
}
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = self.proj_in.forward(x)?;
let activated = nn::gelu_approximate(&h)?;
self.proj_out.forward(&activated)
}
}
struct LtxTransformerBlock {
attn1: GatedSelfAttention,
attn2: GatedCrossAttention,
ff: GeGluFfn,
scale_shift_table: Array,
prompt_scale_shift_table: Array,
audio_attn1: GatedSelfAttention,
audio_attn2: GatedCrossAttention,
audio_ff: GeGluFfn,
audio_scale_shift_table: Array,
audio_prompt_scale_shift_table: Array,
audio_to_video_attn: GatedCrossAttention,
video_to_audio_attn: GatedCrossAttention,
scale_shift_table_a2v_ca_audio: Array,
scale_shift_table_a2v_ca_video: Array,
}
impl LtxTransformerBlock {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
config: &LtxConfig,
) -> Result<Self, InferenceError> {
let quant = config.quant.as_ref();
let eps = config.norm_eps;
Ok(Self {
attn1: GatedSelfAttention::load(
tensors,
&format!("{prefix}.attn1"),
quant,
config.num_heads,
config.head_dim,
eps,
)?,
attn2: GatedCrossAttention::load(
tensors,
&format!("{prefix}.attn2"),
quant,
config.num_heads,
config.head_dim,
eps,
)?,
ff: GeGluFfn::load(tensors, &format!("{prefix}.ff"), quant)?,
scale_shift_table: get_tensor(tensors, &format!("{prefix}.scale_shift_table"))?,
prompt_scale_shift_table: get_tensor(
tensors,
&format!("{prefix}.prompt_scale_shift_table"),
)?,
audio_attn1: GatedSelfAttention::load(
tensors,
&format!("{prefix}.audio_attn1"),
quant,
config.audio_heads,
config.audio_head_dim,
eps,
)?,
audio_attn2: GatedCrossAttention::load(
tensors,
&format!("{prefix}.audio_attn2"),
quant,
config.audio_heads,
config.audio_head_dim,
eps,
)?,
audio_ff: GeGluFfn::load(tensors, &format!("{prefix}.audio_ff"), quant)?,
audio_scale_shift_table: get_tensor(
tensors,
&format!("{prefix}.audio_scale_shift_table"),
)?,
audio_prompt_scale_shift_table: get_tensor(
tensors,
&format!("{prefix}.audio_prompt_scale_shift_table"),
)?,
audio_to_video_attn: GatedCrossAttention::load(
tensors,
&format!("{prefix}.audio_to_video_attn"),
quant,
config.audio_heads,
config.audio_head_dim,
eps,
)?,
video_to_audio_attn: GatedCrossAttention::load(
tensors,
&format!("{prefix}.video_to_audio_attn"),
quant,
config.audio_heads,
config.audio_head_dim,
eps,
)?,
scale_shift_table_a2v_ca_audio: get_tensor(
tensors,
&format!("{prefix}.scale_shift_table_a2v_ca_audio"),
)?,
scale_shift_table_a2v_ca_video: get_tensor(
tensors,
&format!("{prefix}.scale_shift_table_a2v_ca_video"),
)?,
})
}
fn forward(
&mut self,
x: &Array,
context: &Array,
timestep_emb: &Array,
audio: Option<&Array>,
audio_context: Option<&Array>,
audio_timestep_emb: Option<&Array>,
video_rope_freqs: Option<(&Array, &Array)>,
audio_rope_freqs: Option<(&Array, &Array)>,
video_cross_rope_freqs: Option<(&Array, &Array)>,
audio_cross_rope_freqs: Option<(&Array, &Array)>,
video_prompt_emb: &Array, audio_prompt_emb: Option<&Array>, av_ca_video_emb: &Array, av_ca_audio_emb: Option<&Array>, av_ca_a2v_gate_emb: &Array, av_ca_v2a_gate_emb: Option<&Array>, ) -> Result<(Array, Option<Array>), mlx_rs::error::Exception> {
let adaln_modulate =
|x: &Array, scale: &Array, shift: &Array| -> Result<Array, mlx_rs::error::Exception> {
let eps = Array::from_f32(1e-6);
let var = ops::multiply(x, x)?.mean_axes(&[-1], true)?;
let inv = ops::rsqrt(&ops::add(&var, &eps)?)?;
let normed = ops::multiply(x, &inv)?;
let one = Array::from_f32(1.0);
let scaled = ops::multiply(&normed, &ops::add(&one, scale)?)?;
ops::add(&scaled, shift)
};
let adaln_modulate_raw =
|x: &Array, scale: &Array, shift: &Array| -> Result<Array, mlx_rs::error::Exception> {
let one = Array::from_f32(1.0);
let scaled = ops::multiply(x, &ops::add(&one, scale)?)?;
ops::add(&scaled, shift)
};
let extract_mod = |table: &Array, idx: i32| -> Result<Array, mlx_rs::error::Exception> {
let row = table.index((idx, ..)); ops::reshape(&row, &[1, 1, row.shape()[0]])
};
let unsqueeze = |a: &Array| -> Result<Array, mlx_rs::error::Exception> {
let s = a.shape();
if s.len() == 1 {
ops::reshape(a, &[1, 1, s[0]])
} else if s.len() == 2 {
ops::reshape(a, &[s[0], 1, s[1]])
} else {
Ok(a.clone())
}
};
let ada_chunk_n = |table: &Array,
t_emb: &Array,
idx: i32,
num_params: i32|
-> Result<Array, mlx_rs::error::Exception> {
let hidden = table.shape()[1];
let t_shape = t_emb.shape();
let row = table.index((idx, ..)); match t_shape.len() {
2 => {
let b = t_shape[0];
let reshaped = ops::reshape(t_emb, &[b, num_params, hidden])?;
let slice_i = reshaped.index((.., idx..idx + 1, ..)); let row_bcast = ops::reshape(&row, &[1, 1, hidden])?;
ops::add(&slice_i, &row_bcast)
}
3 => {
let (b, t) = (t_shape[0], t_shape[1]);
let reshaped = ops::reshape(t_emb, &[b, t, num_params, hidden])?;
let slice_i = reshaped.index((.., .., idx..idx + 1, ..)); let collapsed = ops::reshape(&slice_i, &[b, t, hidden])?; let row_bcast = ops::reshape(&row, &[1, 1, hidden])?;
ops::add(&collapsed, &row_bcast)
}
other => Err(mlx_rs::error::Exception::custom(format!(
"ada_chunk: unsupported timestep embedding rank {other}"
))),
}
};
let ada_chunk =
|table: &Array, t_emb: &Array, idx: i32| -> Result<Array, mlx_rs::error::Exception> {
ada_chunk_n(table, t_emb, idx, table.shape()[0])
};
let shift_sa = ada_chunk(&self.scale_shift_table, timestep_emb, 0)?;
let scale_sa = ada_chunk(&self.scale_shift_table, timestep_emb, 1)?;
let gate_sa = ada_chunk(&self.scale_shift_table, timestep_emb, 2)?;
let shift_ff = ada_chunk(&self.scale_shift_table, timestep_emb, 3)?;
let scale_ff = ada_chunk(&self.scale_shift_table, timestep_emb, 4)?;
let gate_ff = ada_chunk(&self.scale_shift_table, timestep_emb, 5)?;
let shift_ca = ada_chunk(&self.scale_shift_table, timestep_emb, 6)?;
let scale_ca = ada_chunk(&self.scale_shift_table, timestep_emb, 7)?;
let gate_ca = ada_chunk(&self.scale_shift_table, timestep_emb, 8)?;
let x_mod = adaln_modulate(x, &scale_sa, &shift_sa)?;
let attn_out = self.attn1.forward(&x_mod, video_rope_freqs)?;
let mut x_vid = ops::add(x, &ops::multiply(&gate_sa, &attn_out)?)?;
let prompt_shift_v = ada_chunk(&self.prompt_scale_shift_table, video_prompt_emb, 0)?;
let prompt_scale_v = ada_chunk(&self.prompt_scale_shift_table, video_prompt_emb, 1)?;
let text_scaled_v = adaln_modulate_raw(context, &prompt_scale_v, &prompt_shift_v)?;
let x_mod = adaln_modulate(&x_vid, &scale_ca, &shift_ca)?;
let ca_out = self.attn2.forward(&x_mod, &text_scaled_v, None, None)?;
x_vid = ops::add(&x_vid, &ops::multiply(&gate_ca, &ca_out)?)?;
let x_audio_out = if let (Some(aud), Some(aud_ctx), Some(aud_temb)) =
(audio, audio_context, audio_timestep_emb)
{
let a_shift_sa = ada_chunk(&self.audio_scale_shift_table, aud_temb, 0)?;
let a_scale_sa = ada_chunk(&self.audio_scale_shift_table, aud_temb, 1)?;
let a_gate_sa = ada_chunk(&self.audio_scale_shift_table, aud_temb, 2)?;
let a_shift_ff = ada_chunk(&self.audio_scale_shift_table, aud_temb, 3)?;
let a_scale_ff = ada_chunk(&self.audio_scale_shift_table, aud_temb, 4)?;
let a_gate_ff = ada_chunk(&self.audio_scale_shift_table, aud_temb, 5)?;
let a_shift_ca = ada_chunk(&self.audio_scale_shift_table, aud_temb, 6)?;
let a_scale_ca = ada_chunk(&self.audio_scale_shift_table, aud_temb, 7)?;
let a_gate_ca = ada_chunk(&self.audio_scale_shift_table, aud_temb, 8)?;
let a_mod = adaln_modulate(aud, &a_scale_sa, &a_shift_sa)?;
let a_attn_out = self.audio_attn1.forward(&a_mod, audio_rope_freqs)?;
let mut x_aud = ops::add(aud, &ops::multiply(&a_gate_sa, &a_attn_out)?)?;
let a_prompt_emb =
audio_prompt_emb.expect("audio_prompt_emb required when audio is active");
let a_prompt_shift = ada_chunk(&self.audio_prompt_scale_shift_table, a_prompt_emb, 0)?;
let a_prompt_scale = ada_chunk(&self.audio_prompt_scale_shift_table, a_prompt_emb, 1)?;
let text_scaled_a = adaln_modulate_raw(aud_ctx, &a_prompt_scale, &a_prompt_shift)?;
let a_mod = adaln_modulate(&x_aud, &a_scale_ca, &a_shift_ca)?;
let a_ca_out = self
.audio_attn2
.forward(&a_mod, &text_scaled_a, None, None)?;
x_aud = ops::add(&x_aud, &ops::multiply(&a_gate_ca, &a_ca_out)?)?;
let video_norm3 = rms_norm_parameterless(&x_vid, 1e-6)?;
let audio_norm3 = rms_norm_parameterless(&x_aud, 1e-6)?;
let av_v_scale_a2v =
ada_chunk_n(&self.scale_shift_table_a2v_ca_video, av_ca_video_emb, 0, 4)?;
let av_v_shift_a2v =
ada_chunk_n(&self.scale_shift_table_a2v_ca_video, av_ca_video_emb, 1, 4)?;
let av_v_scale_v2a =
ada_chunk_n(&self.scale_shift_table_a2v_ca_video, av_ca_video_emb, 2, 4)?;
let av_v_shift_v2a =
ada_chunk_n(&self.scale_shift_table_a2v_ca_video, av_ca_video_emb, 3, 4)?;
let av_v_gate_a2v = {
let row4 = self.scale_shift_table_a2v_ca_video.index((4, ..));
let row4 = ops::reshape(&row4, &[1, 1, row4.shape()[0]])?;
let b_emb = if av_ca_a2v_gate_emb.shape().len() == 2 {
let s = av_ca_a2v_gate_emb.shape();
ops::reshape(av_ca_a2v_gate_emb, &[s[0], 1, s[1]])?
} else {
av_ca_a2v_gate_emb.clone()
};
ops::add(&b_emb, &row4)?
};
let a_gate_emb =
av_ca_v2a_gate_emb.expect("av_ca_v2a_gate_emb required when audio is active");
let av_a_scale_a2v = ada_chunk_n(
&self.scale_shift_table_a2v_ca_audio,
av_ca_audio_emb.expect("av_ca_audio_emb required when audio is active"),
0,
4,
)?;
let av_a_shift_a2v = ada_chunk_n(
&self.scale_shift_table_a2v_ca_audio,
av_ca_audio_emb.unwrap(),
1,
4,
)?;
let av_a_scale_v2a = ada_chunk_n(
&self.scale_shift_table_a2v_ca_audio,
av_ca_audio_emb.unwrap(),
2,
4,
)?;
let av_a_shift_v2a = ada_chunk_n(
&self.scale_shift_table_a2v_ca_audio,
av_ca_audio_emb.unwrap(),
3,
4,
)?;
let av_a_gate_v2a = {
let row4 = self.scale_shift_table_a2v_ca_audio.index((4, ..));
let row4 = ops::reshape(&row4, &[1, 1, row4.shape()[0]])?;
let b_emb = if a_gate_emb.shape().len() == 2 {
let s = a_gate_emb.shape();
ops::reshape(a_gate_emb, &[s[0], 1, s[1]])?
} else {
a_gate_emb.clone()
};
ops::add(&b_emb, &row4)?
};
let video_q_a2v = adaln_modulate_raw(&video_norm3, &av_v_scale_a2v, &av_v_shift_a2v)?;
let audio_kv_a2v = adaln_modulate_raw(&audio_norm3, &av_a_scale_a2v, &av_a_shift_a2v)?;
let a2v_out = self.audio_to_video_attn.forward(
&video_q_a2v,
&audio_kv_a2v,
video_cross_rope_freqs,
audio_cross_rope_freqs,
)?;
x_vid = ops::add(&x_vid, &ops::multiply(&av_v_gate_a2v, &a2v_out)?)?;
let audio_q_v2a = adaln_modulate_raw(&audio_norm3, &av_a_scale_v2a, &av_a_shift_v2a)?;
let video_kv_v2a = adaln_modulate_raw(&video_norm3, &av_v_scale_v2a, &av_v_shift_v2a)?;
let v2a_out = self.video_to_audio_attn.forward(
&audio_q_v2a,
&video_kv_v2a,
audio_cross_rope_freqs,
video_cross_rope_freqs,
)?;
x_aud = ops::add(&x_aud, &ops::multiply(&av_a_gate_v2a, &v2a_out)?)?;
let a_mod = adaln_modulate(&x_aud, &a_scale_ff, &a_shift_ff)?;
let a_ff_out = self.audio_ff.forward(&a_mod)?;
x_aud = ops::add(&x_aud, &ops::multiply(&a_gate_ff, &a_ff_out)?)?;
Some(x_aud)
} else {
None
};
let x_mod = adaln_modulate(&x_vid, &scale_ff, &shift_ff)?;
let ff_out = self.ff.forward(&x_mod)?;
x_vid = ops::add(&x_vid, &ops::multiply(&gate_ff, &ff_out)?)?;
Ok((x_vid, x_audio_out))
}
}
struct AdaLnSingle {
emb_timestep_linear1: QLinear,
emb_timestep_linear2: QLinear,
linear: QLinear,
}
impl AdaLnSingle {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
) -> Result<Self, InferenceError> {
Ok(Self {
emb_timestep_linear1: build_qlinear(
tensors,
&format!("{prefix}.emb.timestep_embedder.linear1"),
quant,
)?,
emb_timestep_linear2: build_qlinear(
tensors,
&format!("{prefix}.emb.timestep_embedder.linear2"),
quant,
)?,
linear: build_qlinear(tensors, &format!("{prefix}.linear"), quant)?,
})
}
fn forward(&mut self, timestep: &Array) -> Result<(Array, Array), mlx_rs::error::Exception> {
let scale = Array::from_f32(1000.0);
let t_scaled = ops::multiply(timestep, &scale)?;
let t_emb = timestep_embedding_tensor(&t_scaled, 256)?;
let h = self.emb_timestep_linear1.forward(&t_emb)?;
let h = nn::silu(&h)?;
let embedded = self.emb_timestep_linear2.forward(&h)?;
let h = nn::silu(&embedded)?;
let params = self.linear.forward(&h)?;
Ok((params, embedded))
}
}
struct AdaLnSingleWide {
emb_timestep_linear1: QLinear,
emb_timestep_linear2: QLinear,
linear: QLinear,
}
impl AdaLnSingleWide {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
) -> Result<Self, InferenceError> {
Ok(Self {
emb_timestep_linear1: build_qlinear(
tensors,
&format!("{prefix}.emb.timestep_embedder.linear1"),
quant,
)?,
emb_timestep_linear2: build_qlinear(
tensors,
&format!("{prefix}.emb.timestep_embedder.linear2"),
quant,
)?,
linear: build_qlinear(tensors, &format!("{prefix}.linear"), quant)?,
})
}
fn forward(&mut self, timestep: &Array) -> Result<Array, mlx_rs::error::Exception> {
let scale = Array::from_f32(1000.0);
let t_scaled = ops::multiply(timestep, &scale)?;
let t_emb = timestep_embedding_tensor(&t_scaled, 256)?;
let h = self.emb_timestep_linear1.forward(&t_emb)?;
let h = nn::silu(&h)?;
let h = self.emb_timestep_linear2.forward(&h)?;
let h = nn::silu(&h)?;
self.linear.forward(&h)
}
}
struct PromptAdaLnSingle {
emb_timestep_linear1: QLinear,
emb_timestep_linear2: QLinear,
linear: QLinear,
}
impl PromptAdaLnSingle {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
) -> Result<Self, InferenceError> {
Ok(Self {
emb_timestep_linear1: build_qlinear(
tensors,
&format!("{prefix}.emb.timestep_embedder.linear1"),
quant,
)?,
emb_timestep_linear2: build_qlinear(
tensors,
&format!("{prefix}.emb.timestep_embedder.linear2"),
quant,
)?,
linear: build_qlinear(tensors, &format!("{prefix}.linear"), quant)?,
})
}
fn forward(&mut self, timestep: &Array) -> Result<Array, mlx_rs::error::Exception> {
let scale = Array::from_f32(1000.0);
let t_scaled = ops::multiply(timestep, &scale)?;
let t_emb = timestep_embedding_tensor(&t_scaled, 256)?;
let h = self.emb_timestep_linear1.forward(&t_emb)?;
let h = nn::silu(&h)?;
let h = self.emb_timestep_linear2.forward(&h)?;
let h = nn::silu(&h)?;
self.linear.forward(&h)
}
}
struct LtxTransformer {
patchify_proj: QLinear,
proj_out: QLinear,
adaln_single: AdaLnSingle,
prompt_adaln_single: PromptAdaLnSingle,
scale_shift_table: Array,
audio_patchify_proj: nn::Linear,
audio_proj_out: nn::Linear,
audio_adaln_single: AdaLnSingle,
audio_prompt_adaln_single: PromptAdaLnSingle,
audio_scale_shift_table: Array,
av_ca_a2v_gate_adaln_single: AdaLnSingle,
av_ca_v2a_gate_adaln_single: AdaLnSingle,
av_ca_video_scale_shift_adaln_single: AdaLnSingleWide,
av_ca_audio_scale_shift_adaln_single: AdaLnSingleWide,
blocks: Vec<LtxTransformerBlock>,
}
impl LtxTransformer {
fn load(tensors: &HashMap<String, Array>, config: &LtxConfig) -> Result<Self, InferenceError> {
let quant = config.quant.as_ref();
let pfx = "transformer";
let patchify_proj = build_qlinear(tensors, &format!("{pfx}.patchify_proj"), quant)?;
let proj_out = build_qlinear(tensors, &format!("{pfx}.proj_out"), quant)?;
let adaln_single = AdaLnSingle::load(tensors, &format!("{pfx}.adaln_single"), quant)?;
let prompt_adaln_single =
PromptAdaLnSingle::load(tensors, &format!("{pfx}.prompt_adaln_single"), quant)?;
let scale_shift_table = get_tensor(tensors, &format!("{pfx}.scale_shift_table"))?;
let audio_patchify_proj =
build_dense_linear(tensors, &format!("{pfx}.audio_patchify_proj"))?;
let audio_proj_out = build_dense_linear(tensors, &format!("{pfx}.audio_proj_out"))?;
let audio_adaln_single =
AdaLnSingle::load(tensors, &format!("{pfx}.audio_adaln_single"), quant)?;
let audio_prompt_adaln_single =
PromptAdaLnSingle::load(tensors, &format!("{pfx}.audio_prompt_adaln_single"), quant)?;
let audio_scale_shift_table =
get_tensor(tensors, &format!("{pfx}.audio_scale_shift_table"))?;
let av_ca_a2v_gate_adaln_single = AdaLnSingle::load(
tensors,
&format!("{pfx}.av_ca_a2v_gate_adaln_single"),
quant,
)?;
let av_ca_v2a_gate_adaln_single = AdaLnSingle::load(
tensors,
&format!("{pfx}.av_ca_v2a_gate_adaln_single"),
quant,
)?;
let av_ca_video_scale_shift_adaln_single = AdaLnSingleWide::load(
tensors,
&format!("{pfx}.av_ca_video_scale_shift_adaln_single"),
quant,
)?;
let av_ca_audio_scale_shift_adaln_single = AdaLnSingleWide::load(
tensors,
&format!("{pfx}.av_ca_audio_scale_shift_adaln_single"),
quant,
)?;
let mut blocks = Vec::with_capacity(config.num_layers);
for i in 0..config.num_layers {
blocks.push(LtxTransformerBlock::load(
tensors,
&format!("{pfx}.transformer_blocks.{i}"),
config,
)?);
}
Ok(Self {
patchify_proj,
proj_out,
adaln_single,
prompt_adaln_single,
scale_shift_table,
audio_patchify_proj,
audio_proj_out,
audio_adaln_single,
audio_prompt_adaln_single,
audio_scale_shift_table,
av_ca_a2v_gate_adaln_single,
av_ca_v2a_gate_adaln_single,
av_ca_video_scale_shift_adaln_single,
av_ca_audio_scale_shift_adaln_single,
blocks,
})
}
fn forward(
&mut self,
latents: &Array,
text_embed: &Array,
timestep: &Array,
global_timestep: &Array,
audio_latents: Option<&Array>,
audio_text_embed: Option<&Array>,
rope: &RopeBundle,
) -> Result<(Array, Option<Array>), mlx_rs::error::Exception> {
let video_rope_freqs = rope.video_pair();
let audio_rope_freqs = rope.audio_pair();
let video_cross_rope_freqs = rope.video_cross_pair();
let audio_cross_rope_freqs = rope.audio_cross_pair();
let unsqueeze_final = |a: &Array| -> Result<Array, mlx_rs::error::Exception> {
let s = a.shape();
if s.len() == 1 {
ops::reshape(a, &[1, 1, s[0]])
} else if s.len() == 2 {
ops::reshape(a, &[s[0], 1, s[1]])
} else {
Ok(a.clone())
}
};
dump_ltx_stage_first_call("latents_input", latents);
let latents = latents.as_dtype(mlx_rs::Dtype::Bfloat16)?;
dump_ltx_stage_first_call("latents_bf16", &latents);
let text_embed = text_embed.as_dtype(mlx_rs::Dtype::Bfloat16)?;
let timestep = timestep.as_dtype(mlx_rs::Dtype::Bfloat16)?;
let global_timestep = global_timestep.as_dtype(mlx_rs::Dtype::Bfloat16)?;
let mut hidden = self.patchify_proj.forward(&latents)?;
dump_ltx_stage_first_call("patchify_out", &hidden);
let (timestep_emb, video_embedded_ts) = self.adaln_single.forward(×tep)?;
dump_ltx_stage_first_call("timestep_adaln_params", ×tep_emb);
dump_ltx_stage_first_call("timestep_embedded", &video_embedded_ts);
let audio_latents_bf = match audio_latents {
Some(a) => Some(a.as_dtype(mlx_rs::Dtype::Bfloat16)?),
None => None,
};
let audio_text_embed_bf = match audio_text_embed {
Some(a) => Some(a.as_dtype(mlx_rs::Dtype::Bfloat16)?),
None => None,
};
let mut audio_hidden = if let Some(aud_lat) = audio_latents_bf.as_ref() {
let ah = self.audio_patchify_proj.forward(aud_lat)?;
Some(ah)
} else {
None
};
let (audio_timestep_emb, audio_embedded_ts) = if audio_latents.is_some() {
let (p, e) = self.audio_adaln_single.forward(&global_timestep)?;
(Some(p), Some(e))
} else {
(None, None)
};
let video_prompt_emb = self.prompt_adaln_single.forward(&global_timestep)?;
let av_ca_video_emb = self
.av_ca_video_scale_shift_adaln_single
.forward(&global_timestep)?;
let (av_ca_a2v_gate_emb, _) = self.av_ca_a2v_gate_adaln_single.forward(&global_timestep)?;
let (audio_prompt_emb, av_ca_audio_emb, av_ca_v2a_gate_emb): (
Option<Array>,
Option<Array>,
Option<Array>,
) = if audio_latents.is_some() {
let ap = self.audio_prompt_adaln_single.forward(&global_timestep)?;
let av_aa = self
.av_ca_audio_scale_shift_adaln_single
.forward(&global_timestep)?;
let (av_v2a, _) = self.av_ca_v2a_gate_adaln_single.forward(&global_timestep)?;
(Some(ap), Some(av_aa), Some(av_v2a))
} else {
(None, None, None)
};
let total_blocks = self.blocks.len();
let mid_idx = total_blocks / 2; let last_idx = total_blocks - 1; for (i, block) in self.blocks.iter_mut().enumerate() {
let (vid_out, aud_out) = block.forward(
&hidden,
&text_embed,
×tep_emb,
audio_hidden.as_ref(),
audio_text_embed_bf.as_ref(),
audio_timestep_emb.as_ref(),
video_rope_freqs,
audio_rope_freqs,
video_cross_rope_freqs,
audio_cross_rope_freqs,
&video_prompt_emb,
audio_prompt_emb.as_ref(),
&av_ca_video_emb,
av_ca_audio_emb.as_ref(),
&av_ca_a2v_gate_emb,
av_ca_v2a_gate_emb.as_ref(),
)?;
hidden = vid_out;
audio_hidden = aud_out;
if i == 0 || i == 2 || i == 5 || i == 10 || i == 15 || i == 20 || i == 30 || i == 40 {
dump_ltx_stage_first_call(&format!("block{:02}_out", i), &hidden);
}
if i == 0 {
dump_ltx_stage_first_call("block00_out", &hidden);
} else if i == mid_idx {
dump_ltx_stage_first_call(&format!("block_mid{mid_idx:02}_out"), &hidden);
} else if i == last_idx {
dump_ltx_stage_first_call(&format!("block_last{last_idx:02}_out"), &hidden);
}
}
let compute_final_mod =
|table: &Array, embedded_ts: &Array| -> Result<Array, mlx_rs::error::Exception> {
let t_shape = embedded_ts.shape();
let dim = table.shape()[1];
let (b, n_or_1) = match t_shape.len() {
2 => (t_shape[0], 1),
3 => (t_shape[0], t_shape[1]),
other => {
return Err(mlx_rs::error::Exception::custom(format!(
"embedded_ts unexpected rank {other}"
)))
}
};
let embedded_exp = ops::reshape(embedded_ts, &[b, n_or_1, 1, dim])?;
let table_exp = ops::reshape(table, &[1, 1, 2, dim])?;
ops::add(&table_exp, &embedded_exp) };
let v_scale_shift = compute_final_mod(&self.scale_shift_table, &video_embedded_ts)?;
let v_shift = v_scale_shift.index((.., .., 0, ..));
let v_scale = v_scale_shift.index((.., .., 1, ..));
let v_ln = ltx_layer_norm_parameterless(&hidden, 1e-6)?;
let one = Array::from_f32(1.0);
let v_mod = ops::add(&ops::multiply(&v_ln, &ops::add(&one, &v_scale)?)?, &v_shift)?;
let video_out = self.proj_out.forward(&v_mod)?;
dump_ltx_stage_first_call("final_output_video", &video_out);
let audio_out = if let Some(aud_h) = audio_hidden {
let aud_emb = audio_embedded_ts.as_ref().ok_or_else(|| {
mlx_rs::error::Exception::custom(
"audio_embedded_ts must be set when audio_hidden is".to_string(),
)
})?;
let a_scale_shift = compute_final_mod(&self.audio_scale_shift_table, aud_emb)?;
let a_shift = a_scale_shift.index((.., .., 0, ..));
let a_scale = a_scale_shift.index((.., .., 1, ..));
let a_ln = ltx_layer_norm_parameterless(&aud_h, 1e-6)?;
let a_mod = ops::add(&ops::multiply(&a_ln, &ops::add(&one, &a_scale)?)?, &a_shift)?;
Some(self.audio_proj_out.forward(&a_mod)?)
} else {
None
};
Ok((video_out, audio_out))
}
}
struct ConnectorBlock {
num_heads: usize,
head_dim: usize,
attn_to_q: nn::Linear,
attn_to_k: nn::Linear,
attn_to_v: nn::Linear,
attn_to_out: nn::Linear,
attn_q_norm: Option<Array>,
attn_k_norm: Option<Array>,
attn_to_gate_logits: Option<nn::Linear>,
ff_proj_in: nn::Linear,
ff_proj_out: nn::Linear,
}
impl ConnectorBlock {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
num_heads: usize,
head_dim: usize,
) -> Result<Self, InferenceError> {
let has_gate = tensors.contains_key(&format!("{prefix}.attn1.to_gate_logits.weight"));
let attn_to_gate_logits = if has_gate {
Some(build_dense_linear(
tensors,
&format!("{prefix}.attn1.to_gate_logits"),
)?)
} else {
None
};
Ok(Self {
num_heads,
head_dim,
attn_to_q: build_dense_linear(tensors, &format!("{prefix}.attn1.to_q"))?,
attn_to_k: build_dense_linear(tensors, &format!("{prefix}.attn1.to_k"))?,
attn_to_v: build_dense_linear(tensors, &format!("{prefix}.attn1.to_v"))?,
attn_to_out: build_dense_linear(tensors, &format!("{prefix}.attn1.to_out.0"))?,
attn_q_norm: tensors
.get(&format!("{prefix}.attn1.q_norm.weight"))
.cloned(),
attn_k_norm: tensors
.get(&format!("{prefix}.attn1.k_norm.weight"))
.cloned(),
attn_to_gate_logits,
ff_proj_in: build_dense_linear(tensors, &format!("{prefix}.ff.net.0.proj"))?,
ff_proj_out: build_dense_linear(tensors, &format!("{prefix}.ff.net.2"))?,
})
}
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape();
let batch = shape[0];
let seq_len = shape[1];
let hdim = shape[2];
let nh = self.num_heads as i32;
let hd = self.head_dim as i32;
let normed = rms_norm_parameterless(x, 1e-6)?;
let mut q = self.attn_to_q.forward(&normed)?;
let mut k = self.attn_to_k.forward(&normed)?;
let v = self.attn_to_v.forward(&normed)?;
if let Some(ref w) = self.attn_q_norm {
q = apply_weighted_rms(&q, w, 1e-6)?;
}
if let Some(ref w) = self.attn_k_norm {
k = apply_weighted_rms(&k, w, 1e-6)?;
}
let reshape_head = |t: &Array| -> Result<Array, mlx_rs::error::Exception> {
let r = ops::reshape(t, &[batch, seq_len, nh, hd])?;
ops::transpose_axes(&r, &[0, 2, 1, 3])
};
let mut qh = reshape_head(&q)?;
let mut kh = reshape_head(&k)?;
let vh = reshape_head(&v)?;
let inner_dim = self.num_heads * self.head_dim;
let (cos_f, sin_f) =
precompute_split_rope(seq_len, self.num_heads, inner_dim, 10_000.0, 4096.0)?;
qh = apply_split_rope(&qh, &cos_f, &sin_f)?;
kh = apply_split_rope(&kh, &cos_f, &sin_f)?;
let scale = Array::from_f32(1.0 / (self.head_dim as f32).sqrt());
let scores = ops::multiply(
&ops::matmul(&qh, &ops::transpose_axes(&kh, &[0, 1, 3, 2])?)?,
&scale,
)?;
let attn = ops::softmax_axis(&scores, -1, None)?;
let attn_out = ops::matmul(&attn, &vh)?;
let attn_out = ops::transpose_axes(&attn_out, &[0, 2, 1, 3])?;
let gated = if let Some(ref mut gate) = self.attn_to_gate_logits {
let gate_logits = gate.forward(&normed)?; let two = Array::from_f32(2.0);
let gates = ops::multiply(&ops::sigmoid(&gate_logits)?, &two)?;
let gates = ops::expand_dims(&gates, -1)?; ops::multiply(&attn_out, &gates)?
} else {
attn_out
};
let merged = ops::reshape(&gated, &[batch, seq_len, hdim])?;
let attn_out = self.attn_to_out.forward(&merged)?;
let x_post_attn = ops::add(x, &attn_out)?;
let normed = rms_norm_parameterless(&x_post_attn, 1e-6)?;
let h = self.ff_proj_in.forward(&normed)?;
let activated = nn::gelu_approximate(&h)?;
let ff_out = self.ff_proj_out.forward(&activated)?;
ops::add(&x_post_attn, &ff_out)
}
}
fn precompute_split_rope(
seq_len: i32,
num_heads: usize,
inner_dim: usize,
theta: f32,
max_pos: f32,
) -> Result<(Array, Array), mlx_rs::error::Exception> {
let num_freqs = inner_dim / 2;
let freq_indices: Vec<f32> = (0..num_freqs)
.map(|i| {
let t = i as f32 / (num_freqs.saturating_sub(1).max(1) as f32);
theta.powf(t) * std::f32::consts::FRAC_PI_2
})
.collect();
let freq_indices = Array::from_slice(&freq_indices, &[num_freqs as i32]);
let positions: Vec<f32> = (0..seq_len)
.map(|p| (p as f32 / max_pos) * 2.0 - 1.0)
.collect();
let positions = Array::from_slice(&positions, &[1, seq_len, 1]);
let freqs = ops::multiply(&positions, &freq_indices)?;
let cos_f = ops::cos(&freqs)?;
let sin_f = ops::sin(&freqs)?;
let head_dim_half = (inner_dim / (2 * num_heads)) as i32;
let cos_f = ops::reshape(&cos_f, &[1, seq_len, num_heads as i32, head_dim_half])?;
let sin_f = ops::reshape(&sin_f, &[1, seq_len, num_heads as i32, head_dim_half])?;
let cos_f = ops::transpose_axes(&cos_f, &[0, 2, 1, 3])?;
let sin_f = ops::transpose_axes(&sin_f, &[0, 2, 1, 3])?;
Ok((cos_f, sin_f))
}
fn precompute_split_rope_multiaxis(
positions: &[Vec<f32>],
num_heads: usize,
head_dim: usize,
theta: f32,
max_pos: &[f32],
) -> Result<(Array, Array), mlx_rs::error::Exception> {
let num_pos_dims = max_pos.len();
assert!(
positions.len() == num_pos_dims,
"positions axis count mismatch"
);
let inner_dim = num_heads * head_dim;
let n_elem = 2 * num_pos_dims;
let num_freqs = inner_dim / n_elem;
let expected = inner_dim / 2; let covered = num_freqs * num_pos_dims;
let pad = expected - covered;
let freq_indices: Vec<f32> = (0..num_freqs)
.map(|i| {
let t = i as f32 / (num_freqs.saturating_sub(1).max(1) as f32);
theta.powf(t) * std::f32::consts::FRAC_PI_2
})
.collect();
let num_tokens = positions[0].len();
let mut freqs_flat = vec![0.0f32; num_tokens * expected];
for tok in 0..num_tokens {
let base = tok * expected + pad;
for f in 0..num_freqs {
for d in 0..num_pos_dims {
let p = positions[d][tok];
let frac = (p / max_pos[d]) * 2.0 - 1.0;
freqs_flat[base + f * num_pos_dims + d] = frac * freq_indices[f];
}
}
}
let freqs = Array::from_slice(&freqs_flat, &[1, num_tokens as i32, expected as i32]);
let cos_f = ops::cos(&freqs)?;
let sin_f = ops::sin(&freqs)?;
let head_dim_half = (head_dim / 2) as i32;
let cos_f = ops::reshape(
&cos_f,
&[1, num_tokens as i32, num_heads as i32, head_dim_half],
)?;
let sin_f = ops::reshape(
&sin_f,
&[1, num_tokens as i32, num_heads as i32, head_dim_half],
)?;
let cos_f = ops::transpose_axes(&cos_f, &[0, 2, 1, 3])?;
let sin_f = ops::transpose_axes(&sin_f, &[0, 2, 1, 3])?;
Ok((cos_f, sin_f))
}
#[allow(dead_code)]
fn precompute_interleaved_rope(
positions: &[Vec<f32>],
num_heads: usize,
head_dim: usize,
theta: f32,
max_pos: &[f32],
) -> Result<(Array, Array), mlx_rs::error::Exception> {
let num_pos_dims = max_pos.len();
assert!(
positions.len() == num_pos_dims,
"positions axis count mismatch"
);
let n_elem = 2 * num_pos_dims;
let inner_dim = num_heads * head_dim;
let num_freqs = inner_dim / n_elem;
let freq_indices: Vec<f32> = (0..num_freqs)
.map(|i| {
let t = i as f32 / (num_freqs.saturating_sub(1).max(1) as f32);
theta.powf(t) * std::f32::consts::FRAC_PI_2
})
.collect();
let num_tokens = positions[0].len();
let mut freqs_flat = vec![0.0f32; num_tokens * num_freqs * num_pos_dims];
for tok in 0..num_tokens {
for d in 0..num_pos_dims {
let p = positions[d][tok];
let frac = (p / max_pos[d]) * 2.0 - 1.0;
for f in 0..num_freqs {
let out_idx = tok * (num_freqs * num_pos_dims) + f * num_pos_dims + d;
freqs_flat[out_idx] = frac * freq_indices[f];
}
}
}
let freqs = Array::from_slice(
&freqs_flat,
&[1, num_tokens as i32, (num_freqs * num_pos_dims) as i32],
);
let cos_f = ops::cos(&freqs)?;
let sin_f = ops::sin(&freqs)?;
let cos_f = ops::repeat_axis::<f32>(cos_f, 2, -1)?;
let sin_f = ops::repeat_axis::<f32>(sin_f, 2, -1)?;
let covered = 2 * num_freqs * num_pos_dims;
let pad = (inner_dim - covered) as i32;
let cos_f = if pad > 0 {
let ones = Array::ones::<f32>(&[1, num_tokens as i32, pad])?;
ops::concatenate_axis(&[&ones, &cos_f], -1)?
} else {
cos_f
};
let sin_f = if pad > 0 {
let zeros = Array::zeros::<f32>(&[1, num_tokens as i32, pad])?;
ops::concatenate_axis(&[&zeros, &sin_f], -1)?
} else {
sin_f
};
let cos_f = ops::reshape(
&cos_f,
&[1, num_tokens as i32, num_heads as i32, head_dim as i32],
)?;
let sin_f = ops::reshape(
&sin_f,
&[1, num_tokens as i32, num_heads as i32, head_dim as i32],
)?;
let cos_f = ops::transpose_axes(&cos_f, &[0, 2, 1, 3])?;
let sin_f = ops::transpose_axes(&sin_f, &[0, 2, 1, 3])?;
Ok((cos_f, sin_f))
}
#[allow(dead_code)]
fn apply_interleaved_rope(
x: &Array,
cos_f: &Array,
sin_f: &Array,
) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape();
let last = shape[shape.len() - 1];
let half = last / 2;
let mut paired_shape: Vec<i32> = shape.to_vec();
let ln = paired_shape.len();
paired_shape[ln - 1] = half;
paired_shape.push(2);
let x_pairs = ops::reshape(x, &paired_shape)?;
let x1 = x_pairs.index((.., .., .., .., 0..1));
let x2 = x_pairs.index((.., .., .., .., 1..2));
let neg_x2 = ops::negative(&x2)?;
let rotated_pairs = ops::concatenate_axis(&[&neg_x2, &x1], -1)?; let x_rot = ops::reshape(&rotated_pairs, shape)?; ops::add(&ops::multiply(x, cos_f)?, &ops::multiply(&x_rot, sin_f)?)
}
fn apply_split_rope(
x: &Array,
cos_f: &Array,
sin_f: &Array,
) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape();
let half = shape[shape.len() - 1] / 2;
let x1 = x.index((.., .., .., ..half));
let x2 = x.index((.., .., .., half..));
let out1 = ops::subtract(&ops::multiply(&x1, cos_f)?, &ops::multiply(&x2, sin_f)?)?;
let out2 = ops::add(&ops::multiply(&x1, sin_f)?, &ops::multiply(&x2, cos_f)?)?;
ops::concatenate_axis(&[&out1, &out2], -1)
}
const VIDEO_TEMPORAL_SCALE: f32 = 8.0;
const VIDEO_SPATIAL_SCALE: f32 = 32.0;
const AUDIO_DOWNSAMPLE_FACTOR: f32 = 4.0;
const AUDIO_HOP_LENGTH: f32 = 160.0;
const AUDIO_SAMPLE_RATE: f32 = 16000.0;
fn build_video_positions_3d(
latent_f: i32,
latent_h: i32,
latent_w: i32,
fps: f32,
) -> [Vec<f32>; 3] {
let f = latent_f as usize;
let h = latent_h as usize;
let w = latent_w as usize;
let n = f * h * w;
let f_mids: Vec<f32> = (0..f)
.map(|i| {
let start = (i as f32 * VIDEO_TEMPORAL_SCALE + 1.0 - VIDEO_TEMPORAL_SCALE).max(0.0);
let end =
((i as f32 + 1.0) * VIDEO_TEMPORAL_SCALE + 1.0 - VIDEO_TEMPORAL_SCALE).max(0.0);
(start + end) / 2.0 / fps
})
.collect();
let h_mids: Vec<f32> = (0..h)
.map(|i| i as f32 * VIDEO_SPATIAL_SCALE + VIDEO_SPATIAL_SCALE / 2.0)
.collect();
let w_mids: Vec<f32> = (0..w)
.map(|i| i as f32 * VIDEO_SPATIAL_SCALE + VIDEO_SPATIAL_SCALE / 2.0)
.collect();
let mut f_col = Vec::with_capacity(n);
let mut h_col = Vec::with_capacity(n);
let mut w_col = Vec::with_capacity(n);
for fi in 0..f {
for hi in 0..h {
for wi in 0..w {
f_col.push(f_mids[fi]);
h_col.push(h_mids[hi]);
w_col.push(w_mids[wi]);
}
}
}
[f_col, h_col, w_col]
}
fn build_audio_positions_1d(num_tokens: i32) -> [Vec<f32>; 1] {
let n = num_tokens as usize;
let mids: Vec<f32> = (0..n)
.map(|i| {
let i_f = i as f32;
let s = ((i_f * AUDIO_DOWNSAMPLE_FACTOR + 1.0 - AUDIO_DOWNSAMPLE_FACTOR).max(0.0))
* AUDIO_HOP_LENGTH
/ AUDIO_SAMPLE_RATE;
let e = (((i_f + 1.0) * AUDIO_DOWNSAMPLE_FACTOR + 1.0 - AUDIO_DOWNSAMPLE_FACTOR)
.max(0.0))
* AUDIO_HOP_LENGTH
/ AUDIO_SAMPLE_RATE;
(s + e) / 2.0
})
.collect();
[mids]
}
#[derive(Default, Clone)]
pub struct RopeBundle {
pub video: Option<(Array, Array)>,
pub audio: Option<(Array, Array)>,
pub video_cross: Option<(Array, Array)>,
pub audio_cross: Option<(Array, Array)>,
}
impl RopeBundle {
pub fn build(
latent_f: i32,
latent_h: i32,
latent_w: i32,
audio_seq_len: i32,
fps: f32,
video_num_heads: usize,
video_head_dim: usize,
audio_num_heads: usize,
audio_head_dim: usize,
) -> Result<Self, mlx_rs::error::Exception> {
const THETA: f32 = 10000.0;
const AV_CROSS_HEADS: usize = 32;
const AV_CROSS_HEAD_DIM: usize = 64;
const MAX_POS_3D: [f32; 3] = [20.0, 2048.0, 2048.0];
const MAX_POS_1D: [f32; 1] = [20.0];
let video_pos = build_video_positions_3d(latent_f, latent_h, latent_w, fps);
let video = precompute_split_rope_multiaxis(
&video_pos,
video_num_heads,
video_head_dim,
THETA,
&MAX_POS_3D,
)?;
let video_cross_pos = [video_pos[0].clone()];
let video_cross = precompute_split_rope_multiaxis(
&video_cross_pos,
AV_CROSS_HEADS,
AV_CROSS_HEAD_DIM,
THETA,
&MAX_POS_1D,
)?;
let (audio, audio_cross) = if audio_seq_len > 0 {
let audio_pos = build_audio_positions_1d(audio_seq_len);
let audio = precompute_split_rope_multiaxis(
&audio_pos,
audio_num_heads,
audio_head_dim,
THETA,
&MAX_POS_1D,
)?;
let audio_cross = precompute_split_rope_multiaxis(
&audio_pos,
AV_CROSS_HEADS,
AV_CROSS_HEAD_DIM,
THETA,
&MAX_POS_1D,
)?;
(Some(audio), Some(audio_cross))
} else {
(None, None)
};
Ok(Self {
video: Some(video),
audio,
video_cross: Some(video_cross),
audio_cross,
})
}
fn video_pair(&self) -> Option<(&Array, &Array)> {
self.video.as_ref().map(|(c, s)| (c, s))
}
fn audio_pair(&self) -> Option<(&Array, &Array)> {
self.audio.as_ref().map(|(c, s)| (c, s))
}
fn video_cross_pair(&self) -> Option<(&Array, &Array)> {
self.video_cross.as_ref().map(|(c, s)| (c, s))
}
fn audio_cross_pair(&self) -> Option<(&Array, &Array)> {
self.audio_cross.as_ref().map(|(c, s)| (c, s))
}
}
fn ltx_layer_norm_parameterless(x: &Array, eps: f32) -> Result<Array, mlx_rs::error::Exception> {
mlx_rs::fast::layer_norm(x, None, None, eps)
}
fn rms_norm_parameterless(x: &Array, eps: f32) -> Result<Array, mlx_rs::error::Exception> {
let last = x.shape()[x.shape().len() - 1];
use std::cell::RefCell;
use std::collections::HashMap;
thread_local! {
static ONES_CACHE: RefCell<HashMap<i32, &'static Array>> = RefCell::new(HashMap::new());
}
let weight = ONES_CACHE.with(|cell| {
if let Some(w) = cell.borrow().get(&last) {
return *w;
}
let ones = Array::ones::<f32>(&[last]).expect("ones");
let leaked: &'static Array = Box::leak(Box::new(ones));
cell.borrow_mut().insert(last, leaked);
leaked
});
mlx_rs::fast::rms_norm(x, weight, eps)
}
fn apply_weighted_rms(
x: &Array,
weight: &Array,
eps: f32,
) -> Result<Array, mlx_rs::error::Exception> {
let normed = rms_norm_parameterless(x, eps)?;
ops::multiply(&normed, weight)
}
struct TextEmbeddingConnector {
video_aggregate_embed: nn::Linear,
video_blocks: Vec<ConnectorBlock>,
video_registers: Option<Array>,
audio_aggregate_embed: nn::Linear,
audio_blocks: Vec<ConnectorBlock>,
audio_registers: Option<Array>,
}
impl TextEmbeddingConnector {
fn load(tensors: &HashMap<String, Array>) -> Result<Self, InferenceError> {
let pfx = "connector";
let video_aggregate_embed = build_dense_linear(
tensors,
&format!("{pfx}.text_embedding_projection.video_aggregate_embed"),
)?;
let mut video_blocks = Vec::with_capacity(8);
for i in 0..8 {
video_blocks.push(ConnectorBlock::load(
tensors,
&format!("{pfx}.video_embeddings_connector.transformer_1d_blocks.{i}"),
32,
128,
)?);
}
let video_registers = tensors
.get(&format!(
"{pfx}.video_embeddings_connector.learnable_registers"
))
.cloned();
let audio_aggregate_embed = build_dense_linear(
tensors,
&format!("{pfx}.text_embedding_projection.audio_aggregate_embed"),
)?;
let mut audio_blocks = Vec::with_capacity(8);
for i in 0..8 {
audio_blocks.push(ConnectorBlock::load(
tensors,
&format!("{pfx}.audio_embeddings_connector.transformer_1d_blocks.{i}"),
32,
64,
)?);
}
let audio_registers = tensors
.get(&format!(
"{pfx}.audio_embeddings_connector.learnable_registers"
))
.cloned();
Ok(Self {
video_aggregate_embed,
video_blocks,
video_registers,
audio_aggregate_embed,
audio_blocks,
audio_registers,
})
}
fn forward(
&mut self,
text_embeddings: &Array,
n_valid: usize,
) -> Result<(Array, Array), mlx_rs::error::Exception> {
let batch = text_embeddings.shape()[0];
let use_registers_only = is_all_zero(text_embeddings);
let embedding_dim = 3840_f32;
let video_scale = Array::from_f32((4096.0_f32 / embedding_dim).sqrt());
let audio_scale = Array::from_f32((2048.0_f32 / embedding_dim).sqrt());
let mut run_pass_through = |aggregate: &mut nn::Linear,
blocks: &mut [ConnectorBlock],
regs: Option<&Array>,
scale: &Array,
dump_tag: &str|
-> Result<Array, mlx_rs::error::Exception> {
let projected = if use_registers_only {
match regs {
Some(regs) => {
let rs = regs.shape();
let num_registers = rs[0];
let dim = rs[1];
let r3 = ops::reshape(regs, &[1, num_registers, dim])?;
let seq_len = text_embeddings.shape()[1];
let tiles = (seq_len + num_registers - 1) / num_registers;
let tiled = ops::tile(&r3, &[batch, tiles, 1])?;
tiled.index((.., 0..seq_len, ..))
}
None => {
let rescaled = ops::multiply(text_embeddings, scale)?;
aggregate.forward(&rescaled)?
}
}
} else {
let rescaled = ops::multiply(text_embeddings, scale)?;
let proj = aggregate.forward(&rescaled)?;
dump_ltx_stage_first_call(&format!("connector_proj_{dump_tag}"), &proj);
match regs {
Some(regs) => {
let proj_shape = proj.shape();
let seq_len = proj_shape[1];
let dim = proj_shape[2];
let rs = regs.shape();
let num_registers = rs[0];
let r3 = ops::reshape(regs, &[1, num_registers, dim])?;
let tiles = (seq_len + num_registers - 1) / num_registers;
let tiled = ops::tile(&r3, &[batch, tiles, 1])?;
let tiled = tiled.index((.., 0..seq_len, ..));
let n_valid_i32 = n_valid.min(seq_len as usize) as i32;
let pad_count = seq_len - n_valid_i32;
let adjusted = if n_valid_i32 == 0 {
tiled.clone()
} else {
let valid = proj.index((.., pad_count..seq_len, ..));
if n_valid_i32 == seq_len {
valid
} else {
let zeros =
Array::zeros::<f32>(&[batch, seq_len - n_valid_i32, dim])?;
ops::concatenate_axis(&[&valid, &zeros], 1)?
}
};
let mask_vals: Vec<f32> = (0..seq_len)
.map(|i| if i < n_valid_i32 { 1.0 } else { 0.0 })
.collect();
let flipped = Array::from_slice(&mask_vals, &[1, seq_len, 1]);
let inv_flipped = ops::subtract(&Array::from_f32(1.0), &flipped)?;
let a_part = ops::multiply(&flipped, &adjusted)?;
let b_part = ops::multiply(&inv_flipped, &tiled)?;
ops::add(&a_part, &b_part)?
}
None => proj,
}
};
let mut h = projected;
for block in blocks.iter_mut() {
h = block.forward(&h)?;
}
let out = rms_norm_parameterless(&h, 1e-6)?;
dump_ltx_stage_first_call(&format!("connector_out_{dump_tag}"), &out);
Ok(out)
};
let video_h = run_pass_through(
&mut self.video_aggregate_embed,
&mut self.video_blocks,
self.video_registers.as_ref(),
&video_scale,
"video",
)?;
let audio_h = run_pass_through(
&mut self.audio_aggregate_embed,
&mut self.audio_blocks,
self.audio_registers.as_ref(),
&audio_scale,
"audio",
)?;
Ok((video_h, audio_h))
}
}
fn is_all_zero(x: &Array) -> bool {
let Ok(absx) = ops::abs(x) else {
return false;
};
let rank = absx.shape().len();
let mut reduced = absx;
for _ in 0..rank {
let Ok(r) = reduced.mean_axes(&[0], false) else {
return false;
};
reduced = r;
}
if mlx_rs::transforms::eval([&reduced]).is_err() {
return false;
}
let slice: &[f32] = reduced.as_slice();
slice.first().map(|v| *v == 0.0).unwrap_or(false)
}
fn conv3d_causal(
x: &Array,
weight: &Array,
bias: Option<&Array>,
) -> Result<Array, mlx_rs::error::Exception> {
let x_shape = x.shape();
let (batch, t_in, h_in, w_in, _c_in) =
(x_shape[0], x_shape[1], x_shape[2], x_shape[3], x_shape[4]);
let w_shape = weight.shape();
let (c_out, kt, kh, kw, c_in_w) = (w_shape[0], w_shape[1], w_shape[2], w_shape[3], w_shape[4]);
let pad_t = kt - 1;
let x_padded = if pad_t > 0 {
let first = x.index((.., 0..1, .., .., ..));
let first_rep = ops::repeat_axis::<f32>(first, pad_t, 1)?;
ops::concatenate_axis(&[&first_rep, x], 1)?
} else {
x.clone()
};
let pad_h = kh / 2;
let pad_w = kw / 2;
let mut accum: Option<Array> = None;
for tk in 0..kt {
let w_slice = weight.index((.., tk, .., .., ..));
let frames_slice = x_padded.index((.., tk..(tk + t_in), .., .., ..));
let bt = batch * t_in;
let flat = ops::reshape(&frames_slice, &[bt, h_in, w_in, c_in_w])?;
let conv_out = ops::conv2d(
&flat,
&w_slice,
(1, 1), (pad_h, pad_w), None::<(i32, i32)>,
None::<i32>,
)?;
accum = Some(match accum {
Some(a) => ops::add(&a, &conv_out)?,
None => conv_out,
});
}
let mut result = accum.unwrap();
if let Some(b) = bias {
let b_reshaped = ops::reshape(b, &[1, 1, 1, c_out])?;
result = ops::add(&result, &b_reshaped)?;
}
let h_out = result.shape()[1];
let w_out = result.shape()[2];
ops::reshape(&result, &[batch, t_in, h_out, w_out, c_out])
}
fn pixel_norm(x: &Array) -> Result<Array, mlx_rs::error::Exception> {
rms_norm_parameterless(x, 1e-8)
}
fn pixel_shuffle_3d(
x: &Array,
factor_t: i32,
factor_h: i32,
factor_w: i32,
) -> Result<Array, mlx_rs::error::Exception> {
let s = x.shape();
let (b, t, h, w, c_total) = (s[0], s[1], s[2], s[3], s[4]);
let factor = factor_t * factor_h * factor_w;
let c_out = c_total / factor;
let x = ops::reshape(x, &[b, t, h, w, c_out, factor_t, factor_h, factor_w])?;
let x = ops::transpose_axes(&x, &[0, 1, 5, 2, 6, 3, 7, 4])?;
ops::reshape(&x, &[b, t * factor_t, h * factor_h, w * factor_w, c_out])
}
fn pixel_unshuffle_3d(
x: &Array,
factor_t: i32,
factor_h: i32,
factor_w: i32,
) -> Result<Array, mlx_rs::error::Exception> {
let s = x.shape();
let (b, t_big, h_big, w_big, c) = (s[0], s[1], s[2], s[3], s[4]);
let t = t_big / factor_t;
let h = h_big / factor_h;
let w = w_big / factor_w;
let x = ops::reshape(x, &[b, t, factor_t, h, factor_h, w, factor_w, c])?;
let x = ops::transpose_axes(&x, &[0, 1, 3, 5, 7, 2, 4, 6])?;
ops::reshape(&x, &[b, t, h, w, c * factor_t * factor_h * factor_w])
}
struct VaeResNet3dBlock {
conv1_weight: Array,
conv1_bias: Option<Array>,
conv2_weight: Array,
conv2_bias: Option<Array>,
}
impl VaeResNet3dBlock {
fn load(tensors: &HashMap<String, Array>, prefix: &str) -> Result<Self, InferenceError> {
Ok(Self {
conv1_weight: get_tensor(tensors, &format!("{prefix}.conv1.conv.weight"))?,
conv1_bias: tensors.get(&format!("{prefix}.conv1.conv.bias")).cloned(),
conv2_weight: get_tensor(tensors, &format!("{prefix}.conv2.conv.weight"))?,
conv2_bias: tensors.get(&format!("{prefix}.conv2.conv.bias")).cloned(),
})
}
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = pixel_norm(x)?;
let h = nn::silu(&h)?;
let h = conv3d_causal(&h, &self.conv1_weight, self.conv1_bias.as_ref())?;
let h = pixel_norm(&h)?;
let h = nn::silu(&h)?;
let h = conv3d_causal(&h, &self.conv2_weight, self.conv2_bias.as_ref())?;
ops::add(x, &h)
}
}
struct VaeUpsampleBlock {
conv_weight: Array,
conv_bias: Option<Array>,
shuffle_factors: (i32, i32, i32),
}
impl VaeUpsampleBlock {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
shuffle_factors: (i32, i32, i32),
) -> Result<Self, InferenceError> {
Ok(Self {
conv_weight: get_tensor(tensors, &format!("{prefix}.conv.conv.weight"))?,
conv_bias: tensors.get(&format!("{prefix}.conv.conv.bias")).cloned(),
shuffle_factors,
})
}
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = conv3d_causal(x, &self.conv_weight, self.conv_bias.as_ref())?;
let (ft, fh, fw) = self.shuffle_factors;
let shuffled = pixel_shuffle_3d(&h, ft, fh, fw)?;
if ft > 1 {
let s = shuffled.shape();
shuffled
.index((.., 1..s[1], .., .., ..))
.as_dtype(shuffled.dtype())
} else {
Ok(shuffled)
}
}
}
struct VaeDecoder3D {
conv_in_weight: Array,
conv_in_bias: Option<Array>,
conv_out_weight: Array,
conv_out_bias: Option<Array>,
per_channel_mean: Option<Array>,
per_channel_std: Option<Array>,
res_blocks_0: Vec<VaeResNet3dBlock>,
upsample_1: VaeUpsampleBlock,
res_blocks_2: Vec<VaeResNet3dBlock>,
upsample_3: VaeUpsampleBlock,
res_blocks_4: Vec<VaeResNet3dBlock>,
upsample_5: VaeUpsampleBlock,
res_blocks_6: Vec<VaeResNet3dBlock>,
upsample_7: VaeUpsampleBlock,
res_blocks_8: Vec<VaeResNet3dBlock>,
}
impl VaeDecoder3D {
fn load(tensors: &HashMap<String, Array>) -> Result<Self, InferenceError> {
let pfx = "vae_decoder";
let conv_in_weight = get_tensor(tensors, &format!("{pfx}.conv_in.conv.weight"))?;
let conv_in_bias = tensors.get(&format!("{pfx}.conv_in.conv.bias")).cloned();
let conv_out_weight = get_tensor(tensors, &format!("{pfx}.conv_out.conv.weight"))?;
let conv_out_bias = tensors.get(&format!("{pfx}.conv_out.conv.bias")).cloned();
let per_channel_mean = tensors
.get(&format!("{pfx}.per_channel_statistics.mean"))
.cloned();
let per_channel_std = tensors
.get(&format!("{pfx}.per_channel_statistics.std"))
.cloned();
let load_res_blocks =
|block_idx: usize, count: usize| -> Result<Vec<VaeResNet3dBlock>, InferenceError> {
let mut blocks = Vec::with_capacity(count);
for i in 0..count {
blocks.push(VaeResNet3dBlock::load(
tensors,
&format!("{pfx}.up_blocks.{block_idx}.res_blocks.{i}"),
)?);
}
Ok(blocks)
};
let res_blocks_0 = load_res_blocks(0, 2)?;
let upsample_1 = VaeUpsampleBlock::load(tensors, &format!("{pfx}.up_blocks.1"), (2, 2, 2))?;
let res_blocks_2 = load_res_blocks(2, 2)?;
let upsample_3 = VaeUpsampleBlock::load(tensors, &format!("{pfx}.up_blocks.3"), (2, 2, 2))?;
let res_blocks_4 = load_res_blocks(4, 4)?;
let upsample_5 = VaeUpsampleBlock::load(tensors, &format!("{pfx}.up_blocks.5"), (2, 1, 1))?;
let res_blocks_6 = load_res_blocks(6, 6)?;
let upsample_7 = VaeUpsampleBlock::load(tensors, &format!("{pfx}.up_blocks.7"), (1, 2, 2))?;
let res_blocks_8 = load_res_blocks(8, 4)?;
Ok(Self {
conv_in_weight,
conv_in_bias,
conv_out_weight,
conv_out_bias,
per_channel_mean,
per_channel_std,
res_blocks_0,
upsample_1,
res_blocks_2,
upsample_3,
res_blocks_4,
upsample_5,
res_blocks_6,
upsample_7,
res_blocks_8,
})
}
fn decode(&self, latents: &Array) -> Result<Array, mlx_rs::error::Exception> {
let shape = latents.shape();
let batch = shape[0];
let x = ops::transpose_axes(latents, &[0, 2, 3, 4, 1])?;
let x = if let (Some(ref mean), Some(ref std_)) =
(&self.per_channel_mean, &self.per_channel_std)
{
let mean = ops::reshape(mean, &[1, 1, 1, 1, mean.shape()[0]])?;
let std_ = ops::reshape(std_, &[1, 1, 1, 1, std_.shape()[0]])?;
ops::add(&ops::multiply(&x, &std_)?, &mean)?
} else {
x
};
let mut h = conv3d_causal(&x, &self.conv_in_weight, self.conv_in_bias.as_ref())?;
for block in &self.res_blocks_0 {
h = block.forward(&h)?;
}
h = self.upsample_1.forward(&h)?;
for block in &self.res_blocks_2 {
h = block.forward(&h)?;
}
h = self.upsample_3.forward(&h)?;
for block in &self.res_blocks_4 {
h = block.forward(&h)?;
}
h = self.upsample_5.forward(&h)?;
for block in &self.res_blocks_6 {
h = block.forward(&h)?;
}
h = self.upsample_7.forward(&h)?;
for block in &self.res_blocks_8 {
h = block.forward(&h)?;
}
let h = pixel_norm(&h)?;
let h = nn::silu(&h)?;
let h = conv3d_causal(&h, &self.conv_out_weight, self.conv_out_bias.as_ref())?;
let h_shape = h.shape();
let (t_up, h_up, w_up) = (h_shape[1], h_shape[2], h_shape[3]);
let ps: i32 = 4;
let out = ops::reshape(&h, &[batch, t_up, h_up, w_up, 3, ps, ps])?;
let out = ops::transpose_axes(&out, &[0, 1, 2, 6, 3, 5, 4])?;
let out = ops::reshape(&out, &[batch, t_up, h_up * ps, w_up * ps, 3])?;
let half = Array::from_f32(0.5);
let one_a = Array::from_f32(1.0);
let scaled = ops::multiply(&ops::add(&out, &one_a)?, &half)?;
let zero = Array::from_f32(0.0);
let out = ops::clip(&scaled, (&zero, &one_a))?;
Ok(out)
}
}
struct VaeDownsampleBlock {
conv_weight: Array,
conv_bias: Option<Array>,
unshuffle_factors: (i32, i32, i32),
}
impl VaeDownsampleBlock {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
unshuffle_factors: (i32, i32, i32),
) -> Result<Self, InferenceError> {
Ok(Self {
conv_weight: get_tensor(tensors, &format!("{prefix}.conv.conv.weight"))?,
conv_bias: tensors.get(&format!("{prefix}.conv.conv.bias")).cloned(),
unshuffle_factors,
})
}
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = conv3d_causal(x, &self.conv_weight, self.conv_bias.as_ref())?;
let (ft, fh, fw) = self.unshuffle_factors;
pixel_unshuffle_3d(&h, ft, fh, fw)
}
}
struct VaeEncoder3D {
conv_in_weight: Array,
conv_in_bias: Option<Array>,
conv_out_weight: Array,
conv_out_bias: Option<Array>,
mean_of_means: Option<Array>,
std_of_means: Option<Array>,
res_blocks_0: Vec<VaeResNet3dBlock>,
downsample_1: VaeDownsampleBlock,
res_blocks_2: Vec<VaeResNet3dBlock>,
downsample_3: VaeDownsampleBlock,
res_blocks_4: Vec<VaeResNet3dBlock>,
downsample_5: VaeDownsampleBlock,
res_blocks_6: Vec<VaeResNet3dBlock>,
downsample_7: VaeDownsampleBlock,
res_blocks_8: Vec<VaeResNet3dBlock>,
}
impl VaeEncoder3D {
fn load(tensors: &HashMap<String, Array>) -> Result<Self, InferenceError> {
let pfx = "vae_encoder";
let conv_in_weight = get_tensor(tensors, &format!("{pfx}.conv_in.conv.weight"))?;
let conv_in_bias = tensors.get(&format!("{pfx}.conv_in.conv.bias")).cloned();
let conv_out_weight = get_tensor(tensors, &format!("{pfx}.conv_out.conv.weight"))?;
let conv_out_bias = tensors.get(&format!("{pfx}.conv_out.conv.bias")).cloned();
let mean_of_means = tensors
.get(&format!("{pfx}.per_channel_statistics._mean_of_means"))
.cloned();
let std_of_means = tensors
.get(&format!("{pfx}.per_channel_statistics._std_of_means"))
.cloned();
let load_res_blocks =
|block_idx: usize, count: usize| -> Result<Vec<VaeResNet3dBlock>, InferenceError> {
let mut blocks = Vec::with_capacity(count);
for i in 0..count {
blocks.push(VaeResNet3dBlock::load(
tensors,
&format!("{pfx}.down_blocks.{block_idx}.res_blocks.{i}"),
)?);
}
Ok(blocks)
};
let res_blocks_0 = load_res_blocks(0, 4)?;
let downsample_1 =
VaeDownsampleBlock::load(tensors, &format!("{pfx}.down_blocks.1"), (1, 2, 2))?;
let res_blocks_2 = load_res_blocks(2, 6)?;
let downsample_3 =
VaeDownsampleBlock::load(tensors, &format!("{pfx}.down_blocks.3"), (2, 1, 1))?;
let res_blocks_4 = load_res_blocks(4, 4)?;
let downsample_5 =
VaeDownsampleBlock::load(tensors, &format!("{pfx}.down_blocks.5"), (2, 2, 2))?;
let res_blocks_6 = load_res_blocks(6, 2)?;
let downsample_7 =
VaeDownsampleBlock::load(tensors, &format!("{pfx}.down_blocks.7"), (2, 2, 2))?;
let res_blocks_8 = load_res_blocks(8, 2)?;
Ok(Self {
conv_in_weight,
conv_in_bias,
conv_out_weight,
conv_out_bias,
mean_of_means,
std_of_means,
res_blocks_0,
downsample_1,
res_blocks_2,
downsample_3,
res_blocks_4,
downsample_5,
res_blocks_6,
downsample_7,
res_blocks_8,
})
}
fn encode(&self, rgb: &Array) -> Result<Array, mlx_rs::error::Exception> {
let packed = pixel_unshuffle_3d(rgb, 1, 4, 4)?;
let mut h = conv3d_causal(&packed, &self.conv_in_weight, self.conv_in_bias.as_ref())?;
for block in &self.res_blocks_0 {
h = block.forward(&h)?;
}
h = self.downsample_1.forward(&h)?;
for block in &self.res_blocks_2 {
h = block.forward(&h)?;
}
h = self.downsample_3.forward(&h)?;
for block in &self.res_blocks_4 {
h = block.forward(&h)?;
}
h = self.downsample_5.forward(&h)?;
for block in &self.res_blocks_6 {
h = block.forward(&h)?;
}
h = self.downsample_7.forward(&h)?;
for block in &self.res_blocks_8 {
h = block.forward(&h)?;
}
let h = pixel_norm(&h)?;
let h = nn::silu(&h)?;
let h = conv3d_causal(&h, &self.conv_out_weight, self.conv_out_bias.as_ref())?;
let latent = h.index((.., .., .., .., ..128));
let latent = if let (Some(m), Some(s)) = (&self.mean_of_means, &self.std_of_means) {
let mean = ops::reshape(m, &[1, 1, 1, 1, -1])?;
let std = ops::reshape(s, &[1, 1, 1, 1, -1])?;
ops::divide(&ops::subtract(&latent, &mean)?, &std)?
} else {
latent
};
ops::transpose_axes(&latent, &[0, 4, 1, 2, 3])
}
}
fn audio_conv2d_forward(
input: &Array,
weight: &Array,
bias: Option<&Array>,
stride: (i32, i32),
padding: (i32, i32),
) -> Result<Array, mlx_rs::error::Exception> {
let mut y = ops::conv2d(
input,
weight,
stride,
padding,
None::<(i32, i32)>,
None::<i32>,
)?;
if let Some(b) = bias {
y = ops::add(&y, b)?;
}
Ok(y)
}
fn audio_upsample_2x(x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape();
let (b, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
let expanded_h = ops::reshape(x, &[b, h, 1, w, c])?;
let tiled_h = ops::concatenate_axis(&[&expanded_h, &expanded_h], 2)?;
let merged_h = ops::reshape(&tiled_h, &[b, h * 2, w, c])?;
let expanded_w = ops::reshape(&merged_h, &[b, h * 2, w, 1, c])?;
let tiled_w = ops::concatenate_axis(&[&expanded_w, &expanded_w], 3)?;
ops::reshape(&tiled_w, &[b, h * 2, w * 2, c])
}
struct AudioResNetBlock {
conv1_weight: Array,
conv1_bias: Option<Array>,
conv2_weight: Array,
conv2_bias: Option<Array>,
nin_shortcut_weight: Option<Array>,
nin_shortcut_bias: Option<Array>,
}
impl AudioResNetBlock {
fn load(tensors: &HashMap<String, Array>, prefix: &str) -> Result<Self, InferenceError> {
let conv1_weight = get_tensor(tensors, &format!("{prefix}.conv1.conv.weight"))?;
let conv1_bias = tensors.get(&format!("{prefix}.conv1.conv.bias")).cloned();
let conv2_weight = get_tensor(tensors, &format!("{prefix}.conv2.conv.weight"))?;
let conv2_bias = tensors.get(&format!("{prefix}.conv2.conv.bias")).cloned();
let nin_shortcut_weight = tensors
.get(&format!("{prefix}.nin_shortcut.conv.weight"))
.cloned();
let nin_shortcut_bias = tensors
.get(&format!("{prefix}.nin_shortcut.conv.bias"))
.cloned();
Ok(Self {
conv1_weight,
conv1_bias,
conv2_weight,
conv2_bias,
nin_shortcut_weight,
nin_shortcut_bias,
})
}
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let h = nn::silu(x)?;
let h = audio_conv2d_forward(
&h,
&self.conv1_weight,
self.conv1_bias.as_ref(),
(1, 1),
(1, 1),
)?;
let h = nn::silu(&h)?;
let h = audio_conv2d_forward(
&h,
&self.conv2_weight,
self.conv2_bias.as_ref(),
(1, 1),
(1, 1),
)?;
let skip = if let Some(ref sw) = self.nin_shortcut_weight {
audio_conv2d_forward(x, sw, self.nin_shortcut_bias.as_ref(), (1, 1), (0, 0))?
} else {
x.clone()
};
ops::add(&skip, &h)
}
}
struct AudioUpBlock {
blocks: Vec<AudioResNetBlock>,
upsample_conv_weight: Option<Array>,
upsample_conv_bias: Option<Array>,
}
impl AudioUpBlock {
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let mut h = x.clone();
for block in &self.blocks {
h = block.forward(&h)?;
}
if let Some(ref w) = self.upsample_conv_weight {
h = audio_upsample_2x(&h)?;
h = audio_conv2d_forward(&h, w, self.upsample_conv_bias.as_ref(), (1, 1), (1, 1))?;
}
Ok(h)
}
}
struct AudioVae {
conv_in_weight: Array,
conv_in_bias: Option<Array>,
mid_block_1: AudioResNetBlock,
mid_block_2: AudioResNetBlock,
up_blocks: Vec<AudioUpBlock>,
conv_out_weight: Array,
conv_out_bias: Option<Array>,
denorm_mean: Array,
denorm_std: Array,
}
impl AudioVae {
fn load(tensors: &HashMap<String, Array>) -> Result<Self, InferenceError> {
let pfx = "audio_vae.decoder";
let conv_in_weight = get_tensor(tensors, &format!("{pfx}.conv_in.conv.weight"))?;
let conv_in_bias = tensors.get(&format!("{pfx}.conv_in.conv.bias")).cloned();
let conv_out_weight = get_tensor(tensors, &format!("{pfx}.conv_out.conv.weight"))?;
let conv_out_bias = tensors.get(&format!("{pfx}.conv_out.conv.bias")).cloned();
let mid_block_1 = AudioResNetBlock::load(tensors, &format!("{pfx}.mid.block_1"))?;
let mid_block_2 = AudioResNetBlock::load(tensors, &format!("{pfx}.mid.block_2"))?;
let mut num_up_blocks = 0usize;
for key in tensors.keys() {
if let Some(rest) = key.strip_prefix(&format!("{pfx}.up.")) {
if let Some(idx_str) = rest.split('.').next() {
if let Ok(idx) = idx_str.parse::<usize>() {
num_up_blocks = num_up_blocks.max(idx + 1);
}
}
}
}
let mut up_blocks = Vec::with_capacity(num_up_blocks);
for i in 0..num_up_blocks {
let bpfx = format!("{pfx}.up.{i}");
let mut num_blocks = 0usize;
for key in tensors.keys() {
if let Some(rest) = key.strip_prefix(&format!("{bpfx}.block.")) {
if let Some(idx_str) = rest.split('.').next() {
if let Ok(idx) = idx_str.parse::<usize>() {
num_blocks = num_blocks.max(idx + 1);
}
}
}
}
let mut blocks = Vec::with_capacity(num_blocks);
for b in 0..num_blocks {
blocks.push(AudioResNetBlock::load(
tensors,
&format!("{bpfx}.block.{b}"),
)?);
}
let upsample_conv_weight = tensors
.get(&format!("{bpfx}.upsample.conv.conv.weight"))
.cloned();
let upsample_conv_bias = tensors
.get(&format!("{bpfx}.upsample.conv.conv.bias"))
.cloned();
up_blocks.push(AudioUpBlock {
blocks,
upsample_conv_weight,
upsample_conv_bias,
});
}
let denorm_mean = get_tensor(tensors, "audio_vae.per_channel_statistics._mean_of_means")?;
let denorm_std = get_tensor(tensors, "audio_vae.per_channel_statistics._std_of_means")?;
Ok(Self {
conv_in_weight,
conv_in_bias,
mid_block_1,
mid_block_2,
up_blocks,
conv_out_weight,
conv_out_bias,
denorm_mean,
denorm_std,
})
}
fn decode(&self, latents: &Array) -> Result<Array, mlx_rs::error::Exception> {
let x = ops::transpose_axes(latents, &[0, 2, 3, 1])?;
let mut h = audio_conv2d_forward(
&x,
&self.conv_in_weight,
self.conv_in_bias.as_ref(),
(1, 1),
(1, 1),
)?;
h = self.mid_block_1.forward(&h)?;
h = self.mid_block_2.forward(&h)?;
for block in self.up_blocks.iter().rev() {
h = block.forward(&h)?;
}
h = nn::silu(&h)?;
h = audio_conv2d_forward(
&h,
&self.conv_out_weight,
self.conv_out_bias.as_ref(),
(1, 1),
(1, 1),
)?;
ops::transpose_axes(&h, &[0, 3, 1, 2])
}
}
fn vocoder_conv1d(
input: &Array,
weight: &Array,
bias: Option<&Array>,
stride: i32,
padding: i32,
dilation: i32,
) -> Result<Array, mlx_rs::error::Exception> {
let mut y = ops::conv1d(input, weight, stride, padding, dilation, None::<i32>)?;
if let Some(b) = bias {
y = ops::add(&y, b)?;
}
Ok(y)
}
fn vocoder_conv_transpose1d(
input: &Array,
weight: &Array,
bias: Option<&Array>,
stride: i32,
padding: i32,
) -> Result<Array, mlx_rs::error::Exception> {
let mut y = ops::conv_transpose1d(
input,
weight,
stride,
padding,
None::<i32>,
None::<i32>,
None::<i32>,
)?;
if let Some(b) = bias {
y = ops::add(&y, b)?;
}
Ok(y)
}
fn snake_activation(
x: &Array,
alpha: &Array,
beta: Option<&Array>,
) -> Result<Array, mlx_rs::error::Exception> {
let ax = ops::multiply(x, alpha)?;
let sin_ax = ops::sin(&ax)?;
let sin2 = ops::multiply(&sin_ax, &sin_ax)?;
let scaled = if let Some(b) = beta {
ops::multiply(&sin2, b)?
} else {
ops::divide(&sin2, alpha)?
};
ops::add(x, &scaled)
}
struct VocoderResBlock {
layers: Vec<VocoderResLayer>,
}
struct VocoderResLayer {
alpha1: Array,
beta1: Option<Array>,
conv1_weight: Array,
conv1_bias: Option<Array>,
conv1_dilation: i32,
alpha2: Array,
beta2: Option<Array>,
conv2_weight: Array,
conv2_bias: Option<Array>,
}
impl VocoderResBlock {
fn forward(&self, input: &Array) -> Result<Array, mlx_rs::error::Exception> {
let mut x = input.clone();
for layer in &self.layers {
let residual = x.clone();
x = snake_activation(&x, &layer.alpha1, layer.beta1.as_ref())?;
let k = layer.conv1_weight.shape()[1];
let pad = layer.conv1_dilation * (k - 1) / 2;
x = vocoder_conv1d(
&x,
&layer.conv1_weight,
layer.conv1_bias.as_ref(),
1,
pad,
layer.conv1_dilation,
)?;
x = snake_activation(&x, &layer.alpha2, layer.beta2.as_ref())?;
let k2 = layer.conv2_weight.shape()[1];
let pad2 = (k2 - 1) / 2;
x = vocoder_conv1d(
&x,
&layer.conv2_weight,
layer.conv2_bias.as_ref(),
1,
pad2,
1,
)?;
x = ops::add(&x, &residual)?;
}
Ok(x)
}
}
struct VocoderGenerator {
conv_pre_weight: Array,
conv_pre_bias: Option<Array>,
ups: Vec<(Array, Option<Array>, i32)>,
resblock_groups: Vec<Vec<VocoderResBlock>>,
act_post_alpha: Array,
act_post_beta: Option<Array>,
conv_post_weight: Array,
conv_post_bias: Option<Array>,
}
impl VocoderGenerator {
fn forward(&self, input: &Array) -> Result<Array, mlx_rs::error::Exception> {
let k_pre = self.conv_pre_weight.shape()[1];
let pad_pre = (k_pre - 1) / 2;
let mut x = vocoder_conv1d(
input,
&self.conv_pre_weight,
self.conv_pre_bias.as_ref(),
1,
pad_pre,
1,
)?;
for (i, (up_w, up_b, stride)) in self.ups.iter().enumerate() {
let k_up = up_w.shape()[1];
let pad_up = (k_up - *stride) / 2;
x = vocoder_conv_transpose1d(&x, up_w, up_b.as_ref(), *stride, pad_up)?;
if let Some(resblocks) = self.resblock_groups.get(i) {
if !resblocks.is_empty() {
let mut sum = resblocks[0].forward(&x)?;
for rb in &resblocks[1..] {
let out = rb.forward(&x)?;
sum = ops::add(&sum, &out)?;
}
let n = Array::from_f32(resblocks.len() as f32);
x = ops::divide(&sum, &n)?;
}
}
}
x = snake_activation(&x, &self.act_post_alpha, self.act_post_beta.as_ref())?;
let k_post = self.conv_post_weight.shape()[1];
let pad_post = (k_post - 1) / 2;
x = vocoder_conv1d(
&x,
&self.conv_post_weight,
self.conv_post_bias.as_ref(),
1,
pad_post,
1,
)?;
ops::tanh(&x)
}
}
struct Vocoder {
generator: VocoderGenerator,
}
impl Vocoder {
fn load(tensors: &HashMap<String, Array>) -> Result<Self, InferenceError> {
let prefix = "vocoder";
let gen = Self::load_generator(tensors, prefix)?;
Ok(Self { generator: gen })
}
fn load_generator(
tensors: &HashMap<String, Array>,
prefix: &str,
) -> Result<VocoderGenerator, InferenceError> {
let conv_pre_weight = get_tensor(tensors, &format!("{prefix}.conv_pre.weight"))?;
let conv_pre_bias = tensors.get(&format!("{prefix}.conv_pre.bias")).cloned();
let mut num_ups = 0usize;
loop {
let key = format!("{prefix}.ups.{num_ups}.weight");
if tensors.contains_key(&key) {
num_ups += 1;
} else {
break;
}
}
info!(num_ups, prefix, "vocoder: discovered upsample stages");
let mut ups = Vec::with_capacity(num_ups);
for i in 0..num_ups {
let w = get_tensor(tensors, &format!("{prefix}.ups.{i}.weight"))?;
let b = tensors.get(&format!("{prefix}.ups.{i}.bias")).cloned();
let k = w.shape()[1];
let stride = k / 2;
if stride < 1 {
return Err(InferenceError::InferenceFailed(format!(
"vocoder ups.{i} has invalid kernel size {k}"
)));
}
ups.push((w, b, stride));
}
let mut num_resblocks = 0usize;
loop {
let key = format!("{prefix}.resblocks.{num_resblocks}.convs1.0.weight");
if tensors.contains_key(&key) {
num_resblocks += 1;
} else {
break;
}
}
info!(num_resblocks, prefix, "vocoder: discovered resblocks");
let blocks_per_stage = if num_ups > 0 && num_resblocks > 0 {
num_resblocks / num_ups
} else {
num_resblocks
};
let mut resblock_groups: Vec<Vec<VocoderResBlock>> = Vec::new();
let mut rb_idx = 0usize;
for _stage in 0..num_ups {
let mut stage_blocks = Vec::new();
for _ in 0..blocks_per_stage {
if rb_idx >= num_resblocks {
break;
}
let block = Self::load_resblock(tensors, prefix, rb_idx)?;
stage_blocks.push(block);
rb_idx += 1;
}
resblock_groups.push(stage_blocks);
}
if rb_idx < num_resblocks && !resblock_groups.is_empty() {
let last = resblock_groups.last_mut().unwrap();
while rb_idx < num_resblocks {
let block = Self::load_resblock(tensors, prefix, rb_idx)?;
last.push(block);
rb_idx += 1;
}
}
let act_post_alpha = get_tensor(tensors, &format!("{prefix}.act_post.act.alpha"))?;
let act_post_beta = tensors.get(&format!("{prefix}.act_post.act.beta")).cloned();
let conv_post_weight = get_tensor(tensors, &format!("{prefix}.conv_post.weight"))?;
let conv_post_bias = tensors.get(&format!("{prefix}.conv_post.bias")).cloned();
Ok(VocoderGenerator {
conv_pre_weight,
conv_pre_bias,
ups,
resblock_groups,
act_post_alpha,
act_post_beta,
conv_post_weight,
conv_post_bias,
})
}
fn load_resblock(
tensors: &HashMap<String, Array>,
prefix: &str,
idx: usize,
) -> Result<VocoderResBlock, InferenceError> {
let mut num_layers = 0usize;
loop {
let key = format!("{prefix}.resblocks.{idx}.convs1.{num_layers}.weight");
if tensors.contains_key(&key) {
num_layers += 1;
} else {
break;
}
}
let dilations = [1, 3, 5, 7, 11, 13];
let mut layers = Vec::with_capacity(num_layers);
for l in 0..num_layers {
let alpha1 = get_tensor(
tensors,
&format!("{prefix}.resblocks.{idx}.acts1.{l}.act.alpha"),
)
.or_else(|_| {
Array::ones::<f32>(&[1]).map_err(|e| InferenceError::InferenceFailed(e.to_string()))
})?;
let beta1 = tensors
.get(&format!("{prefix}.resblocks.{idx}.acts1.{l}.act.beta"))
.cloned();
let conv1_weight = get_tensor(
tensors,
&format!("{prefix}.resblocks.{idx}.convs1.{l}.weight"),
)?;
let conv1_bias = tensors
.get(&format!("{prefix}.resblocks.{idx}.convs1.{l}.bias"))
.cloned();
let alpha2 = get_tensor(
tensors,
&format!("{prefix}.resblocks.{idx}.acts2.{l}.act.alpha"),
)
.or_else(|_| {
Array::ones::<f32>(&[1]).map_err(|e| InferenceError::InferenceFailed(e.to_string()))
})?;
let beta2 = tensors
.get(&format!("{prefix}.resblocks.{idx}.acts2.{l}.act.beta"))
.cloned();
let conv2_weight = get_tensor(
tensors,
&format!("{prefix}.resblocks.{idx}.convs2.{l}.weight"),
)?;
let conv2_bias = tensors
.get(&format!("{prefix}.resblocks.{idx}.convs2.{l}.bias"))
.cloned();
let dilation = if l < dilations.len() { dilations[l] } else { 1 };
layers.push(VocoderResLayer {
alpha1,
beta1,
conv1_weight,
conv1_bias,
conv1_dilation: dilation,
alpha2,
beta2,
conv2_weight,
conv2_bias,
});
}
Ok(VocoderResBlock { layers })
}
fn forward(&self, audio: &Array) -> Result<Array, mlx_rs::error::Exception> {
let shape = audio.shape();
if shape.len() < 2 {
return Array::zeros::<f32>(&[1, 2, 1]);
}
let x = if shape.len() == 3 {
ops::transpose_axes(audio, &[0, 2, 1])?
} else {
let expanded = ops::expand_dims(audio, 0)?;
ops::transpose_axes(&expanded, &[0, 2, 1])?
};
let wav = self.generator.forward(&x)?;
ops::transpose_axes(&wav, &[0, 2, 1])
}
}
pub struct RectifiedFlowScheduler {
pub num_inference_steps: usize,
pub timesteps: Vec<f32>,
}
const DISTILLED_SIGMAS_8: [f32; 9] = [
1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0,
];
const BASE_SHIFT_ANCHOR: f32 = 1024.0;
const MAX_SHIFT_ANCHOR: f32 = 4096.0;
pub fn ltx2_schedule(steps: usize, num_tokens: usize) -> Vec<f32> {
let max_shift = 2.05f32;
let base_shift = 0.95f32;
let terminal = 0.1f32;
let mut sigmas: Vec<f32> = (0..=steps)
.map(|i| 1.0 - (i as f32 / steps as f32))
.collect();
let mm = (max_shift - base_shift) / (MAX_SHIFT_ANCHOR - BASE_SHIFT_ANCHOR);
let b = base_shift - mm * BASE_SHIFT_ANCHOR;
let sigma_shift = num_tokens as f32 * mm + b;
let exp_s = sigma_shift.exp();
for s in sigmas.iter_mut() {
if *s != 0.0 {
*s = exp_s / (exp_s + (1.0 / *s - 1.0));
}
}
let last_nonzero_idx = sigmas
.iter()
.rposition(|&s| s != 0.0)
.unwrap_or(sigmas.len() - 1);
let one_minus_last = 1.0 - sigmas[last_nonzero_idx];
let scale = one_minus_last / (1.0 - terminal);
if scale != 0.0 {
for s in sigmas.iter_mut() {
if *s != 0.0 {
*s = 1.0 - (1.0 - *s) / scale;
}
}
}
sigmas
}
impl RectifiedFlowScheduler {
pub fn new_with_tokens(num_inference_steps: usize, num_tokens: usize) -> Self {
let timesteps = if num_inference_steps == 8 {
DISTILLED_SIGMAS_8.to_vec()
} else {
ltx2_schedule(num_inference_steps, num_tokens)
};
Self {
num_inference_steps,
timesteps,
}
}
pub fn new(num_inference_steps: usize) -> Self {
let timesteps = if num_inference_steps == 8 {
DISTILLED_SIGMAS_8.to_vec()
} else {
ltx2_schedule(num_inference_steps, MAX_SHIFT_ANCHOR as usize)
};
Self {
num_inference_steps,
timesteps,
}
}
pub fn step(
&self,
velocity: &Array,
step_index: usize,
sample: &Array,
) -> Result<Array, mlx_rs::error::Exception> {
let sigma = self.timesteps[step_index];
let sigma_next = if step_index + 1 < self.timesteps.len() {
self.timesteps[step_index + 1]
} else {
0.0
};
let dt = Array::from_f32(sigma_next - sigma);
ops::add(sample, &ops::multiply(velocity, &dt)?)
}
pub fn init_noise(&self, shape: &[i32], seed: u64) -> Result<Array, mlx_rs::error::Exception> {
mlx_rs::random::seed(seed)?;
let noise = mlx_rs::random::normal::<f32>(shape, None, None, None)?;
noise.as_dtype(mlx_rs::Dtype::Bfloat16)
}
}
fn timestep_embedding(timestep: f32, dim: usize) -> Result<Array, mlx_rs::error::Exception> {
let half = dim / 2;
let ln_max_period = 10_000_f32.ln();
let mut emb = vec![0.0f32; dim];
for i in 0..half {
let freq = (-(i as f32) / half as f32 * ln_max_period).exp();
emb[i] = (timestep * freq).cos();
emb[i + half] = (timestep * freq).sin();
}
Ok(Array::from_slice(&emb, &[1, dim as i32]))
}
fn timestep_embedding_tensor(
timestep: &Array,
dim: usize,
) -> Result<Array, mlx_rs::error::Exception> {
let half = dim / 2;
let ln_max_period = 10_000_f32.ln();
let freqs: Vec<f32> = (0..half)
.map(|i| (-(i as f32) / half as f32 * ln_max_period).exp())
.collect();
let freqs = Array::from_slice(&freqs, &[half as i32]);
let t = ops::expand_dims(timestep, -1)?;
let tf = ops::multiply(&t, &freqs)?;
let s = ops::sin(&tf)?;
let c = ops::cos(&tf)?;
ops::concatenate_axis(&[&c, &s], -1)
}
pub struct LtxBackend {
transformer: LtxTransformer,
connector: TextEmbeddingConnector,
vae: VaeDecoder3D,
vae_encoder: Option<VaeEncoder3D>,
audio_vae: AudioVae,
vocoder: Vocoder,
config: LtxConfig,
t5_encoder: Option<super::mlx_flux::T5TextEncoder>,
t5_tokenizer: Option<tokenizers::Tokenizer>,
gemma_encoder: Option<super::mlx_gemma3::Gemma3TextEncoder>,
}
unsafe impl Send for LtxBackend {}
unsafe impl Sync for LtxBackend {}
impl LtxBackend {
pub fn load(model_dir: &Path) -> Result<Self, InferenceError> {
let config = LtxConfig::default();
info!(
hidden = config.hidden_dim,
layers = config.num_layers,
heads = config.num_heads,
audio_hidden = config.audio_hidden_dim,
audio_heads = config.audio_heads,
"loading FULL LTX-2.3 model via MLX (video + audio)"
);
#[cfg(feature = "mlx-metal")]
let default_device = mlx_rs::Device::gpu();
#[cfg(not(feature = "mlx-metal"))]
let default_device = mlx_rs::Device::cpu();
match std::env::var("CAR_MLX_DEVICE").ok().as_deref() {
Some("cpu") => mlx_rs::Device::set_default(&mlx_rs::Device::cpu()),
#[cfg(feature = "mlx-metal")]
Some("gpu") => mlx_rs::Device::set_default(&mlx_rs::Device::gpu()),
_ => mlx_rs::Device::set_default(&default_device),
}
info!("loading safetensors weights for LTX-2.3 (all modalities)");
let tensors = load_ltx_tensors(model_dir)?;
info!(tensors = tensors.len(), "LTX tensors loaded");
let transformer = LtxTransformer::load(&tensors, &config)?;
info!("LTX transformer loaded (48 blocks, video + audio pathways)");
let connector = TextEmbeddingConnector::load(&tensors)?;
info!("text embedding connector loaded (video + audio)");
let vae = VaeDecoder3D::load(&tensors)?;
info!("3D causal VAE decoder loaded");
let vae_encoder = if model_dir.join("vae_encoder.safetensors").exists() {
let enc = VaeEncoder3D::load(&tensors)?;
info!("3D causal VAE encoder loaded (i2v available)");
Some(enc)
} else {
None
};
let audio_vae = AudioVae::load(&tensors)?;
info!("audio VAE loaded");
let vocoder = Vocoder::load(&tensors)?;
info!("vocoder loaded");
let (t5_encoder, t5_tokenizer) = Self::try_load_t5(model_dir);
let gemma_encoder = match super::mlx_gemma3::Gemma3TextEncoder::try_load_default() {
Ok(Some(enc)) => {
info!("Gemma 3 12B text encoder loaded (text conditioning active)");
Some(enc)
}
Ok(None) => {
tracing::warn!(
"Gemma 3 12B cache not found — text prompt will be ignored. \
Run `huggingface-cli download mlx-community/gemma-3-12b-it-4bit` \
to enable text conditioning."
);
None
}
Err(e) => {
tracing::warn!(error = %e, "Gemma 3 12B load failed — running unconditional");
None
}
};
info!("FULL LTX-2.3 model loaded successfully (all generation modes)");
Ok(Self {
transformer,
connector,
vae,
vae_encoder,
audio_vae,
vocoder,
config,
t5_encoder,
t5_tokenizer,
gemma_encoder,
})
}
fn try_load_t5(
_model_dir: &Path,
) -> (
Option<super::mlx_flux::T5TextEncoder>,
Option<tokenizers::Tokenizer>,
) {
let flux_dir = latest_huggingface_repo_snapshot("mlx-community/Flux-1.lite-8B-MLX-Q4");
let Some(flux_dir) = flux_dir else {
info!("Flux model not found in HF cache — T5 encoder unavailable for LTX");
return (None, None);
};
info!(flux_dir = %flux_dir.display(), "loading T5-XXL from Flux model for LTX text conditioning");
let t5_tensors = match super::mlx::load_all_tensors(&flux_dir) {
Ok(t) => t,
Err(e) => {
tracing::warn!("failed to load Flux tensors for T5: {e}");
return (None, None);
}
};
let flux_config = super::mlx_flux::FluxConfig::default();
let t5 = match super::mlx_flux::T5TextEncoder::load(&t5_tensors, &flux_config) {
Ok(t5) => t5,
Err(e) => {
tracing::warn!("failed to load T5 encoder from Flux: {e}");
return (None, None);
}
};
info!("T5-XXL encoder loaded from Flux model");
let tok_path = flux_dir.join("tokenizer_2").join("tokenizer.json");
let tokenizer = if tok_path.exists() {
match tokenizers::Tokenizer::from_file(&tok_path) {
Ok(t) => Some(t),
Err(e) => {
tracing::warn!("failed to load T5 tokenizer: {e}");
None
}
}
} else {
info!("T5 tokenizer not found at {}", tok_path.display());
None
};
(Some(t5), tokenizer)
}
pub fn generate(
&mut self,
req: &GenerateVideoRequest,
) -> Result<GenerateVideoResult, InferenceError> {
req.validate().map_err(InferenceError::InferenceFailed)?;
let width = req.width.unwrap_or(768);
let height = req.height.unwrap_or(512);
let num_frames = req.num_frames.unwrap_or(41);
let steps = req.steps.unwrap_or(20) as usize;
let output_fps = req.fps.unwrap_or(24);
let guidance_scale = req.guidance.unwrap_or(3.0);
let audio_guidance_scale = req
.audio_guidance
.unwrap_or_else(|| guidance_scale.min(3.0));
let seed = req.seed.unwrap_or(42);
let _map_err = |e: mlx_rs::error::Exception| InferenceError::InferenceFailed(e.to_string());
info!(
prompt = %req.prompt,
width,
height,
num_frames,
steps,
guidance = guidance_scale,
seed,
"generating video with LTX-2.3 (full model)"
);
let output_path = req
.output_path
.clone()
.unwrap_or_else(|| "output.mp4".to_string());
let map_err = |e: mlx_rs::error::Exception| InferenceError::InferenceFailed(e.to_string());
let aggregate_dim: i32 = 188160;
let gemma_seq_default: i32 = 1024;
let (text_embed, n_valid) = if let Some(ref mut gemma) = self.gemma_encoder {
info!(prompt = %req.prompt, "encoding prompt via Gemma 3 12B");
gemma.encode_for_ltx(&req.prompt, gemma_seq_default as usize)?
} else {
if !req.prompt.is_empty() {
tracing::warn!(
prompt = %req.prompt,
"Gemma 3 encoder unavailable — text prompt ignored"
);
}
(
Array::zeros::<f32>(&[1, gemma_seq_default, aggregate_dim]).map_err(map_err)?,
0usize,
)
};
let (video_cond, audio_cond) = self
.connector
.forward(&text_embed, n_valid)
.map_err(map_err)?;
info!(
video_cond_shape = ?video_cond.shape(),
audio_cond_shape = ?audio_cond.shape(),
"text conditioning computed via connector"
);
let spatial_downsample: i32 = 32;
let effective_h = (height as i32 / spatial_downsample) * spatial_downsample;
let effective_w = (width as i32 / spatial_downsample) * spatial_downsample;
let latent_h = effective_h / spatial_downsample;
let latent_w = effective_w / spatial_downsample;
let nf = num_frames as i32;
let latent_t = 1 + (nf - 1) / 8;
let effective_t = 1 + (latent_t - 1) * 8;
if effective_h != height as i32 || effective_w != width as i32 || effective_t != nf {
tracing::warn!(
requested_h = height, effective_h,
requested_w = width, effective_w,
requested_frames = nf, effective_frames = effective_t,
"dimensions adjusted to upstream constraints (H,W divisible by 32; frames = 1 + 8k)"
);
}
let latent_c = self.config.in_channels as i32;
let num_patches = latent_t * latent_h * latent_w;
info!(
latent_t,
latent_h, latent_w, latent_c, num_patches, "computed latent dimensions"
);
let anchor_latent = self.compute_anchor_latent(req, latent_h, latent_w)?;
let anchor_patches = (latent_h * latent_w) as usize;
let audio_mode = true;
let _ = req.effective_mode();
let audio_fps = output_fps as f32;
let audio_frames = if audio_mode {
(num_frames as f32 / audio_fps * 25.0).round() as i32
} else {
0
};
let audio_latent_c: i32 = 128;
let audio_seed = seed.wrapping_add(1);
let scheduler = RectifiedFlowScheduler::new_with_tokens(steps, num_patches as usize);
let mut latents = scheduler
.init_noise(&[1, num_patches, latent_c], seed)
.map_err(map_err)?;
let mode = req.effective_mode();
if mode == VideoMode::AudioRefVideo {
let audio_path = req.audio_path.as_deref().ok_or_else(|| {
InferenceError::InferenceFailed("audio_ref_video requires audio_path".to_string())
})?;
tracing::warn!(
audio_path,
"audio_ref_video: audio_path is passthrough-only on this backend — \
recorded on the request for downstream muxing, NOT used for \
video conditioning. Real conditioning is tracked at \
Parslee-ai/car#130; the CLI exposes this path via `--audio-mux` \
(#183)."
);
}
let mut audio_latents = if audio_mode {
Some(
scheduler
.init_noise(&[1, audio_frames, audio_latent_c], audio_seed)
.map_err(map_err)?,
)
} else {
None
};
if audio_mode {
info!(
audio_frames,
audio_latent_c,
mode = ?mode,
"joint audio+video pathway enabled"
);
}
if let Some(ref anchor) = anchor_latent {
latents =
splice_first_temporal_slice(&latents, anchor, anchor_patches).map_err(map_err)?;
}
info!(steps, "starting rectified flow denoising loop");
let rope = RopeBundle::build(
latent_t,
latent_h,
latent_w,
if audio_mode { audio_frames } else { 0 },
audio_fps,
self.config.num_heads,
self.config.head_dim,
self.config.audio_heads,
self.config.audio_head_dim,
)
.map_err(map_err)?;
let cfg_enabled = guidance_scale > 1.0;
let video_scale = Array::from_f32(guidance_scale);
let audio_scale = Array::from_f32(audio_guidance_scale);
let null_cond = if cfg_enabled {
Some(Array::zeros::<f32>(video_cond.shape()).map_err(map_err)?)
} else {
None
};
let null_audio_cond = if cfg_enabled && audio_mode {
Some(Array::zeros::<f32>(audio_cond.shape()).map_err(map_err)?)
} else {
None
};
let denoise_mask: Option<Array> = if anchor_latent.is_some() {
let mut mask_vals = vec![1.0f32; num_patches as usize];
for i in 0..anchor_patches {
mask_vals[i] = 0.0;
}
Some(Array::from_slice(&mask_vals, &[1, num_patches]))
} else {
None
};
for step_idx in 0..steps {
let t = scheduler.timesteps[step_idx];
let global_timestep = ops::reshape(&Array::from_f32(t), &[1]).map_err(map_err)?;
let timestep = if let Some(ref mask) = denoise_mask {
let t_scalar = Array::from_f32(t);
ops::multiply(mask, &t_scalar).map_err(map_err)?
} else {
ops::reshape(&Array::from_f32(t), &[1]).map_err(map_err)?
};
let (velocity, audio_velocity) = self
.transformer
.forward(
&latents,
&video_cond,
×tep,
&global_timestep,
audio_latents.as_ref(),
if audio_mode { Some(&audio_cond) } else { None },
&rope,
)
.map_err(map_err)?;
let (velocity, audio_velocity) = if cfg_enabled {
let (uncond_velocity, uncond_audio_velocity) = self
.transformer
.forward(
&latents,
null_cond.as_ref().unwrap(),
×tep,
&global_timestep,
audio_latents.as_ref(),
null_audio_cond.as_ref(),
&rope,
)
.map_err(map_err)?;
let diff = ops::subtract(&velocity, &uncond_velocity).map_err(map_err)?;
let scaled_diff = ops::multiply(&diff, &video_scale).map_err(map_err)?;
let v = ops::add(&uncond_velocity, &scaled_diff).map_err(map_err)?;
let av = match (audio_velocity, uncond_audio_velocity) {
(Some(a_cond), Some(a_uncond)) => {
let a_diff = ops::subtract(&a_cond, &a_uncond).map_err(map_err)?;
let a_scaled = ops::multiply(&a_diff, &audio_scale).map_err(map_err)?;
Some(ops::add(&a_uncond, &a_scaled).map_err(map_err)?)
}
_ => None,
};
(v, av)
} else {
(velocity, audio_velocity)
};
latents = scheduler
.step(&velocity, step_idx, &latents)
.map_err(map_err)?;
if let Some(ref anchor) = anchor_latent {
latents = splice_first_temporal_slice(&latents, anchor, anchor_patches)
.map_err(map_err)?;
}
if let (Some(a_lat), Some(a_vel)) = (audio_latents.as_ref(), audio_velocity.as_ref()) {
audio_latents = Some(scheduler.step(a_vel, step_idx, a_lat).map_err(map_err)?);
}
let mut to_eval: Vec<&Array> = vec![&latents];
if let Some(ref a) = audio_latents {
to_eval.push(a);
}
mlx_rs::transforms::eval(to_eval).map_err(map_err)?;
if step_idx % 5 == 0 || step_idx == steps - 1 {
info!(step = step_idx + 1, total = steps, t, "denoising step");
}
}
info!("denoising complete, decoding latents to video frames");
if let Ok(path) = std::env::var("CAR_LTX_REF_LATENT") {
info!(%path, "overriding latents with reference safetensors");
let loaded = Array::load_safetensors(&path).map_err(|e| {
InferenceError::InferenceFailed(format!("load ref latent {path}: {e}"))
})?;
if let Some(r) = loaded.get("latent") {
latents = r.as_dtype(latents.dtype()).map_err(map_err)?;
} else {
return Err(InferenceError::InferenceFailed(format!(
"ref latent safetensors missing 'latent' key: {path}"
)));
}
}
if let Ok(dir) = std::env::var("CAR_DUMP_LTX_STAGE") {
let _ = std::fs::create_dir_all(&dir);
if let Ok(l_f32) = latents.as_dtype(mlx_rs::Dtype::Float32) {
let _ = mlx_rs::transforms::eval([&l_f32]);
let shape = l_f32.shape().to_vec();
let data: &[f32] = l_f32.as_slice();
let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
let _ = std::fs::write(format!("{dir}/final_video_latent_patched.bin"), &bytes);
let _ = std::fs::write(
format!("{dir}/final_video_latent_patched.meta"),
format!("{shape:?}\n"),
);
}
}
let latents_5d = ops::reshape(&latents, &[1, latent_t, latent_h, latent_w, latent_c])
.map_err(map_err)?;
let latents_5d = ops::transpose_axes(&latents_5d, &[0, 4, 1, 2, 3]).map_err(map_err)?;
let video_frames = self.vae.decode(&latents_5d).map_err(map_err)?;
info!(frames_shape = ?video_frames.shape(), "video frames decoded");
let frame_shape = video_frames.shape();
let total_frames = frame_shape[1];
let out_h = frame_shape[2];
let out_w = frame_shape[3];
let scale_255 = Array::from_f32(255.0);
let pixels_u8 = ops::multiply(&video_frames, &scale_255).map_err(map_err)?;
mlx_rs::transforms::eval([&pixels_u8]).map_err(map_err)?;
let tmp_dir = std::env::temp_dir().join(format!("ltx_frames_{seed}"));
std::fs::create_dir_all(&tmp_dir)
.map_err(|e| InferenceError::InferenceFailed(format!("mkdir: {e}")))?;
let raw_path = tmp_dir.join("frames.raw");
let pixel_data: Vec<f32> = pixels_u8.as_slice::<f32>().to_vec();
let pixel_bytes: Vec<u8> = pixel_data
.iter()
.map(|&v| v.clamp(0.0, 255.0) as u8)
.collect();
std::fs::write(&raw_path, &pixel_bytes)
.map_err(|e| InferenceError::InferenceFailed(format!("write frames: {e}")))?;
let ffmpeg_status = std::process::Command::new("ffmpeg")
.args([
"-y",
"-f",
"rawvideo",
"-pix_fmt",
"rgb24",
"-s",
&format!("{out_w}x{out_h}"),
"-r",
&output_fps.to_string(),
"-i",
raw_path.to_str().unwrap_or("frames.raw"),
"-c:v",
"libx264",
"-pix_fmt",
"yuv420p",
"-frames:v",
&total_frames.to_string(),
&output_path,
])
.output();
let _ = std::fs::remove_dir_all(&tmp_dir);
match ffmpeg_status {
Ok(output) if output.status.success() => {
info!(path = %output_path, frames = total_frames, "MP4 encoded successfully");
}
Ok(output) => {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(InferenceError::InferenceFailed(format!(
"ffmpeg failed: {stderr}"
)));
}
Err(e) => {
return Err(InferenceError::InferenceFailed(format!(
"ffmpeg not found or failed to execute: {e}. \
Install ffmpeg to encode video output."
)));
}
}
if mode == VideoMode::AudioVideo {
if let Some(a_lat) = audio_latents {
info!("decoding audio latents → mel spectrogram → waveform");
let decode_result = self.decode_audio_latents(&a_lat).map_err(|e| {
InferenceError::InferenceFailed(format!("decode_audio_latents: {e}"))
});
let waveform = match decode_result {
Ok(w) => w,
Err(e) => {
let _ = std::fs::remove_file(&output_path);
return Err(e);
}
};
let sample_rate = 16000u32;
if let Err(e) = mux_audio_into_mp4(&output_path, &waveform, sample_rate) {
let _ = std::fs::remove_file(&output_path);
return Err(e);
}
info!(sample_rate, "audio track muxed into MP4");
}
}
Ok(GenerateVideoResult {
video_path: output_path,
media_type: "video/mp4".to_string(),
model_used: Some("ltx-2.3-mlx-q4".to_string()),
})
}
fn decode_audio_latents(&self, tokens: &Array) -> Result<Array, mlx_rs::error::Exception> {
let mean = &self.audio_vae.denorm_mean;
let std_ = &self.audio_vae.denorm_std;
let mean = ops::reshape(mean, &[1, 1, mean.shape()[0]])?;
let std_ = ops::reshape(std_, &[1, 1, std_.shape()[0]])?;
let denormed = ops::add(&ops::multiply(tokens, &std_)?, &mean)?;
let s = denormed.shape();
let (b, t) = (s[0], s[1]);
let c: i32 = 8;
let f: i32 = 16;
let reshaped = ops::reshape(&denormed, &[b, t, c, f])?;
let vae_latent = ops::transpose_axes(&reshaped, &[0, 2, 1, 3])?;
let mel = self.audio_vae.decode(&vae_latent)?;
let ms = mel.shape();
debug_assert_eq!(ms[1], 2, "audio VAE must output stereo mel");
let (mb, mt, mf) = (ms[0], ms[2], ms[3]);
let mel_bctf = ops::transpose_axes(&mel, &[0, 1, 3, 2])?; let mel_chan = ops::reshape(&mel_bctf, &[mb, 2 * mf, mt])?;
self.vocoder.forward(&mel_chan)
}
fn compute_anchor_latent(
&self,
req: &GenerateVideoRequest,
latent_h: i32,
latent_w: i32,
) -> Result<Option<Array>, InferenceError> {
let image_path = match req.effective_mode() {
VideoMode::T2v | VideoMode::AudioVideo | VideoMode::AudioRefVideo => {
return Ok(None);
}
VideoMode::I2v => req.image_path.as_deref().ok_or_else(|| {
InferenceError::InferenceFailed("i2v mode requires image_path".to_string())
})?,
VideoMode::Extend => {
return Err(InferenceError::UnsupportedMode {
mode: "extend",
backend: "native-mlx-ltx",
reason: "accepted on the request surface but not yet wired; \
video extension requires VAE-encoding the input clip and \
splicing its latents into the denoising schedule",
});
}
VideoMode::Retake => {
return Err(InferenceError::UnsupportedMode {
mode: "retake",
backend: "native-mlx-ltx",
reason: "accepted on the request surface but not yet wired; \
retake requires partial-frame-range masked diffusion on \
VAE-encoded input latents",
});
}
};
let encoder = self.vae_encoder.as_ref().ok_or_else(|| {
InferenceError::InferenceFailed(
"image-to-video requires the VAE encoder, which is not available \
in this checkpoint (missing vae_encoder.safetensors)"
.to_string(),
)
})?;
let pixel_w = (latent_w * 32) as u32;
let pixel_h = (latent_h * 32) as u32;
info!(
path = image_path,
pixel_w, pixel_h, "encoding i2v reference frame"
);
let img = load_rgb_image(std::path::Path::new(image_path), pixel_w, pixel_h)?;
let img = {
let two = Array::from_f32(2.0);
let one = Array::from_f32(1.0);
ops::subtract(
&ops::multiply(&img, &two)
.map_err(|e| InferenceError::InferenceFailed(format!("scale img: {e}")))?,
&one,
)
.map_err(|e| InferenceError::InferenceFailed(format!("shift img: {e}")))?
};
let s = img.shape();
let frames = ops::reshape(&img, &[s[0], 1, s[1], s[2], s[3]])
.and_then(|x| ops::tile(&x, &[1, 8, 1, 1, 1]))
.map_err(|e| InferenceError::InferenceFailed(format!("tile ref frame: {e}")))?;
let latent = encoder
.encode(&frames)
.map_err(|e| InferenceError::InferenceFailed(format!("vae_encoder.encode: {e}")))?;
let shape = latent.shape();
let c = shape[1];
let lh = shape[3];
let lw = shape[4];
debug_assert_eq!(
(lh, lw),
(latent_h, latent_w),
"encoder latent spatial size must match computed latent slab \
(internal invariant — inputs derived from latent_h/w above)"
);
let first_t = latent.index((.., .., 0..1, .., ..));
let nhwc = ops::transpose_axes(&first_t, &[0, 2, 3, 4, 1])
.map_err(|e| InferenceError::InferenceFailed(format!("anchor transpose: {e}")))?;
let seq = ops::reshape(&nhwc, &[1, lh * lw, c])
.map_err(|e| InferenceError::InferenceFailed(format!("anchor reshape: {e}")))?;
Ok(Some(seq))
}
}
fn huggingface_cache_root() -> PathBuf {
std::env::var("HF_HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| {
std::env::var("HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("."))
.join(".cache")
.join("huggingface")
})
.join("hub")
}
fn huggingface_repo_dir(repo_id: &str) -> PathBuf {
huggingface_cache_root().join(format!("models--{}", repo_id.replace('/', "--")))
}
fn resolve_huggingface_ref_snapshot(repo_dir: &Path, name: &str) -> Option<PathBuf> {
let sha = std::fs::read_to_string(repo_dir.join("refs").join(name))
.ok()?
.trim()
.to_string();
if sha.is_empty() {
return None;
}
let snapshot = repo_dir.join("snapshots").join(sha);
if snapshot.is_dir() {
Some(snapshot)
} else {
None
}
}
fn latest_huggingface_repo_snapshot(repo_id: &str) -> Option<PathBuf> {
let repo_dir = huggingface_repo_dir(repo_id);
if let Some(snapshot) = resolve_huggingface_ref_snapshot(&repo_dir, "main") {
return Some(snapshot);
}
let snapshots = repo_dir.join("snapshots");
let mut candidates: Vec<(SystemTime, PathBuf)> = std::fs::read_dir(snapshots)
.ok()?
.filter_map(Result::ok)
.map(|entry| entry.path())
.filter(|path| path.is_dir())
.map(|path| {
let modified = path
.metadata()
.and_then(|metadata| metadata.modified())
.unwrap_or(SystemTime::UNIX_EPOCH);
(modified, path)
})
.collect();
candidates.sort();
candidates.pop().map(|(_, path)| path)
}
fn splice_first_temporal_slice(
latents: &Array,
anchor: &Array,
anchor_patches: usize,
) -> Result<Array, mlx_rs::error::Exception> {
debug_assert!(anchor_patches <= latents.shape()[1] as usize);
let tail = latents.index((.., (anchor_patches as i32).., ..));
ops::concatenate_axis(&[anchor, &tail], 1)
}
fn mux_audio_into_mp4(
video_path: &str,
waveform: &Array,
sample_rate: u32,
) -> Result<(), InferenceError> {
let w_shape = waveform.shape();
let stereo = match w_shape.len() {
2 => waveform.clone(),
3 => ops::reshape(waveform, &[w_shape[1], w_shape[2]])
.map_err(|e| InferenceError::InferenceFailed(format!("waveform squeeze: {e}")))?,
other => {
return Err(InferenceError::InferenceFailed(format!(
"expected waveform rank 2 or 3, got {other}"
)))
}
};
let s = stereo.shape();
if s[0] != 2 {
return Err(InferenceError::InferenceFailed(format!(
"expected stereo waveform, got {} channels",
s[0]
)));
}
let n_samples = s[1];
let interleaved = ops::transpose_axes(&stereo, &[1, 0])
.map_err(|e| InferenceError::InferenceFailed(format!("interleave: {e}")))?;
mlx_rs::transforms::eval([&interleaved])
.map_err(|e| InferenceError::InferenceFailed(format!("eval waveform: {e}")))?;
let raw: &[f32] = interleaved.as_slice::<f32>();
let mut bytes: Vec<u8> = Vec::with_capacity(raw.len() * 4);
for &v in raw {
let clean = if v.is_finite() {
v.clamp(-1.0, 1.0)
} else {
0.0
};
bytes.extend_from_slice(&clean.to_le_bytes());
}
let pcm_path = std::env::temp_dir().join(format!(
"ltx_audio_{}.pcm",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
));
std::fs::write(&pcm_path, &bytes)
.map_err(|e| InferenceError::InferenceFailed(format!("write pcm: {e}")))?;
let tmp_out = format!("{video_path}.with_audio.mp4");
let status = std::process::Command::new("ffmpeg")
.args([
"-y",
"-i",
video_path,
"-f",
"f32le",
"-ar",
&sample_rate.to_string(),
"-ac",
"2",
"-i",
pcm_path.to_str().unwrap_or(""),
"-c:v",
"copy",
"-c:a",
"aac",
"-b:a",
"192k",
"-shortest",
&tmp_out,
])
.output();
let _ = std::fs::remove_file(&pcm_path);
match status {
Ok(o) if o.status.success() => {
if let Err(e) = std::fs::rename(&tmp_out, video_path) {
let _ = std::fs::remove_file(&tmp_out);
return Err(InferenceError::InferenceFailed(format!("replace mp4: {e}")));
}
info!(samples = n_samples, sample_rate, "muxed audio into mp4");
Ok(())
}
Ok(o) => {
let _ = std::fs::remove_file(&tmp_out);
Err(InferenceError::InferenceFailed(format!(
"ffmpeg audio mux failed: {}",
String::from_utf8_lossy(&o.stderr)
)))
}
Err(e) => {
let _ = std::fs::remove_file(&tmp_out);
Err(InferenceError::InferenceFailed(format!(
"ffmpeg audio mux spawn failed: {e}"
)))
}
}
}