1use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::{Arc, Mutex};
9use std::time::Instant;
10
11use anyhow::{bail, Context, Result};
12use candle_core::{DType, Device, IndexOp, Tensor};
13use candle_nn::VarBuilder;
14use candle_transformers::models::ltx_video::{
15 sampling::{
16 FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerConfig, TimeShiftType,
17 },
18 transformer::{LtxVideoTransformer3DModel, LtxVideoTransformer3DModelConfig},
19 vae::{AutoencoderKLLtxVideo, AutoencoderKLLtxVideoConfig},
20};
21
22use mold_core::{GenerateRequest, GenerateResponse, ModelPaths, OutputFormat, VideoData};
23
24use crate::device::{fmt_gb, usable_free_vram_bytes};
25use crate::engine::{gpu_dtype, rand_seed, seeded_randn, LoadStrategy};
26use crate::engine_base::EngineBase;
27use crate::progress::{ProgressCallback, ProgressEvent};
28use crate::shared_pool::SharedPool;
29
30use super::{latent_upsampler::LatentUpsampler, video_enc};
31
32const VAE_SPATIAL_COMPRESSION: usize = 32;
38const VAE_TEMPORAL_COMPRESSION: usize = 8;
40const LATENT_CHANNELS: usize = 128;
42const PATCH_SIZE: usize = 1;
44const PATCH_SIZE_T: usize = 1;
45
46const LTX_098_DISTILLED_FIRST_PASS_SIGMAS: &[f32] =
47 &[1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250];
48const LTX_098_DISTILLED_SECOND_PASS_SIGMAS: &[f32] = &[0.9094, 0.7250, 0.4219];
49const LTX_098_DEV_FIRST_PASS_GUIDANCE_SCALE: &[f32] = &[1.0, 1.0, 6.0, 8.0, 6.0, 1.0, 1.0];
50const LTX_098_DEV_FIRST_PASS_STG_SCALE: &[f32] = &[0.0, 0.0, 4.0, 4.0, 4.0, 2.0, 1.0];
51const LTX_098_DEV_FIRST_PASS_RESCALING_SCALE: &[f32] = &[1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 1.0];
52const LTX_098_DEV_FIRST_PASS_GUIDANCE_TIMESTEPS: &[f32] =
53 &[1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180];
54const LTX_096_DEV_SKIP_BLOCKS: &[usize] = &[19];
55const LTX_098_2B_DISTILLED_SKIP_BLOCKS: &[usize] = &[];
56const LTX_098_13B_DISTILLED_SKIP_BLOCKS: &[usize] = &[42];
57const LTX_098_13B_DEV_FIRST_PASS_SKIP_BLOCKS: &[&[usize]] = &[
58 &[],
59 &[11, 25, 35, 39],
60 &[22, 35, 39],
61 &[28],
62 &[28],
63 &[28],
64 &[28],
65];
66const LTX_098_13B_DEV_SECOND_PASS_SKIP_BLOCKS: &[usize] = &[27];
67const LTX_VIDEO_FULL_RESIDENT_RUNTIME_HEADROOM: u64 = 2_000_000_000;
68
69fn is_official_ltx_transformer_checkpoint(path: &std::path::Path) -> bool {
70 path.file_name()
71 .and_then(|name| name.to_str())
72 .is_some_and(|name| {
73 name.ends_with(".safetensors")
74 && name.starts_with("ltx")
75 && !name.starts_with("diffusion_pytorch_model")
76 })
77}
78
79fn remap_official_ltx_transformer_key(key: &str) -> String {
80 let key = key
81 .replace("proj_in", "patchify_proj")
82 .replace("time_embed", "adaln_single")
83 .replace("norm_q", "q_norm")
84 .replace("norm_k", "k_norm");
85 format!("model.diffusion_model.{key}")
86}
87
88#[derive(Clone, Copy, Debug, PartialEq, Eq)]
89enum LtxPipelineMode {
90 Base,
91 Multiscale,
92}
93
94#[derive(Clone, Debug)]
95struct LtxGuidanceConfig {
96 guidance_scale: Vec<f32>,
97 stg_scale: Vec<f32>,
98 rescaling_scale: Vec<f32>,
99 guidance_timesteps: Option<Vec<f32>>,
100 skip_block_list: Vec<Vec<usize>>,
101 cfg_star_rescale: bool,
102}
103
104impl LtxGuidanceConfig {
105 fn constant(
106 guidance_scale: f32,
107 stg_scale: f32,
108 rescaling_scale: f32,
109 skip_block_list: &[usize],
110 ) -> Self {
111 Self {
112 guidance_scale: vec![guidance_scale],
113 stg_scale: vec![stg_scale],
114 rescaling_scale: vec![rescaling_scale],
115 guidance_timesteps: None,
116 skip_block_list: vec![skip_block_list.to_vec()],
117 cfg_star_rescale: false,
118 }
119 }
120}
121
122#[derive(Clone, Debug)]
123struct LtxPassConfig {
124 num_inference_steps: u32,
125 custom_sigmas: Option<Vec<f32>>,
126 skip_initial_inference_steps: usize,
127 skip_final_inference_steps: usize,
128 guidance: LtxGuidanceConfig,
129 tone_map_compression_ratio: f32,
130}
131
132#[derive(Clone, Debug)]
133struct LtxMultiscaleConfig {
134 downscale_factor: f32,
135 first_pass: LtxPassConfig,
136 second_pass: LtxPassConfig,
137}
138
139#[derive(Clone, Debug)]
140struct LtxModelPreset {
141 transformer_config: LtxVideoTransformer3DModelConfig,
142 vae_config: AutoencoderKLLtxVideoConfig,
143 scheduler_config: FlowMatchEulerDiscreteSchedulerConfig,
144 base_pass: LtxPassConfig,
145 decode_timestep: f32,
146 decode_noise_scale: f32,
147 mode: LtxPipelineMode,
148 multiscale: Option<LtxMultiscaleConfig>,
149}
150
151impl LtxModelPreset {
152 fn for_model(model_name: &str) -> Result<Self> {
153 if model_name.contains("ltx-video-0.9.6-distilled") {
154 Ok(Self {
155 transformer_config: transformer_2b_config(),
156 vae_config: improved_vae_config(),
157 scheduler_config: scheduler_config(true),
158 base_pass: LtxPassConfig {
159 num_inference_steps: 8,
160 custom_sigmas: None,
161 skip_initial_inference_steps: 0,
162 skip_final_inference_steps: 0,
163 guidance: LtxGuidanceConfig::constant(1.0, 0.0, 1.0, &[]),
164 tone_map_compression_ratio: 0.0,
165 },
166 decode_timestep: 0.05,
167 decode_noise_scale: 0.025,
168 mode: LtxPipelineMode::Base,
169 multiscale: None,
170 })
171 } else if model_name.contains("ltx-video-0.9.6") {
172 Ok(Self {
173 transformer_config: transformer_2b_config(),
174 vae_config: improved_vae_config(),
175 scheduler_config: scheduler_config(false),
176 base_pass: LtxPassConfig {
177 num_inference_steps: 40,
178 custom_sigmas: None,
179 skip_initial_inference_steps: 0,
180 skip_final_inference_steps: 0,
181 guidance: LtxGuidanceConfig::constant(3.0, 1.0, 0.7, LTX_096_DEV_SKIP_BLOCKS),
182 tone_map_compression_ratio: 0.0,
183 },
184 decode_timestep: 0.05,
185 decode_noise_scale: 0.025,
186 mode: LtxPipelineMode::Base,
187 multiscale: None,
188 })
189 } else if model_name.contains("ltx-video-0.9.8-2b-distilled") {
190 Ok(Self {
191 transformer_config: transformer_2b_config(),
192 vae_config: improved_vae_config(),
193 scheduler_config: scheduler_config(false),
194 base_pass: LtxPassConfig {
195 num_inference_steps: 7,
196 custom_sigmas: Some(LTX_098_DISTILLED_FIRST_PASS_SIGMAS.to_vec()),
197 skip_initial_inference_steps: 0,
198 skip_final_inference_steps: 0,
199 guidance: LtxGuidanceConfig::constant(
200 1.0,
201 0.0,
202 1.0,
203 LTX_098_2B_DISTILLED_SKIP_BLOCKS,
204 ),
205 tone_map_compression_ratio: 0.0,
206 },
207 decode_timestep: 0.05,
208 decode_noise_scale: 0.025,
209 mode: LtxPipelineMode::Multiscale,
210 multiscale: Some(LtxMultiscaleConfig {
211 downscale_factor: 0.6666666,
212 first_pass: LtxPassConfig {
213 num_inference_steps: 7,
214 custom_sigmas: Some(LTX_098_DISTILLED_FIRST_PASS_SIGMAS.to_vec()),
215 skip_initial_inference_steps: 0,
216 skip_final_inference_steps: 0,
217 guidance: LtxGuidanceConfig::constant(
218 1.0,
219 0.0,
220 1.0,
221 LTX_098_2B_DISTILLED_SKIP_BLOCKS,
222 ),
223 tone_map_compression_ratio: 0.0,
224 },
225 second_pass: LtxPassConfig {
226 num_inference_steps: 3,
227 custom_sigmas: Some(LTX_098_DISTILLED_SECOND_PASS_SIGMAS.to_vec()),
228 skip_initial_inference_steps: 0,
229 skip_final_inference_steps: 0,
230 guidance: LtxGuidanceConfig::constant(
231 1.0,
232 0.0,
233 1.0,
234 LTX_098_2B_DISTILLED_SKIP_BLOCKS,
235 ),
236 tone_map_compression_ratio: 0.0,
237 },
238 }),
239 })
240 } else if model_name.contains("ltx-video-0.9.8-13b-distilled") {
241 Ok(Self {
242 transformer_config: transformer_13b_config(),
243 vae_config: improved_vae_config(),
244 scheduler_config: scheduler_config(false),
245 base_pass: LtxPassConfig {
246 num_inference_steps: 7,
247 custom_sigmas: Some(LTX_098_DISTILLED_FIRST_PASS_SIGMAS.to_vec()),
248 skip_initial_inference_steps: 0,
249 skip_final_inference_steps: 0,
250 guidance: LtxGuidanceConfig::constant(
251 1.0,
252 0.0,
253 1.0,
254 LTX_098_13B_DISTILLED_SKIP_BLOCKS,
255 ),
256 tone_map_compression_ratio: 0.0,
257 },
258 decode_timestep: 0.05,
259 decode_noise_scale: 0.025,
260 mode: LtxPipelineMode::Multiscale,
261 multiscale: Some(LtxMultiscaleConfig {
262 downscale_factor: 0.6666666,
263 first_pass: LtxPassConfig {
264 num_inference_steps: 7,
265 custom_sigmas: Some(LTX_098_DISTILLED_FIRST_PASS_SIGMAS.to_vec()),
266 skip_initial_inference_steps: 0,
267 skip_final_inference_steps: 0,
268 guidance: LtxGuidanceConfig::constant(
269 1.0,
270 0.0,
271 1.0,
272 LTX_098_13B_DISTILLED_SKIP_BLOCKS,
273 ),
274 tone_map_compression_ratio: 0.0,
275 },
276 second_pass: LtxPassConfig {
277 num_inference_steps: 3,
278 custom_sigmas: Some(LTX_098_DISTILLED_SECOND_PASS_SIGMAS.to_vec()),
279 skip_initial_inference_steps: 0,
280 skip_final_inference_steps: 0,
281 guidance: LtxGuidanceConfig::constant(
282 1.0,
283 0.0,
284 1.0,
285 LTX_098_13B_DISTILLED_SKIP_BLOCKS,
286 ),
287 tone_map_compression_ratio: 0.6,
288 },
289 }),
290 })
291 } else if model_name.contains("ltx-video-0.9.8-13b-dev") {
292 Ok(Self {
293 transformer_config: transformer_13b_config(),
294 vae_config: improved_vae_config(),
295 scheduler_config: scheduler_config(false),
296 base_pass: LtxPassConfig {
297 num_inference_steps: 30,
298 custom_sigmas: None,
299 skip_initial_inference_steps: 0,
300 skip_final_inference_steps: 0,
301 guidance: LtxGuidanceConfig::constant(8.0, 4.0, 0.5, &[28]),
302 tone_map_compression_ratio: 0.0,
303 },
304 decode_timestep: 0.05,
305 decode_noise_scale: 0.025,
306 mode: LtxPipelineMode::Multiscale,
307 multiscale: Some(LtxMultiscaleConfig {
308 downscale_factor: 0.6666666,
309 first_pass: LtxPassConfig {
310 num_inference_steps: 30,
311 custom_sigmas: None,
312 skip_initial_inference_steps: 0,
313 skip_final_inference_steps: 3,
314 guidance: LtxGuidanceConfig {
315 guidance_scale: LTX_098_DEV_FIRST_PASS_GUIDANCE_SCALE.to_vec(),
316 stg_scale: LTX_098_DEV_FIRST_PASS_STG_SCALE.to_vec(),
317 rescaling_scale: LTX_098_DEV_FIRST_PASS_RESCALING_SCALE.to_vec(),
318 guidance_timesteps: Some(
319 LTX_098_DEV_FIRST_PASS_GUIDANCE_TIMESTEPS.to_vec(),
320 ),
321 skip_block_list: LTX_098_13B_DEV_FIRST_PASS_SKIP_BLOCKS
322 .iter()
323 .map(|blocks| blocks.to_vec())
324 .collect(),
325 cfg_star_rescale: true,
326 },
327 tone_map_compression_ratio: 0.0,
328 },
329 second_pass: LtxPassConfig {
330 num_inference_steps: 30,
331 custom_sigmas: None,
332 skip_initial_inference_steps: 17,
333 skip_final_inference_steps: 0,
334 guidance: LtxGuidanceConfig {
335 guidance_scale: vec![1.0],
336 stg_scale: vec![1.0],
337 rescaling_scale: vec![1.0],
338 guidance_timesteps: Some(vec![1.0]),
339 skip_block_list: vec![LTX_098_13B_DEV_SECOND_PASS_SKIP_BLOCKS.to_vec()],
340 cfg_star_rescale: true,
341 },
342 tone_map_compression_ratio: 0.0,
343 },
344 }),
345 })
346 } else {
347 bail!("unsupported LTX model preset for {}", model_name);
348 }
349 }
350}
351
352fn transformer_2b_config() -> LtxVideoTransformer3DModelConfig {
353 LtxVideoTransformer3DModelConfig {
354 num_layers: 28,
355 num_attention_heads: 32,
356 attention_head_dim: 64,
357 cross_attention_dim: 2048,
358 caption_channels: 4096,
359 ..Default::default()
360 }
361}
362
363fn transformer_13b_config() -> LtxVideoTransformer3DModelConfig {
364 LtxVideoTransformer3DModelConfig {
365 num_layers: 48,
366 num_attention_heads: 32,
367 attention_head_dim: 128,
368 cross_attention_dim: 4096,
369 caption_channels: 4096,
370 ..Default::default()
371 }
372}
373
374fn is_legacy_ltx_video_13b(model_name: &str, preset: &LtxModelPreset) -> bool {
375 model_name.contains("13b")
376 || (preset.transformer_config.num_layers >= 48
377 && preset.transformer_config.attention_head_dim >= 128)
378}
379
380fn ltx_video_transformer_residency_guard(
381 model_name: &str,
382 preset: &LtxModelPreset,
383 transformer_bytes: u64,
384 usable_vram_bytes: Option<u64>,
385 is_cuda: bool,
386) -> Result<()> {
387 if !is_cuda || !is_legacy_ltx_video_13b(model_name, preset) {
388 return Ok(());
389 }
390 let Some(usable_vram_bytes) = usable_vram_bytes.filter(|bytes| *bytes > 0) else {
391 return Ok(());
392 };
393 let required = transformer_bytes.saturating_add(LTX_VIDEO_FULL_RESIDENT_RUNTIME_HEADROOM);
394 if required <= usable_vram_bytes {
395 return Ok(());
396 }
397
398 bail!(
399 "legacy LTX-Video 13B BF16 requires full transformer residency ({} weights + {} runtime headroom) but only {} usable VRAM is available. MOLD_OFFLOAD is not implemented for this legacy transformer yet; use ltx-video-0.9.8-2b-distilled, lower --width/--height/--frames, or use an LTX-2 FP8 model with adaptive offload.",
400 fmt_gb(transformer_bytes),
401 fmt_gb(LTX_VIDEO_FULL_RESIDENT_RUNTIME_HEADROOM),
402 fmt_gb(usable_vram_bytes),
403 )
404}
405
406fn improved_vae_config() -> AutoencoderKLLtxVideoConfig {
407 AutoencoderKLLtxVideoConfig {
408 block_out_channels: vec![128, 256, 512, 1024, 2048],
409 decoder_block_out_channels: vec![256, 512, 1024],
410 spatiotemporal_scaling: vec![true, true, true, true],
411 decoder_spatiotemporal_scaling: vec![true, true, true],
412 layers_per_block: vec![4, 6, 6, 2, 2],
413 decoder_layers_per_block: vec![5, 5, 5, 5],
414 decoder_inject_noise: vec![false, false, false, false],
415 decoder_upsample_residual: vec![true, true, true],
416 decoder_upsample_factor: vec![2, 2, 2],
417 timestep_conditioning: true,
418 ..Default::default()
419 }
420}
421
422fn scheduler_config(stochastic_sampling: bool) -> FlowMatchEulerDiscreteSchedulerConfig {
423 FlowMatchEulerDiscreteSchedulerConfig {
424 num_train_timesteps: 1000,
425 shift: 1.0,
426 use_dynamic_shifting: false,
427 base_shift: Some(0.5),
428 max_shift: Some(1.15),
429 base_image_seq_len: Some(256),
430 max_image_seq_len: Some(4096),
431 invert_sigmas: false,
432 shift_terminal: None,
433 use_karras_sigmas: false,
434 use_exponential_sigmas: false,
435 use_beta_sigmas: false,
436 time_shift_type: TimeShiftType::Exponential,
437 stochastic_sampling,
438 }
439}
440
441#[derive(Clone, Debug)]
442struct LtxResolvedStep {
443 guidance_scale: f32,
444 stg_scale: f32,
445 rescaling_scale: f32,
446 skip_blocks: Vec<usize>,
447}
448
449fn clamp_skip_blocks(skip_blocks: &[usize], num_layers: usize) -> Vec<usize> {
450 skip_blocks
451 .iter()
452 .copied()
453 .filter(|idx| *idx < num_layers)
454 .collect()
455}
456
457fn resolve_guidance_index(guidance_timesteps: &[f32], sigma: f32) -> usize {
458 guidance_timesteps
459 .iter()
460 .position(|value| *value <= sigma)
461 .unwrap_or_else(|| guidance_timesteps.len().saturating_sub(1))
462}
463
464fn resolve_step_schedule(
465 pass: &LtxPassConfig,
466 sigmas: &[f32],
467 num_layers: usize,
468) -> Vec<LtxResolvedStep> {
469 sigmas
470 .iter()
471 .map(|sigma| {
472 let mapped = pass
473 .guidance
474 .guidance_timesteps
475 .as_ref()
476 .map(|timesteps| resolve_guidance_index(timesteps, *sigma))
477 .unwrap_or(0);
478 let value_at = |values: &[f32]| -> f32 {
479 if values.len() == 1 {
480 values[0]
481 } else {
482 values[mapped.min(values.len() - 1)]
483 }
484 };
485 let skip_blocks = if pass.guidance.skip_block_list.is_empty() {
486 Vec::new()
487 } else if pass.guidance.skip_block_list.len() == 1 {
488 clamp_skip_blocks(&pass.guidance.skip_block_list[0], num_layers)
489 } else {
490 clamp_skip_blocks(
491 &pass.guidance.skip_block_list
492 [mapped.min(pass.guidance.skip_block_list.len() - 1)],
493 num_layers,
494 )
495 };
496 LtxResolvedStep {
497 guidance_scale: value_at(&pass.guidance.guidance_scale),
498 stg_scale: value_at(&pass.guidance.stg_scale),
499 rescaling_scale: value_at(&pass.guidance.rescaling_scale),
500 skip_blocks,
501 }
502 })
503 .collect()
504}
505
506fn std_over_dims_except0_keepdim(x: &Tensor) -> Result<Tensor> {
507 let rank = x.rank();
508 if rank < 2 {
509 bail!("std_over_dims_except0_keepdim expects rank >= 2, got {rank}");
510 }
511 let b = x.dim(0)?;
512 let flat = x.flatten_from(1)?;
513 let var = flat.var_keepdim(1)?;
514 let std = var.sqrt()?;
515 let mut shape = Vec::with_capacity(rank);
516 shape.push(b);
517 shape.extend(std::iter::repeat_n(1usize, rank - 1));
518 Ok(std.reshape(shape)?)
519}
520
521fn rescale_noise_cfg(
522 noise_cfg: &Tensor,
523 noise_pred_text: &Tensor,
524 guidance_rescale: f32,
525) -> Result<Tensor> {
526 let std_text = std_over_dims_except0_keepdim(noise_pred_text)?;
527 let std_cfg = std_over_dims_except0_keepdim(noise_cfg)?;
528 let ratio = std_text.broadcast_div(&std_cfg)?;
529 let noise_pred_rescaled = noise_cfg.broadcast_mul(&ratio)?;
530 let a = noise_pred_rescaled.affine(guidance_rescale as f64, 0.0)?;
531 let b = noise_cfg.affine((1.0 - guidance_rescale) as f64, 0.0)?;
532 Ok(a.broadcast_add(&b)?)
533}
534
535fn cfg_star_rescale_uncond(noise_pred_uncond: &Tensor, noise_pred_text: &Tensor) -> Result<Tensor> {
536 let batch = noise_pred_text.dim(0)?;
537 let positive_flat = noise_pred_text.flatten_from(1)?;
538 let negative_flat = noise_pred_uncond.flatten_from(1)?;
539 let dot = positive_flat
540 .broadcast_mul(&negative_flat)?
541 .sum_keepdim(1)?;
542 let squared = negative_flat.sqr()?.sum_keepdim(1)?.affine(1.0, 1e-8)?;
543 let alpha = dot.broadcast_div(&squared)?;
544 let alpha = alpha.reshape((batch, 1, 1))?;
545 Ok(noise_pred_uncond.broadcast_mul(&alpha.broadcast_as(noise_pred_uncond.shape())?)?)
546}
547
548fn create_skip_layer_mask(
549 num_layers: usize,
550 batch_size: usize,
551 layers_to_skip: &[usize],
552 device: &Device,
553) -> Result<Option<Tensor>> {
554 let layers_to_skip = clamp_skip_blocks(layers_to_skip, num_layers);
555 if layers_to_skip.is_empty() {
556 return Ok(None);
557 }
558
559 let mut mask_data = vec![0.0f32; num_layers * batch_size];
560 for &layer_idx in &layers_to_skip {
561 for batch_idx in 0..batch_size {
562 mask_data[layer_idx * batch_size + batch_idx] = 1.0;
563 }
564 }
565 Ok(Some(Tensor::from_vec(
566 mask_data,
567 (num_layers, batch_size),
568 device,
569 )?))
570}
571
572fn tone_map_latents(latents: &Tensor, compression: f32) -> Result<Tensor> {
573 if compression == 0.0 {
574 return Ok(latents.clone());
575 }
576 if !(0.0..=1.0).contains(&compression) {
577 bail!("tone map compression must be in [0, 1], got {compression}");
578 }
579 let scale_factor = compression * 0.75;
580 let abs_latents = latents.abs()?;
581 let sigmoid_term = abs_latents
582 .affine(1.0, -1.0)?
583 .affine((4.0 * scale_factor) as f64, 0.0)?;
584 let sigmoid_term = candle_nn::ops::sigmoid(&sigmoid_term)?;
585 let scales = sigmoid_term.affine((-0.8 * scale_factor) as f64, 1.0)?;
586 Ok(latents.broadcast_mul(&scales)?)
587}
588
589fn normalize_latents_with_vae(latents: &Tensor, vae: &AutoencoderKLLtxVideo) -> Result<Tensor> {
590 let c = latents.dim(1)?;
591 let mean = vae
592 .latents_mean()
593 .reshape((1, c, 1, 1, 1))?
594 .to_device(latents.device())?
595 .to_dtype(latents.dtype())?;
596 let std = vae
597 .latents_std()
598 .reshape((1, c, 1, 1, 1))?
599 .to_device(latents.device())?
600 .to_dtype(latents.dtype())?;
601 Ok(latents.broadcast_sub(&mean)?.broadcast_div(&std)?)
602}
603
604fn denormalize_latents_with_vae(latents: &Tensor, vae: &AutoencoderKLLtxVideo) -> Result<Tensor> {
605 let c = latents.dim(1)?;
606 let mean = vae
607 .latents_mean()
608 .reshape((1, c, 1, 1, 1))?
609 .to_device(latents.device())?
610 .to_dtype(latents.dtype())?;
611 let std = vae
612 .latents_std()
613 .reshape((1, c, 1, 1, 1))?
614 .to_device(latents.device())?
615 .to_dtype(latents.dtype())?;
616 Ok(latents.broadcast_mul(&std)?.broadcast_add(&mean)?)
617}
618
619fn adain_filter_latents(latents: &Tensor, reference_latents: &Tensor) -> Result<Tensor> {
620 let latents_f32 = latents.to_dtype(DType::F32)?;
621 let reference_f32 = reference_latents.to_dtype(DType::F32)?;
622
623 let latents_flat = latents_f32.flatten_from(2)?;
624 let reference_flat = reference_f32.flatten_from(2)?;
625
626 let lat_mean = latents_flat.mean_keepdim(2)?;
627 let lat_std = latents_flat.var_keepdim(2)?.affine(1.0, 1e-6)?.sqrt()?;
628 let ref_mean = reference_flat.mean_keepdim(2)?;
629 let ref_std = reference_flat.var_keepdim(2)?.affine(1.0, 1e-6)?.sqrt()?;
630
631 let filtered = latents_flat
632 .broadcast_sub(&lat_mean)?
633 .broadcast_div(&lat_std)?
634 .broadcast_mul(&ref_std)?
635 .broadcast_add(&ref_mean)?
636 .reshape(latents.shape())?;
637
638 Ok(filtered.to_dtype(latents.dtype())?)
639}
640
641#[allow(dead_code)]
646struct LoadedLtxVideo {
647 transformer: Option<LtxVideoTransformer3DModel>,
648 vae: Option<AutoencoderKLLtxVideo>,
649 device: Device,
650 dtype: DType,
651}
652
653#[allow(dead_code)]
658pub struct LtxVideoEngine {
659 base: EngineBase<LoadedLtxVideo>,
660 t5_variant: Option<String>,
661 shared_pool: Option<Arc<Mutex<SharedPool>>>,
662 pending_placement: Option<mold_core::types::DevicePlacement>,
663 single_file_native_format: Option<bool>,
670 vae_in_checkpoint: bool,
673}
674
675impl LtxVideoEngine {
676 pub fn new(
677 model_name: String,
678 paths: ModelPaths,
679 t5_variant: Option<String>,
680 load_strategy: LoadStrategy,
681 gpu_ordinal: usize,
682 shared_pool: Option<Arc<Mutex<SharedPool>>>,
683 ) -> Self {
684 Self {
685 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
686 t5_variant,
687 shared_pool,
688 pending_placement: None,
689 single_file_native_format: None,
690 vae_in_checkpoint: false,
691 }
692 }
693
694 #[allow(clippy::too_many_arguments)]
708 pub fn from_single_file(
709 model_name: String,
710 checkpoint: PathBuf,
711 vae_path: Option<PathBuf>,
712 t5_encoder_path: Option<PathBuf>,
713 t5_tokenizer_path: Option<PathBuf>,
714 t5_variant: Option<String>,
715 load_strategy: LoadStrategy,
716 gpu_ordinal: usize,
717 shared_pool: Option<Arc<Mutex<SharedPool>>>,
718 ) -> anyhow::Result<Self> {
719 if !checkpoint.exists() {
720 anyhow::bail!(
721 "single-file LTX-Video checkpoint not found: {}",
722 checkpoint.display()
723 );
724 }
725
726 let bundle = super::single_file::load(&checkpoint).map_err(|e| {
727 anyhow::anyhow!(
728 "failed to parse single-file LTX-Video checkpoint {}: {e}",
729 checkpoint.display()
730 )
731 })?;
732
733 let is_native = bundle.format == super::single_file::LtxKeyFormat::Native;
734
735 let (resolved_vae, vae_in_checkpoint) = if bundle.has_vae {
737 (checkpoint.clone(), true)
738 } else {
739 let vae = vae_path.ok_or_else(|| {
740 anyhow::anyhow!(
741 "LTX-Video checkpoint {} contains no VAE weights (`vae.*` keys). \
742 Pull the `ltx-video-vae` companion first: `mold pull ltx-video-vae`",
743 checkpoint.display()
744 )
745 })?;
746 if !vae.exists() {
747 anyhow::bail!(
748 "ltx-video-vae companion not on disk: {}. \
749 Run `mold pull ltx-video-vae` to download it.",
750 vae.display()
751 );
752 }
753 (vae, false)
754 };
755
756 let paths = ModelPaths {
757 transformer: checkpoint.clone(),
758 transformer_shards: Vec::new(),
759 vae: resolved_vae,
760 spatial_upscaler: None,
761 temporal_upscaler: None,
762 distilled_lora: None,
763 t5_encoder: t5_encoder_path,
764 clip_encoder: None,
765 t5_tokenizer: t5_tokenizer_path,
766 clip_tokenizer: None,
767 clip_encoder_2: None,
768 clip_tokenizer_2: None,
769 text_encoder_files: Vec::new(),
770 text_tokenizer: None,
771 decoder: None,
772 };
773
774 Ok(Self {
775 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
776 t5_variant,
777 shared_pool,
778 pending_placement: None,
779 single_file_native_format: Some(is_native),
780 vae_in_checkpoint,
781 })
782 }
783}
784
785impl LtxVideoEngine {
790 fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
791 let start = Instant::now();
792 let preset = LtxModelPreset::for_model(&self.base.model_name)?;
793
794 let num_frames = req.frames.unwrap_or(25);
796 let fps = req.fps.unwrap_or(24);
797 let steps = req.steps;
798 let guidance = req.guidance;
799
800 if !(num_frames.wrapping_sub(1)).is_multiple_of(8) {
802 bail!(
803 "frame count must be 8n+1 (9, 17, 25, 33, ...), got {}",
804 num_frames
805 );
806 }
807
808 let seed = req.seed.unwrap_or_else(rand_seed);
809 let width = req.width;
810 let height = req.height;
811
812 if !width.is_multiple_of(VAE_SPATIAL_COMPRESSION as u32)
814 || !height.is_multiple_of(VAE_SPATIAL_COMPRESSION as u32)
815 {
816 bail!(
817 "LTX Video requires width and height to be multiples of {}, got {}x{}",
818 VAE_SPATIAL_COMPRESSION,
819 width,
820 height
821 );
822 }
823
824 let latent_h = height as usize / VAE_SPATIAL_COMPRESSION;
826 let latent_w = width as usize / VAE_SPATIAL_COMPRESSION;
827 let latent_f = (num_frames as usize - 1) / VAE_TEMPORAL_COMPRESSION + 1;
828 self.generate_sequential(
830 req, &preset, seed, num_frames, fps, steps, guidance, width, height, latent_h,
831 latent_w, latent_f, start,
832 )
833 }
834}
835
836impl crate::engine::InferenceEngine for LtxVideoEngine {
837 fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
838 self.pending_placement = req.placement.clone();
839 let result = self.generate_inner(req);
840 self.pending_placement = None;
841 result
842 }
843
844 fn model_name(&self) -> &str {
845 &self.base.model_name
846 }
847
848 fn is_loaded(&self) -> bool {
849 self.base.is_loaded()
850 }
851
852 fn load(&mut self) -> Result<()> {
853 Ok(())
855 }
856
857 fn unload(&mut self) {
858 self.base.unload();
859 }
860
861 fn set_on_progress(&mut self, callback: ProgressCallback) {
862 self.base.set_on_progress(callback);
863 }
864
865 fn clear_on_progress(&mut self) {
866 self.base.clear_on_progress();
867 }
868
869 fn model_paths(&self) -> Option<&ModelPaths> {
870 Some(&self.base.paths)
871 }
872
873 fn as_chain_renderer(&mut self) -> Option<&mut dyn crate::ltx2::ChainStageRenderer> {
874 Some(self)
875 }
876}
877
878impl crate::ltx2::ChainStageRenderer for LtxVideoEngine {
891 fn render_stage(
892 &mut self,
893 stage_req: &GenerateRequest,
894 _carry: Option<&crate::ltx2::ChainTail>,
895 motion_tail_pixel_frames: u32,
896 _stage_progress: Option<&mut dyn FnMut(crate::ltx2::StageProgressEvent)>,
897 ) -> Result<crate::ltx2::StageOutcome> {
898 let start = Instant::now();
899 let frames = self.render_chain_frames_internal(stage_req)?;
900 let generation_time_ms = start.elapsed().as_millis() as u64;
901 if frames.is_empty() {
902 bail!("LtxVideoEngine.render_stage: pipeline produced zero frames");
903 }
904
905 let tail_count = (motion_tail_pixel_frames as usize).clamp(1, frames.len());
911 let tail_frames: Vec<image::RgbImage> = frames
912 .iter()
913 .skip(frames.len() - tail_count)
914 .cloned()
915 .collect();
916
917 Ok(crate::ltx2::StageOutcome {
918 frames,
919 tail: crate::ltx2::ChainTail {
920 frames: tail_frames.len() as u32,
921 tail_rgb_frames: tail_frames,
922 },
923 audio: None,
926 generation_time_ms,
927 })
928 }
929}
930
931impl LtxVideoEngine {
932 fn render_chain_frames_internal(
938 &mut self,
939 req: &GenerateRequest,
940 ) -> Result<Vec<image::RgbImage>> {
941 let mut apng_req = req.clone();
942 apng_req.output_format = Some(OutputFormat::Apng);
943 apng_req.gif_preview = false;
946 let response = self.generate_inner(&apng_req)?;
947 let video = response
948 .video
949 .ok_or_else(|| anyhow::anyhow!("LtxVideoEngine.generate returned no video data"))?;
950 decode_apng_to_rgb_frames(&video.data)
951 }
952}
953
954fn decode_apng_to_rgb_frames(apng_bytes: &[u8]) -> Result<Vec<image::RgbImage>> {
955 use image::AnimationDecoder;
956 let cursor = std::io::Cursor::new(apng_bytes);
957 let decoder = image::codecs::png::PngDecoder::new(cursor)
958 .map_err(|e| anyhow::anyhow!("failed to open APNG bytes: {e}"))?;
959 let apng = decoder
960 .apng()
961 .map_err(|e| anyhow::anyhow!("decoded PNG is not animated: {e}"))?;
962 let mut out = Vec::new();
963 for frame in apng.into_frames() {
964 let frame = frame.map_err(|e| anyhow::anyhow!("APNG frame decode failed: {e}"))?;
965 let rgba = frame.into_buffer();
966 let (w, h) = rgba.dimensions();
967 let mut rgb_data = Vec::with_capacity((w as usize) * (h as usize) * 3);
968 for px in rgba.pixels() {
969 rgb_data.extend_from_slice(&px.0[..3]);
970 }
971 let rgb = image::RgbImage::from_raw(w, h, rgb_data)
972 .ok_or_else(|| anyhow::anyhow!("failed to construct RgbImage from APNG frame"))?;
973 out.push(rgb);
974 }
975 Ok(out)
976}
977
978impl LtxVideoEngine {
983 fn load_transformer(
984 &self,
985 preset: &LtxModelPreset,
986 device: &Device,
987 dtype: DType,
988 ) -> Result<LtxVideoTransformer3DModel> {
989 let transformer_files: Vec<std::path::PathBuf> =
990 if !self.base.paths.transformer_shards.is_empty() {
991 self.base.paths.transformer_shards.clone()
992 } else {
993 vec![self.base.paths.transformer.clone()]
994 };
995
996 let is_gguf = transformer_files
997 .first()
998 .and_then(|p| p.extension())
999 .is_some_and(|e| e == "gguf");
1000 if is_gguf {
1001 bail!("GGUF quantized LTX Video transformer is not yet supported — use :bf16 variant");
1002 }
1003 let transformer_bytes = transformer_files.iter().try_fold(0u64, |acc, path| {
1004 let metadata = std::fs::metadata(path).with_context(|| {
1005 format!("failed to stat LTX Video transformer {}", path.display())
1006 })?;
1007 Ok::<_, anyhow::Error>(acc.saturating_add(metadata.len()))
1008 })?;
1009 ltx_video_transformer_residency_guard(
1010 &self.base.model_name,
1011 preset,
1012 transformer_bytes,
1013 usable_free_vram_bytes(self.base.gpu_ordinal),
1014 device.is_cuda(),
1015 )?;
1016
1017 let vb = unsafe { VarBuilder::from_mmaped_safetensors(&transformer_files, dtype, device)? };
1019 let use_remap = match self.single_file_native_format {
1023 Some(is_native) => is_native,
1024 None => {
1025 transformer_files.len() == 1
1026 && is_official_ltx_transformer_checkpoint(&transformer_files[0])
1027 }
1028 };
1029 let vb = if use_remap {
1030 vb.rename_f(remap_official_ltx_transformer_key)
1031 } else {
1032 vb
1033 };
1034 Ok(LtxVideoTransformer3DModel::new(
1035 &preset.transformer_config,
1036 vb,
1037 )?)
1038 }
1039
1040 fn load_vae(
1041 &self,
1042 preset: &LtxModelPreset,
1043 device: &Device,
1044 dtype: DType,
1045 ) -> Result<AutoencoderKLLtxVideo> {
1046 let vb = self.load_vae_var_builder(dtype, device)?;
1047 Ok(AutoencoderKLLtxVideo::new(preset.vae_config.clone(), vb)?)
1048 }
1049
1050 fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
1051 if self.vae_in_checkpoint {
1052 return Ok(None);
1053 }
1054 let Some(shared_pool) = &self.shared_pool else {
1055 return Ok(None);
1056 };
1057 shared_pool
1058 .lock()
1059 .unwrap()
1060 .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
1061 }
1062
1063 fn load_vae_var_builder<'a>(&self, dtype: DType, device: &Device) -> Result<VarBuilder<'a>> {
1064 if let Some(tensors) = self.load_vae_cpu_tensors()? {
1065 return Ok(crate::encoders::park::varbuilder_from_parked(
1066 tensors.as_ref(),
1067 dtype,
1068 device,
1069 ));
1070 }
1071
1072 let vb = unsafe {
1074 VarBuilder::from_mmaped_safetensors(
1075 std::slice::from_ref(&self.base.paths.vae),
1076 dtype,
1077 device,
1078 )?
1079 };
1080 let vb = if self.vae_in_checkpoint {
1083 vb.pp("vae")
1084 } else {
1085 vb
1086 };
1087 Ok(vb)
1088 }
1089
1090 #[allow(clippy::too_many_arguments)]
1091 fn denoise_pass(
1092 &self,
1093 transformer: &mut LtxVideoTransformer3DModel,
1094 prompt_embeds: &Tensor,
1095 attention_mask: &Tensor,
1096 uncond_embeds: Option<&Tensor>,
1097 uncond_mask: Option<&Tensor>,
1098 pass: &LtxPassConfig,
1099 scheduler_cfg: &FlowMatchEulerDiscreteSchedulerConfig,
1100 seed: u64,
1101 width: u32,
1102 height: u32,
1103 num_frames: u32,
1104 fps: u32,
1105 device: &Device,
1106 dtype: DType,
1107 progress: &crate::progress::ProgressReporter,
1108 stage_name: &str,
1109 ltx_debug: bool,
1110 initial_latents: Option<Tensor>,
1111 ) -> Result<Tensor> {
1112 let latent_h = height as usize / VAE_SPATIAL_COMPRESSION;
1113 let latent_w = width as usize / VAE_SPATIAL_COMPRESSION;
1114 let latent_f = (num_frames as usize - 1) / VAE_TEMPORAL_COMPRESSION + 1;
1115 let mut scheduler = FlowMatchEulerDiscreteScheduler::new(scheduler_cfg.clone())?;
1116
1117 scheduler.set_timesteps(
1118 if pass.custom_sigmas.is_some() {
1119 None
1120 } else {
1121 Some(pass.num_inference_steps as usize)
1122 },
1123 device,
1124 pass.custom_sigmas.as_deref(),
1125 None,
1126 None,
1127 )?;
1128
1129 let schedule_sigmas = scheduler
1130 .sigmas()
1131 .to_device(&Device::Cpu)?
1132 .to_vec1::<f32>()?;
1133 let total_steps = schedule_sigmas.len() - 1;
1134 if pass.skip_initial_inference_steps + pass.skip_final_inference_steps >= total_steps {
1135 bail!(
1136 "invalid LTX pass schedule: skip_initial={} + skip_final={} >= total_steps={}",
1137 pass.skip_initial_inference_steps,
1138 pass.skip_final_inference_steps,
1139 total_steps
1140 );
1141 }
1142 let start_step = pass.skip_initial_inference_steps;
1143 let end_step = total_steps - pass.skip_final_inference_steps;
1144 let run_sigmas = schedule_sigmas[start_step..end_step].to_vec();
1145 scheduler.set_begin_index(start_step);
1146
1147 let step_schedule =
1148 resolve_step_schedule(pass, &run_sigmas, transformer.config().num_layers);
1149 let video_coords = build_video_coords(1, latent_f, latent_h, latent_w, fps, device)?;
1150
1151 let mut latents = match initial_latents {
1152 Some(latents) => pack_initial_latents_for_second_pass(&latents)?,
1153 None => {
1154 let noise = seeded_randn(
1155 seed,
1156 &[1, LATENT_CHANNELS, latent_f, latent_h, latent_w],
1157 device,
1158 DType::F32,
1159 )?;
1160 pack_latents(&noise, PATCH_SIZE, PATCH_SIZE_T)?
1161 }
1162 };
1163
1164 progress.stage_start(stage_name);
1165 let denoise_start = Instant::now();
1166
1167 for (step, sigma) in run_sigmas.iter().copied().enumerate() {
1168 let step_start = Instant::now();
1169 let resolved = &step_schedule[step];
1170 let batch = latents.dim(0)?;
1171 let timestep_t = Tensor::full(sigma, (batch,), device)?.to_dtype(dtype)?;
1172 let latents_input = latents.to_dtype(dtype)?;
1173
1174 let do_cfg = resolved.guidance_scale > 1.0 && uncond_embeds.is_some();
1175 let do_stg = resolved.stg_scale > 0.0 && !resolved.skip_blocks.is_empty();
1176
1177 if do_stg {
1178 transformer.set_skip_block_list(vec![]);
1179 } else {
1180 transformer.set_skip_block_list(resolved.skip_blocks.clone());
1181 }
1182
1183 let cond_pred = transformer.forward(
1184 &latents_input,
1185 prompt_embeds,
1186 ×tep_t,
1187 Some(attention_mask),
1188 latent_f,
1189 latent_h,
1190 latent_w,
1191 None,
1192 Some(&video_coords),
1193 None,
1194 )?;
1195 let cond_f32 = cond_pred.to_dtype(DType::F32)?;
1196 let mut combined = cond_f32.clone();
1197
1198 if do_cfg {
1199 let uncond_pred = transformer.forward(
1200 &latents_input,
1201 uncond_embeds.expect("checked above"),
1202 ×tep_t,
1203 uncond_mask,
1204 latent_f,
1205 latent_h,
1206 latent_w,
1207 None,
1208 Some(&video_coords),
1209 None,
1210 )?;
1211 let mut uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
1212 if pass.guidance.cfg_star_rescale {
1213 uncond_f32 = cfg_star_rescale_uncond(&uncond_f32, &cond_f32)?;
1214 }
1215 let diff = cond_f32.broadcast_sub(&uncond_f32)?;
1216 combined =
1217 uncond_f32.broadcast_add(&diff.affine(resolved.guidance_scale as f64, 0.0)?)?;
1218 }
1219
1220 if do_stg {
1221 let skip_layer_mask = create_skip_layer_mask(
1222 transformer.config().num_layers,
1223 batch,
1224 &resolved.skip_blocks,
1225 device,
1226 )?;
1227 let perturbed = transformer.forward(
1228 &latents_input,
1229 prompt_embeds,
1230 ×tep_t,
1231 Some(attention_mask),
1232 latent_f,
1233 latent_h,
1234 latent_w,
1235 None,
1236 Some(&video_coords),
1237 skip_layer_mask.as_ref(),
1238 )?;
1239 let perturbed_f32 = perturbed.to_dtype(DType::F32)?;
1240 let diff_stg = cond_f32.broadcast_sub(&perturbed_f32)?;
1241 combined =
1242 combined.broadcast_add(&diff_stg.affine(resolved.stg_scale as f64, 0.0)?)?;
1243 if pass.guidance.cfg_star_rescale && resolved.rescaling_scale > 0.0 {
1244 combined = rescale_noise_cfg(&combined, &cond_f32, resolved.rescaling_scale)?;
1245 }
1246 }
1247
1248 let model_output =
1249 if transformer.config().out_channels / 2 == transformer.config().in_channels {
1250 combined
1251 .chunk(2, 2)?
1252 .into_iter()
1253 .next()
1254 .expect("out_channels / 2 == in_channels implies chunk(2) succeeds")
1255 } else {
1256 combined
1257 };
1258
1259 if ltx_debug {
1260 let out_rms = model_output.sqr()?.mean_all()?.to_scalar::<f32>()?.sqrt();
1261 let lat_rms = latents.sqr()?.mean_all()?.to_scalar::<f32>()?.sqrt();
1262 if step < 3 || step == run_sigmas.len() - 1 {
1263 progress.info(&format!(
1264 "Pass {stage_name} step {}: sigma={:.4}, guidance={:.2}, stg={:.2}, lat_rms={:.4}, out_rms={:.4}",
1265 step,
1266 sigma,
1267 resolved.guidance_scale,
1268 resolved.stg_scale,
1269 lat_rms,
1270 out_rms
1271 ));
1272 }
1273 }
1274
1275 latents = scheduler
1276 .step(&model_output, sigma, &latents, None)?
1277 .prev_sample;
1278
1279 progress.emit(ProgressEvent::DenoiseStep {
1280 step: step + 1,
1281 total: run_sigmas.len(),
1282 elapsed: step_start.elapsed(),
1283 });
1284 }
1285
1286 progress.stage_done(stage_name, denoise_start.elapsed());
1287 unpack_latents(
1288 &latents,
1289 latent_f,
1290 latent_h,
1291 latent_w,
1292 PATCH_SIZE,
1293 PATCH_SIZE_T,
1294 )
1295 }
1296
1297 #[allow(clippy::too_many_arguments)]
1298 fn generate_sequential(
1299 &mut self,
1300 req: &GenerateRequest,
1301 preset: &LtxModelPreset,
1302 seed: u64,
1303 num_frames: u32,
1304 fps: u32,
1305 steps: u32,
1306 guidance: f64,
1307 width: u32,
1308 height: u32,
1309 _latent_h: usize,
1310 _latent_w: usize,
1311 _latent_f: usize,
1312 start: Instant,
1313 ) -> Result<GenerateResponse> {
1314 let progress = &self.base.progress;
1315 let paths = &self.base.paths;
1316 let ltx_debug = std::env::var("MOLD_LTX_DEBUG").is_ok_and(|v| v == "1");
1317
1318 if preset.mode == LtxPipelineMode::Multiscale && paths.spatial_upscaler.is_none() {
1319 bail!("LTX 0.9.8 requires a spatial upscaler asset in the pulled model files");
1320 }
1321
1322 let device = crate::device::create_device(self.base.gpu_ordinal, progress)?;
1324 let dtype = gpu_dtype(&device);
1325
1326 progress.info(&format!(
1327 "LTX Video: {}×{} × {} frames, {} steps, seed {}",
1328 width, height, num_frames, steps, seed
1329 ));
1330 if preset.mode == LtxPipelineMode::Multiscale {
1331 progress.info("Using the full 0.9.8 multiscale refinement path.");
1332 if steps != preset.base_pass.num_inference_steps {
1333 progress.info(&format!(
1334 "Ignoring --steps={} for multiscale LTX preset {}; using the preset schedule instead.",
1335 steps, self.base.model_name
1336 ));
1337 }
1338 if guidance != preset.base_pass.guidance.guidance_scale[0] as f64 {
1339 progress.info(&format!(
1340 "Ignoring --guidance={guidance:.2} for multiscale LTX preset {}; using the preset guidance schedule instead.",
1341 self.base.model_name
1342 ));
1343 }
1344 }
1345
1346 progress.stage_start("Loading T5-XXL encoder");
1350 let t5_start = Instant::now();
1351
1352 let t5_encoder_path = paths
1353 .t5_encoder
1354 .as_ref()
1355 .ok_or_else(|| anyhow::anyhow!("T5 encoder path not configured"))?;
1356 let t5_tokenizer_path = paths
1357 .t5_tokenizer
1358 .as_ref()
1359 .ok_or_else(|| anyhow::anyhow!("T5 tokenizer path not configured"))?;
1360
1361 let tier1 = self
1362 .pending_placement
1363 .as_ref()
1364 .map(|p| p.text_encoders)
1365 .unwrap_or_default();
1366 let t5_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
1367 let cached_t5_tokenizer = self
1368 .shared_pool
1369 .as_ref()
1370 .map(|pool| pool.lock().unwrap().load_tokenizer(t5_tokenizer_path))
1371 .transpose()?;
1372 let mut t5 = crate::encoders::t5::T5Encoder::load_with_tokenizer(
1373 t5_encoder_path,
1374 t5_tokenizer_path,
1375 &t5_device,
1376 dtype,
1377 progress,
1378 cached_t5_tokenizer,
1379 )?;
1380 progress.stage_done("Loading T5-XXL encoder", t5_start.elapsed());
1381
1382 progress.stage_start("Encoding prompt");
1383 let encode_start = Instant::now();
1384 let prompt_embeds = t5.encode(&req.prompt, &t5_device, dtype)?;
1385 let prompt_embeds = prompt_embeds.to_device(&device)?;
1386 progress.stage_done("Encoding prompt", encode_start.elapsed());
1388
1389 let prompt_seq_len = prompt_embeds.dim(1)?;
1391 let attention_mask =
1392 Tensor::ones((1, prompt_seq_len), DType::F32, &device)?.to_dtype(dtype)?;
1393
1394 let needs_uncond = match preset.mode {
1395 LtxPipelineMode::Base => guidance > 1.0,
1396 LtxPipelineMode::Multiscale => preset
1397 .multiscale
1398 .as_ref()
1399 .into_iter()
1400 .flat_map(|cfg| [&cfg.first_pass, &cfg.second_pass])
1401 .any(|pass| {
1402 pass.guidance
1403 .guidance_scale
1404 .iter()
1405 .any(|scale| *scale > 1.0)
1406 }),
1407 };
1408
1409 let (uncond_embeds, uncond_mask) = if needs_uncond {
1410 progress.stage_start("Encoding negative prompt (CFG)");
1411 let ue = t5.encode("", &t5_device, dtype)?;
1412 let ue = ue.to_device(&device)?;
1413 let ue_seq = ue.dim(1)?;
1414 let um = Tensor::ones((1, ue_seq), DType::F32, &device)?.to_dtype(dtype)?;
1415 progress.stage_done("Encoding negative prompt (CFG)", encode_start.elapsed());
1416 (Some(ue), Some(um))
1417 } else {
1418 (None, None)
1419 };
1420
1421 drop(t5);
1423 device.synchronize()?;
1424 progress.info("T5 encoder dropped, VRAM freed");
1425
1426 let (mut latents, decode_width, decode_height, tone_map_compression_ratio) = match preset
1427 .mode
1428 {
1429 LtxPipelineMode::Base => {
1430 progress.stage_start("Loading LTX Video transformer");
1431 let transformer_start = Instant::now();
1432 let mut transformer = self.load_transformer(preset, &device, dtype)?;
1433 progress.stage_done("Loading LTX Video transformer", transformer_start.elapsed());
1434
1435 let mut pass = preset.base_pass.clone();
1436 pass.num_inference_steps = steps;
1437 pass.guidance.guidance_scale = vec![guidance as f32];
1438
1439 let latents = self.denoise_pass(
1440 &mut transformer,
1441 &prompt_embeds,
1442 &attention_mask,
1443 uncond_embeds.as_ref(),
1444 uncond_mask.as_ref().map(|m| m as &Tensor),
1445 &pass,
1446 &preset.scheduler_config,
1447 seed,
1448 width,
1449 height,
1450 num_frames,
1451 fps,
1452 &device,
1453 dtype,
1454 progress,
1455 "Denoising",
1456 ltx_debug,
1457 None,
1458 )?;
1459 drop(transformer);
1460 device.synchronize()?;
1461 (latents, width, height, pass.tone_map_compression_ratio)
1462 }
1463 LtxPipelineMode::Multiscale => {
1464 let multiscale = preset.multiscale.as_ref().expect("multiscale preset");
1465 let first_width = ((width as f32 * multiscale.downscale_factor) as u32)
1466 / VAE_SPATIAL_COMPRESSION as u32
1467 * VAE_SPATIAL_COMPRESSION as u32;
1468 let first_height = ((height as f32 * multiscale.downscale_factor) as u32)
1469 / VAE_SPATIAL_COMPRESSION as u32
1470 * VAE_SPATIAL_COMPRESSION as u32;
1471
1472 progress.stage_start("Loading LTX Video transformer");
1473 let transformer_start = Instant::now();
1474 let mut first_transformer = self.load_transformer(preset, &device, dtype)?;
1475 progress.stage_done("Loading LTX Video transformer", transformer_start.elapsed());
1476
1477 let first_pass_latents = self.denoise_pass(
1478 &mut first_transformer,
1479 &prompt_embeds,
1480 &attention_mask,
1481 uncond_embeds.as_ref(),
1482 uncond_mask.as_ref().map(|m| m as &Tensor),
1483 &multiscale.first_pass,
1484 &preset.scheduler_config,
1485 seed,
1486 first_width,
1487 first_height,
1488 num_frames,
1489 fps,
1490 &device,
1491 dtype,
1492 progress,
1493 "Denoising First Pass",
1494 ltx_debug,
1495 None,
1496 )?;
1497 drop(first_transformer);
1498 device.synchronize()?;
1499
1500 progress.stage_start("Loading spatial upscaler");
1501 let spatial_start = Instant::now();
1502 let vae = self.load_vae(preset, &device, dtype)?;
1503 let upsampler = LatentUpsampler::load(
1504 paths.spatial_upscaler.as_ref().expect("checked above"),
1505 dtype,
1506 &device,
1507 )?;
1508 progress.stage_done("Loading spatial upscaler", spatial_start.elapsed());
1509
1510 progress.stage_start("Refining multiscale pass");
1511 let refine_start = Instant::now();
1512 let first_pass_denorm = cast_latents_for_multiscale_upsampler(
1513 &denormalize_latents_with_vae(&first_pass_latents, &vae)?,
1514 dtype,
1515 )?;
1516 let upsampled_latents =
1517 normalize_latents_with_vae(&upsampler.forward(&first_pass_denorm)?, &vae)?;
1518 let upsampled_latents =
1519 adain_filter_latents(&upsampled_latents, &first_pass_latents)?;
1520 progress.stage_done("Refining multiscale pass", refine_start.elapsed());
1521 drop(upsampler);
1522 drop(vae);
1523 device.synchronize()?;
1524
1525 let second_width = first_width * 2;
1526 let second_height = first_height * 2;
1527 progress.stage_start("Loading LTX Video transformer");
1528 let transformer_start = Instant::now();
1529 let mut second_transformer = self.load_transformer(preset, &device, dtype)?;
1530 progress.stage_done("Loading LTX Video transformer", transformer_start.elapsed());
1531
1532 let latents = self.denoise_pass(
1533 &mut second_transformer,
1534 &prompt_embeds,
1535 &attention_mask,
1536 uncond_embeds.as_ref(),
1537 uncond_mask.as_ref().map(|m| m as &Tensor),
1538 &multiscale.second_pass,
1539 &preset.scheduler_config,
1540 seed,
1541 second_width,
1542 second_height,
1543 num_frames,
1544 fps,
1545 &device,
1546 dtype,
1547 progress,
1548 "Denoising Second Pass",
1549 ltx_debug,
1550 Some(upsampled_latents),
1551 )?;
1552 drop(second_transformer);
1553 device.synchronize()?;
1554 (
1555 latents,
1556 second_width,
1557 second_height,
1558 multiscale.second_pass.tone_map_compression_ratio,
1559 )
1560 }
1561 };
1562
1563 progress.stage_start("Loading VAE decoder");
1567 let vae_start = Instant::now();
1568 let vae = self.load_vae(preset, &device, dtype)?;
1569 progress.stage_done("Loading VAE decoder", vae_start.elapsed());
1570
1571 progress.stage_start("Decoding video frames");
1572 let decode_start = Instant::now();
1573
1574 let decode_timestep = if vae.config().timestep_conditioning {
1575 if preset.decode_noise_scale > 0.0 {
1576 let noise =
1577 seeded_randn(seed ^ 0xdec0de, latents.shape().dims(), &device, DType::F32)?;
1578 latents = (&latents * (1.0 - preset.decode_noise_scale as f64))?
1579 .broadcast_add(&(noise * preset.decode_noise_scale as f64)?)?;
1580 }
1581 Some(Tensor::full(preset.decode_timestep, (1,), &device)?.to_dtype(dtype)?)
1582 } else {
1583 None
1584 };
1585
1586 latents = tone_map_latents(&latents, tone_map_compression_ratio)?;
1587
1588 latents = denormalize_latents_with_vae(&latents, &vae)?;
1590
1591 if ltx_debug {
1592 let l_f32 = latents.to_dtype(DType::F32)?;
1593 progress.info(&format!(
1594 "Latents pre-VAE (un-normalized): mean={:.4}, std={:.4}",
1595 l_f32.mean_all()?.to_scalar::<f32>()?,
1596 l_f32.flatten_all()?.var(0)?.to_scalar::<f32>()?.sqrt()
1597 ));
1598 }
1599
1600 latents = latents.to_dtype(dtype)?;
1601 let (_dec_output, video) = vae.decode(&latents, decode_timestep.as_ref(), false, false)?;
1602 if ltx_debug {
1604 let v_f32 = video.to_dtype(DType::F32)?;
1605 progress.info(&format!(
1606 "VAE output: shape={:?}, mean={:.4}, min={:.4}, max={:.4}",
1607 v_f32.shape(),
1608 v_f32.mean_all()?.to_scalar::<f32>()?,
1609 v_f32.flatten_all()?.min(0)?.to_scalar::<f32>()?,
1610 v_f32.flatten_all()?.max(0)?.to_scalar::<f32>()?
1611 ));
1612 }
1613
1614 progress.stage_done("Decoding video frames", decode_start.elapsed());
1615
1616 drop(vae);
1618 device.synchronize()?;
1619
1620 let output_format = if req.resolved_output_format().is_video() {
1625 req.resolved_output_format()
1626 } else {
1627 OutputFormat::Apng
1628 };
1629 let format_name = output_format.extension().to_uppercase();
1630 progress.stage_start(&format!("Encoding {format_name}"));
1631 let encode_start = Instant::now();
1632
1633 let video = video.to_dtype(DType::F32)?;
1635 let video = ((video.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
1636 let video = video.i(0)?; let num_output_frames = video.dim(1)?;
1640 let mut frames = Vec::with_capacity(num_output_frames);
1641 for f in 0..num_output_frames {
1642 let frame = video.i((.., f, .., ..))?.contiguous()?; let frame = frame.permute((1, 2, 0))?; let frame_data: Vec<u8> = frame.flatten_all()?.to_vec1()?;
1645 let mut rgb = image::RgbImage::from_raw(decode_width, decode_height, frame_data)
1646 .ok_or_else(|| anyhow::anyhow!("failed to create frame image"))?;
1647 if decode_width != width || decode_height != height {
1648 rgb = image::imageops::resize(
1649 &rgb,
1650 width,
1651 height,
1652 image::imageops::FilterType::Triangle,
1653 );
1654 }
1655 frames.push(rgb);
1656 }
1657
1658 let video_bytes = match output_format {
1659 OutputFormat::Apng => {
1660 let metadata = video_enc::VideoMetadata {
1661 prompt: req.prompt.clone(),
1662 model: self.base.model_name.clone(),
1663 seed,
1664 steps,
1665 guidance: req.guidance,
1666 width,
1667 height,
1668 frames: num_output_frames as u32,
1669 fps,
1670 };
1671 video_enc::encode_apng(&frames, fps, Some(&metadata))?
1672 }
1673 OutputFormat::Gif => video_enc::encode_gif(&frames, fps)?,
1674 #[cfg(feature = "webp")]
1675 OutputFormat::Webp => video_enc::encode_webp(&frames, fps)?,
1676 #[cfg(feature = "mp4")]
1677 OutputFormat::Mp4 => video_enc::encode_mp4(&frames, fps)?,
1678 #[cfg(not(feature = "webp"))]
1679 OutputFormat::Webp => {
1680 bail!("WebP output requires the 'webp' feature — rebuild with --features webp")
1681 }
1682 #[cfg(not(feature = "mp4"))]
1683 OutputFormat::Mp4 => {
1684 bail!("MP4 output requires the 'mp4' feature — rebuild with --features mp4")
1685 }
1686 _ => bail!("{format_name} is not a supported video output format"),
1687 };
1688 let thumbnail_bytes = video_enc::first_frame_png(&frames)?;
1689 let gif_preview = if req.gif_preview {
1692 if output_format == OutputFormat::Gif {
1693 video_bytes.clone()
1694 } else {
1695 video_enc::encode_gif(&frames, fps)?
1696 }
1697 } else {
1698 Vec::new()
1699 };
1700
1701 progress.stage_done(&format!("Encoding {format_name}"), encode_start.elapsed());
1702
1703 let generation_time_ms = start.elapsed().as_millis() as u64;
1704 progress.info(&format!(
1705 "Done: {} frames, {:.1}s total",
1706 num_output_frames,
1707 generation_time_ms as f64 / 1000.0
1708 ));
1709
1710 Ok(GenerateResponse {
1711 images: vec![],
1712 video: Some(VideoData {
1713 data: video_bytes,
1714 format: output_format,
1715 width,
1716 height,
1717 frames: num_output_frames as u32,
1718 fps,
1719 thumbnail: thumbnail_bytes,
1720 gif_preview,
1721 has_audio: false,
1722 duration_ms: None,
1723 audio_sample_rate: None,
1724 audio_channels: None,
1725 }),
1726 generation_time_ms,
1727 model: self.base.model_name.clone(),
1728 seed_used: seed,
1729 gpu: None,
1730 })
1731 }
1732}
1733
1734fn pack_latents(latents: &Tensor, patch_size: usize, patch_size_t: usize) -> Result<Tensor> {
1740 let (b, c, f, h, w) = latents.dims5()?;
1741 if f % patch_size_t != 0 || h % patch_size != 0 || w % patch_size != 0 {
1742 bail!("latent dims not divisible by patch sizes");
1743 }
1744 let f2 = f / patch_size_t;
1745 let h2 = h / patch_size;
1746 let w2 = w / patch_size;
1747
1748 let x = latents.reshape(&[b, c, f2, patch_size_t, h2, patch_size, w2, patch_size])?;
1750 let x = x.permute([0, 2, 4, 6, 1, 3, 5, 7])?;
1752 let x = x.flatten_from(4)?;
1754 let d = x.dim(4)?;
1755 let s = f2 * h2 * w2;
1756 Ok(x.reshape((b, s, d))?)
1757}
1758
1759fn unpack_latents(
1761 latents: &Tensor,
1762 num_frames: usize,
1763 height: usize,
1764 width: usize,
1765 patch_size: usize,
1766 patch_size_t: usize,
1767) -> Result<Tensor> {
1768 let (b, _s, d) = latents.dims3()?;
1769 let denom = patch_size_t * patch_size * patch_size;
1770 if d % denom != 0 {
1771 bail!("D={d} not divisible by patch product {denom}");
1772 }
1773 let c = d / denom;
1774
1775 let x = latents.reshape(&[
1776 b,
1777 num_frames,
1778 height,
1779 width,
1780 c,
1781 patch_size_t,
1782 patch_size,
1783 patch_size,
1784 ])?;
1785 let x = x.permute([0, 4, 1, 5, 2, 6, 3, 7])?.contiguous()?;
1787 Ok(x.reshape((
1788 b,
1789 c,
1790 num_frames * patch_size_t,
1791 height * patch_size,
1792 width * patch_size,
1793 ))?)
1794}
1795
1796fn pack_initial_latents_for_second_pass(latents: &Tensor) -> Result<Tensor> {
1797 pack_latents(&latents.to_dtype(DType::F32)?, PATCH_SIZE, PATCH_SIZE_T)
1798}
1799
1800fn cast_latents_for_multiscale_upsampler(latents: &Tensor, dtype: DType) -> Result<Tensor> {
1801 Ok(latents.to_dtype(dtype)?)
1802}
1803
1804fn build_video_coords(
1806 batch_size: usize,
1807 latent_f: usize,
1808 latent_h: usize,
1809 latent_w: usize,
1810 fps: u32,
1811 device: &Device,
1812) -> Result<Tensor> {
1813 let grid_f = Tensor::arange(0u32, latent_f as u32, device)?.to_dtype(DType::F32)?;
1814 let grid_h = Tensor::arange(0u32, latent_h as u32, device)?.to_dtype(DType::F32)?;
1815 let grid_w = Tensor::arange(0u32, latent_w as u32, device)?.to_dtype(DType::F32)?;
1816
1817 let f = grid_f
1818 .reshape((latent_f, 1, 1))?
1819 .broadcast_as((latent_f, latent_h, latent_w))?;
1820 let h = grid_h
1821 .reshape((1, latent_h, 1))?
1822 .broadcast_as((latent_f, latent_h, latent_w))?;
1823 let w = grid_w
1824 .reshape((1, 1, latent_w))?
1825 .broadcast_as((latent_f, latent_h, latent_w))?;
1826
1827 let grid = Tensor::stack(&[f, h, w], 0)?; let seq = latent_f * latent_h * latent_w;
1829 let grid = grid.flatten_from(1)?.transpose(0, 1)?.unsqueeze(0)?; let vf = grid.i((.., .., 0))?;
1833 let vh = grid.i((.., .., 1))?;
1834 let vw = grid.i((.., .., 2))?;
1835
1836 let ts_ratio = VAE_TEMPORAL_COMPRESSION as f64;
1838 let vf = vf
1839 .affine(ts_ratio, 1.0 - ts_ratio)?
1840 .clamp(0.0f32, 10000.0f32)?
1841 .affine(1.0 / fps as f64, 0.0)?;
1842 let sp_ratio = VAE_SPATIAL_COMPRESSION as f64;
1844 let vh = vh.affine(sp_ratio, 0.0)?;
1845 let vw = vw.affine(sp_ratio, 0.0)?;
1846
1847 let coords = Tensor::stack(&[vf, vh, vw], candle_core::D::Minus1)?;
1848 if batch_size > 1 {
1849 Ok(coords.broadcast_as((batch_size, seq, 3))?)
1850 } else {
1851 Ok(coords)
1852 }
1853}
1854
1855#[cfg(test)]
1856mod tests {
1857 use super::{
1858 cast_latents_for_multiscale_upsampler, is_official_ltx_transformer_checkpoint,
1859 pack_initial_latents_for_second_pass, remap_official_ltx_transformer_key, unpack_latents,
1860 LtxModelPreset, LtxPipelineMode, LtxVideoEngine, LATENT_CHANNELS,
1861 LTX_098_DISTILLED_SECOND_PASS_SIGMAS, PATCH_SIZE, PATCH_SIZE_T,
1862 };
1863 use crate::engine::LoadStrategy;
1864 use crate::shared_pool::SharedPool;
1865 use candle_core::{DType, Device, Tensor};
1866 use mold_core::ModelPaths;
1867 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1868 use std::collections::HashMap;
1869 use std::fs;
1870 use std::path::{Path, PathBuf};
1871 use std::sync::{Arc, Mutex};
1872 use std::time::{SystemTime, UNIX_EPOCH};
1873
1874 fn temp_test_dir(prefix: &str) -> PathBuf {
1875 let suffix = SystemTime::now()
1876 .duration_since(UNIX_EPOCH)
1877 .unwrap()
1878 .as_nanos();
1879 let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
1880 fs::create_dir_all(&dir).unwrap();
1881 dir
1882 }
1883
1884 fn ltx_video_model_paths(dir: &Path, vae: PathBuf) -> ModelPaths {
1885 ModelPaths {
1886 transformer: dir.join("transformer.safetensors"),
1887 transformer_shards: vec![],
1888 vae,
1889 spatial_upscaler: None,
1890 temporal_upscaler: None,
1891 distilled_lora: None,
1892 t5_encoder: Some(dir.join("t5.safetensors")),
1893 clip_encoder: None,
1894 t5_tokenizer: Some(dir.join("tokenizer.json")),
1895 clip_tokenizer: None,
1896 clip_encoder_2: None,
1897 clip_tokenizer_2: None,
1898 text_encoder_files: vec![],
1899 text_tokenizer: None,
1900 decoder: None,
1901 }
1902 }
1903
1904 #[test]
1905 fn decode_apng_round_trips_rgb_frames() {
1906 use crate::ltx_video::video_enc::encode_apng;
1907 use image::Rgb;
1908
1909 let make = |r: u8, g: u8, b: u8| {
1912 let mut img = image::RgbImage::new(4, 4);
1913 for px in img.pixels_mut() {
1914 *px = Rgb([r, g, b]);
1915 }
1916 img
1917 };
1918 let inputs = vec![make(255, 0, 0), make(0, 255, 0), make(0, 0, 255)];
1919
1920 let bytes = encode_apng(&inputs, 12, None).expect("encode");
1921 let decoded = super::decode_apng_to_rgb_frames(&bytes).expect("decode");
1922
1923 assert_eq!(decoded.len(), inputs.len());
1924 for (i, (a, b)) in inputs.iter().zip(decoded.iter()).enumerate() {
1925 assert_eq!(a.dimensions(), b.dimensions(), "frame {i} size");
1926 assert_eq!(
1928 a.get_pixel(0, 0),
1929 b.get_pixel(0, 0),
1930 "frame {i} pixel mismatch",
1931 );
1932 }
1933 }
1934
1935 #[test]
1936 fn detects_official_ltx_single_file_checkpoints() {
1937 assert!(is_official_ltx_transformer_checkpoint(Path::new(
1938 "ltxv-2b-0.9.6-distilled-04-25.safetensors"
1939 )));
1940 assert!(is_official_ltx_transformer_checkpoint(Path::new(
1941 "ltxv-13b-0.9.8-dev.safetensors"
1942 )));
1943 assert!(!is_official_ltx_transformer_checkpoint(Path::new(
1944 "diffusion_pytorch_model-00001-of-00002.safetensors"
1945 )));
1946 assert!(!is_official_ltx_transformer_checkpoint(Path::new(
1947 "transformer.gguf"
1948 )));
1949 }
1950
1951 #[test]
1952 fn remaps_official_transformer_keys_to_upstream_checkpoint_names() {
1953 assert_eq!(
1954 remap_official_ltx_transformer_key("proj_in.weight"),
1955 "model.diffusion_model.patchify_proj.weight"
1956 );
1957 assert_eq!(
1958 remap_official_ltx_transformer_key("time_embed.emb.timestep_embedder.linear_1.weight"),
1959 "model.diffusion_model.adaln_single.emb.timestep_embedder.linear_1.weight"
1960 );
1961 assert_eq!(
1962 remap_official_ltx_transformer_key("transformer_blocks.0.attn1.norm_q.weight"),
1963 "model.diffusion_model.transformer_blocks.0.attn1.q_norm.weight"
1964 );
1965 assert_eq!(
1966 remap_official_ltx_transformer_key("caption_projection.linear_2.bias"),
1967 "model.diffusion_model.caption_projection.linear_2.bias"
1968 );
1969 }
1970
1971 #[test]
1972 fn ltx_video_loads_standalone_vae_tensors_through_shared_pool() {
1973 let dir = temp_test_dir("mold-ltx-video-vae-pool");
1974 let vae_path = dir.join("vae.safetensors");
1975 let weight = 1.0f32.to_le_bytes();
1976 let mut tensors = HashMap::new();
1977 tensors.insert(
1978 "encoder.conv_in.weight".to_string(),
1979 TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
1980 );
1981 serialize_to_file(&tensors, &None, &vae_path).unwrap();
1982
1983 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1984 let pooled = shared_pool
1985 .lock()
1986 .unwrap()
1987 .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
1988 .unwrap()
1989 .unwrap();
1990
1991 let engine = LtxVideoEngine::new(
1992 "ltx-video-0.9.6:bf16".to_string(),
1993 ltx_video_model_paths(&dir, vae_path),
1994 None,
1995 LoadStrategy::Sequential,
1996 0,
1997 Some(shared_pool),
1998 );
1999
2000 let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
2001
2002 assert!(Arc::ptr_eq(&pooled, &loaded));
2003 fs::remove_dir_all(dir).ok();
2004 }
2005
2006 #[test]
2007 fn ltx_presets_only_reference_in_range_skip_blocks() {
2008 for model_name in [
2009 "ltx-video-0.9.6:bf16",
2010 "ltx-video-0.9.6-distilled:bf16",
2011 "ltx-video-0.9.8-2b-distilled:bf16",
2012 "ltx-video-0.9.8-13b-dev:bf16",
2013 "ltx-video-0.9.8-13b-distilled:bf16",
2014 ] {
2015 let preset = LtxModelPreset::for_model(model_name).expect("preset should exist");
2016 let mut all_skip_lists = vec![preset.base_pass.guidance.skip_block_list.clone()];
2017 if let Some(multiscale) = &preset.multiscale {
2018 all_skip_lists.push(multiscale.first_pass.guidance.skip_block_list.clone());
2019 all_skip_lists.push(multiscale.second_pass.guidance.skip_block_list.clone());
2020 }
2021 for skip_list_group in all_skip_lists {
2022 for skip_list in skip_list_group {
2023 for skip_block in skip_list {
2024 assert!(
2025 skip_block < preset.transformer_config.num_layers,
2026 "{model_name} skip block {skip_block} is out of range for {} layers",
2027 preset.transformer_config.num_layers
2028 );
2029 }
2030 }
2031 }
2032 }
2033 }
2034
2035 #[test]
2036 fn ltx_video_13b_bf16_fails_before_allocation_when_vram_cannot_hold_residency() {
2037 let preset = LtxModelPreset::for_model("ltx-video-0.9.8-13b-dev:bf16").unwrap();
2038 let err = super::ltx_video_transformer_residency_guard(
2039 "ltx-video-0.9.8-13b-dev:bf16",
2040 &preset,
2041 26_000_000_000,
2042 Some(24_000_000_000),
2043 true,
2044 )
2045 .unwrap_err()
2046 .to_string();
2047
2048 assert!(err.contains("MOLD_OFFLOAD"));
2049 assert!(err.contains("ltx-video-0.9.8-2b-distilled"));
2050 assert!(err.contains("--width/--height/--frames"));
2051 }
2052
2053 #[test]
2054 fn ltx_video_2b_bf16_residency_guard_does_not_reject() {
2055 let preset = LtxModelPreset::for_model("ltx-video-0.9.8-2b-distilled:bf16").unwrap();
2056 super::ltx_video_transformer_residency_guard(
2057 "ltx-video-0.9.8-2b-distilled:bf16",
2058 &preset,
2059 26_000_000_000,
2060 Some(24_000_000_000),
2061 true,
2062 )
2063 .unwrap();
2064 }
2065
2066 #[test]
2067 fn ltx_098_presets_use_multiscale_mode() {
2068 for model_name in [
2069 "ltx-video-0.9.8-2b-distilled:bf16",
2070 "ltx-video-0.9.8-13b-dev:bf16",
2071 "ltx-video-0.9.8-13b-distilled:bf16",
2072 ] {
2073 let preset = LtxModelPreset::for_model(model_name).expect("preset should exist");
2074 assert_eq!(preset.mode, LtxPipelineMode::Multiscale, "{model_name}");
2075 assert!(preset.multiscale.is_some(), "{model_name}");
2076 }
2077 }
2078
2079 #[test]
2080 fn ltx_098_distilled_second_pass_uses_upstream_sigmas() {
2081 for model_name in [
2082 "ltx-video-0.9.8-2b-distilled:bf16",
2083 "ltx-video-0.9.8-13b-distilled:bf16",
2084 ] {
2085 let preset = LtxModelPreset::for_model(model_name).expect("preset should exist");
2086 let multiscale = preset.multiscale.as_ref().expect("multiscale preset");
2087 assert_eq!(
2088 multiscale.second_pass.custom_sigmas.as_deref(),
2089 Some(LTX_098_DISTILLED_SECOND_PASS_SIGMAS),
2090 "{model_name}"
2091 );
2092 }
2093 }
2094
2095 #[test]
2096 fn multiscale_handoff_normalizes_dtypes_for_upsampler_and_second_pass() {
2097 let device = Device::Cpu;
2098 let second_pass_latents =
2099 Tensor::arange(0f32, (LATENT_CHANNELS * 2 * 4 * 6) as f32, &device)
2100 .expect("tensor")
2101 .reshape((1, LATENT_CHANNELS, 2, 4, 6))
2102 .expect("reshape")
2103 .to_dtype(DType::BF16)
2104 .expect("bf16");
2105
2106 let packed = pack_initial_latents_for_second_pass(&second_pass_latents)
2107 .expect("second-pass repack should succeed");
2108 assert_eq!(packed.dtype(), DType::F32);
2109
2110 let unpacked = unpack_latents(&packed, 2, 4, 6, PATCH_SIZE, PATCH_SIZE_T)
2111 .expect("unpack should round-trip");
2112 assert_eq!(unpacked.dtype(), DType::F32);
2113 assert_eq!(
2114 unpacked.dims5().expect("dims"),
2115 (1, LATENT_CHANNELS, 2, 4, 6)
2116 );
2117 assert_eq!(
2118 unpacked
2119 .flatten_all()
2120 .expect("flatten")
2121 .to_vec1::<f32>()
2122 .expect("vec"),
2123 second_pass_latents
2124 .to_dtype(DType::F32)
2125 .expect("f32")
2126 .flatten_all()
2127 .expect("flatten")
2128 .to_vec1::<f32>()
2129 .expect("vec")
2130 );
2131
2132 let upsampler_input =
2133 cast_latents_for_multiscale_upsampler(&unpacked, DType::BF16).expect("cast");
2134 assert_eq!(upsampler_input.dtype(), DType::BF16);
2135 assert_eq!(
2136 upsampler_input.dims5().expect("dims"),
2137 (1, LATENT_CHANNELS, 2, 4, 6)
2138 );
2139 }
2140}