Skip to main content

mold_inference/ltx_video/
pipeline.rs

1//! LTX Video inference engine — text-to-video generation.
2//!
3//! Architecture: T5-XXL text encoder → LTXVideoTransformer3DModel → 3D Causal VAE → APNG/GIF/WebP/MP4
4//! Follows the same patterns as Flux2Engine (drop-and-reload, VRAM management, progress).
5
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::{Arc, Mutex};
9use std::time::Instant;
10
11use anyhow::{bail, Context, Result};
12use candle_core::{DType, Device, IndexOp, Tensor};
13use candle_nn::VarBuilder;
14use candle_transformers::models::ltx_video::{
15    sampling::{
16        FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerConfig, TimeShiftType,
17    },
18    transformer::{LtxVideoTransformer3DModel, LtxVideoTransformer3DModelConfig},
19    vae::{AutoencoderKLLtxVideo, AutoencoderKLLtxVideoConfig},
20};
21
22use mold_core::{GenerateRequest, GenerateResponse, ModelPaths, OutputFormat, VideoData};
23
24use crate::device::{fmt_gb, usable_free_vram_bytes};
25use crate::engine::{gpu_dtype, rand_seed, seeded_randn, LoadStrategy};
26use crate::engine_base::EngineBase;
27use crate::progress::{ProgressCallback, ProgressEvent};
28use crate::shared_pool::SharedPool;
29
30use super::{latent_upsampler::LatentUpsampler, video_enc};
31
32// ---------------------------------------------------------------------------
33// Constants
34// ---------------------------------------------------------------------------
35
36/// Spatial compression ratio of the LTX Video VAE.
37const VAE_SPATIAL_COMPRESSION: usize = 32;
38/// Temporal compression ratio of the LTX Video VAE.
39const VAE_TEMPORAL_COMPRESSION: usize = 8;
40/// Latent channels in the VAE.
41const LATENT_CHANNELS: usize = 128;
42/// Patch sizes (both 1 for current LTX Video checkpoints).
43const PATCH_SIZE: usize = 1;
44const PATCH_SIZE_T: usize = 1;
45
46const LTX_098_DISTILLED_FIRST_PASS_SIGMAS: &[f32] =
47    &[1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250];
48const LTX_098_DISTILLED_SECOND_PASS_SIGMAS: &[f32] = &[0.9094, 0.7250, 0.4219];
49const LTX_098_DEV_FIRST_PASS_GUIDANCE_SCALE: &[f32] = &[1.0, 1.0, 6.0, 8.0, 6.0, 1.0, 1.0];
50const LTX_098_DEV_FIRST_PASS_STG_SCALE: &[f32] = &[0.0, 0.0, 4.0, 4.0, 4.0, 2.0, 1.0];
51const LTX_098_DEV_FIRST_PASS_RESCALING_SCALE: &[f32] = &[1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 1.0];
52const LTX_098_DEV_FIRST_PASS_GUIDANCE_TIMESTEPS: &[f32] =
53    &[1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180];
54const LTX_096_DEV_SKIP_BLOCKS: &[usize] = &[19];
55const LTX_098_2B_DISTILLED_SKIP_BLOCKS: &[usize] = &[];
56const LTX_098_13B_DISTILLED_SKIP_BLOCKS: &[usize] = &[42];
57const LTX_098_13B_DEV_FIRST_PASS_SKIP_BLOCKS: &[&[usize]] = &[
58    &[],
59    &[11, 25, 35, 39],
60    &[22, 35, 39],
61    &[28],
62    &[28],
63    &[28],
64    &[28],
65];
66const LTX_098_13B_DEV_SECOND_PASS_SKIP_BLOCKS: &[usize] = &[27];
67const LTX_VIDEO_FULL_RESIDENT_RUNTIME_HEADROOM: u64 = 2_000_000_000;
68
69fn is_official_ltx_transformer_checkpoint(path: &std::path::Path) -> bool {
70    path.file_name()
71        .and_then(|name| name.to_str())
72        .is_some_and(|name| {
73            name.ends_with(".safetensors")
74                && name.starts_with("ltx")
75                && !name.starts_with("diffusion_pytorch_model")
76        })
77}
78
79fn remap_official_ltx_transformer_key(key: &str) -> String {
80    let key = key
81        .replace("proj_in", "patchify_proj")
82        .replace("time_embed", "adaln_single")
83        .replace("norm_q", "q_norm")
84        .replace("norm_k", "k_norm");
85    format!("model.diffusion_model.{key}")
86}
87
88#[derive(Clone, Copy, Debug, PartialEq, Eq)]
89enum LtxPipelineMode {
90    Base,
91    Multiscale,
92}
93
94#[derive(Clone, Debug)]
95struct LtxGuidanceConfig {
96    guidance_scale: Vec<f32>,
97    stg_scale: Vec<f32>,
98    rescaling_scale: Vec<f32>,
99    guidance_timesteps: Option<Vec<f32>>,
100    skip_block_list: Vec<Vec<usize>>,
101    cfg_star_rescale: bool,
102}
103
104impl LtxGuidanceConfig {
105    fn constant(
106        guidance_scale: f32,
107        stg_scale: f32,
108        rescaling_scale: f32,
109        skip_block_list: &[usize],
110    ) -> Self {
111        Self {
112            guidance_scale: vec![guidance_scale],
113            stg_scale: vec![stg_scale],
114            rescaling_scale: vec![rescaling_scale],
115            guidance_timesteps: None,
116            skip_block_list: vec![skip_block_list.to_vec()],
117            cfg_star_rescale: false,
118        }
119    }
120}
121
122#[derive(Clone, Debug)]
123struct LtxPassConfig {
124    num_inference_steps: u32,
125    custom_sigmas: Option<Vec<f32>>,
126    skip_initial_inference_steps: usize,
127    skip_final_inference_steps: usize,
128    guidance: LtxGuidanceConfig,
129    tone_map_compression_ratio: f32,
130}
131
132#[derive(Clone, Debug)]
133struct LtxMultiscaleConfig {
134    downscale_factor: f32,
135    first_pass: LtxPassConfig,
136    second_pass: LtxPassConfig,
137}
138
139#[derive(Clone, Debug)]
140struct LtxModelPreset {
141    transformer_config: LtxVideoTransformer3DModelConfig,
142    vae_config: AutoencoderKLLtxVideoConfig,
143    scheduler_config: FlowMatchEulerDiscreteSchedulerConfig,
144    base_pass: LtxPassConfig,
145    decode_timestep: f32,
146    decode_noise_scale: f32,
147    mode: LtxPipelineMode,
148    multiscale: Option<LtxMultiscaleConfig>,
149}
150
151impl LtxModelPreset {
152    fn for_model(model_name: &str) -> Result<Self> {
153        if model_name.contains("ltx-video-0.9.6-distilled") {
154            Ok(Self {
155                transformer_config: transformer_2b_config(),
156                vae_config: improved_vae_config(),
157                scheduler_config: scheduler_config(true),
158                base_pass: LtxPassConfig {
159                    num_inference_steps: 8,
160                    custom_sigmas: None,
161                    skip_initial_inference_steps: 0,
162                    skip_final_inference_steps: 0,
163                    guidance: LtxGuidanceConfig::constant(1.0, 0.0, 1.0, &[]),
164                    tone_map_compression_ratio: 0.0,
165                },
166                decode_timestep: 0.05,
167                decode_noise_scale: 0.025,
168                mode: LtxPipelineMode::Base,
169                multiscale: None,
170            })
171        } else if model_name.contains("ltx-video-0.9.6") {
172            Ok(Self {
173                transformer_config: transformer_2b_config(),
174                vae_config: improved_vae_config(),
175                scheduler_config: scheduler_config(false),
176                base_pass: LtxPassConfig {
177                    num_inference_steps: 40,
178                    custom_sigmas: None,
179                    skip_initial_inference_steps: 0,
180                    skip_final_inference_steps: 0,
181                    guidance: LtxGuidanceConfig::constant(3.0, 1.0, 0.7, LTX_096_DEV_SKIP_BLOCKS),
182                    tone_map_compression_ratio: 0.0,
183                },
184                decode_timestep: 0.05,
185                decode_noise_scale: 0.025,
186                mode: LtxPipelineMode::Base,
187                multiscale: None,
188            })
189        } else if model_name.contains("ltx-video-0.9.8-2b-distilled") {
190            Ok(Self {
191                transformer_config: transformer_2b_config(),
192                vae_config: improved_vae_config(),
193                scheduler_config: scheduler_config(false),
194                base_pass: LtxPassConfig {
195                    num_inference_steps: 7,
196                    custom_sigmas: Some(LTX_098_DISTILLED_FIRST_PASS_SIGMAS.to_vec()),
197                    skip_initial_inference_steps: 0,
198                    skip_final_inference_steps: 0,
199                    guidance: LtxGuidanceConfig::constant(
200                        1.0,
201                        0.0,
202                        1.0,
203                        LTX_098_2B_DISTILLED_SKIP_BLOCKS,
204                    ),
205                    tone_map_compression_ratio: 0.0,
206                },
207                decode_timestep: 0.05,
208                decode_noise_scale: 0.025,
209                mode: LtxPipelineMode::Multiscale,
210                multiscale: Some(LtxMultiscaleConfig {
211                    downscale_factor: 0.6666666,
212                    first_pass: LtxPassConfig {
213                        num_inference_steps: 7,
214                        custom_sigmas: Some(LTX_098_DISTILLED_FIRST_PASS_SIGMAS.to_vec()),
215                        skip_initial_inference_steps: 0,
216                        skip_final_inference_steps: 0,
217                        guidance: LtxGuidanceConfig::constant(
218                            1.0,
219                            0.0,
220                            1.0,
221                            LTX_098_2B_DISTILLED_SKIP_BLOCKS,
222                        ),
223                        tone_map_compression_ratio: 0.0,
224                    },
225                    second_pass: LtxPassConfig {
226                        num_inference_steps: 3,
227                        custom_sigmas: Some(LTX_098_DISTILLED_SECOND_PASS_SIGMAS.to_vec()),
228                        skip_initial_inference_steps: 0,
229                        skip_final_inference_steps: 0,
230                        guidance: LtxGuidanceConfig::constant(
231                            1.0,
232                            0.0,
233                            1.0,
234                            LTX_098_2B_DISTILLED_SKIP_BLOCKS,
235                        ),
236                        tone_map_compression_ratio: 0.0,
237                    },
238                }),
239            })
240        } else if model_name.contains("ltx-video-0.9.8-13b-distilled") {
241            Ok(Self {
242                transformer_config: transformer_13b_config(),
243                vae_config: improved_vae_config(),
244                scheduler_config: scheduler_config(false),
245                base_pass: LtxPassConfig {
246                    num_inference_steps: 7,
247                    custom_sigmas: Some(LTX_098_DISTILLED_FIRST_PASS_SIGMAS.to_vec()),
248                    skip_initial_inference_steps: 0,
249                    skip_final_inference_steps: 0,
250                    guidance: LtxGuidanceConfig::constant(
251                        1.0,
252                        0.0,
253                        1.0,
254                        LTX_098_13B_DISTILLED_SKIP_BLOCKS,
255                    ),
256                    tone_map_compression_ratio: 0.0,
257                },
258                decode_timestep: 0.05,
259                decode_noise_scale: 0.025,
260                mode: LtxPipelineMode::Multiscale,
261                multiscale: Some(LtxMultiscaleConfig {
262                    downscale_factor: 0.6666666,
263                    first_pass: LtxPassConfig {
264                        num_inference_steps: 7,
265                        custom_sigmas: Some(LTX_098_DISTILLED_FIRST_PASS_SIGMAS.to_vec()),
266                        skip_initial_inference_steps: 0,
267                        skip_final_inference_steps: 0,
268                        guidance: LtxGuidanceConfig::constant(
269                            1.0,
270                            0.0,
271                            1.0,
272                            LTX_098_13B_DISTILLED_SKIP_BLOCKS,
273                        ),
274                        tone_map_compression_ratio: 0.0,
275                    },
276                    second_pass: LtxPassConfig {
277                        num_inference_steps: 3,
278                        custom_sigmas: Some(LTX_098_DISTILLED_SECOND_PASS_SIGMAS.to_vec()),
279                        skip_initial_inference_steps: 0,
280                        skip_final_inference_steps: 0,
281                        guidance: LtxGuidanceConfig::constant(
282                            1.0,
283                            0.0,
284                            1.0,
285                            LTX_098_13B_DISTILLED_SKIP_BLOCKS,
286                        ),
287                        tone_map_compression_ratio: 0.6,
288                    },
289                }),
290            })
291        } else if model_name.contains("ltx-video-0.9.8-13b-dev") {
292            Ok(Self {
293                transformer_config: transformer_13b_config(),
294                vae_config: improved_vae_config(),
295                scheduler_config: scheduler_config(false),
296                base_pass: LtxPassConfig {
297                    num_inference_steps: 30,
298                    custom_sigmas: None,
299                    skip_initial_inference_steps: 0,
300                    skip_final_inference_steps: 0,
301                    guidance: LtxGuidanceConfig::constant(8.0, 4.0, 0.5, &[28]),
302                    tone_map_compression_ratio: 0.0,
303                },
304                decode_timestep: 0.05,
305                decode_noise_scale: 0.025,
306                mode: LtxPipelineMode::Multiscale,
307                multiscale: Some(LtxMultiscaleConfig {
308                    downscale_factor: 0.6666666,
309                    first_pass: LtxPassConfig {
310                        num_inference_steps: 30,
311                        custom_sigmas: None,
312                        skip_initial_inference_steps: 0,
313                        skip_final_inference_steps: 3,
314                        guidance: LtxGuidanceConfig {
315                            guidance_scale: LTX_098_DEV_FIRST_PASS_GUIDANCE_SCALE.to_vec(),
316                            stg_scale: LTX_098_DEV_FIRST_PASS_STG_SCALE.to_vec(),
317                            rescaling_scale: LTX_098_DEV_FIRST_PASS_RESCALING_SCALE.to_vec(),
318                            guidance_timesteps: Some(
319                                LTX_098_DEV_FIRST_PASS_GUIDANCE_TIMESTEPS.to_vec(),
320                            ),
321                            skip_block_list: LTX_098_13B_DEV_FIRST_PASS_SKIP_BLOCKS
322                                .iter()
323                                .map(|blocks| blocks.to_vec())
324                                .collect(),
325                            cfg_star_rescale: true,
326                        },
327                        tone_map_compression_ratio: 0.0,
328                    },
329                    second_pass: LtxPassConfig {
330                        num_inference_steps: 30,
331                        custom_sigmas: None,
332                        skip_initial_inference_steps: 17,
333                        skip_final_inference_steps: 0,
334                        guidance: LtxGuidanceConfig {
335                            guidance_scale: vec![1.0],
336                            stg_scale: vec![1.0],
337                            rescaling_scale: vec![1.0],
338                            guidance_timesteps: Some(vec![1.0]),
339                            skip_block_list: vec![LTX_098_13B_DEV_SECOND_PASS_SKIP_BLOCKS.to_vec()],
340                            cfg_star_rescale: true,
341                        },
342                        tone_map_compression_ratio: 0.0,
343                    },
344                }),
345            })
346        } else {
347            bail!("unsupported LTX model preset for {}", model_name);
348        }
349    }
350}
351
352fn transformer_2b_config() -> LtxVideoTransformer3DModelConfig {
353    LtxVideoTransformer3DModelConfig {
354        num_layers: 28,
355        num_attention_heads: 32,
356        attention_head_dim: 64,
357        cross_attention_dim: 2048,
358        caption_channels: 4096,
359        ..Default::default()
360    }
361}
362
363fn transformer_13b_config() -> LtxVideoTransformer3DModelConfig {
364    LtxVideoTransformer3DModelConfig {
365        num_layers: 48,
366        num_attention_heads: 32,
367        attention_head_dim: 128,
368        cross_attention_dim: 4096,
369        caption_channels: 4096,
370        ..Default::default()
371    }
372}
373
374fn is_legacy_ltx_video_13b(model_name: &str, preset: &LtxModelPreset) -> bool {
375    model_name.contains("13b")
376        || (preset.transformer_config.num_layers >= 48
377            && preset.transformer_config.attention_head_dim >= 128)
378}
379
380fn ltx_video_transformer_residency_guard(
381    model_name: &str,
382    preset: &LtxModelPreset,
383    transformer_bytes: u64,
384    usable_vram_bytes: Option<u64>,
385    is_cuda: bool,
386) -> Result<()> {
387    if !is_cuda || !is_legacy_ltx_video_13b(model_name, preset) {
388        return Ok(());
389    }
390    let Some(usable_vram_bytes) = usable_vram_bytes.filter(|bytes| *bytes > 0) else {
391        return Ok(());
392    };
393    let required = transformer_bytes.saturating_add(LTX_VIDEO_FULL_RESIDENT_RUNTIME_HEADROOM);
394    if required <= usable_vram_bytes {
395        return Ok(());
396    }
397
398    bail!(
399        "legacy LTX-Video 13B BF16 requires full transformer residency ({} weights + {} runtime headroom) but only {} usable VRAM is available. MOLD_OFFLOAD is not implemented for this legacy transformer yet; use ltx-video-0.9.8-2b-distilled, lower --width/--height/--frames, or use an LTX-2 FP8 model with adaptive offload.",
400        fmt_gb(transformer_bytes),
401        fmt_gb(LTX_VIDEO_FULL_RESIDENT_RUNTIME_HEADROOM),
402        fmt_gb(usable_vram_bytes),
403    )
404}
405
406fn improved_vae_config() -> AutoencoderKLLtxVideoConfig {
407    AutoencoderKLLtxVideoConfig {
408        block_out_channels: vec![128, 256, 512, 1024, 2048],
409        decoder_block_out_channels: vec![256, 512, 1024],
410        spatiotemporal_scaling: vec![true, true, true, true],
411        decoder_spatiotemporal_scaling: vec![true, true, true],
412        layers_per_block: vec![4, 6, 6, 2, 2],
413        decoder_layers_per_block: vec![5, 5, 5, 5],
414        decoder_inject_noise: vec![false, false, false, false],
415        decoder_upsample_residual: vec![true, true, true],
416        decoder_upsample_factor: vec![2, 2, 2],
417        timestep_conditioning: true,
418        ..Default::default()
419    }
420}
421
422fn scheduler_config(stochastic_sampling: bool) -> FlowMatchEulerDiscreteSchedulerConfig {
423    FlowMatchEulerDiscreteSchedulerConfig {
424        num_train_timesteps: 1000,
425        shift: 1.0,
426        use_dynamic_shifting: false,
427        base_shift: Some(0.5),
428        max_shift: Some(1.15),
429        base_image_seq_len: Some(256),
430        max_image_seq_len: Some(4096),
431        invert_sigmas: false,
432        shift_terminal: None,
433        use_karras_sigmas: false,
434        use_exponential_sigmas: false,
435        use_beta_sigmas: false,
436        time_shift_type: TimeShiftType::Exponential,
437        stochastic_sampling,
438    }
439}
440
441#[derive(Clone, Debug)]
442struct LtxResolvedStep {
443    guidance_scale: f32,
444    stg_scale: f32,
445    rescaling_scale: f32,
446    skip_blocks: Vec<usize>,
447}
448
449fn clamp_skip_blocks(skip_blocks: &[usize], num_layers: usize) -> Vec<usize> {
450    skip_blocks
451        .iter()
452        .copied()
453        .filter(|idx| *idx < num_layers)
454        .collect()
455}
456
457fn resolve_guidance_index(guidance_timesteps: &[f32], sigma: f32) -> usize {
458    guidance_timesteps
459        .iter()
460        .position(|value| *value <= sigma)
461        .unwrap_or_else(|| guidance_timesteps.len().saturating_sub(1))
462}
463
464fn resolve_step_schedule(
465    pass: &LtxPassConfig,
466    sigmas: &[f32],
467    num_layers: usize,
468) -> Vec<LtxResolvedStep> {
469    sigmas
470        .iter()
471        .map(|sigma| {
472            let mapped = pass
473                .guidance
474                .guidance_timesteps
475                .as_ref()
476                .map(|timesteps| resolve_guidance_index(timesteps, *sigma))
477                .unwrap_or(0);
478            let value_at = |values: &[f32]| -> f32 {
479                if values.len() == 1 {
480                    values[0]
481                } else {
482                    values[mapped.min(values.len() - 1)]
483                }
484            };
485            let skip_blocks = if pass.guidance.skip_block_list.is_empty() {
486                Vec::new()
487            } else if pass.guidance.skip_block_list.len() == 1 {
488                clamp_skip_blocks(&pass.guidance.skip_block_list[0], num_layers)
489            } else {
490                clamp_skip_blocks(
491                    &pass.guidance.skip_block_list
492                        [mapped.min(pass.guidance.skip_block_list.len() - 1)],
493                    num_layers,
494                )
495            };
496            LtxResolvedStep {
497                guidance_scale: value_at(&pass.guidance.guidance_scale),
498                stg_scale: value_at(&pass.guidance.stg_scale),
499                rescaling_scale: value_at(&pass.guidance.rescaling_scale),
500                skip_blocks,
501            }
502        })
503        .collect()
504}
505
506fn std_over_dims_except0_keepdim(x: &Tensor) -> Result<Tensor> {
507    let rank = x.rank();
508    if rank < 2 {
509        bail!("std_over_dims_except0_keepdim expects rank >= 2, got {rank}");
510    }
511    let b = x.dim(0)?;
512    let flat = x.flatten_from(1)?;
513    let var = flat.var_keepdim(1)?;
514    let std = var.sqrt()?;
515    let mut shape = Vec::with_capacity(rank);
516    shape.push(b);
517    shape.extend(std::iter::repeat_n(1usize, rank - 1));
518    Ok(std.reshape(shape)?)
519}
520
521fn rescale_noise_cfg(
522    noise_cfg: &Tensor,
523    noise_pred_text: &Tensor,
524    guidance_rescale: f32,
525) -> Result<Tensor> {
526    let std_text = std_over_dims_except0_keepdim(noise_pred_text)?;
527    let std_cfg = std_over_dims_except0_keepdim(noise_cfg)?;
528    let ratio = std_text.broadcast_div(&std_cfg)?;
529    let noise_pred_rescaled = noise_cfg.broadcast_mul(&ratio)?;
530    let a = noise_pred_rescaled.affine(guidance_rescale as f64, 0.0)?;
531    let b = noise_cfg.affine((1.0 - guidance_rescale) as f64, 0.0)?;
532    Ok(a.broadcast_add(&b)?)
533}
534
535fn cfg_star_rescale_uncond(noise_pred_uncond: &Tensor, noise_pred_text: &Tensor) -> Result<Tensor> {
536    let batch = noise_pred_text.dim(0)?;
537    let positive_flat = noise_pred_text.flatten_from(1)?;
538    let negative_flat = noise_pred_uncond.flatten_from(1)?;
539    let dot = positive_flat
540        .broadcast_mul(&negative_flat)?
541        .sum_keepdim(1)?;
542    let squared = negative_flat.sqr()?.sum_keepdim(1)?.affine(1.0, 1e-8)?;
543    let alpha = dot.broadcast_div(&squared)?;
544    let alpha = alpha.reshape((batch, 1, 1))?;
545    Ok(noise_pred_uncond.broadcast_mul(&alpha.broadcast_as(noise_pred_uncond.shape())?)?)
546}
547
548fn create_skip_layer_mask(
549    num_layers: usize,
550    batch_size: usize,
551    layers_to_skip: &[usize],
552    device: &Device,
553) -> Result<Option<Tensor>> {
554    let layers_to_skip = clamp_skip_blocks(layers_to_skip, num_layers);
555    if layers_to_skip.is_empty() {
556        return Ok(None);
557    }
558
559    let mut mask_data = vec![0.0f32; num_layers * batch_size];
560    for &layer_idx in &layers_to_skip {
561        for batch_idx in 0..batch_size {
562            mask_data[layer_idx * batch_size + batch_idx] = 1.0;
563        }
564    }
565    Ok(Some(Tensor::from_vec(
566        mask_data,
567        (num_layers, batch_size),
568        device,
569    )?))
570}
571
572fn tone_map_latents(latents: &Tensor, compression: f32) -> Result<Tensor> {
573    if compression == 0.0 {
574        return Ok(latents.clone());
575    }
576    if !(0.0..=1.0).contains(&compression) {
577        bail!("tone map compression must be in [0, 1], got {compression}");
578    }
579    let scale_factor = compression * 0.75;
580    let abs_latents = latents.abs()?;
581    let sigmoid_term = abs_latents
582        .affine(1.0, -1.0)?
583        .affine((4.0 * scale_factor) as f64, 0.0)?;
584    let sigmoid_term = candle_nn::ops::sigmoid(&sigmoid_term)?;
585    let scales = sigmoid_term.affine((-0.8 * scale_factor) as f64, 1.0)?;
586    Ok(latents.broadcast_mul(&scales)?)
587}
588
589fn normalize_latents_with_vae(latents: &Tensor, vae: &AutoencoderKLLtxVideo) -> Result<Tensor> {
590    let c = latents.dim(1)?;
591    let mean = vae
592        .latents_mean()
593        .reshape((1, c, 1, 1, 1))?
594        .to_device(latents.device())?
595        .to_dtype(latents.dtype())?;
596    let std = vae
597        .latents_std()
598        .reshape((1, c, 1, 1, 1))?
599        .to_device(latents.device())?
600        .to_dtype(latents.dtype())?;
601    Ok(latents.broadcast_sub(&mean)?.broadcast_div(&std)?)
602}
603
604fn denormalize_latents_with_vae(latents: &Tensor, vae: &AutoencoderKLLtxVideo) -> Result<Tensor> {
605    let c = latents.dim(1)?;
606    let mean = vae
607        .latents_mean()
608        .reshape((1, c, 1, 1, 1))?
609        .to_device(latents.device())?
610        .to_dtype(latents.dtype())?;
611    let std = vae
612        .latents_std()
613        .reshape((1, c, 1, 1, 1))?
614        .to_device(latents.device())?
615        .to_dtype(latents.dtype())?;
616    Ok(latents.broadcast_mul(&std)?.broadcast_add(&mean)?)
617}
618
619fn adain_filter_latents(latents: &Tensor, reference_latents: &Tensor) -> Result<Tensor> {
620    let latents_f32 = latents.to_dtype(DType::F32)?;
621    let reference_f32 = reference_latents.to_dtype(DType::F32)?;
622
623    let latents_flat = latents_f32.flatten_from(2)?;
624    let reference_flat = reference_f32.flatten_from(2)?;
625
626    let lat_mean = latents_flat.mean_keepdim(2)?;
627    let lat_std = latents_flat.var_keepdim(2)?.affine(1.0, 1e-6)?.sqrt()?;
628    let ref_mean = reference_flat.mean_keepdim(2)?;
629    let ref_std = reference_flat.var_keepdim(2)?.affine(1.0, 1e-6)?.sqrt()?;
630
631    let filtered = latents_flat
632        .broadcast_sub(&lat_mean)?
633        .broadcast_div(&lat_std)?
634        .broadcast_mul(&ref_std)?
635        .broadcast_add(&ref_mean)?
636        .reshape(latents.shape())?;
637
638    Ok(filtered.to_dtype(latents.dtype())?)
639}
640
641// ---------------------------------------------------------------------------
642// Loaded state
643// ---------------------------------------------------------------------------
644
645#[allow(dead_code)]
646struct LoadedLtxVideo {
647    transformer: Option<LtxVideoTransformer3DModel>,
648    vae: Option<AutoencoderKLLtxVideo>,
649    device: Device,
650    dtype: DType,
651}
652
653// ---------------------------------------------------------------------------
654// Engine
655// ---------------------------------------------------------------------------
656
657#[allow(dead_code)]
658pub struct LtxVideoEngine {
659    base: EngineBase<LoadedLtxVideo>,
660    t5_variant: Option<String>,
661    shared_pool: Option<Arc<Mutex<SharedPool>>>,
662    pending_placement: Option<mold_core::types::DevicePlacement>,
663    /// `Some(is_native_format)` when built via `from_single_file`.
664    /// - `true` → native LTX key format; `remap_official_ltx_transformer_key`
665    ///   is applied at load time regardless of the transformer file name.
666    /// - `false` → diffusers format; no remap.
667    ///
668    /// `None` for HF diffusers-layout models (filename-based detection).
669    single_file_native_format: Option<bool>,
670    /// `true` when the single-file checkpoint contains `vae.*` keys.
671    /// `load_vae` opens `paths.vae` under `vb.pp("vae")` when set.
672    vae_in_checkpoint: bool,
673}
674
675impl LtxVideoEngine {
676    pub fn new(
677        model_name: String,
678        paths: ModelPaths,
679        t5_variant: Option<String>,
680        load_strategy: LoadStrategy,
681        gpu_ordinal: usize,
682        shared_pool: Option<Arc<Mutex<SharedPool>>>,
683    ) -> Self {
684        Self {
685            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
686            t5_variant,
687            shared_pool,
688            pending_placement: None,
689            single_file_native_format: None,
690            vae_in_checkpoint: false,
691        }
692    }
693
694    /// Construct an LTX-Video engine from a Civitai single-file safetensors
695    /// checkpoint (phase 5).
696    ///
697    /// Header-parses `checkpoint` via `ltx_video::single_file::load` to
698    /// detect the key format (native vs. diffusers) and VAE presence.
699    ///
700    /// - **Combined checkpoints** (transformer + `vae.*` keys): both
701    ///   components are loaded from `checkpoint`; `vae_path` is ignored.
702    /// - **Transformer-only checkpoints**: the `ltx-video-vae` companion
703    ///   must have been resolved and passed as `vae_path`.
704    ///
705    /// `t5_tokenizer_path` is the path to a companion-pulled T5 tokenizer
706    /// JSON (set via `populate_companion_paths` in the catalog bridge).
707    #[allow(clippy::too_many_arguments)]
708    pub fn from_single_file(
709        model_name: String,
710        checkpoint: PathBuf,
711        vae_path: Option<PathBuf>,
712        t5_encoder_path: Option<PathBuf>,
713        t5_tokenizer_path: Option<PathBuf>,
714        t5_variant: Option<String>,
715        load_strategy: LoadStrategy,
716        gpu_ordinal: usize,
717        shared_pool: Option<Arc<Mutex<SharedPool>>>,
718    ) -> anyhow::Result<Self> {
719        if !checkpoint.exists() {
720            anyhow::bail!(
721                "single-file LTX-Video checkpoint not found: {}",
722                checkpoint.display()
723            );
724        }
725
726        let bundle = super::single_file::load(&checkpoint).map_err(|e| {
727            anyhow::anyhow!(
728                "failed to parse single-file LTX-Video checkpoint {}: {e}",
729                checkpoint.display()
730            )
731        })?;
732
733        let is_native = bundle.format == super::single_file::LtxKeyFormat::Native;
734
735        // Resolve VAE: embedded in the combined checkpoint, or external companion.
736        let (resolved_vae, vae_in_checkpoint) = if bundle.has_vae {
737            (checkpoint.clone(), true)
738        } else {
739            let vae = vae_path.ok_or_else(|| {
740                anyhow::anyhow!(
741                    "LTX-Video checkpoint {} contains no VAE weights (`vae.*` keys). \
742                     Pull the `ltx-video-vae` companion first: `mold pull ltx-video-vae`",
743                    checkpoint.display()
744                )
745            })?;
746            if !vae.exists() {
747                anyhow::bail!(
748                    "ltx-video-vae companion not on disk: {}. \
749                     Run `mold pull ltx-video-vae` to download it.",
750                    vae.display()
751                );
752            }
753            (vae, false)
754        };
755
756        let paths = ModelPaths {
757            transformer: checkpoint.clone(),
758            transformer_shards: Vec::new(),
759            vae: resolved_vae,
760            spatial_upscaler: None,
761            temporal_upscaler: None,
762            distilled_lora: None,
763            t5_encoder: t5_encoder_path,
764            clip_encoder: None,
765            t5_tokenizer: t5_tokenizer_path,
766            clip_tokenizer: None,
767            clip_encoder_2: None,
768            clip_tokenizer_2: None,
769            text_encoder_files: Vec::new(),
770            text_tokenizer: None,
771            decoder: None,
772        };
773
774        Ok(Self {
775            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
776            t5_variant,
777            shared_pool,
778            pending_placement: None,
779            single_file_native_format: Some(is_native),
780            vae_in_checkpoint,
781        })
782    }
783}
784
785// ---------------------------------------------------------------------------
786// InferenceEngine trait
787// ---------------------------------------------------------------------------
788
789impl LtxVideoEngine {
790    fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
791        let start = Instant::now();
792        let preset = LtxModelPreset::for_model(&self.base.model_name)?;
793
794        // Video parameters with defaults
795        let num_frames = req.frames.unwrap_or(25);
796        let fps = req.fps.unwrap_or(24);
797        let steps = req.steps;
798        let guidance = req.guidance;
799
800        // Validate frame count: must be 8n+1
801        if !(num_frames.wrapping_sub(1)).is_multiple_of(8) {
802            bail!(
803                "frame count must be 8n+1 (9, 17, 25, 33, ...), got {}",
804                num_frames
805            );
806        }
807
808        let seed = req.seed.unwrap_or_else(rand_seed);
809        let width = req.width;
810        let height = req.height;
811
812        // Validate dimensions are multiples of 32 (VAE spatial compression)
813        if !width.is_multiple_of(VAE_SPATIAL_COMPRESSION as u32)
814            || !height.is_multiple_of(VAE_SPATIAL_COMPRESSION as u32)
815        {
816            bail!(
817                "LTX Video requires width and height to be multiples of {}, got {}x{}",
818                VAE_SPATIAL_COMPRESSION,
819                width,
820                height
821            );
822        }
823
824        // Latent dimensions
825        let latent_h = height as usize / VAE_SPATIAL_COMPRESSION;
826        let latent_w = width as usize / VAE_SPATIAL_COMPRESSION;
827        let latent_f = (num_frames as usize - 1) / VAE_TEMPORAL_COMPRESSION + 1;
828        // Always use sequential mode for video (high VRAM usage)
829        self.generate_sequential(
830            req, &preset, seed, num_frames, fps, steps, guidance, width, height, latent_h,
831            latent_w, latent_f, start,
832        )
833    }
834}
835
836impl crate::engine::InferenceEngine for LtxVideoEngine {
837    fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
838        self.pending_placement = req.placement.clone();
839        let result = self.generate_inner(req);
840        self.pending_placement = None;
841        result
842    }
843
844    fn model_name(&self) -> &str {
845        &self.base.model_name
846    }
847
848    fn is_loaded(&self) -> bool {
849        self.base.is_loaded()
850    }
851
852    fn load(&mut self) -> Result<()> {
853        // Video engine always uses sequential mode — components loaded per-generate
854        Ok(())
855    }
856
857    fn unload(&mut self) {
858        self.base.unload();
859    }
860
861    fn set_on_progress(&mut self, callback: ProgressCallback) {
862        self.base.set_on_progress(callback);
863    }
864
865    fn clear_on_progress(&mut self) {
866        self.base.clear_on_progress();
867    }
868
869    fn model_paths(&self) -> Option<&ModelPaths> {
870        Some(&self.base.paths)
871    }
872
873    fn as_chain_renderer(&mut self) -> Option<&mut dyn crate::ltx2::ChainStageRenderer> {
874        Some(self)
875    }
876}
877
878// ---------------------------------------------------------------------------
879// Chain rendering (fallback for non-Ltx2 video engines)
880// ---------------------------------------------------------------------------
881// LTX-Video has no img2vid path, so it can't anchor the denoise on a previous
882// stage's last frames the way LTX-2 does. The fallback renders each stage
883// independently and relies on the stitch layer to glue clips together (Cut
884// is the natural seam; Smooth is forced to motion_tail=0 server-side which
885// makes it equivalent to Cut). Trade-off: longer videos with no temporal
886// context handoff. Subjects can drift between clips. Future work: add img2vid
887// to LtxVideoEngine and feed the carry tail's last frame as `source_image`
888// with strength~0.7.
889
890impl crate::ltx2::ChainStageRenderer for LtxVideoEngine {
891    fn render_stage(
892        &mut self,
893        stage_req: &GenerateRequest,
894        _carry: Option<&crate::ltx2::ChainTail>,
895        motion_tail_pixel_frames: u32,
896        _stage_progress: Option<&mut dyn FnMut(crate::ltx2::StageProgressEvent)>,
897    ) -> Result<crate::ltx2::StageOutcome> {
898        let start = Instant::now();
899        let frames = self.render_chain_frames_internal(stage_req)?;
900        let generation_time_ms = start.elapsed().as_millis() as u64;
901        if frames.is_empty() {
902            bail!("LtxVideoEngine.render_stage: pipeline produced zero frames");
903        }
904
905        // Tail data is unused by the next ltx-video stage (we ignore carry),
906        // but the orchestrator still threads `StageOutcome.tail` through and
907        // some `ChainTail` consumers assert `frames > 0`. Hand back the
908        // trailing frames so the type invariant holds and any future
909        // img2vid-aware fallback can use them without an interface change.
910        let tail_count = (motion_tail_pixel_frames as usize).clamp(1, frames.len());
911        let tail_frames: Vec<image::RgbImage> = frames
912            .iter()
913            .skip(frames.len() - tail_count)
914            .cloned()
915            .collect();
916
917        Ok(crate::ltx2::StageOutcome {
918            frames,
919            tail: crate::ltx2::ChainTail {
920                frames: tail_frames.len() as u32,
921                tail_rgb_frames: tail_frames,
922            },
923            // ltx-video has no audio path. Chain orchestrator validation
924            // already rejects enable_audio=true for this family.
925            audio: None,
926            generation_time_ms,
927        })
928    }
929}
930
931impl LtxVideoEngine {
932    /// Render one chain stage by piping through the standard generation
933    /// pipeline with APNG output forced, then decoding bytes back to raw
934    /// `RgbImage` frames. APNG is lossless so the round-trip preserves every
935    /// pixel; the encode/decode cost is ~tens of ms vs multi-second
936    /// inference, so it's effectively free.
937    fn render_chain_frames_internal(
938        &mut self,
939        req: &GenerateRequest,
940    ) -> Result<Vec<image::RgbImage>> {
941        let mut apng_req = req.clone();
942        apng_req.output_format = Some(OutputFormat::Apng);
943        // gif_preview encoding doubles the encode cost; chain mode never
944        // surfaces per-stage previews so skip it.
945        apng_req.gif_preview = false;
946        let response = self.generate_inner(&apng_req)?;
947        let video = response
948            .video
949            .ok_or_else(|| anyhow::anyhow!("LtxVideoEngine.generate returned no video data"))?;
950        decode_apng_to_rgb_frames(&video.data)
951    }
952}
953
954fn decode_apng_to_rgb_frames(apng_bytes: &[u8]) -> Result<Vec<image::RgbImage>> {
955    use image::AnimationDecoder;
956    let cursor = std::io::Cursor::new(apng_bytes);
957    let decoder = image::codecs::png::PngDecoder::new(cursor)
958        .map_err(|e| anyhow::anyhow!("failed to open APNG bytes: {e}"))?;
959    let apng = decoder
960        .apng()
961        .map_err(|e| anyhow::anyhow!("decoded PNG is not animated: {e}"))?;
962    let mut out = Vec::new();
963    for frame in apng.into_frames() {
964        let frame = frame.map_err(|e| anyhow::anyhow!("APNG frame decode failed: {e}"))?;
965        let rgba = frame.into_buffer();
966        let (w, h) = rgba.dimensions();
967        let mut rgb_data = Vec::with_capacity((w as usize) * (h as usize) * 3);
968        for px in rgba.pixels() {
969            rgb_data.extend_from_slice(&px.0[..3]);
970        }
971        let rgb = image::RgbImage::from_raw(w, h, rgb_data)
972            .ok_or_else(|| anyhow::anyhow!("failed to construct RgbImage from APNG frame"))?;
973        out.push(rgb);
974    }
975    Ok(out)
976}
977
978// ---------------------------------------------------------------------------
979// Sequential generation (load-use-drop each component)
980// ---------------------------------------------------------------------------
981
982impl LtxVideoEngine {
983    fn load_transformer(
984        &self,
985        preset: &LtxModelPreset,
986        device: &Device,
987        dtype: DType,
988    ) -> Result<LtxVideoTransformer3DModel> {
989        let transformer_files: Vec<std::path::PathBuf> =
990            if !self.base.paths.transformer_shards.is_empty() {
991                self.base.paths.transformer_shards.clone()
992            } else {
993                vec![self.base.paths.transformer.clone()]
994            };
995
996        let is_gguf = transformer_files
997            .first()
998            .and_then(|p| p.extension())
999            .is_some_and(|e| e == "gguf");
1000        if is_gguf {
1001            bail!("GGUF quantized LTX Video transformer is not yet supported — use :bf16 variant");
1002        }
1003        let transformer_bytes = transformer_files.iter().try_fold(0u64, |acc, path| {
1004            let metadata = std::fs::metadata(path).with_context(|| {
1005                format!("failed to stat LTX Video transformer {}", path.display())
1006            })?;
1007            Ok::<_, anyhow::Error>(acc.saturating_add(metadata.len()))
1008        })?;
1009        ltx_video_transformer_residency_guard(
1010            &self.base.model_name,
1011            preset,
1012            transformer_bytes,
1013            usable_free_vram_bytes(self.base.gpu_ordinal),
1014            device.is_cuda(),
1015        )?;
1016
1017        // SAFETY: mmap'd safetensors files are immutable model weights.
1018        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&transformer_files, dtype, device)? };
1019        // For single-file catalog entries `single_file_native_format` carries
1020        // the key-based detection result so an arbitrary Civitai filename
1021        // (not starting with "ltx") still gets the correct remap.
1022        let use_remap = match self.single_file_native_format {
1023            Some(is_native) => is_native,
1024            None => {
1025                transformer_files.len() == 1
1026                    && is_official_ltx_transformer_checkpoint(&transformer_files[0])
1027            }
1028        };
1029        let vb = if use_remap {
1030            vb.rename_f(remap_official_ltx_transformer_key)
1031        } else {
1032            vb
1033        };
1034        Ok(LtxVideoTransformer3DModel::new(
1035            &preset.transformer_config,
1036            vb,
1037        )?)
1038    }
1039
1040    fn load_vae(
1041        &self,
1042        preset: &LtxModelPreset,
1043        device: &Device,
1044        dtype: DType,
1045    ) -> Result<AutoencoderKLLtxVideo> {
1046        let vb = self.load_vae_var_builder(dtype, device)?;
1047        Ok(AutoencoderKLLtxVideo::new(preset.vae_config.clone(), vb)?)
1048    }
1049
1050    fn load_vae_cpu_tensors(&self) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
1051        if self.vae_in_checkpoint {
1052            return Ok(None);
1053        }
1054        let Some(shared_pool) = &self.shared_pool else {
1055            return Ok(None);
1056        };
1057        shared_pool
1058            .lock()
1059            .unwrap()
1060            .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.vae))
1061    }
1062
1063    fn load_vae_var_builder<'a>(&self, dtype: DType, device: &Device) -> Result<VarBuilder<'a>> {
1064        if let Some(tensors) = self.load_vae_cpu_tensors()? {
1065            return Ok(crate::encoders::park::varbuilder_from_parked(
1066                tensors.as_ref(),
1067                dtype,
1068                device,
1069            ));
1070        }
1071
1072        // SAFETY: mmap'd safetensors file is immutable model data.
1073        let vb = unsafe {
1074            VarBuilder::from_mmaped_safetensors(
1075                std::slice::from_ref(&self.base.paths.vae),
1076                dtype,
1077                device,
1078            )?
1079        };
1080        // Combined single-file checkpoints store VAE weights under the
1081        // `vae.*` namespace; standalone VAE files use root-level keys.
1082        let vb = if self.vae_in_checkpoint {
1083            vb.pp("vae")
1084        } else {
1085            vb
1086        };
1087        Ok(vb)
1088    }
1089
1090    #[allow(clippy::too_many_arguments)]
1091    fn denoise_pass(
1092        &self,
1093        transformer: &mut LtxVideoTransformer3DModel,
1094        prompt_embeds: &Tensor,
1095        attention_mask: &Tensor,
1096        uncond_embeds: Option<&Tensor>,
1097        uncond_mask: Option<&Tensor>,
1098        pass: &LtxPassConfig,
1099        scheduler_cfg: &FlowMatchEulerDiscreteSchedulerConfig,
1100        seed: u64,
1101        width: u32,
1102        height: u32,
1103        num_frames: u32,
1104        fps: u32,
1105        device: &Device,
1106        dtype: DType,
1107        progress: &crate::progress::ProgressReporter,
1108        stage_name: &str,
1109        ltx_debug: bool,
1110        initial_latents: Option<Tensor>,
1111    ) -> Result<Tensor> {
1112        let latent_h = height as usize / VAE_SPATIAL_COMPRESSION;
1113        let latent_w = width as usize / VAE_SPATIAL_COMPRESSION;
1114        let latent_f = (num_frames as usize - 1) / VAE_TEMPORAL_COMPRESSION + 1;
1115        let mut scheduler = FlowMatchEulerDiscreteScheduler::new(scheduler_cfg.clone())?;
1116
1117        scheduler.set_timesteps(
1118            if pass.custom_sigmas.is_some() {
1119                None
1120            } else {
1121                Some(pass.num_inference_steps as usize)
1122            },
1123            device,
1124            pass.custom_sigmas.as_deref(),
1125            None,
1126            None,
1127        )?;
1128
1129        let schedule_sigmas = scheduler
1130            .sigmas()
1131            .to_device(&Device::Cpu)?
1132            .to_vec1::<f32>()?;
1133        let total_steps = schedule_sigmas.len() - 1;
1134        if pass.skip_initial_inference_steps + pass.skip_final_inference_steps >= total_steps {
1135            bail!(
1136                "invalid LTX pass schedule: skip_initial={} + skip_final={} >= total_steps={}",
1137                pass.skip_initial_inference_steps,
1138                pass.skip_final_inference_steps,
1139                total_steps
1140            );
1141        }
1142        let start_step = pass.skip_initial_inference_steps;
1143        let end_step = total_steps - pass.skip_final_inference_steps;
1144        let run_sigmas = schedule_sigmas[start_step..end_step].to_vec();
1145        scheduler.set_begin_index(start_step);
1146
1147        let step_schedule =
1148            resolve_step_schedule(pass, &run_sigmas, transformer.config().num_layers);
1149        let video_coords = build_video_coords(1, latent_f, latent_h, latent_w, fps, device)?;
1150
1151        let mut latents = match initial_latents {
1152            Some(latents) => pack_initial_latents_for_second_pass(&latents)?,
1153            None => {
1154                let noise = seeded_randn(
1155                    seed,
1156                    &[1, LATENT_CHANNELS, latent_f, latent_h, latent_w],
1157                    device,
1158                    DType::F32,
1159                )?;
1160                pack_latents(&noise, PATCH_SIZE, PATCH_SIZE_T)?
1161            }
1162        };
1163
1164        progress.stage_start(stage_name);
1165        let denoise_start = Instant::now();
1166
1167        for (step, sigma) in run_sigmas.iter().copied().enumerate() {
1168            let step_start = Instant::now();
1169            let resolved = &step_schedule[step];
1170            let batch = latents.dim(0)?;
1171            let timestep_t = Tensor::full(sigma, (batch,), device)?.to_dtype(dtype)?;
1172            let latents_input = latents.to_dtype(dtype)?;
1173
1174            let do_cfg = resolved.guidance_scale > 1.0 && uncond_embeds.is_some();
1175            let do_stg = resolved.stg_scale > 0.0 && !resolved.skip_blocks.is_empty();
1176
1177            if do_stg {
1178                transformer.set_skip_block_list(vec![]);
1179            } else {
1180                transformer.set_skip_block_list(resolved.skip_blocks.clone());
1181            }
1182
1183            let cond_pred = transformer.forward(
1184                &latents_input,
1185                prompt_embeds,
1186                &timestep_t,
1187                Some(attention_mask),
1188                latent_f,
1189                latent_h,
1190                latent_w,
1191                None,
1192                Some(&video_coords),
1193                None,
1194            )?;
1195            let cond_f32 = cond_pred.to_dtype(DType::F32)?;
1196            let mut combined = cond_f32.clone();
1197
1198            if do_cfg {
1199                let uncond_pred = transformer.forward(
1200                    &latents_input,
1201                    uncond_embeds.expect("checked above"),
1202                    &timestep_t,
1203                    uncond_mask,
1204                    latent_f,
1205                    latent_h,
1206                    latent_w,
1207                    None,
1208                    Some(&video_coords),
1209                    None,
1210                )?;
1211                let mut uncond_f32 = uncond_pred.to_dtype(DType::F32)?;
1212                if pass.guidance.cfg_star_rescale {
1213                    uncond_f32 = cfg_star_rescale_uncond(&uncond_f32, &cond_f32)?;
1214                }
1215                let diff = cond_f32.broadcast_sub(&uncond_f32)?;
1216                combined =
1217                    uncond_f32.broadcast_add(&diff.affine(resolved.guidance_scale as f64, 0.0)?)?;
1218            }
1219
1220            if do_stg {
1221                let skip_layer_mask = create_skip_layer_mask(
1222                    transformer.config().num_layers,
1223                    batch,
1224                    &resolved.skip_blocks,
1225                    device,
1226                )?;
1227                let perturbed = transformer.forward(
1228                    &latents_input,
1229                    prompt_embeds,
1230                    &timestep_t,
1231                    Some(attention_mask),
1232                    latent_f,
1233                    latent_h,
1234                    latent_w,
1235                    None,
1236                    Some(&video_coords),
1237                    skip_layer_mask.as_ref(),
1238                )?;
1239                let perturbed_f32 = perturbed.to_dtype(DType::F32)?;
1240                let diff_stg = cond_f32.broadcast_sub(&perturbed_f32)?;
1241                combined =
1242                    combined.broadcast_add(&diff_stg.affine(resolved.stg_scale as f64, 0.0)?)?;
1243                if pass.guidance.cfg_star_rescale && resolved.rescaling_scale > 0.0 {
1244                    combined = rescale_noise_cfg(&combined, &cond_f32, resolved.rescaling_scale)?;
1245                }
1246            }
1247
1248            let model_output =
1249                if transformer.config().out_channels / 2 == transformer.config().in_channels {
1250                    combined
1251                        .chunk(2, 2)?
1252                        .into_iter()
1253                        .next()
1254                        .expect("out_channels / 2 == in_channels implies chunk(2) succeeds")
1255                } else {
1256                    combined
1257                };
1258
1259            if ltx_debug {
1260                let out_rms = model_output.sqr()?.mean_all()?.to_scalar::<f32>()?.sqrt();
1261                let lat_rms = latents.sqr()?.mean_all()?.to_scalar::<f32>()?.sqrt();
1262                if step < 3 || step == run_sigmas.len() - 1 {
1263                    progress.info(&format!(
1264                        "Pass {stage_name} step {}: sigma={:.4}, guidance={:.2}, stg={:.2}, lat_rms={:.4}, out_rms={:.4}",
1265                        step,
1266                        sigma,
1267                        resolved.guidance_scale,
1268                        resolved.stg_scale,
1269                        lat_rms,
1270                        out_rms
1271                    ));
1272                }
1273            }
1274
1275            latents = scheduler
1276                .step(&model_output, sigma, &latents, None)?
1277                .prev_sample;
1278
1279            progress.emit(ProgressEvent::DenoiseStep {
1280                step: step + 1,
1281                total: run_sigmas.len(),
1282                elapsed: step_start.elapsed(),
1283            });
1284        }
1285
1286        progress.stage_done(stage_name, denoise_start.elapsed());
1287        unpack_latents(
1288            &latents,
1289            latent_f,
1290            latent_h,
1291            latent_w,
1292            PATCH_SIZE,
1293            PATCH_SIZE_T,
1294        )
1295    }
1296
1297    #[allow(clippy::too_many_arguments)]
1298    fn generate_sequential(
1299        &mut self,
1300        req: &GenerateRequest,
1301        preset: &LtxModelPreset,
1302        seed: u64,
1303        num_frames: u32,
1304        fps: u32,
1305        steps: u32,
1306        guidance: f64,
1307        width: u32,
1308        height: u32,
1309        _latent_h: usize,
1310        _latent_w: usize,
1311        _latent_f: usize,
1312        start: Instant,
1313    ) -> Result<GenerateResponse> {
1314        let progress = &self.base.progress;
1315        let paths = &self.base.paths;
1316        let ltx_debug = std::env::var("MOLD_LTX_DEBUG").is_ok_and(|v| v == "1");
1317
1318        if preset.mode == LtxPipelineMode::Multiscale && paths.spatial_upscaler.is_none() {
1319            bail!("LTX 0.9.8 requires a spatial upscaler asset in the pulled model files");
1320        }
1321
1322        // Select device
1323        let device = crate::device::create_device(self.base.gpu_ordinal, progress)?;
1324        let dtype = gpu_dtype(&device);
1325
1326        progress.info(&format!(
1327            "LTX Video: {}×{} × {} frames, {} steps, seed {}",
1328            width, height, num_frames, steps, seed
1329        ));
1330        if preset.mode == LtxPipelineMode::Multiscale {
1331            progress.info("Using the full 0.9.8 multiscale refinement path.");
1332            if steps != preset.base_pass.num_inference_steps {
1333                progress.info(&format!(
1334                    "Ignoring --steps={} for multiscale LTX preset {}; using the preset schedule instead.",
1335                    steps, self.base.model_name
1336                ));
1337            }
1338            if guidance != preset.base_pass.guidance.guidance_scale[0] as f64 {
1339                progress.info(&format!(
1340                    "Ignoring --guidance={guidance:.2} for multiscale LTX preset {}; using the preset guidance schedule instead.",
1341                    self.base.model_name
1342                ));
1343            }
1344        }
1345
1346        // ---------------------------------------------------------------
1347        // Step 1: Encode prompt with T5-XXL
1348        // ---------------------------------------------------------------
1349        progress.stage_start("Loading T5-XXL encoder");
1350        let t5_start = Instant::now();
1351
1352        let t5_encoder_path = paths
1353            .t5_encoder
1354            .as_ref()
1355            .ok_or_else(|| anyhow::anyhow!("T5 encoder path not configured"))?;
1356        let t5_tokenizer_path = paths
1357            .t5_tokenizer
1358            .as_ref()
1359            .ok_or_else(|| anyhow::anyhow!("T5 tokenizer path not configured"))?;
1360
1361        let tier1 = self
1362            .pending_placement
1363            .as_ref()
1364            .map(|p| p.text_encoders)
1365            .unwrap_or_default();
1366        let t5_device = crate::device::resolve_device(Some(tier1), || Ok(device.clone()))?;
1367        let cached_t5_tokenizer = self
1368            .shared_pool
1369            .as_ref()
1370            .map(|pool| pool.lock().unwrap().load_tokenizer(t5_tokenizer_path))
1371            .transpose()?;
1372        let mut t5 = crate::encoders::t5::T5Encoder::load_with_tokenizer(
1373            t5_encoder_path,
1374            t5_tokenizer_path,
1375            &t5_device,
1376            dtype,
1377            progress,
1378            cached_t5_tokenizer,
1379        )?;
1380        progress.stage_done("Loading T5-XXL encoder", t5_start.elapsed());
1381
1382        progress.stage_start("Encoding prompt");
1383        let encode_start = Instant::now();
1384        let prompt_embeds = t5.encode(&req.prompt, &t5_device, dtype)?;
1385        let prompt_embeds = prompt_embeds.to_device(&device)?;
1386        // prompt_embeds: [1, seq_len, 4096] (T5 encoder already adds batch dim)
1387        progress.stage_done("Encoding prompt", encode_start.elapsed());
1388
1389        // Build attention mask (all ones — no padding for single prompt)
1390        let prompt_seq_len = prompt_embeds.dim(1)?;
1391        let attention_mask =
1392            Tensor::ones((1, prompt_seq_len), DType::F32, &device)?.to_dtype(dtype)?;
1393
1394        let needs_uncond = match preset.mode {
1395            LtxPipelineMode::Base => guidance > 1.0,
1396            LtxPipelineMode::Multiscale => preset
1397                .multiscale
1398                .as_ref()
1399                .into_iter()
1400                .flat_map(|cfg| [&cfg.first_pass, &cfg.second_pass])
1401                .any(|pass| {
1402                    pass.guidance
1403                        .guidance_scale
1404                        .iter()
1405                        .any(|scale| *scale > 1.0)
1406                }),
1407        };
1408
1409        let (uncond_embeds, uncond_mask) = if needs_uncond {
1410            progress.stage_start("Encoding negative prompt (CFG)");
1411            let ue = t5.encode("", &t5_device, dtype)?;
1412            let ue = ue.to_device(&device)?;
1413            let ue_seq = ue.dim(1)?;
1414            let um = Tensor::ones((1, ue_seq), DType::F32, &device)?.to_dtype(dtype)?;
1415            progress.stage_done("Encoding negative prompt (CFG)", encode_start.elapsed());
1416            (Some(ue), Some(um))
1417        } else {
1418            (None, None)
1419        };
1420
1421        // Drop T5 to free VRAM
1422        drop(t5);
1423        device.synchronize()?;
1424        progress.info("T5 encoder dropped, VRAM freed");
1425
1426        let (mut latents, decode_width, decode_height, tone_map_compression_ratio) = match preset
1427            .mode
1428        {
1429            LtxPipelineMode::Base => {
1430                progress.stage_start("Loading LTX Video transformer");
1431                let transformer_start = Instant::now();
1432                let mut transformer = self.load_transformer(preset, &device, dtype)?;
1433                progress.stage_done("Loading LTX Video transformer", transformer_start.elapsed());
1434
1435                let mut pass = preset.base_pass.clone();
1436                pass.num_inference_steps = steps;
1437                pass.guidance.guidance_scale = vec![guidance as f32];
1438
1439                let latents = self.denoise_pass(
1440                    &mut transformer,
1441                    &prompt_embeds,
1442                    &attention_mask,
1443                    uncond_embeds.as_ref(),
1444                    uncond_mask.as_ref().map(|m| m as &Tensor),
1445                    &pass,
1446                    &preset.scheduler_config,
1447                    seed,
1448                    width,
1449                    height,
1450                    num_frames,
1451                    fps,
1452                    &device,
1453                    dtype,
1454                    progress,
1455                    "Denoising",
1456                    ltx_debug,
1457                    None,
1458                )?;
1459                drop(transformer);
1460                device.synchronize()?;
1461                (latents, width, height, pass.tone_map_compression_ratio)
1462            }
1463            LtxPipelineMode::Multiscale => {
1464                let multiscale = preset.multiscale.as_ref().expect("multiscale preset");
1465                let first_width = ((width as f32 * multiscale.downscale_factor) as u32)
1466                    / VAE_SPATIAL_COMPRESSION as u32
1467                    * VAE_SPATIAL_COMPRESSION as u32;
1468                let first_height = ((height as f32 * multiscale.downscale_factor) as u32)
1469                    / VAE_SPATIAL_COMPRESSION as u32
1470                    * VAE_SPATIAL_COMPRESSION as u32;
1471
1472                progress.stage_start("Loading LTX Video transformer");
1473                let transformer_start = Instant::now();
1474                let mut first_transformer = self.load_transformer(preset, &device, dtype)?;
1475                progress.stage_done("Loading LTX Video transformer", transformer_start.elapsed());
1476
1477                let first_pass_latents = self.denoise_pass(
1478                    &mut first_transformer,
1479                    &prompt_embeds,
1480                    &attention_mask,
1481                    uncond_embeds.as_ref(),
1482                    uncond_mask.as_ref().map(|m| m as &Tensor),
1483                    &multiscale.first_pass,
1484                    &preset.scheduler_config,
1485                    seed,
1486                    first_width,
1487                    first_height,
1488                    num_frames,
1489                    fps,
1490                    &device,
1491                    dtype,
1492                    progress,
1493                    "Denoising First Pass",
1494                    ltx_debug,
1495                    None,
1496                )?;
1497                drop(first_transformer);
1498                device.synchronize()?;
1499
1500                progress.stage_start("Loading spatial upscaler");
1501                let spatial_start = Instant::now();
1502                let vae = self.load_vae(preset, &device, dtype)?;
1503                let upsampler = LatentUpsampler::load(
1504                    paths.spatial_upscaler.as_ref().expect("checked above"),
1505                    dtype,
1506                    &device,
1507                )?;
1508                progress.stage_done("Loading spatial upscaler", spatial_start.elapsed());
1509
1510                progress.stage_start("Refining multiscale pass");
1511                let refine_start = Instant::now();
1512                let first_pass_denorm = cast_latents_for_multiscale_upsampler(
1513                    &denormalize_latents_with_vae(&first_pass_latents, &vae)?,
1514                    dtype,
1515                )?;
1516                let upsampled_latents =
1517                    normalize_latents_with_vae(&upsampler.forward(&first_pass_denorm)?, &vae)?;
1518                let upsampled_latents =
1519                    adain_filter_latents(&upsampled_latents, &first_pass_latents)?;
1520                progress.stage_done("Refining multiscale pass", refine_start.elapsed());
1521                drop(upsampler);
1522                drop(vae);
1523                device.synchronize()?;
1524
1525                let second_width = first_width * 2;
1526                let second_height = first_height * 2;
1527                progress.stage_start("Loading LTX Video transformer");
1528                let transformer_start = Instant::now();
1529                let mut second_transformer = self.load_transformer(preset, &device, dtype)?;
1530                progress.stage_done("Loading LTX Video transformer", transformer_start.elapsed());
1531
1532                let latents = self.denoise_pass(
1533                    &mut second_transformer,
1534                    &prompt_embeds,
1535                    &attention_mask,
1536                    uncond_embeds.as_ref(),
1537                    uncond_mask.as_ref().map(|m| m as &Tensor),
1538                    &multiscale.second_pass,
1539                    &preset.scheduler_config,
1540                    seed,
1541                    second_width,
1542                    second_height,
1543                    num_frames,
1544                    fps,
1545                    &device,
1546                    dtype,
1547                    progress,
1548                    "Denoising Second Pass",
1549                    ltx_debug,
1550                    Some(upsampled_latents),
1551                )?;
1552                drop(second_transformer);
1553                device.synchronize()?;
1554                (
1555                    latents,
1556                    second_width,
1557                    second_height,
1558                    multiscale.second_pass.tone_map_compression_ratio,
1559                )
1560            }
1561        };
1562
1563        // ---------------------------------------------------------------
1564        // Step 5: Load VAE and decode
1565        // ---------------------------------------------------------------
1566        progress.stage_start("Loading VAE decoder");
1567        let vae_start = Instant::now();
1568        let vae = self.load_vae(preset, &device, dtype)?;
1569        progress.stage_done("Loading VAE decoder", vae_start.elapsed());
1570
1571        progress.stage_start("Decoding video frames");
1572        let decode_start = Instant::now();
1573
1574        let decode_timestep = if vae.config().timestep_conditioning {
1575            if preset.decode_noise_scale > 0.0 {
1576                let noise =
1577                    seeded_randn(seed ^ 0xdec0de, latents.shape().dims(), &device, DType::F32)?;
1578                latents = (&latents * (1.0 - preset.decode_noise_scale as f64))?
1579                    .broadcast_add(&(noise * preset.decode_noise_scale as f64)?)?;
1580            }
1581            Some(Tensor::full(preset.decode_timestep, (1,), &device)?.to_dtype(dtype)?)
1582        } else {
1583            None
1584        };
1585
1586        latents = tone_map_latents(&latents, tone_map_compression_ratio)?;
1587
1588        // Un-normalize latents immediately before VAE decode.
1589        latents = denormalize_latents_with_vae(&latents, &vae)?;
1590
1591        if ltx_debug {
1592            let l_f32 = latents.to_dtype(DType::F32)?;
1593            progress.info(&format!(
1594                "Latents pre-VAE (un-normalized): mean={:.4}, std={:.4}",
1595                l_f32.mean_all()?.to_scalar::<f32>()?,
1596                l_f32.flatten_all()?.var(0)?.to_scalar::<f32>()?.sqrt()
1597            ));
1598        }
1599
1600        latents = latents.to_dtype(dtype)?;
1601        let (_dec_output, video) = vae.decode(&latents, decode_timestep.as_ref(), false, false)?;
1602        // video: [B, 3, F, H, W] in model dtype
1603        if ltx_debug {
1604            let v_f32 = video.to_dtype(DType::F32)?;
1605            progress.info(&format!(
1606                "VAE output: shape={:?}, mean={:.4}, min={:.4}, max={:.4}",
1607                v_f32.shape(),
1608                v_f32.mean_all()?.to_scalar::<f32>()?,
1609                v_f32.flatten_all()?.min(0)?.to_scalar::<f32>()?,
1610                v_f32.flatten_all()?.max(0)?.to_scalar::<f32>()?
1611            ));
1612        }
1613
1614        progress.stage_done("Decoding video frames", decode_start.elapsed());
1615
1616        // Drop VAE
1617        drop(vae);
1618        device.synchronize()?;
1619
1620        // ---------------------------------------------------------------
1621        // Step 6: Post-process and encode video
1622        // ---------------------------------------------------------------
1623        // Default to APNG for video output (lossless, metadata-rich)
1624        let output_format = if req.resolved_output_format().is_video() {
1625            req.resolved_output_format()
1626        } else {
1627            OutputFormat::Apng
1628        };
1629        let format_name = output_format.extension().to_uppercase();
1630        progress.stage_start(&format!("Encoding {format_name}"));
1631        let encode_start = Instant::now();
1632
1633        // Convert to [0, 255] u8
1634        let video = video.to_dtype(DType::F32)?;
1635        let video = ((video.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
1636        let video = video.i(0)?; // Remove batch dim: [3, F, H, W]
1637
1638        // Extract individual frames as RgbImage
1639        let num_output_frames = video.dim(1)?;
1640        let mut frames = Vec::with_capacity(num_output_frames);
1641        for f in 0..num_output_frames {
1642            let frame = video.i((.., f, .., ..))?.contiguous()?; // [3, H, W]
1643            let frame = frame.permute((1, 2, 0))?; // [H, W, 3]
1644            let frame_data: Vec<u8> = frame.flatten_all()?.to_vec1()?;
1645            let mut rgb = image::RgbImage::from_raw(decode_width, decode_height, frame_data)
1646                .ok_or_else(|| anyhow::anyhow!("failed to create frame image"))?;
1647            if decode_width != width || decode_height != height {
1648                rgb = image::imageops::resize(
1649                    &rgb,
1650                    width,
1651                    height,
1652                    image::imageops::FilterType::Triangle,
1653                );
1654            }
1655            frames.push(rgb);
1656        }
1657
1658        let video_bytes = match output_format {
1659            OutputFormat::Apng => {
1660                let metadata = video_enc::VideoMetadata {
1661                    prompt: req.prompt.clone(),
1662                    model: self.base.model_name.clone(),
1663                    seed,
1664                    steps,
1665                    guidance: req.guidance,
1666                    width,
1667                    height,
1668                    frames: num_output_frames as u32,
1669                    fps,
1670                };
1671                video_enc::encode_apng(&frames, fps, Some(&metadata))?
1672            }
1673            OutputFormat::Gif => video_enc::encode_gif(&frames, fps)?,
1674            #[cfg(feature = "webp")]
1675            OutputFormat::Webp => video_enc::encode_webp(&frames, fps)?,
1676            #[cfg(feature = "mp4")]
1677            OutputFormat::Mp4 => video_enc::encode_mp4(&frames, fps)?,
1678            #[cfg(not(feature = "webp"))]
1679            OutputFormat::Webp => {
1680                bail!("WebP output requires the 'webp' feature — rebuild with --features webp")
1681            }
1682            #[cfg(not(feature = "mp4"))]
1683            OutputFormat::Mp4 => {
1684                bail!("MP4 output requires the 'mp4' feature — rebuild with --features mp4")
1685            }
1686            _ => bail!("{format_name} is not a supported video output format"),
1687        };
1688        let thumbnail_bytes = video_enc::first_frame_png(&frames)?;
1689        // Generate a GIF preview only when the caller will use it (TUI gallery or --preview).
1690        // If the primary format is already GIF, reuse the data; otherwise encode on demand.
1691        let gif_preview = if req.gif_preview {
1692            if output_format == OutputFormat::Gif {
1693                video_bytes.clone()
1694            } else {
1695                video_enc::encode_gif(&frames, fps)?
1696            }
1697        } else {
1698            Vec::new()
1699        };
1700
1701        progress.stage_done(&format!("Encoding {format_name}"), encode_start.elapsed());
1702
1703        let generation_time_ms = start.elapsed().as_millis() as u64;
1704        progress.info(&format!(
1705            "Done: {} frames, {:.1}s total",
1706            num_output_frames,
1707            generation_time_ms as f64 / 1000.0
1708        ));
1709
1710        Ok(GenerateResponse {
1711            images: vec![],
1712            video: Some(VideoData {
1713                data: video_bytes,
1714                format: output_format,
1715                width,
1716                height,
1717                frames: num_output_frames as u32,
1718                fps,
1719                thumbnail: thumbnail_bytes,
1720                gif_preview,
1721                has_audio: false,
1722                duration_ms: None,
1723                audio_sample_rate: None,
1724                audio_channels: None,
1725            }),
1726            generation_time_ms,
1727            model: self.base.model_name.clone(),
1728            seed_used: seed,
1729            gpu: None,
1730        })
1731    }
1732}
1733
1734// ---------------------------------------------------------------------------
1735// Latent packing/unpacking (matches LTX Video pipeline)
1736// ---------------------------------------------------------------------------
1737
1738/// Pack latents from [B,C,F,H,W] → [B,S,D] where S = F*H*W, D = C*pt*p*p.
1739fn pack_latents(latents: &Tensor, patch_size: usize, patch_size_t: usize) -> Result<Tensor> {
1740    let (b, c, f, h, w) = latents.dims5()?;
1741    if f % patch_size_t != 0 || h % patch_size != 0 || w % patch_size != 0 {
1742        bail!("latent dims not divisible by patch sizes");
1743    }
1744    let f2 = f / patch_size_t;
1745    let h2 = h / patch_size;
1746    let w2 = w / patch_size;
1747
1748    // [B, C, F2, pt, H2, p, W2, p]
1749    let x = latents.reshape(&[b, c, f2, patch_size_t, h2, patch_size, w2, patch_size])?;
1750    // permute → [B, F2, H2, W2, C, pt, p, p]
1751    let x = x.permute([0, 2, 4, 6, 1, 3, 5, 7])?;
1752    // flatten last 4 → [B, F2, H2, W2, D]
1753    let x = x.flatten_from(4)?;
1754    let d = x.dim(4)?;
1755    let s = f2 * h2 * w2;
1756    Ok(x.reshape((b, s, d))?)
1757}
1758
1759/// Unpack latents from [B,S,D] → [B,C,F,H,W].
1760fn unpack_latents(
1761    latents: &Tensor,
1762    num_frames: usize,
1763    height: usize,
1764    width: usize,
1765    patch_size: usize,
1766    patch_size_t: usize,
1767) -> Result<Tensor> {
1768    let (b, _s, d) = latents.dims3()?;
1769    let denom = patch_size_t * patch_size * patch_size;
1770    if d % denom != 0 {
1771        bail!("D={d} not divisible by patch product {denom}");
1772    }
1773    let c = d / denom;
1774
1775    let x = latents.reshape(&[
1776        b,
1777        num_frames,
1778        height,
1779        width,
1780        c,
1781        patch_size_t,
1782        patch_size,
1783        patch_size,
1784    ])?;
1785    // [B, C, F2, pt, H2, p, W2, p]
1786    let x = x.permute([0, 4, 1, 5, 2, 6, 3, 7])?.contiguous()?;
1787    Ok(x.reshape((
1788        b,
1789        c,
1790        num_frames * patch_size_t,
1791        height * patch_size,
1792        width * patch_size,
1793    ))?)
1794}
1795
1796fn pack_initial_latents_for_second_pass(latents: &Tensor) -> Result<Tensor> {
1797    pack_latents(&latents.to_dtype(DType::F32)?, PATCH_SIZE, PATCH_SIZE_T)
1798}
1799
1800fn cast_latents_for_multiscale_upsampler(latents: &Tensor, dtype: DType) -> Result<Tensor> {
1801    Ok(latents.to_dtype(dtype)?)
1802}
1803
1804/// Build video coordinates for 3D RoPE: [B, seq, 3] with (frame, height, width).
1805fn build_video_coords(
1806    batch_size: usize,
1807    latent_f: usize,
1808    latent_h: usize,
1809    latent_w: usize,
1810    fps: u32,
1811    device: &Device,
1812) -> Result<Tensor> {
1813    let grid_f = Tensor::arange(0u32, latent_f as u32, device)?.to_dtype(DType::F32)?;
1814    let grid_h = Tensor::arange(0u32, latent_h as u32, device)?.to_dtype(DType::F32)?;
1815    let grid_w = Tensor::arange(0u32, latent_w as u32, device)?.to_dtype(DType::F32)?;
1816
1817    let f = grid_f
1818        .reshape((latent_f, 1, 1))?
1819        .broadcast_as((latent_f, latent_h, latent_w))?;
1820    let h = grid_h
1821        .reshape((1, latent_h, 1))?
1822        .broadcast_as((latent_f, latent_h, latent_w))?;
1823    let w = grid_w
1824        .reshape((1, 1, latent_w))?
1825        .broadcast_as((latent_f, latent_h, latent_w))?;
1826
1827    let grid = Tensor::stack(&[f, h, w], 0)?; // [3, F, H, W]
1828    let seq = latent_f * latent_h * latent_w;
1829    let grid = grid.flatten_from(1)?.transpose(0, 1)?.unsqueeze(0)?; // [1, seq, 3]
1830
1831    // Apply compression ratios to coordinates
1832    let vf = grid.i((.., .., 0))?;
1833    let vh = grid.i((.., .., 1))?;
1834    let vw = grid.i((.., .., 2))?;
1835
1836    // Temporal: (L * 8 + 1 - 8).clamp(0) / fps
1837    let ts_ratio = VAE_TEMPORAL_COMPRESSION as f64;
1838    let vf = vf
1839        .affine(ts_ratio, 1.0 - ts_ratio)?
1840        .clamp(0.0f32, 10000.0f32)?
1841        .affine(1.0 / fps as f64, 0.0)?;
1842    // Spatial: L * 32
1843    let sp_ratio = VAE_SPATIAL_COMPRESSION as f64;
1844    let vh = vh.affine(sp_ratio, 0.0)?;
1845    let vw = vw.affine(sp_ratio, 0.0)?;
1846
1847    let coords = Tensor::stack(&[vf, vh, vw], candle_core::D::Minus1)?;
1848    if batch_size > 1 {
1849        Ok(coords.broadcast_as((batch_size, seq, 3))?)
1850    } else {
1851        Ok(coords)
1852    }
1853}
1854
1855#[cfg(test)]
1856mod tests {
1857    use super::{
1858        cast_latents_for_multiscale_upsampler, is_official_ltx_transformer_checkpoint,
1859        pack_initial_latents_for_second_pass, remap_official_ltx_transformer_key, unpack_latents,
1860        LtxModelPreset, LtxPipelineMode, LtxVideoEngine, LATENT_CHANNELS,
1861        LTX_098_DISTILLED_SECOND_PASS_SIGMAS, PATCH_SIZE, PATCH_SIZE_T,
1862    };
1863    use crate::engine::LoadStrategy;
1864    use crate::shared_pool::SharedPool;
1865    use candle_core::{DType, Device, Tensor};
1866    use mold_core::ModelPaths;
1867    use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1868    use std::collections::HashMap;
1869    use std::fs;
1870    use std::path::{Path, PathBuf};
1871    use std::sync::{Arc, Mutex};
1872    use std::time::{SystemTime, UNIX_EPOCH};
1873
1874    fn temp_test_dir(prefix: &str) -> PathBuf {
1875        let suffix = SystemTime::now()
1876            .duration_since(UNIX_EPOCH)
1877            .unwrap()
1878            .as_nanos();
1879        let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
1880        fs::create_dir_all(&dir).unwrap();
1881        dir
1882    }
1883
1884    fn ltx_video_model_paths(dir: &Path, vae: PathBuf) -> ModelPaths {
1885        ModelPaths {
1886            transformer: dir.join("transformer.safetensors"),
1887            transformer_shards: vec![],
1888            vae,
1889            spatial_upscaler: None,
1890            temporal_upscaler: None,
1891            distilled_lora: None,
1892            t5_encoder: Some(dir.join("t5.safetensors")),
1893            clip_encoder: None,
1894            t5_tokenizer: Some(dir.join("tokenizer.json")),
1895            clip_tokenizer: None,
1896            clip_encoder_2: None,
1897            clip_tokenizer_2: None,
1898            text_encoder_files: vec![],
1899            text_tokenizer: None,
1900            decoder: None,
1901        }
1902    }
1903
1904    #[test]
1905    fn decode_apng_round_trips_rgb_frames() {
1906        use crate::ltx_video::video_enc::encode_apng;
1907        use image::Rgb;
1908
1909        // Build three 4x4 frames with distinct solid colors so any axis swap
1910        // or RGB ↔ BGR mistake shows up immediately.
1911        let make = |r: u8, g: u8, b: u8| {
1912            let mut img = image::RgbImage::new(4, 4);
1913            for px in img.pixels_mut() {
1914                *px = Rgb([r, g, b]);
1915            }
1916            img
1917        };
1918        let inputs = vec![make(255, 0, 0), make(0, 255, 0), make(0, 0, 255)];
1919
1920        let bytes = encode_apng(&inputs, 12, None).expect("encode");
1921        let decoded = super::decode_apng_to_rgb_frames(&bytes).expect("decode");
1922
1923        assert_eq!(decoded.len(), inputs.len());
1924        for (i, (a, b)) in inputs.iter().zip(decoded.iter()).enumerate() {
1925            assert_eq!(a.dimensions(), b.dimensions(), "frame {i} size");
1926            // Sample one pixel — APNG is lossless, so equality is exact.
1927            assert_eq!(
1928                a.get_pixel(0, 0),
1929                b.get_pixel(0, 0),
1930                "frame {i} pixel mismatch",
1931            );
1932        }
1933    }
1934
1935    #[test]
1936    fn detects_official_ltx_single_file_checkpoints() {
1937        assert!(is_official_ltx_transformer_checkpoint(Path::new(
1938            "ltxv-2b-0.9.6-distilled-04-25.safetensors"
1939        )));
1940        assert!(is_official_ltx_transformer_checkpoint(Path::new(
1941            "ltxv-13b-0.9.8-dev.safetensors"
1942        )));
1943        assert!(!is_official_ltx_transformer_checkpoint(Path::new(
1944            "diffusion_pytorch_model-00001-of-00002.safetensors"
1945        )));
1946        assert!(!is_official_ltx_transformer_checkpoint(Path::new(
1947            "transformer.gguf"
1948        )));
1949    }
1950
1951    #[test]
1952    fn remaps_official_transformer_keys_to_upstream_checkpoint_names() {
1953        assert_eq!(
1954            remap_official_ltx_transformer_key("proj_in.weight"),
1955            "model.diffusion_model.patchify_proj.weight"
1956        );
1957        assert_eq!(
1958            remap_official_ltx_transformer_key("time_embed.emb.timestep_embedder.linear_1.weight"),
1959            "model.diffusion_model.adaln_single.emb.timestep_embedder.linear_1.weight"
1960        );
1961        assert_eq!(
1962            remap_official_ltx_transformer_key("transformer_blocks.0.attn1.norm_q.weight"),
1963            "model.diffusion_model.transformer_blocks.0.attn1.q_norm.weight"
1964        );
1965        assert_eq!(
1966            remap_official_ltx_transformer_key("caption_projection.linear_2.bias"),
1967            "model.diffusion_model.caption_projection.linear_2.bias"
1968        );
1969    }
1970
1971    #[test]
1972    fn ltx_video_loads_standalone_vae_tensors_through_shared_pool() {
1973        let dir = temp_test_dir("mold-ltx-video-vae-pool");
1974        let vae_path = dir.join("vae.safetensors");
1975        let weight = 1.0f32.to_le_bytes();
1976        let mut tensors = HashMap::new();
1977        tensors.insert(
1978            "encoder.conv_in.weight".to_string(),
1979            TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
1980        );
1981        serialize_to_file(&tensors, &None, &vae_path).unwrap();
1982
1983        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1984        let pooled = shared_pool
1985            .lock()
1986            .unwrap()
1987            .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
1988            .unwrap()
1989            .unwrap();
1990
1991        let engine = LtxVideoEngine::new(
1992            "ltx-video-0.9.6:bf16".to_string(),
1993            ltx_video_model_paths(&dir, vae_path),
1994            None,
1995            LoadStrategy::Sequential,
1996            0,
1997            Some(shared_pool),
1998        );
1999
2000        let loaded = engine.load_vae_cpu_tensors().unwrap().unwrap();
2001
2002        assert!(Arc::ptr_eq(&pooled, &loaded));
2003        fs::remove_dir_all(dir).ok();
2004    }
2005
2006    #[test]
2007    fn ltx_presets_only_reference_in_range_skip_blocks() {
2008        for model_name in [
2009            "ltx-video-0.9.6:bf16",
2010            "ltx-video-0.9.6-distilled:bf16",
2011            "ltx-video-0.9.8-2b-distilled:bf16",
2012            "ltx-video-0.9.8-13b-dev:bf16",
2013            "ltx-video-0.9.8-13b-distilled:bf16",
2014        ] {
2015            let preset = LtxModelPreset::for_model(model_name).expect("preset should exist");
2016            let mut all_skip_lists = vec![preset.base_pass.guidance.skip_block_list.clone()];
2017            if let Some(multiscale) = &preset.multiscale {
2018                all_skip_lists.push(multiscale.first_pass.guidance.skip_block_list.clone());
2019                all_skip_lists.push(multiscale.second_pass.guidance.skip_block_list.clone());
2020            }
2021            for skip_list_group in all_skip_lists {
2022                for skip_list in skip_list_group {
2023                    for skip_block in skip_list {
2024                        assert!(
2025                            skip_block < preset.transformer_config.num_layers,
2026                            "{model_name} skip block {skip_block} is out of range for {} layers",
2027                            preset.transformer_config.num_layers
2028                        );
2029                    }
2030                }
2031            }
2032        }
2033    }
2034
2035    #[test]
2036    fn ltx_video_13b_bf16_fails_before_allocation_when_vram_cannot_hold_residency() {
2037        let preset = LtxModelPreset::for_model("ltx-video-0.9.8-13b-dev:bf16").unwrap();
2038        let err = super::ltx_video_transformer_residency_guard(
2039            "ltx-video-0.9.8-13b-dev:bf16",
2040            &preset,
2041            26_000_000_000,
2042            Some(24_000_000_000),
2043            true,
2044        )
2045        .unwrap_err()
2046        .to_string();
2047
2048        assert!(err.contains("MOLD_OFFLOAD"));
2049        assert!(err.contains("ltx-video-0.9.8-2b-distilled"));
2050        assert!(err.contains("--width/--height/--frames"));
2051    }
2052
2053    #[test]
2054    fn ltx_video_2b_bf16_residency_guard_does_not_reject() {
2055        let preset = LtxModelPreset::for_model("ltx-video-0.9.8-2b-distilled:bf16").unwrap();
2056        super::ltx_video_transformer_residency_guard(
2057            "ltx-video-0.9.8-2b-distilled:bf16",
2058            &preset,
2059            26_000_000_000,
2060            Some(24_000_000_000),
2061            true,
2062        )
2063        .unwrap();
2064    }
2065
2066    #[test]
2067    fn ltx_098_presets_use_multiscale_mode() {
2068        for model_name in [
2069            "ltx-video-0.9.8-2b-distilled:bf16",
2070            "ltx-video-0.9.8-13b-dev:bf16",
2071            "ltx-video-0.9.8-13b-distilled:bf16",
2072        ] {
2073            let preset = LtxModelPreset::for_model(model_name).expect("preset should exist");
2074            assert_eq!(preset.mode, LtxPipelineMode::Multiscale, "{model_name}");
2075            assert!(preset.multiscale.is_some(), "{model_name}");
2076        }
2077    }
2078
2079    #[test]
2080    fn ltx_098_distilled_second_pass_uses_upstream_sigmas() {
2081        for model_name in [
2082            "ltx-video-0.9.8-2b-distilled:bf16",
2083            "ltx-video-0.9.8-13b-distilled:bf16",
2084        ] {
2085            let preset = LtxModelPreset::for_model(model_name).expect("preset should exist");
2086            let multiscale = preset.multiscale.as_ref().expect("multiscale preset");
2087            assert_eq!(
2088                multiscale.second_pass.custom_sigmas.as_deref(),
2089                Some(LTX_098_DISTILLED_SECOND_PASS_SIGMAS),
2090                "{model_name}"
2091            );
2092        }
2093    }
2094
2095    #[test]
2096    fn multiscale_handoff_normalizes_dtypes_for_upsampler_and_second_pass() {
2097        let device = Device::Cpu;
2098        let second_pass_latents =
2099            Tensor::arange(0f32, (LATENT_CHANNELS * 2 * 4 * 6) as f32, &device)
2100                .expect("tensor")
2101                .reshape((1, LATENT_CHANNELS, 2, 4, 6))
2102                .expect("reshape")
2103                .to_dtype(DType::BF16)
2104                .expect("bf16");
2105
2106        let packed = pack_initial_latents_for_second_pass(&second_pass_latents)
2107            .expect("second-pass repack should succeed");
2108        assert_eq!(packed.dtype(), DType::F32);
2109
2110        let unpacked = unpack_latents(&packed, 2, 4, 6, PATCH_SIZE, PATCH_SIZE_T)
2111            .expect("unpack should round-trip");
2112        assert_eq!(unpacked.dtype(), DType::F32);
2113        assert_eq!(
2114            unpacked.dims5().expect("dims"),
2115            (1, LATENT_CHANNELS, 2, 4, 6)
2116        );
2117        assert_eq!(
2118            unpacked
2119                .flatten_all()
2120                .expect("flatten")
2121                .to_vec1::<f32>()
2122                .expect("vec"),
2123            second_pass_latents
2124                .to_dtype(DType::F32)
2125                .expect("f32")
2126                .flatten_all()
2127                .expect("flatten")
2128                .to_vec1::<f32>()
2129                .expect("vec")
2130        );
2131
2132        let upsampler_input =
2133            cast_latents_for_multiscale_upsampler(&unpacked, DType::BF16).expect("cast");
2134        assert_eq!(upsampler_input.dtype(), DType::BF16);
2135        assert_eq!(
2136            upsampler_input.dims5().expect("dims"),
2137            (1, LATENT_CHANNELS, 2, 4, 6)
2138        );
2139    }
2140}