use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use anyhow::{bail, Context, Result};
use candle_core::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::ltx_video::{
sampling::{
FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerConfig, TimeShiftType,
},
transformer::{LtxVideoTransformer3DModel, LtxVideoTransformer3DModelConfig},
vae::{AutoencoderKLLtxVideo, AutoencoderKLLtxVideoConfig},
};
use mold_core::{GenerateRequest, GenerateResponse, ModelPaths, OutputFormat, VideoData};
use crate::device::{fmt_gb, usable_free_vram_bytes};
use crate::engine::{gpu_dtype, rand_seed, seeded_randn, LoadStrategy};
use crate::engine_base::EngineBase;
use crate::progress::{ProgressCallback, ProgressEvent};
use crate::shared_pool::SharedPool;
use super::{latent_upsampler::LatentUpsampler, video_enc};
const VAE_SPATIAL_COMPRESSION: usize = 32;
const VAE_TEMPORAL_COMPRESSION: usize = 8;
const LATENT_CHANNELS: usize = 128;
const PATCH_SIZE: usize = 1;
const PATCH_SIZE_T: usize = 1;
const LTX_098_DISTILLED_FIRST_PASS_SIGMAS: &[f32] =
&[1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250];
const LTX_098_DISTILLED_SECOND_PASS_SIGMAS: &[f32] = &[0.9094, 0.7250, 0.4219];
const LTX_098_DEV_FIRST_PASS_GUIDANCE_SCALE: &[f32] = &[1.0, 1.0, 6.0, 8.0, 6.0, 1.0, 1.0];
const LTX_098_DEV_FIRST_PASS_STG_SCALE: &[f32] = &[0.0, 0.0, 4.0, 4.0, 4.0, 2.0, 1.0];
const LTX_098_DEV_FIRST_PASS_RESCALING_SCALE: &[f32] = &[1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 1.0];
const LTX_098_DEV_FIRST_PASS_GUIDANCE_TIMESTEPS: &[f32] =
&[1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180];
const LTX_096_DEV_SKIP_BLOCKS: &[usize] = &[19];
const LTX_098_2B_DISTILLED_SKIP_BLOCKS: &[usize] = &[];
const LTX_098_13B_DISTILLED_SKIP_BLOCKS: &[usize] = &[42];
const LTX_098_13B_DEV_FIRST_PASS_SKIP_BLOCKS: &[&[usize]] = &[
&[],
&[11, 25, 35, 39],
&[22, 35, 39],
&[28],
&[28],
&[28],
&[28],
];
const LTX_098_13B_DEV_SECOND_PASS_SKIP_BLOCKS: &[usize] = &[27];
const LTX_VIDEO_FULL_RESIDENT_RUNTIME_HEADROOM: u64 = 2_000_000_000;
fn is_official_ltx_transformer_checkpoint(path: &std::path::Path) -> bool {
path.file_name()
.and_then(|name| name.to_str())
.is_some_and(|name| {
name.ends_with(".safetensors")
&& name.starts_with("ltx")
&& !name.starts_with("diffusion_pytorch_model")
})
}
fn remap_official_ltx_transformer_key(key: &str) -> String {
let key = key
.replace("proj_in", "patchify_proj")
.replace("time_embed", "adaln_single")
.replace("norm_q", "q_norm")
.replace("norm_k", "k_norm");
format!("model.diffusion_model.{key}")
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum LtxPipelineMode {
Base,
Multiscale,
}
#[derive(Clone, Debug)]
struct LtxGuidanceConfig {
guidance_scale: Vec<f32>,
stg_scale: Vec<f32>,
rescaling_scale: Vec<f32>,
guidance_timesteps: Option<Vec<f32>>,
skip_block_list: Vec<Vec<usize>>,
cfg_star_rescale: bool,
}
impl LtxGuidanceConfig {
fn constant(
guidance_scale: f32,
stg_scale: f32,
rescaling_scale: f32,
skip_block_list: &[usize],
) -> Self {
Self {
guidance_scale: vec![guidance_scale],
stg_scale: vec![stg_scale],
rescaling_scale: vec![rescaling_scale],
guidance_timesteps: None,
skip_block_list: vec![skip_block_list.to_vec()],
cfg_star_rescale: false,
}
}
}
#[derive(Clone, Debug)]
struct LtxPassConfig {
num_inference_steps: u32,
custom_sigmas: Option<Vec<f32>>,
skip_initial_inference_steps: usize,
skip_final_inference_steps: usize,
guidance: LtxGuidanceConfig,
tone_map_compression_ratio: f32,
}
#[derive(Clone, Debug)]
struct LtxMultiscaleConfig {
downscale_factor: f32,
first_pass: LtxPassConfig,
second_pass: LtxPassConfig,
}
#[derive(Clone, Debug)]
struct LtxModelPreset {
transformer_config: LtxVideoTransformer3DModelConfig,
vae_config: AutoencoderKLLtxVideoConfig,
scheduler_config: FlowMatchEulerDiscreteSchedulerConfig,
base_pass: LtxPassConfig,
decode_timestep: f32,
decode_noise_scale: f32,
mode: LtxPipelineMode,
multiscale: Option<LtxMultiscaleConfig>,
}
impl LtxModelPreset {
fn for_model(model_name: &str) -> Result<Self> {
if model_name.contains("ltx-video-0.9.6-distilled") {
Ok(Self {
transformer_config: transformer_2b_config(),
vae_config: improved_vae_config(),
scheduler_config: scheduler_config(true),
base_pass: LtxPassConfig {
num_inference_steps: 8,
custom_sigmas: None,
skip_initial_inference_steps: 0,
skip_final_inference_steps: 0,
guidance: LtxGuidanceConfig::constant(1.0, 0.0, 1.0, &[]),
tone_map_compression_ratio: 0.0,
},
decode_timestep: 0.05,
decode_noise_scale: 0.025,
mode: LtxPipelineMode::Base,
multiscale: None,
})
} else if model_name.contains("ltx-video-0.9.6") {
Ok(Self {
transformer_config: transformer_2b_config(),
vae_config: improved_vae_config(),
scheduler_config: scheduler_config(false),
base_pass: LtxPassConfig {
num_inference_steps: 40,
custom_sigmas: None,
skip_initial_inference_steps: 0,
skip_final_inference_steps: 0,
guidance: LtxGuidanceConfig::constant(3.0, 1.0, 0.7, LTX_096_DEV_SKIP_BLOCKS),
tone_map_compression_ratio: 0.0,
},
decode_timestep: 0.05,
decode_noise_scale: 0.025,
mode: LtxPipelineMode::Base,
multiscale: None,
})
} else if model_name.contains("ltx-video-0.9.8-2b-distilled") {
Ok(Self {
transformer_config: transformer_2b_config(),
vae_config: improved_vae_config(),
scheduler_config: scheduler_config(false),
base_pass: LtxPassConfig {
num_inference_steps: 7,
custom_sigmas: Some(LTX_098_DISTILLED_FIRST_PASS_SIGMAS.to_vec()),
skip_initial_inference_steps: 0,
skip_final_inference_steps: 0,
guidance: LtxGuidanceConfig::constant(
1.0,
0.0,
1.0,
LTX_098_2B_DISTILLED_SKIP_BLOCKS,
),
tone_map_compression_ratio: 0.0,
},
decode_timestep: 0.05,
decode_noise_scale: 0.025,
mode: LtxPipelineMode::Multiscale,
multiscale: Some(LtxMultiscaleConfig {
downscale_factor: 0.6666666,
first_pass: LtxPassConfig {
num_inference_steps: 7,
custom_sigmas: Some(LTX_098_DISTILLED_FIRST_PASS_SIGMAS.to_vec()),
skip_initial_inference_steps: 0,
skip_final_inference_steps: 0,
guidance: LtxGuidanceConfig::constant(
1.0,
0.0,
1.0,
LTX_098_2B_DISTILLED_SKIP_BLOCKS,
),
tone_map_compression_ratio: 0.0,
},
second_pass: LtxPassConfig {
num_inference_steps: 3,
custom_sigmas: Some(LTX_098_DISTILLED_SECOND_PASS_SIGMAS.to_vec()),
skip_initial_inference_steps: 0,
skip_final_inference_steps: 0,
guidance: LtxGuidanceConfig::constant(
1.0,
0.0,
1.0,
LTX_098_2B_DISTILLED_SKIP_BLOCKS,
),
tone_map_compression_ratio: 0.0,
},
}),
})
} else if model_name.contains("ltx-video-0.9.8-13b-distilled") {
Ok(Self {
transformer_config: transformer_13b_config(),
vae_config: improved_vae_config(),
scheduler_config: scheduler_config(false),
base_pass: LtxPassConfig {
num_inference_steps: 7,
custom_sigmas: Some(LTX_098_DISTILLED_FIRST_PASS_SIGMAS.to_vec()),
skip_initial_inference_steps: 0,
skip_final_inference_steps: 0,
guidance: LtxGuidanceConfig::constant(
1.0,
0.0,
1.0,
LTX_098_13B_DISTILLED_SKIP_BLOCKS,
),
tone_map_compression_ratio: 0.0,
},
decode_timestep: 0.05,
decode_noise_scale: 0.025,
mode: LtxPipelineMode::Multiscale,
multiscale: Some(LtxMultiscaleConfig {
downscale_factor: 0.6666666,
first_pass: LtxPassConfig {
num_inference_steps: 7,
custom_sigmas: Some(LTX_098_DISTILLED_FIRST_PASS_SIGMAS.to_vec()),
skip_initial_inference_steps: 0,
skip_final_inference_steps: 0,
guidance: LtxGuidanceConfig::constant(
1.0,
0.0,
1.0,
LTX_098_13B_DISTILLED_SKIP_BLOCKS,
),
tone_map_compression_ratio: 0.0,
},
second_pass: LtxPassConfig {
num_inference_steps: 3,
custom_sigmas: Some(LTX_098_DISTILLED_SECOND_PASS_SIGMAS.to_vec()),
skip_initial_inference_steps: 0,
skip_final_inference_steps: 0,
guidance: LtxGuidanceConfig::constant(
1.0,
0.0,
1.0,
LTX_098_13B_DISTILLED_SKIP_BLOCKS,
),
tone_map_compression_ratio: 0.6,
},
}),
})
} else if model_name.contains("ltx-video-0.9.8-13b-dev") {
Ok(Self {
transformer_config: transformer_13b_config(),
vae_config: improved_vae_config(),
scheduler_config: scheduler_config(false),
base_pass: LtxPassConfig {
num_inference_steps: 30,
custom_sigmas: None,
skip_initial_inference_steps: 0,
skip_final_inference_steps: 0,
guidance: LtxGuidanceConfig::constant(8.0, 4.0, 0.5, &[28]),
tone_map_compression_ratio: 0.0,
},
decode_timestep: 0.05,
decode_noise_scale: 0.025,
mode: LtxPipelineMode::Multiscale,
multiscale: Some(LtxMultiscaleConfig {
downscale_factor: 0.6666666,
first_pass: LtxPassConfig {
num_inference_steps: 30,
custom_sigmas: None,
skip_initial_inference_steps: 0,
skip_final_inference_steps: 3,
guidance: LtxGuidanceConfig {
guidance_scale: LTX_098_DEV_FIRST_PASS_GUIDANCE_SCALE.to_vec(),
stg_scale: LTX_098_DEV_FIRST_PASS_STG_SCALE.to_vec(),
rescaling_scale: LTX_098_DEV_FIRST_PASS_RESCALING_SCALE.to_vec(),
guidance_timesteps: Some(
LTX_098_DEV_FIRST_PASS_GUIDANCE_TIMESTEPS.to_vec(),
),
skip_block_list: LTX_098_13B_DEV_FIRST_PASS_SKIP_BLOCKS
.iter()
.map(|blocks| blocks.to_vec())
.collect(),
cfg_star_rescale: true,
},
tone_map_compression_ratio: 0.0,
},
second_pass: LtxPassConfig {
num_inference_steps: 30,
custom_sigmas: None,
skip_initial_inference_steps: 17,
skip_final_inference_steps: 0,
guidance: LtxGuidanceConfig {
guidance_scale: vec![1.0],
stg_scale: vec![1.0],
rescaling_scale: vec![1.0],
guidance_timesteps: Some(vec![1.0]),
skip_block_list: vec![LTX_098_13B_DEV_SECOND_PASS_SKIP_BLOCKS.to_vec()],
cfg_star_rescale: true,
},
tone_map_compression_ratio: 0.0,
},
}),
})
} else {
bail!("unsupported LTX model preset for {}", model_name);
}
}
}
fn transformer_2b_config() -> LtxVideoTransformer3DModelConfig {
LtxVideoTransformer3DModelConfig {
num_layers: 28,
num_attention_heads: 32,
attention_head_dim: 64,
cross_attention_dim: 2048,
caption_channels: 4096,
..Default::default()
}
}
fn transformer_13b_config() -> LtxVideoTransformer3DModelConfig {
LtxVideoTransformer3DModelConfig {
num_layers: 48,
num_attention_heads: 32,
attention_head_dim: 128,
cross_attention_dim: 4096,
caption_channels: 4096,
..Default::default()
}
}
fn is_legacy_ltx_video_13b(model_name: &str, preset: &LtxModelPreset) -> bool {
model_name.contains("13b")
|| (preset.transformer_config.num_layers >= 48
&& preset.transformer_config.attention_head_dim >= 128)
}
fn ltx_video_transformer_residency_guard(
model_name: &str,
preset: &LtxModelPreset,
transformer_bytes: u64,
usable_vram_bytes: Option<u64>,
is_cuda: bool,
) -> Result<()> {
if !is_cuda || !is_legacy_ltx_video_13b(model_name, preset) {
return Ok(());
}
let Some(usable_vram_bytes) = usable_vram_bytes.filter(|bytes| *bytes > 0) else {
return Ok(());
};
let required = transformer_bytes.saturating_add(LTX_VIDEO_FULL_RESIDENT_RUNTIME_HEADROOM);
if required <= usable_vram_bytes {
return Ok(());
}
bail!(
"legacy LTX-Video 13B BF16 requires full transformer residency ({} weights + {} runtime headroom) but only {} usable VRAM is available. MOLD_OFFLOAD is not implemented for this legacy transformer yet; use ltx-video-0.9.8-2b-distilled, lower --width/--height/--frames, or use an LTX-2 FP8 model with adaptive offload.",
fmt_gb(transformer_bytes),
fmt_gb(LTX_VIDEO_FULL_RESIDENT_RUNTIME_HEADROOM),
fmt_gb(usable_vram_bytes),
)
}
fn improved_vae_config() -> AutoencoderKLLtxVideoConfig {
AutoencoderKLLtxVideoConfig {
block_out_channels: vec![128, 256, 512, 1024, 2048],
decoder_block_out_channels: vec![256, 512, 1024],
spatiotemporal_scaling: vec![true, true, true, true],
decoder_spatiotemporal_scaling: vec![true, true, true],
layers_per_block: vec![4, 6, 6, 2, 2],
decoder_layers_per_block: vec![5, 5, 5, 5],
decoder_inject_noise: vec![false, false, false, false],
decoder_upsample_residual: vec![true, true, true],
decoder_upsample_factor: vec![2, 2, 2],
timestep_conditioning: true,
..Default::default()
}
}
fn scheduler_config(stochastic_sampling: bool) -> FlowMatchEulerDiscreteSchedulerConfig {
FlowMatchEulerDiscreteSchedulerConfig {
num_train_timesteps: 1000,
shift: 1.0,
use_dynamic_shifting: false,
base_shift: Some(0.5),
max_shift: Some(1.15),
base_image_seq_len: Some(256),
max_image_seq_len: Some(4096),
invert_sigmas: false,
shift_terminal: None,
use_karras_sigmas: false,
use_exponential_sigmas: false,
use_beta_sigmas: false,
time_shift_type: TimeShiftType::Exponential,
stochastic_sampling,
}
}
#[derive(Clone, Debug)]
struct LtxResolvedStep {
guidance_scale: f32,
stg_scale: f32,
rescaling_scale: f32,
skip_blocks: Vec<usize>,
}
fn clamp_skip_blocks(skip_blocks: &[usize], num_layers: usize) -> Vec<usize> {
skip_blocks
.iter()
.copied()
.filter(|idx| *idx < num_layers)
.collect()
}
fn resolve_guidance_index(guidance_timesteps: &[f32], sigma: f32) -> usize {
guidance_timesteps
.iter()
.position(|value| *value <= sigma)
.unwrap_or_else(|| guidance_timesteps.len().saturating_sub(1))
}
fn resolve_step_schedule(
pass: &LtxPassConfig,
sigmas: &[f32],
num_layers: usize,
) -> Vec<LtxResolvedStep> {
sigmas
.iter()
.map(|sigma| {
let mapped = pass
.guidance
.guidance_timesteps
.as_ref()
.map(|timesteps| resolve_guidance_index(timesteps, *sigma))
.unwrap_or(0);
let value_at = |values: &[f32]| -> f32 {
if values.len() == 1 {
values[0]
} else {
values[mapped.min(values.len() - 1)]
}
};
let skip_blocks = if pass.guidance.skip_block_list.is_empty() {
Vec::new()
} else if pass.guidance.skip_block_list.len() == 1 {
clamp_skip_blocks(&pass.guidance.skip_block_list[0], num_layers)
} else {
clamp_skip_blocks(
&pass.guidance.skip_block_list
[mapped.min(pass.guidance.skip_block_list.len() - 1)],
num_layers,
)
};
LtxResolvedStep {
guidance_scale: value_at(&pass.guidance.guidance_scale),
stg_scale: value_at(&pass.guidance.stg_scale),
rescaling_scale: value_at(&pass.guidance.rescaling_scale),
skip_blocks,
}
})
.collect()
}
fn std_over_dims_except0_keepdim(x: &Tensor) -> Result<Tensor> {
let rank = x.rank();
if rank < 2 {
bail!("std_over_dims_except0_keepdim expects rank >= 2, got {rank}");
}
let b = x.dim(0)?;
let flat = x.flatten_from(1)?;
let var = flat.var_keepdim(1)?;
let std = var.sqrt()?;
let mut shape = Vec::with_capacity(rank);
shape.push(b);
shape.extend(std::iter::repeat_n(1usize, rank - 1));
Ok(std.reshape(shape)?)
}
fn rescale_noise_cfg(
noise_cfg: &Tensor,
noise_pred_text: &Tensor,
guidance_rescale: f32,
) -> Result<Tensor> {
let std_text = std_over_dims_except0_keepdim(noise_pred_text)?;
let std_cfg = std_over_dims_except0_keepdim(noise_cfg)?;
let ratio = std_text.broadcast_div(&std_cfg)?;
let noise_pred_rescaled = noise_cfg.broadcast_mul(&ratio)?;
let a = noise_pred_rescaled.affine(guidance_rescale as f64, 0.0)?;
let b = noise_cfg.affine((1.0 - guidance_rescale) as f64, 0.0)?;
Ok(a.broadcast_add(&b)?)
}
fn cfg_star_rescale_uncond(noise_pred_uncond: &Tensor, noise_pred_text: &Tensor) -> Result<Tensor> {
let batch = noise_pred_text.dim(0)?;
let positive_flat = noise_pred_text.flatten_from(1)?;
let negative_flat = noise_pred_uncond.flatten_from(1)?;
let dot = positive_flat
.broadcast_mul(&negative_flat)?
.sum_keepdim(1)?;
let squared = negative_flat.sqr()?.sum_keepdim(1)?.affine(1.0, 1e-8)?;
let alpha = dot.broadcast_div(&squared)?;
let alpha = alpha.reshape((batch, 1, 1))?;
Ok(noise_pred_uncond.broadcast_mul(&alpha.broadcast_as(noise_pred_uncond.shape())?)?)
}
fn create_skip_layer_mask(
num_layers: usize,
batch_size: usize,
layers_to_skip: &[usize],
device: &Device,
) -> Result<Option<Tensor>> {
let layers_to_skip = clamp_skip_blocks(layers_to_skip, num_layers);
if layers_to_skip.is_empty() {
return Ok(None);
}
let mut mask_data = vec![0.0f32; num_layers * batch_size];
for &layer_idx in &layers_to_skip {
for batch_idx in 0..batch_size {
mask_data[layer_idx * batch_size + batch_idx] = 1.0;
}
}
Ok(Some(Tensor::from_vec(
mask_data,
(num_layers, batch_size),
device,
)?))
}
fn tone_map_latents(latents: &Tensor, compression: f32) -> Result<Tensor> {
if compression == 0.0 {
return Ok(latents.clone());
}
if !(0.0..=1.0).contains(&compression) {
bail!("tone map compression must be in [0, 1], got {compression}");
}
let scale_factor = compression * 0.75;
let abs_latents = latents.abs()?;
let sigmoid_term = abs_latents
.affine(1.0, -1.0)?
.affine((4.0 * scale_factor) as f64, 0.0)?;
let sigmoid_term = candle_nn::ops::sigmoid(&sigmoid_term)?;
let scales = sigmoid_term.affine((-0.8 * scale_factor) as f64, 1.0)?;
Ok(latents.broadcast_mul(&scales)?)
}
fn normalize_latents_with_vae(latents: &Tensor, vae: &AutoencoderKLLtxVideo) -> Result<Tensor> {
let c = latents.dim(1)?;
let mean = vae
.latents_mean()
.reshape((1, c, 1, 1, 1))?
.to_device(latents.device())?
.to_dtype(latents.dtype())?;
let std = vae
.latents_std()
.reshape((1, c, 1, 1, 1))?
.to_device(latents.device())?
.to_dtype(latents.dtype())?;
Ok(latents.broadcast_sub(&mean)?.broadcast_div(&std)?)
}
fn denormalize_latents_with_vae(latents: &Tensor, vae: &AutoencoderKLLtxVideo) -> Result<Tensor> {
let c = latents.dim(1)?;
let mean = vae
.latents_mean()
.reshape((1, c, 1, 1, 1))?
.to_device(latents.device())?
.to_dtype(latents.dtype())?;
let std = vae
.latents_std()
.reshape((1, c, 1, 1, 1))?
.to_device(latents.device())?
.to_dtype(latents.dtype())?;
Ok(latents.broadcast_mul(&std)?.broadcast_add(&mean)?)
}
fn adain_filter_latents(latents: &Tensor, reference_latents: &Tensor) -> Result<Tensor> {
let latents_f32 = latents.to_dtype(DType::F32)?;
let reference_f32 = reference_latents.to_dtype(DType::F32)?;
let latents_flat = latents_f32.flatten_from(2)?;
let reference_flat = reference_f32.flatten_from(2)?;
let lat_mean = latents_flat.mean_keepdim(2)?;
let lat_std = latents_flat.var_keepdim(2)?.affine(1.0, 1e-6)?.sqrt()?;
let ref_mean = reference_flat.mean_keepdim(2)?;
let ref_std = reference_flat.var_keepdim(2)?.affine(1.0, 1e-6)?.sqrt()?;
let filtered = latents_flat
.broadcast_sub(&lat_mean)?
.broadcast_div(&lat_std)?
.broadcast_mul(&ref_std)?
.broadcast_add(&ref_mean)?
.reshape(latents.shape())?;
Ok(filtered.to_dtype(latents.dtype())?)
}
#[allow(dead_code)]
struct LoadedLtxVideo {
transformer: Option<LtxVideoTransformer3DModel>,
vae: Option<AutoencoderKLLtxVideo>,
device: Device,
dtype: DType,
}
#[allow(dead_code)]
pub struct LtxVideoEngine {
base: EngineBase<LoadedLtxVideo>,
t5_variant: Option<String>,
shared_pool: Option<Arc<Mutex<SharedPool>>>,
pending_placement: Option<mold_core::types::DevicePlacement>,
single_file_native_format: Option<bool>,
vae_in_checkpoint: bool,
}
impl LtxVideoEngine {
pub fn new(
model_name: String,
paths: ModelPaths,
t5_variant: Option<String>,
load_strategy: LoadStrategy,
gpu_ordinal: usize,
shared_pool: Option<Arc<Mutex<SharedPool>>>,
) -> Self {
Self {
base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
t5_variant,
shared_pool,
pending_placement: None,
single_file_native_format: None,
vae_in_checkpoint: false,
}
}
#[allow(clippy::too_many_arguments)]
pub fn from_single_file(
model_name: String,
checkpoint: PathBuf,
vae_path: Option<PathBuf>,
t5_encoder_path: Option<PathBuf>,
t5_tokenizer_path: Option<PathBuf>,
t5_variant: Option<String>,
load_strategy: LoadStrategy,
gpu_ordinal: usize,
shared_pool: Option<Arc<Mutex<SharedPool>>>,
) -> anyhow::Result<Self> {
if !checkpoint.exists() {
anyhow::bail!(
"single-file LTX-Video checkpoint not found: {}",
checkpoint.display()
);
}
let bundle = super::single_file::load(&checkpoint).map_err(|e| {
anyhow::anyhow!(
"failed to parse single-file LTX-Video checkpoint {}: {e}",
checkpoint.display()
)
})?;
let is_native = bundle.format == super::single_file::LtxKeyFormat::Native;
let (resolved_vae, vae_in_checkpoint) = if bundle.has_vae {
(checkpoint.clone(), true)
} else {
let vae = vae_path.ok_or_else(|| {
anyhow::anyhow!(
"LTX-Video checkpoint {} contains no VAE weights (`vae.*` keys). \
Pull the `ltx-video-vae` companion first: `mold pull ltx-video-vae`",
checkpoint.display()
)
})?;
if !vae.exists() {
anyhow::bail!(
"ltx-video-vae companion not on disk: {}. \
Run `mold pull ltx-video-vae` to download it.",
vae.display()
);
}
(vae, false)
};
let paths = ModelPaths {
transformer: checkpoint.clone(),
transformer_shards: Vec::new(),
vae: resolved_vae,
spatial_upscaler: None,
temporal_upscaler: None,
distilled_lora: None,
t5_encoder: t5_encoder_path,
clip_encoder: None,
t5_tokenizer: t5_tokenizer_path,
clip_tokenizer: None,
clip_encoder_2: None,
clip_tokenizer_2: None,
text_encoder_files: Vec::new(),
text_tokenizer: None,
decoder: None,
};
Ok(Self {
base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
t5_variant,
shared_pool,
pending_placement: None,
single_file_native_format: Some(is_native),
vae_in_checkpoint,
})
}
}
impl LtxVideoEngine {
fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
let start = Instant::now();
let preset = LtxModelPreset::for_model(&self.base.model_name)?;
let num_frames = req.frames.unwrap_or(25);
let fps = req.fps.unwrap_or(24);
let steps = req.steps;
let guidance = req.guidance;
if !(num_frames.wrapping_sub(1)).is_multiple_of(8) {
bail!(
"frame count must be 8n+1 (9, 17, 25, 33, ...), got {}",
num_frames
);
}
let seed = req.seed.unwrap_or_else(rand_seed);
let width = req.width;
let height = req.height;
if !width.is_multiple_of(VAE_SPATIAL_COMPRESSION as u32)
|| !height.is_multiple_of(VAE_SPATIAL_COMPRESSION as u32)
{
bail!(
"LTX Video requires width and height to be multiples of {}, got {}x{}",
VAE_SPATIAL_COMPRESSION,
width,
height
);
}
let latent_h = height as usize / VAE_SPATIAL_COMPRESSION;
let latent_w = width as usize / VAE_SPATIAL_COMPRESSION;
let latent_f = (num_frames as usize - 1) / VAE_TEMPORAL_COMPRESSION + 1;
self.generate_sequential(
req, &preset, seed, num_frames, fps, steps, guidance, width, height, latent_h,
latent_w, latent_f, start,
)
}
}
impl crate::engine::InferenceEngine for LtxVideoEngine {
fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
self.pending_placement = req.placement.clone();
let result = self.generate_inner(req);
self.pending_placement = None;
result
}
fn model_name(&self) -> &str {
&self.base.model_name
}
fn is_loaded(&self) -> bool {
self.base.is_loaded()
}
fn load(&mut self) -> Result<()> {
Ok(())
}
fn unload(&mut self) {
self.base.unload();
}
fn set_on_progress(&mut self, callback: ProgressCallback) {
self.base.set_on_progress(callback);
}
fn clear_on_progress(&mut self) {
self.base.clear_on_progress();
}
fn model_paths(&self) -> Option<&ModelPaths> {
Some(&self.base.paths)
}
fn as_chain_renderer(&mut self) -> Option<&mut dyn crate::ltx2::ChainStageRenderer> {
Some(self)
}
}
impl crate::ltx2::ChainStageRenderer for LtxVideoEngine {
fn render_stage(
&mut self,
stage_req: &GenerateRequest,
_carry: Option<&crate::ltx2::ChainTail>,
motion_tail_pixel_frames: u32,
_stage_progress: Option<&mut dyn FnMut(crate::ltx2::StageProgressEvent)>,
) -> Result<crate::ltx2::StageOutcome> {
let start = Instant::now();
let frames = self.render_chain_frames_internal(stage_req)?;
let generation_time_ms = start.elapsed().as_millis() as u64;
if frames.is_empty() {
bail!("LtxVideoEngine.render_stage: pipeline produced zero frames");
}
let tail_count = (motion_tail_pixel_frames as usize).clamp(1, frames.len());
let tail_frames: Vec<image::RgbImage> = frames
.iter()
.skip(frames.len() - tail_count)
.cloned()
.collect();
Ok(crate::ltx2::StageOutcome {
frames,
tail: crate::ltx2::ChainTail {
frames: tail_frames.len() as u32,
tail_rgb_frames: tail_frames,
},
audio: None,
generation_time_ms,
})
}
}
impl LtxVideoEngine {
fn render_chain_frames_internal(
&mut self,
req: &GenerateRequest,
) -> Result<Vec<image::RgbImage>> {
let mut apng_req = req.clone();
apng_req.output_format = Some(OutputFormat::Apng);
apng_req.gif_preview = false;
let response = self.generate_inner(&apng_req)?;
let video = response
.video
.ok_or_else(|| anyhow::anyhow!("LtxVideoEngine.generate returned no video data"))?;
decode_apng_to_rgb_frames(&video.data)
}
}
fn decode_apng_to_rgb_frames(apng_bytes: &[u8]) -> Result<Vec<image::RgbImage>> {
use image::AnimationDecoder;
let cursor = std::io::Cursor::new(apng_bytes);
let decoder = image::codecs::png::PngDecoder::new(cursor)
.map_err(|e| anyhow::anyhow!("failed to open APNG bytes: {e}"))?;
let apng = decoder
.apng()
.map_err(|e| anyhow::anyhow!("decoded PNG is not animated: {e}"))?;
let mut out = Vec::new();
for frame in apng.into_frames() {
let frame = frame.map_err(|e| anyhow::anyhow!("APNG frame decode failed: {e}"))?;
let rgba = frame.into_buffer();
let (w, h) = rgba.dimensions();
let mut rgb_data = Vec::with_capacity((w as usize) * (h as usize) * 3);
for px in rgba.pixels() {
rgb_data.extend_from_slice(&px.0[..3]);
}
let rgb = image::RgbImage::from_raw(w, h, rgb_data)
.ok_or_else(|| anyhow::anyhow!("failed to construct RgbImage from APNG frame"))?;
out.push(rgb);
}
Ok(out)
}
impl LtxVideoEngine {
fn load_transformer(
&self,
preset: &LtxModelPreset,
device: &Device,
dtype: DType,
) -> Result<LtxVideoTransformer3DModel> {
let transformer_files: Vec<std::path::PathBuf> =
if !self.base.paths.transformer_shards.is_empty() {
self.base.paths.transformer_shards.clone()
} else {
vec![self.base.paths.transformer.clone()]
};
let is_gguf = transformer_files
.first()
.and_then(|p| p.extension())
.is_some_and(|e| e == "gguf");
if is_gguf {
bail!("GGUF quantized LTX Video transformer is not yet supported — use :bf16 variant");
}
let transformer_bytes = transformer_files.iter().try_fold(0u64, |acc, path| {
let metadata = std::fs::metadata(path).with_context(|| {
format!("failed to stat LTX Video transformer {}", path.display())
})?;
Ok::<_, anyhow::Error>(acc.saturating_add(metadata.len()))
})?;
ltx_video_transformer_residency_guard(
&self.base.model_name,
preset,
transformer_bytes,
usable_free_vram_bytes(self.base.gpu_ordinal),
device.is_cuda(),
)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&transformer_files, dtype, device)? };
let use_remap = match self.single_file_native_format {
Some(is_native) => is_native,
None => {
transformer_files.len() == 1
&& is_official_ltx_transformer_checkpoint(&transformer_files[0])
}
};
let vb = if use_remap {
vb.rename_f(remap_official_ltx_transformer_key)
} else {
vb
};
Ok(LtxVideoTransformer3DModel::new(
&preset.transformer_config,
vb,
)?)
}
fn load_vae(
&self,
preset: &LtxModelPreset,
device: &Device,
dtype: DType,
) -> Result<AutoencoderKLLtxVideo> {
let vb = self.load_vae_var_builder(dtype, device)?;
Ok(AutoencoderKLLtxVideo::new(preset.vae_config.clone(), vb)?)
}
fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
if self.vae_in_checkpoint {
return Ok(None);
}
let Some(shared_pool) = &self.shared_pool else {
return Ok(None);
};
shared_pool
.lock()
.unwrap()
.load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
}
fn load_vae_var_builder<'a>(&self, dtype: DType, device: &Device) -> Result<VarBuilder<'a>> {
if let Some(tensors) = self.load_vae_cpu_tensors()? {
return Ok(crate::encoders::park::varbuilder_from_parked(
tensors.as_ref(),
dtype,
device,
));
}
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(
std::slice::from_ref(&self.base.paths.vae),
dtype,
device,
)?
};
let vb = if self.vae_in_checkpoint {
vb.pp("vae")
} else {
vb
};
Ok(vb)
}
#[allow(clippy::too_many_arguments)]
fn denoise_pass(
&self,
transformer: &mut LtxVideoTransformer3DModel,
prompt_embeds: &Tensor,
attention_mask: &Tensor,
uncond_embeds: Option<&Tensor>,
uncond_mask: Option<&Tensor>,
pass: &LtxPassConfig,
scheduler_cfg: &FlowMatchEulerDiscreteSchedulerConfig,
seed: u64,
width: u32,
height: u32,
num_frames: u32,
fps: u32,
device: &Device,
dtype: DType,
progress: &crate::progress::ProgressReporter,
stage_name: &str,
ltx_debug: bool,
initial_latents: Option<Tensor>,
) -> Result<Tensor> {
let latent_h = height as usize / VAE_SPATIAL_COMPRESSION;
let latent_w = width as usize / VAE_SPATIAL_COMPRESSION;
let latent_f = (num_frames as usize - 1) / VAE_TEMPORAL_COMPRESSION + 1;
let mut scheduler = FlowMatchEulerDiscreteScheduler::new(scheduler_cfg.clone())?;
scheduler.set_timesteps(
if pass.custom_sigmas.is_some() {
None
} else {
Some(pass.num_inference_steps as usize)
},
device,
pass.custom_sigmas.as_deref(),
None,
None,
)?;
let schedule_sigmas = scheduler
.sigmas()
.to_device(&Device::Cpu)?
.to_vec1::<f32>()?;
let total_steps = schedule_sigmas.len() - 1;
if pass.skip_initial_inference_steps + pass.skip_final_inference_steps >= total_steps {
bail!(
"invalid LTX pass schedule: skip_initial={} + skip_final={} >= total_steps={}",
pass.skip_initial_inference_steps,
pass.skip_final_inference_steps,
total_steps
);
}
let start_step = pass.skip_initial_inference_steps;
let end_step = total_steps - pass.skip_final_inference_steps;
let run_sigmas = schedule_sigmas[start_step..end_step].to_vec();
scheduler.set_begin_index(start_step);
let step_schedule =
resolve_step_schedule(pass, &run_sigmas, transformer.config().num_layers);
let video_coords = build_video_coords(1, latent_f, latent_h, latent_w, fps, device)?;
let mut latents = match initial_latents {
Some(latents) => pack_initial_latents_for_second_pass(&latents)?,
None => {
let noise = seeded_randn(
seed,
&[1, LATENT_CHANNELS, latent_f, latent_h, latent_w],
device,
DType::F32,
)?;
pack_latents(&noise, PATCH_SIZE, PATCH_SIZE_T)?
}
};
progress.stage_start(stage_name);
let denoise_start = Instant::now();
for (step, sigma) in run_sigmas.iter().copied().enumerate() {
let step_start = Instant::now();
let resolved = &step_schedule[step];
let batch = latents.dim(0)?;
let timestep_t = Tensor::full(sigma, (batch,), device)?.to_dtype(dtype)?;
let latents_input = latents.to_dtype(dtype)?;
let do_cfg = resolved.guidance_scale > 1.0 && uncond_embeds.is_some();
let do_stg = resolved.stg_scale > 0.0 && !resolved.skip_blocks.is_empty();
if do_stg {
transformer.set_skip_block_list(vec![]);
} else {
transformer.set_skip_block_list(resolved.skip_blocks.clone());
}
let cond_pred = transformer.forward(
&latents_input,
prompt_embeds,
×tep_t,
Some(attention_mask),
latent_f,
latent_h,
latent_w,
None,
Some(&video_coords),
None,
)?;
let cond_f32 = cond_pred.to_dtype(DType::F32)?;
let mut combined = cond_f32.clone();
if do_cfg {
let uncond_pred = transformer.forward(
&latents_input,
uncond_embeds.expect("checked above"),
×tep_t,
uncond_mask,
latent_f,
latent_h,
latent_w,
None,
Some(&video_coords),
None,
)?;
let mut uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
if pass.guidance.cfg_star_rescale {
uncond_f32 = cfg_star_rescale_uncond(&uncond_f32, &cond_f32)?;
}
let diff = cond_f32.broadcast_sub(&uncond_f32)?;
combined =
uncond_f32.broadcast_add(&diff.affine(resolved.guidance_scale as f64, 0.0)?)?;
}
if do_stg {
let skip_layer_mask = create_skip_layer_mask(
transformer.config().num_layers,
batch,
&resolved.skip_blocks,
device,
)?;
let perturbed = transformer.forward(
&latents_input,
prompt_embeds,
×tep_t,
Some(attention_mask),
latent_f,
latent_h,
latent_w,
None,
Some(&video_coords),
skip_layer_mask.as_ref(),
)?;
let perturbed_f32 = perturbed.to_dtype(DType::F32)?;
let diff_stg = cond_f32.broadcast_sub(&perturbed_f32)?;
combined =
combined.broadcast_add(&diff_stg.affine(resolved.stg_scale as f64, 0.0)?)?;
if pass.guidance.cfg_star_rescale && resolved.rescaling_scale > 0.0 {
combined = rescale_noise_cfg(&combined, &cond_f32, resolved.rescaling_scale)?;
}
}
let model_output =
if transformer.config().out_channels / 2 == transformer.config().in_channels {
combined
.chunk(2, 2)?
.into_iter()
.next()
.expect("out_channels / 2 == in_channels implies chunk(2) succeeds")
} else {
combined
};
if ltx_debug {
let out_rms = model_output.sqr()?.mean_all()?.to_scalar::<f32>()?.sqrt();
let lat_rms = latents.sqr()?.mean_all()?.to_scalar::<f32>()?.sqrt();
if step < 3 || step == run_sigmas.len() - 1 {
progress.info(&format!(
"Pass {stage_name} step {}: sigma={:.4}, guidance={:.2}, stg={:.2}, lat_rms={:.4}, out_rms={:.4}",
step,
sigma,
resolved.guidance_scale,
resolved.stg_scale,
lat_rms,
out_rms
));
}
}
latents = scheduler
.step(&model_output, sigma, &latents, None)?
.prev_sample;
progress.emit(ProgressEvent::DenoiseStep {
step: step + 1,
total: run_sigmas.len(),
elapsed: step_start.elapsed(),
});
}
progress.stage_done(stage_name, denoise_start.elapsed());
unpack_latents(
&latents,
latent_f,
latent_h,
latent_w,
PATCH_SIZE,
PATCH_SIZE_T,
)
}
#[allow(clippy::too_many_arguments)]
fn generate_sequential(
&mut self,
req: &GenerateRequest,
preset: &LtxModelPreset,
seed: u64,
num_frames: u32,
fps: u32,
steps: u32,
guidance: f64,
width: u32,
height: u32,
_latent_h: usize,
_latent_w: usize,
_latent_f: usize,
start: Instant,
) -> Result<GenerateResponse> {
let progress = &self.base.progress;
let paths = &self.base.paths;
let ltx_debug = std::env::var("MOLD_LTX_DEBUG").is_ok_and(|v| v == "1");
if preset.mode == LtxPipelineMode::Multiscale && paths.spatial_upscaler.is_none() {
bail!("LTX 0.9.8 requires a spatial upscaler asset in the pulled model files");
}
let device = crate::device::create_device(self.base.gpu_ordinal, progress)?;
let dtype = gpu_dtype(&device);
progress.info(&format!(
"LTX Video: {}×{} × {} frames, {} steps, seed {}",
width, height, num_frames, steps, seed
));
if preset.mode == LtxPipelineMode::Multiscale {
progress.info("Using the full 0.9.8 multiscale refinement path.");
if steps != preset.base_pass.num_inference_steps {
progress.info(&format!(
"Ignoring --steps={} for multiscale LTX preset {}; using the preset schedule instead.",
steps, self.base.model_name
));
}
if guidance != preset.base_pass.guidance.guidance_scale[0] as f64 {
progress.info(&format!(
"Ignoring --guidance={guidance:.2} for multiscale LTX preset {}; using the preset guidance schedule instead.",
self.base.model_name
));
}
}
progress.stage_start("Loading T5-XXL encoder");
let t5_start = Instant::now();
let t5_encoder_path = paths
.t5_encoder
.as_ref()
.ok_or_else(|| anyhow::anyhow!("T5 encoder path not configured"))?;
let t5_tokenizer_path = paths
.t5_tokenizer
.as_ref()
.ok_or_else(|| anyhow::anyhow!("T5 tokenizer path not configured"))?;
let tier1 = self
.pending_placement
.as_ref()
.map(|p| p.text_encoders)
.unwrap_or_default();
let t5_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
let cached_t5_tokenizer = self
.shared_pool
.as_ref()
.map(|pool| pool.lock().unwrap().load_tokenizer(t5_tokenizer_path))
.transpose()?;
let mut t5 = crate::encoders::t5::T5Encoder::load_with_tokenizer(
t5_encoder_path,
t5_tokenizer_path,
&t5_device,
dtype,
progress,
cached_t5_tokenizer,
)?;
progress.stage_done("Loading T5-XXL encoder", t5_start.elapsed());
progress.stage_start("Encoding prompt");
let encode_start = Instant::now();
let prompt_embeds = t5.encode(&req.prompt, &t5_device, dtype)?;
let prompt_embeds = prompt_embeds.to_device(&device)?;
progress.stage_done("Encoding prompt", encode_start.elapsed());
let prompt_seq_len = prompt_embeds.dim(1)?;
let attention_mask =
Tensor::ones((1, prompt_seq_len), DType::F32, &device)?.to_dtype(dtype)?;
let needs_uncond = match preset.mode {
LtxPipelineMode::Base => guidance > 1.0,
LtxPipelineMode::Multiscale => preset
.multiscale
.as_ref()
.into_iter()
.flat_map(|cfg| [&cfg.first_pass, &cfg.second_pass])
.any(|pass| {
pass.guidance
.guidance_scale
.iter()
.any(|scale| *scale > 1.0)
}),
};
let (uncond_embeds, uncond_mask) = if needs_uncond {
progress.stage_start("Encoding negative prompt (CFG)");
let ue = t5.encode("", &t5_device, dtype)?;
let ue = ue.to_device(&device)?;
let ue_seq = ue.dim(1)?;
let um = Tensor::ones((1, ue_seq), DType::F32, &device)?.to_dtype(dtype)?;
progress.stage_done("Encoding negative prompt (CFG)", encode_start.elapsed());
(Some(ue), Some(um))
} else {
(None, None)
};
drop(t5);
device.synchronize()?;
progress.info("T5 encoder dropped, VRAM freed");
let (mut latents, decode_width, decode_height, tone_map_compression_ratio) = match preset
.mode
{
LtxPipelineMode::Base => {
progress.stage_start("Loading LTX Video transformer");
let transformer_start = Instant::now();
let mut transformer = self.load_transformer(preset, &device, dtype)?;
progress.stage_done("Loading LTX Video transformer", transformer_start.elapsed());
let mut pass = preset.base_pass.clone();
pass.num_inference_steps = steps;
pass.guidance.guidance_scale = vec![guidance as f32];
let latents = self.denoise_pass(
&mut transformer,
&prompt_embeds,
&attention_mask,
uncond_embeds.as_ref(),
uncond_mask.as_ref().map(|m| m as &Tensor),
&pass,
&preset.scheduler_config,
seed,
width,
height,
num_frames,
fps,
&device,
dtype,
progress,
"Denoising",
ltx_debug,
None,
)?;
drop(transformer);
device.synchronize()?;
(latents, width, height, pass.tone_map_compression_ratio)
}
LtxPipelineMode::Multiscale => {
let multiscale = preset.multiscale.as_ref().expect("multiscale preset");
let first_width = ((width as f32 * multiscale.downscale_factor) as u32)
/ VAE_SPATIAL_COMPRESSION as u32
* VAE_SPATIAL_COMPRESSION as u32;
let first_height = ((height as f32 * multiscale.downscale_factor) as u32)
/ VAE_SPATIAL_COMPRESSION as u32
* VAE_SPATIAL_COMPRESSION as u32;
progress.stage_start("Loading LTX Video transformer");
let transformer_start = Instant::now();
let mut first_transformer = self.load_transformer(preset, &device, dtype)?;
progress.stage_done("Loading LTX Video transformer", transformer_start.elapsed());
let first_pass_latents = self.denoise_pass(
&mut first_transformer,
&prompt_embeds,
&attention_mask,
uncond_embeds.as_ref(),
uncond_mask.as_ref().map(|m| m as &Tensor),
&multiscale.first_pass,
&preset.scheduler_config,
seed,
first_width,
first_height,
num_frames,
fps,
&device,
dtype,
progress,
"Denoising First Pass",
ltx_debug,
None,
)?;
drop(first_transformer);
device.synchronize()?;
progress.stage_start("Loading spatial upscaler");
let spatial_start = Instant::now();
let vae = self.load_vae(preset, &device, dtype)?;
let upsampler = LatentUpsampler::load(
paths.spatial_upscaler.as_ref().expect("checked above"),
dtype,
&device,
)?;
progress.stage_done("Loading spatial upscaler", spatial_start.elapsed());
progress.stage_start("Refining multiscale pass");
let refine_start = Instant::now();
let first_pass_denorm = cast_latents_for_multiscale_upsampler(
&denormalize_latents_with_vae(&first_pass_latents, &vae)?,
dtype,
)?;
let upsampled_latents =
normalize_latents_with_vae(&upsampler.forward(&first_pass_denorm)?, &vae)?;
let upsampled_latents =
adain_filter_latents(&upsampled_latents, &first_pass_latents)?;
progress.stage_done("Refining multiscale pass", refine_start.elapsed());
drop(upsampler);
drop(vae);
device.synchronize()?;
let second_width = first_width * 2;
let second_height = first_height * 2;
progress.stage_start("Loading LTX Video transformer");
let transformer_start = Instant::now();
let mut second_transformer = self.load_transformer(preset, &device, dtype)?;
progress.stage_done("Loading LTX Video transformer", transformer_start.elapsed());
let latents = self.denoise_pass(
&mut second_transformer,
&prompt_embeds,
&attention_mask,
uncond_embeds.as_ref(),
uncond_mask.as_ref().map(|m| m as &Tensor),
&multiscale.second_pass,
&preset.scheduler_config,
seed,
second_width,
second_height,
num_frames,
fps,
&device,
dtype,
progress,
"Denoising Second Pass",
ltx_debug,
Some(upsampled_latents),
)?;
drop(second_transformer);
device.synchronize()?;
(
latents,
second_width,
second_height,
multiscale.second_pass.tone_map_compression_ratio,
)
}
};
progress.stage_start("Loading VAE decoder");
let vae_start = Instant::now();
let vae = self.load_vae(preset, &device, dtype)?;
progress.stage_done("Loading VAE decoder", vae_start.elapsed());
progress.stage_start("Decoding video frames");
let decode_start = Instant::now();
let decode_timestep = if vae.config().timestep_conditioning {
if preset.decode_noise_scale > 0.0 {
let noise =
seeded_randn(seed ^ 0xdec0de, latents.shape().dims(), &device, DType::F32)?;
latents = (&latents * (1.0 - preset.decode_noise_scale as f64))?
.broadcast_add(&(noise * preset.decode_noise_scale as f64)?)?;
}
Some(Tensor::full(preset.decode_timestep, (1,), &device)?.to_dtype(dtype)?)
} else {
None
};
latents = tone_map_latents(&latents, tone_map_compression_ratio)?;
latents = denormalize_latents_with_vae(&latents, &vae)?;
if ltx_debug {
let l_f32 = latents.to_dtype(DType::F32)?;
progress.info(&format!(
"Latents pre-VAE (un-normalized): mean={:.4}, std={:.4}",
l_f32.mean_all()?.to_scalar::<f32>()?,
l_f32.flatten_all()?.var(0)?.to_scalar::<f32>()?.sqrt()
));
}
latents = latents.to_dtype(dtype)?;
let (_dec_output, video) = vae.decode(&latents, decode_timestep.as_ref(), false, false)?;
if ltx_debug {
let v_f32 = video.to_dtype(DType::F32)?;
progress.info(&format!(
"VAE output: shape={:?}, mean={:.4}, min={:.4}, max={:.4}",
v_f32.shape(),
v_f32.mean_all()?.to_scalar::<f32>()?,
v_f32.flatten_all()?.min(0)?.to_scalar::<f32>()?,
v_f32.flatten_all()?.max(0)?.to_scalar::<f32>()?
));
}
progress.stage_done("Decoding video frames", decode_start.elapsed());
drop(vae);
device.synchronize()?;
let output_format = if req.resolved_output_format().is_video() {
req.resolved_output_format()
} else {
OutputFormat::Apng
};
let format_name = output_format.extension().to_uppercase();
progress.stage_start(&format!("Encoding {format_name}"));
let encode_start = Instant::now();
let video = video.to_dtype(DType::F32)?;
let video = ((video.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
let video = video.i(0)?;
let num_output_frames = video.dim(1)?;
let mut frames = Vec::with_capacity(num_output_frames);
for f in 0..num_output_frames {
let frame = video.i((.., f, .., ..))?.contiguous()?; let frame = frame.permute((1, 2, 0))?; let frame_data: Vec<u8> = frame.flatten_all()?.to_vec1()?;
let mut rgb = image::RgbImage::from_raw(decode_width, decode_height, frame_data)
.ok_or_else(|| anyhow::anyhow!("failed to create frame image"))?;
if decode_width != width || decode_height != height {
rgb = image::imageops::resize(
&rgb,
width,
height,
image::imageops::FilterType::Triangle,
);
}
frames.push(rgb);
}
let video_bytes = match output_format {
OutputFormat::Apng => {
let metadata = video_enc::VideoMetadata {
prompt: req.prompt.clone(),
model: self.base.model_name.clone(),
seed,
steps,
guidance: req.guidance,
width,
height,
frames: num_output_frames as u32,
fps,
};
video_enc::encode_apng(&frames, fps, Some(&metadata))?
}
OutputFormat::Gif => video_enc::encode_gif(&frames, fps)?,
#[cfg(feature = "webp")]
OutputFormat::Webp => video_enc::encode_webp(&frames, fps)?,
#[cfg(feature = "mp4")]
OutputFormat::Mp4 => video_enc::encode_mp4(&frames, fps)?,
#[cfg(not(feature = "webp"))]
OutputFormat::Webp => {
bail!("WebP output requires the 'webp' feature — rebuild with --features webp")
}
#[cfg(not(feature = "mp4"))]
OutputFormat::Mp4 => {
bail!("MP4 output requires the 'mp4' feature — rebuild with --features mp4")
}
_ => bail!("{format_name} is not a supported video output format"),
};
let thumbnail_bytes = video_enc::first_frame_png(&frames)?;
let gif_preview = if req.gif_preview {
if output_format == OutputFormat::Gif {
video_bytes.clone()
} else {
video_enc::encode_gif(&frames, fps)?
}
} else {
Vec::new()
};
progress.stage_done(&format!("Encoding {format_name}"), encode_start.elapsed());
let generation_time_ms = start.elapsed().as_millis() as u64;
progress.info(&format!(
"Done: {} frames, {:.1}s total",
num_output_frames,
generation_time_ms as f64 / 1000.0
));
Ok(GenerateResponse {
images: vec![],
video: Some(VideoData {
data: video_bytes,
format: output_format,
width,
height,
frames: num_output_frames as u32,
fps,
thumbnail: thumbnail_bytes,
gif_preview,
has_audio: false,
duration_ms: None,
audio_sample_rate: None,
audio_channels: None,
}),
generation_time_ms,
model: self.base.model_name.clone(),
seed_used: seed,
gpu: None,
})
}
}
fn pack_latents(latents: &Tensor, patch_size: usize, patch_size_t: usize) -> Result<Tensor> {
let (b, c, f, h, w) = latents.dims5()?;
if f % patch_size_t != 0 || h % patch_size != 0 || w % patch_size != 0 {
bail!("latent dims not divisible by patch sizes");
}
let f2 = f / patch_size_t;
let h2 = h / patch_size;
let w2 = w / patch_size;
let x = latents.reshape(&[b, c, f2, patch_size_t, h2, patch_size, w2, patch_size])?;
let x = x.permute([0, 2, 4, 6, 1, 3, 5, 7])?;
let x = x.flatten_from(4)?;
let d = x.dim(4)?;
let s = f2 * h2 * w2;
Ok(x.reshape((b, s, d))?)
}
fn unpack_latents(
latents: &Tensor,
num_frames: usize,
height: usize,
width: usize,
patch_size: usize,
patch_size_t: usize,
) -> Result<Tensor> {
let (b, _s, d) = latents.dims3()?;
let denom = patch_size_t * patch_size * patch_size;
if d % denom != 0 {
bail!("D={d} not divisible by patch product {denom}");
}
let c = d / denom;
let x = latents.reshape(&[
b,
num_frames,
height,
width,
c,
patch_size_t,
patch_size,
patch_size,
])?;
let x = x.permute([0, 4, 1, 5, 2, 6, 3, 7])?.contiguous()?;
Ok(x.reshape((
b,
c,
num_frames * patch_size_t,
height * patch_size,
width * patch_size,
))?)
}
fn pack_initial_latents_for_second_pass(latents: &Tensor) -> Result<Tensor> {
pack_latents(&latents.to_dtype(DType::F32)?, PATCH_SIZE, PATCH_SIZE_T)
}
fn cast_latents_for_multiscale_upsampler(latents: &Tensor, dtype: DType) -> Result<Tensor> {
Ok(latents.to_dtype(dtype)?)
}
fn build_video_coords(
batch_size: usize,
latent_f: usize,
latent_h: usize,
latent_w: usize,
fps: u32,
device: &Device,
) -> Result<Tensor> {
let grid_f = Tensor::arange(0u32, latent_f as u32, device)?.to_dtype(DType::F32)?;
let grid_h = Tensor::arange(0u32, latent_h as u32, device)?.to_dtype(DType::F32)?;
let grid_w = Tensor::arange(0u32, latent_w as u32, device)?.to_dtype(DType::F32)?;
let f = grid_f
.reshape((latent_f, 1, 1))?
.broadcast_as((latent_f, latent_h, latent_w))?;
let h = grid_h
.reshape((1, latent_h, 1))?
.broadcast_as((latent_f, latent_h, latent_w))?;
let w = grid_w
.reshape((1, 1, latent_w))?
.broadcast_as((latent_f, latent_h, latent_w))?;
let grid = Tensor::stack(&[f, h, w], 0)?; let seq = latent_f * latent_h * latent_w;
let grid = grid.flatten_from(1)?.transpose(0, 1)?.unsqueeze(0)?;
let vf = grid.i((.., .., 0))?;
let vh = grid.i((.., .., 1))?;
let vw = grid.i((.., .., 2))?;
let ts_ratio = VAE_TEMPORAL_COMPRESSION as f64;
let vf = vf
.affine(ts_ratio, 1.0 - ts_ratio)?
.clamp(0.0f32, 10000.0f32)?
.affine(1.0 / fps as f64, 0.0)?;
let sp_ratio = VAE_SPATIAL_COMPRESSION as f64;
let vh = vh.affine(sp_ratio, 0.0)?;
let vw = vw.affine(sp_ratio, 0.0)?;
let coords = Tensor::stack(&[vf, vh, vw], candle_core::D::Minus1)?;
if batch_size > 1 {
Ok(coords.broadcast_as((batch_size, seq, 3))?)
} else {
Ok(coords)
}
}
#[cfg(test)]
mod tests {
use super::{
cast_latents_for_multiscale_upsampler, is_official_ltx_transformer_checkpoint,
pack_initial_latents_for_second_pass, remap_official_ltx_transformer_key, unpack_latents,
LtxModelPreset, LtxPipelineMode, LtxVideoEngine, LATENT_CHANNELS,
LTX_098_DISTILLED_SECOND_PASS_SIGMAS, PATCH_SIZE, PATCH_SIZE_T,
};
use crate::engine::LoadStrategy;
use crate::shared_pool::SharedPool;
use candle_core::{DType, Device, Tensor};
use mold_core::ModelPaths;
use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
fn temp_test_dir(prefix: &str) -> PathBuf {
let suffix = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
fs::create_dir_all(&dir).unwrap();
dir
}
fn ltx_video_model_paths(dir: &Path, vae: PathBuf) -> ModelPaths {
ModelPaths {
transformer: dir.join("transformer.safetensors"),
transformer_shards: vec![],
vae,
spatial_upscaler: None,
temporal_upscaler: None,
distilled_lora: None,
t5_encoder: Some(dir.join("t5.safetensors")),
clip_encoder: None,
t5_tokenizer: Some(dir.join("tokenizer.json")),
clip_tokenizer: None,
clip_encoder_2: None,
clip_tokenizer_2: None,
text_encoder_files: vec![],
text_tokenizer: None,
decoder: None,
}
}
#[test]
fn decode_apng_round_trips_rgb_frames() {
use crate::ltx_video::video_enc::encode_apng;
use image::Rgb;
let make = |r: u8, g: u8, b: u8| {
let mut img = image::RgbImage::new(4, 4);
for px in img.pixels_mut() {
*px = Rgb([r, g, b]);
}
img
};
let inputs = vec![make(255, 0, 0), make(0, 255, 0), make(0, 0, 255)];
let bytes = encode_apng(&inputs, 12, None).expect("encode");
let decoded = super::decode_apng_to_rgb_frames(&bytes).expect("decode");
assert_eq!(decoded.len(), inputs.len());
for (i, (a, b)) in inputs.iter().zip(decoded.iter()).enumerate() {
assert_eq!(a.dimensions(), b.dimensions(), "frame {i} size");
assert_eq!(
a.get_pixel(0, 0),
b.get_pixel(0, 0),
"frame {i} pixel mismatch",
);
}
}
#[test]
fn detects_official_ltx_single_file_checkpoints() {
assert!(is_official_ltx_transformer_checkpoint(Path::new(
"ltxv-2b-0.9.6-distilled-04-25.safetensors"
)));
assert!(is_official_ltx_transformer_checkpoint(Path::new(
"ltxv-13b-0.9.8-dev.safetensors"
)));
assert!(!is_official_ltx_transformer_checkpoint(Path::new(
"diffusion_pytorch_model-00001-of-00002.safetensors"
)));
assert!(!is_official_ltx_transformer_checkpoint(Path::new(
"transformer.gguf"
)));
}
#[test]
fn remaps_official_transformer_keys_to_upstream_checkpoint_names() {
assert_eq!(
remap_official_ltx_transformer_key("proj_in.weight"),
"model.diffusion_model.patchify_proj.weight"
);
assert_eq!(
remap_official_ltx_transformer_key("time_embed.emb.timestep_embedder.linear_1.weight"),
"model.diffusion_model.adaln_single.emb.timestep_embedder.linear_1.weight"
);
assert_eq!(
remap_official_ltx_transformer_key("transformer_blocks.0.attn1.norm_q.weight"),
"model.diffusion_model.transformer_blocks.0.attn1.q_norm.weight"
);
assert_eq!(
remap_official_ltx_transformer_key("caption_projection.linear_2.bias"),
"model.diffusion_model.caption_projection.linear_2.bias"
);
}
#[test]
fn ltx_video_loads_standalone_vae_tensors_through_shared_pool() {
let dir = temp_test_dir("mold-ltx-video-vae-pool");
let vae_path = dir.join("vae.safetensors");
let weight = 1.0f32.to_le_bytes();
let mut tensors = HashMap::new();
tensors.insert(
"encoder.conv_in.weight".to_string(),
TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
);
serialize_to_file(&tensors, &None, &vae_path).unwrap();
let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
let pooled = shared_pool
.lock()
.unwrap()
.load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
.unwrap()
.unwrap();
let engine = LtxVideoEngine::new(
"ltx-video-0.9.6:bf16".to_string(),
ltx_video_model_paths(&dir, vae_path),
None,
LoadStrategy::Sequential,
0,
Some(shared_pool),
);
let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
assert!(Arc::ptr_eq(&pooled, &loaded));
fs::remove_dir_all(dir).ok();
}
#[test]
fn ltx_presets_only_reference_in_range_skip_blocks() {
for model_name in [
"ltx-video-0.9.6:bf16",
"ltx-video-0.9.6-distilled:bf16",
"ltx-video-0.9.8-2b-distilled:bf16",
"ltx-video-0.9.8-13b-dev:bf16",
"ltx-video-0.9.8-13b-distilled:bf16",
] {
let preset = LtxModelPreset::for_model(model_name).expect("preset should exist");
let mut all_skip_lists = vec![preset.base_pass.guidance.skip_block_list.clone()];
if let Some(multiscale) = &preset.multiscale {
all_skip_lists.push(multiscale.first_pass.guidance.skip_block_list.clone());
all_skip_lists.push(multiscale.second_pass.guidance.skip_block_list.clone());
}
for skip_list_group in all_skip_lists {
for skip_list in skip_list_group {
for skip_block in skip_list {
assert!(
skip_block < preset.transformer_config.num_layers,
"{model_name} skip block {skip_block} is out of range for {} layers",
preset.transformer_config.num_layers
);
}
}
}
}
}
#[test]
fn ltx_video_13b_bf16_fails_before_allocation_when_vram_cannot_hold_residency() {
let preset = LtxModelPreset::for_model("ltx-video-0.9.8-13b-dev:bf16").unwrap();
let err = super::ltx_video_transformer_residency_guard(
"ltx-video-0.9.8-13b-dev:bf16",
&preset,
26_000_000_000,
Some(24_000_000_000),
true,
)
.unwrap_err()
.to_string();
assert!(err.contains("MOLD_OFFLOAD"));
assert!(err.contains("ltx-video-0.9.8-2b-distilled"));
assert!(err.contains("--width/--height/--frames"));
}
#[test]
fn ltx_video_2b_bf16_residency_guard_does_not_reject() {
let preset = LtxModelPreset::for_model("ltx-video-0.9.8-2b-distilled:bf16").unwrap();
super::ltx_video_transformer_residency_guard(
"ltx-video-0.9.8-2b-distilled:bf16",
&preset,
26_000_000_000,
Some(24_000_000_000),
true,
)
.unwrap();
}
#[test]
fn ltx_098_presets_use_multiscale_mode() {
for model_name in [
"ltx-video-0.9.8-2b-distilled:bf16",
"ltx-video-0.9.8-13b-dev:bf16",
"ltx-video-0.9.8-13b-distilled:bf16",
] {
let preset = LtxModelPreset::for_model(model_name).expect("preset should exist");
assert_eq!(preset.mode, LtxPipelineMode::Multiscale, "{model_name}");
assert!(preset.multiscale.is_some(), "{model_name}");
}
}
#[test]
fn ltx_098_distilled_second_pass_uses_upstream_sigmas() {
for model_name in [
"ltx-video-0.9.8-2b-distilled:bf16",
"ltx-video-0.9.8-13b-distilled:bf16",
] {
let preset = LtxModelPreset::for_model(model_name).expect("preset should exist");
let multiscale = preset.multiscale.as_ref().expect("multiscale preset");
assert_eq!(
multiscale.second_pass.custom_sigmas.as_deref(),
Some(LTX_098_DISTILLED_SECOND_PASS_SIGMAS),
"{model_name}"
);
}
}
#[test]
fn multiscale_handoff_normalizes_dtypes_for_upsampler_and_second_pass() {
let device = Device::Cpu;
let second_pass_latents =
Tensor::arange(0f32, (LATENT_CHANNELS * 2 * 4 * 6) as f32, &device)
.expect("tensor")
.reshape((1, LATENT_CHANNELS, 2, 4, 6))
.expect("reshape")
.to_dtype(DType::BF16)
.expect("bf16");
let packed = pack_initial_latents_for_second_pass(&second_pass_latents)
.expect("second-pass repack should succeed");
assert_eq!(packed.dtype(), DType::F32);
let unpacked = unpack_latents(&packed, 2, 4, 6, PATCH_SIZE, PATCH_SIZE_T)
.expect("unpack should round-trip");
assert_eq!(unpacked.dtype(), DType::F32);
assert_eq!(
unpacked.dims5().expect("dims"),
(1, LATENT_CHANNELS, 2, 4, 6)
);
assert_eq!(
unpacked
.flatten_all()
.expect("flatten")
.to_vec1::<f32>()
.expect("vec"),
second_pass_latents
.to_dtype(DType::F32)
.expect("f32")
.flatten_all()
.expect("flatten")
.to_vec1::<f32>()
.expect("vec")
);
let upsampler_input =
cast_latents_for_multiscale_upsampler(&unpacked, DType::BF16).expect("cast");
assert_eq!(upsampler_input.dtype(), DType::BF16);
assert_eq!(
upsampler_input.dims5().expect("dims"),
(1, LATENT_CHANNELS, 2, 4, 6)
);
}
}