Skip to main content

mold_inference/ltx2/
pipeline.rs

1#![allow(clippy::type_complexity)]
2
3use anyhow::{bail, Context, Result};
4use candle_core::Device;
5use mold_core::{
6    GenerateRequest, GenerateResponse, Ltx2PipelineMode, ModelPaths, OutputFormat, VideoData,
7};
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::time::Instant;
11
12use super::assets;
13use super::backend::Ltx2Backend;
14use super::chain::{ChainStageRenderer, ChainTail, StageOutcome, StageProgressEvent};
15use super::conditioning::{self, StagedLatent};
16use super::execution;
17use super::lora;
18use super::media::{self, ProbeMetadata};
19use super::plan::{Ltx2GeneratePlan, PipelineKind};
20use super::preset;
21use super::runtime::{Ltx2RuntimeSession, NativeRenderedVideo};
22use super::text::gemma::GemmaAssets;
23use super::text::prompt_encoder::NativePromptEncoder;
24use crate::engine::{gpu_dtype, rand_seed, InferenceEngine, LoadStrategy};
25use crate::ltx_video::video_enc;
26use crate::progress::ProgressCallback;
27
28/// Soft-conditioning strength for the cross-stage identity anchor on chain
29/// continuations. The denoise mask at the anchor token becomes
30/// `1 - strength = 0.6`, so the denoiser blends ~60% generated / ~40%
31/// reference on every step — a gentle pull toward the source image rather
32/// than a hard pin (hard-pinning a single pixel frame past the motion tail
33/// would make continuations feel like cuts back to the starting shot).
34const CHAIN_SOFT_ANCHOR_STRENGTH: f32 = 0.4;
35
36pub struct Ltx2Engine {
37    model_name: String,
38    paths: ModelPaths,
39    loaded: bool,
40    native_runtime: Option<Ltx2RuntimeSession>,
41    on_progress: Option<ProgressCallback>,
42    pending_placement: Option<mold_core::types::DevicePlacement>,
43    /// GPU ordinal this engine is pinned to. Every `Device::new_cuda` and
44    /// `reclaim_gpu_memory` call must use this ordinal — hardcoding `0` here
45    /// is what took down the process on <gpu-host> when LTX-2 ran alongside
46    /// SD3.5 on a multi-GPU host.
47    gpu_ordinal: usize,
48    /// Optional preset hint used when the model name doesn't carry a
49    /// recognisable family substring (`ltx-2.3`, `ltx-2`). Populated by
50    /// `from_single_file` from the safetensors `__metadata__.model_version`
51    /// so catalog (`cv:*` / `hf:*`) IDs select the right preset without
52    /// requiring renames.
53    preset_hint: Option<String>,
54}
55
56impl Ltx2Engine {
57    fn debug_timings_enabled() -> bool {
58        std::env::var_os("MOLD_LTX2_DEBUG_TIMINGS").is_some()
59    }
60
61    fn log_timing(label: &str, start: Instant) {
62        if Self::debug_timings_enabled() {
63            eprintln!(
64                "[ltx2-timing] {label} {:.3}s",
65                start.elapsed().as_secs_f64()
66            );
67        }
68    }
69
70    pub fn new(
71        model_name: String,
72        paths: ModelPaths,
73        _load_strategy: LoadStrategy,
74        gpu_ordinal: usize,
75    ) -> Self {
76        Self {
77            model_name,
78            paths,
79            loaded: false,
80            native_runtime: None,
81            on_progress: None,
82            pending_placement: None,
83            gpu_ordinal,
84            preset_hint: None,
85        }
86    }
87
88    /// Construct an LTX-2 engine from a Civitai single-file safetensors
89    /// checkpoint (phase 5).
90    ///
91    /// LTX-2 combined checkpoints (the standard Lightricks format) bundle
92    /// both the video transformer (`transformer_blocks.*`) and the VAE
93    /// (`vae.*`) in a single file. The runtime always loads both from
94    /// `paths.transformer`, so on the checkpoint side this is structurally
95    /// identical to `new()`.
96    ///
97    /// Validates via `ltx_2::single_file::load` that the checkpoint has
98    /// detectable LTX-2 transformer keys and — critically — contains `vae.*`
99    /// keys. If the VAE is absent the call fails with an actionable error
100    /// message, since the LTX-2 runtime has no separate-VAE fallback.
101    ///
102    /// `paths` is the full resolved companion graph (text_encoder_files,
103    /// upscalers, distilled_lora, …). `transformer` and `vae` are
104    /// overridden to point at the single checkpoint; everything else is
105    /// preserved so the Gemma TE companion (Civitai catalog entries don't
106    /// bundle the encoder) and any other resolved companions reach the
107    /// runtime intact. Discarding `paths` here is what bit cv:* LTX-2
108    /// loads in phase 5 — the runtime then bailed with `LTX-2 requires
109    /// Gemma text encoder files to be available`.
110    pub fn from_single_file(
111        model_name: String,
112        checkpoint: PathBuf,
113        paths: ModelPaths,
114        load_strategy: LoadStrategy,
115        gpu_ordinal: usize,
116    ) -> anyhow::Result<Self> {
117        if !checkpoint.exists() {
118            anyhow::bail!(
119                "single-file LTX-2 checkpoint not found: {}",
120                checkpoint.display()
121            );
122        }
123
124        let bundle = super::single_file::load(&checkpoint).map_err(|e| {
125            anyhow::anyhow!(
126                "failed to parse single-file LTX-2 checkpoint {}: {e}",
127                checkpoint.display()
128            )
129        })?;
130
131        if !bundle.has_vae {
132            anyhow::bail!(
133                "LTX-2 checkpoint {} contains no VAE weights (`vae.*` keys). \
134                 This appears to be a transformer-only fine-tune. \
135                 The LTX-2 runtime requires a combined transformer+VAE checkpoint. \
136                 Phase-5 does not yet support separate-VAE loading for LTX-2.",
137                checkpoint.display()
138            );
139        }
140
141        // For LTX-2 the combined checkpoint path serves as both transformer
142        // and VAE source; the runtime reads VAE under `vb.pp("vae")` from
143        // the same file, so we leave `vae` empty. Every other companion
144        // path the catalog bridge populated (most importantly
145        // `text_encoder_files` for Gemma) flows through unchanged.
146        let paths = ModelPaths {
147            transformer: checkpoint,
148            transformer_shards: Vec::new(),
149            vae: PathBuf::default(),
150            ..paths
151        };
152
153        let mut engine = Self::new(model_name, paths, load_strategy, gpu_ordinal);
154        // Catalog (`cv:*`) IDs don't contain `ltx-2.3` / `ltx-2` substrings,
155        // so `preset_for_model` would bail. The bundled `model_version`
156        // from the safetensors `__metadata__` (e.g. `"2.3.0"`) is the
157        // authoritative source — record it as a hint that
158        // `materialize_request` consults via `preset_for_model_with_hint`.
159        engine.preset_hint = bundle.model_version;
160        Ok(engine)
161    }
162
163    #[cfg(test)]
164    fn with_runtime_session(
165        model_name: String,
166        paths: ModelPaths,
167        runtime: Ltx2RuntimeSession,
168    ) -> Self {
169        Self {
170            model_name,
171            paths,
172            loaded: false,
173            native_runtime: Some(runtime),
174            on_progress: None,
175            pending_placement: None,
176            gpu_ordinal: 0,
177            preset_hint: None,
178        }
179    }
180
181    fn emit(&self, stage: &str) {
182        if let Some(callback) = &self.on_progress {
183            callback(crate::ProgressEvent::StageStart {
184                name: stage.to_string(),
185            });
186        }
187    }
188
189    fn info(&self, message: &str) {
190        if let Some(callback) = &self.on_progress {
191            callback(crate::ProgressEvent::Info {
192                message: message.to_string(),
193            });
194        }
195    }
196
197    fn is_oom_error(err: &impl std::fmt::Display) -> bool {
198        let msg = err.to_string().to_ascii_lowercase();
199        msg.contains("out of memory")
200            || msg.contains("out_of_memory")
201            || msg.contains("cudaerrormemoryallocation")
202    }
203
204    fn unload_runtime_state(&mut self) -> Option<usize> {
205        self.loaded = false;
206        let should_reclaim = self
207            .native_runtime
208            .as_ref()
209            .is_some_and(Ltx2RuntimeSession::needs_cuda_reclaim_on_unload);
210        self.native_runtime = None;
211        should_reclaim.then_some(self.gpu_ordinal)
212    }
213
214    fn gemma_root(&self) -> Result<PathBuf> {
215        assets::gemma_root(&self.paths)
216    }
217
218    fn select_pipeline(&self, req: &GenerateRequest) -> Result<PipelineKind> {
219        if let Some(mode) = req.pipeline {
220            return Ok(match mode {
221                Ltx2PipelineMode::OneStage => PipelineKind::OneStage,
222                Ltx2PipelineMode::TwoStage => PipelineKind::TwoStage,
223                Ltx2PipelineMode::TwoStageHq => PipelineKind::TwoStageHq,
224                Ltx2PipelineMode::Distilled => PipelineKind::Distilled,
225                Ltx2PipelineMode::IcLora => PipelineKind::IcLora,
226                Ltx2PipelineMode::Keyframe => PipelineKind::Keyframe,
227                Ltx2PipelineMode::A2Vid => PipelineKind::A2Vid,
228                Ltx2PipelineMode::Retake => PipelineKind::Retake,
229            });
230        }
231
232        if req.retake_range.is_some() {
233            return Ok(PipelineKind::Retake);
234        }
235        if req.audio_file.is_some() || req.audio_file_path.is_some() {
236            return Ok(PipelineKind::A2Vid);
237        }
238        if req.keyframes.as_ref().is_some_and(|items| items.len() > 1) {
239            return Ok(PipelineKind::Keyframe);
240        }
241        if req.source_video.is_some() || req.source_video_path.is_some() {
242            return Ok(PipelineKind::IcLora);
243        }
244        if self.model_name.contains("distilled") {
245            // Distilled checkpoints also require a spatial upsampler (single
246            // upscale stage instead of two denoise passes); without one,
247            // fall back to a plain one-stage denoise that runs the
248            // transformer end-to-end on the requested resolution.
249            return Ok(if self.paths.spatial_upscaler.is_some() {
250                PipelineKind::Distilled
251            } else {
252                PipelineKind::OneStage
253            });
254        }
255        // TwoStage runs an upscale-and-refine pass after stage 1 and bails
256        // at runtime if `spatial_upscaler` isn't on disk. Single-file
257        // catalog (`cv:*`) checkpoints don't ship the upsampler asset, so
258        // fall back to OneStage when it's missing — the user gets a clean
259        // single-pass video instead of a 422 several stages in.
260        Ok(if self.paths.spatial_upscaler.is_some() {
261            PipelineKind::TwoStage
262        } else {
263            PipelineKind::OneStage
264        })
265    }
266
267    fn request_quantization(&self) -> Option<String> {
268        assets::request_quantization(&self.model_name)
269    }
270
271    #[allow(dead_code)]
272    fn camera_control_preset(name: &str) -> Option<lora::CameraControlPreset> {
273        lora::camera_control_preset(name)
274    }
275
276    pub(crate) fn materialize_request(
277        &self,
278        req: &GenerateRequest,
279        work_dir: &Path,
280        output_path: &Path,
281    ) -> Result<Ltx2GeneratePlan> {
282        let pipeline = self.select_pipeline(req)?;
283        let gemma_root = self.gemma_root()?;
284        let prompt_tokens = GemmaAssets::discover(&gemma_root)?
285            .encode_prompt_pair(&req.prompt, req.negative_prompt.as_deref())?;
286        let conditioning = conditioning::stage_conditioning(req, work_dir)?;
287        let loras = lora::resolve_loras(&self.model_name, req)?;
288        let preset =
289            preset::preset_for_model_with_hint(&self.model_name, self.preset_hint.as_deref())?;
290        let execution_graph =
291            execution::build_execution_graph(req, pipeline, &conditioning, &preset, loras.len());
292        let spatial_upsampler_path = assets::resolve_spatial_upscaler_path(
293            &self.model_name,
294            &self.paths,
295            req.spatial_upscale,
296        )?
297        .map(|path| path.to_string_lossy().to_string());
298        let temporal_upsampler_path =
299            assets::resolve_temporal_upscaler_path(&self.paths, req.temporal_upscale)?
300                .map(|path| path.to_string_lossy().to_string());
301
302        Ok(Ltx2GeneratePlan {
303            pipeline,
304            preset,
305            checkpoint_is_distilled: self.model_name.contains("distilled"),
306            execution_graph,
307            checkpoint_path: self.paths.transformer.to_string_lossy().to_string(),
308            distilled_checkpoint_path: pipeline
309                .requires_distilled_checkpoint()
310                .then(|| self.paths.transformer.to_string_lossy().to_string()),
311            distilled_lora_path: self
312                .paths
313                .distilled_lora
314                .as_ref()
315                .map(|path| path.to_string_lossy().to_string()),
316            spatial_upsampler_path,
317            temporal_upsampler_path,
318            gemma_root: gemma_root.to_string_lossy().to_string(),
319            output_path: output_path.to_string_lossy().to_string(),
320            prompt: req.prompt.clone(),
321            negative_prompt: req.negative_prompt.clone(),
322            prompt_tokens,
323            seed: req.seed.unwrap_or_else(rand_seed),
324            width: req.width,
325            height: req.height,
326            num_frames: req.frames.unwrap_or(97),
327            frame_rate: req.fps.unwrap_or(24),
328            num_inference_steps: req.steps,
329            guidance: req.guidance,
330            quantization: self.request_quantization(),
331            streaming_prefetch_count: Some(preset.streaming_prefetch_count),
332            conditioning,
333            loras,
334            retake_range: req.retake_range.clone(),
335            spatial_upscale: req.spatial_upscale,
336            temporal_upscale: req.temporal_upscale,
337        })
338    }
339
340    fn probe_video(&self, input_video: &Path) -> Result<ProbeMetadata> {
341        media::probe_video(input_video)
342    }
343
344    fn native_device_for_backend(&self, backend: Ltx2Backend) -> Result<Device> {
345        match backend {
346            Ltx2Backend::Cuda => {
347                self.info("CUDA detected, using native LTX-2 GPU path");
348                let device = Device::new_cuda(self.gpu_ordinal)?;
349                configure_native_ltx2_cuda_device(&device)?;
350                Ok(device)
351            }
352            Ltx2Backend::Cpu => {
353                let forced_cpu = std::env::var("MOLD_DEVICE")
354                    .map(|value| value.eq_ignore_ascii_case("cpu"))
355                    .unwrap_or(false);
356                if forced_cpu {
357                    self.info("CPU forced via MOLD_DEVICE=cpu for native LTX-2");
358                } else {
359                    self.info("No CUDA detected; using native LTX-2 CPU fallback");
360                }
361                Ok(Device::Cpu)
362            }
363            Ltx2Backend::Metal => unreachable!("unsupported Metal backend should have errored"),
364        }
365    }
366
367    fn load_runtime_session_on_device(
368        &self,
369        plan: &Ltx2GeneratePlan,
370        device: Device,
371    ) -> Result<Ltx2RuntimeSession> {
372        let load_start = Instant::now();
373        let prompt_device = resolve_prompt_encoder_device(&device, self.gpu_ordinal);
374        log_prompt_encoder_placement(&device, &prompt_device);
375        let dtype = gpu_dtype(&prompt_device);
376        self.emit("Loading native LTX-2 prompt encoder");
377        let prompt_encoder = NativePromptEncoder::load(
378            Path::new(&plan.gemma_root),
379            Path::new(&plan.checkpoint_path),
380            &plan.preset,
381            &prompt_device,
382            dtype,
383        )?;
384        Self::log_timing("pipeline.create_runtime.load_prompt_encoder", load_start);
385        // Cross-device case (transformer on CUDA, encoder on CPU/sibling GPU)
386        // can't use the deferred-cuda path because the prompt encoder doesn't
387        // need a CUDA stream sync at the transformer's ordinal. Fall back to
388        // the synchronous path; encode-time `move_prompt_encoding_to_device`
389        // handles the cross-device tensor copy.
390        let same_device = device.same_device(&prompt_device);
391        if prompt_device.is_cuda() && same_device {
392            Ok(Ltx2RuntimeSession::new_deferred_cuda(
393                prompt_encoder,
394                self.gpu_ordinal,
395            ))
396        } else {
397            Ok(Ltx2RuntimeSession::new(
398                device,
399                prompt_encoder,
400                self.gpu_ordinal,
401            ))
402        }
403    }
404
405    fn create_runtime_session(&self, plan: &Ltx2GeneratePlan) -> Result<Ltx2RuntimeSession> {
406        let backend = Ltx2Backend::detect();
407        backend.ensure_supported()?;
408
409        // Honor Tier 1 `text_encoders` override for the Gemma prompt encoder.
410        // Auto falls back to whatever `native_device_for_backend(backend)` picks
411        // (CUDA when available, else CPU). Explicit Cpu/Gpu skips that auto path.
412        let tier1 = self.pending_placement.as_ref().map(|p| p.text_encoders);
413        let device =
414            crate::device::resolve_device(tier1, || self.native_device_for_backend(backend))?;
415        if device.is_cuda() {
416            configure_native_ltx2_cuda_device(&device)?;
417        }
418        // Only auto CUDA placement should retry on OOM — if the user explicitly
419        // pinned the encoder to a GPU, surface the OOM rather than silently
420        // rewriting their request.
421        let override_is_auto = matches!(tier1, None | Some(mold_core::types::DeviceRef::Auto));
422        match self.load_runtime_session_on_device(plan, device) {
423            Ok(runtime) => Ok(runtime),
424            Err(err)
425                if matches!(backend, Ltx2Backend::Cuda)
426                    && override_is_auto
427                    && Self::is_oom_error(&err) =>
428            {
429                self.info(
430                    "Native LTX-2 prompt path ran out of CUDA memory; retrying with CPU fallback",
431                );
432                crate::device::reclaim_gpu_memory(self.gpu_ordinal);
433                self.load_runtime_session_on_device(plan, Device::Cpu)
434            }
435            Err(err) => Err(err),
436        }
437    }
438
439    fn encode_native_video(
440        &self,
441        req: &GenerateRequest,
442        plan: &Ltx2GeneratePlan,
443        rendered: &NativeRenderedVideo,
444        work_dir: &Path,
445    ) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>, Option<ProbeMetadata>)> {
446        if let Some(audio_track) = rendered.audio_track.as_ref() {
447            let wav_path = work_dir.join("native-audio.wav");
448            fs::write(
449                &wav_path,
450                media::encode_wav_f32_interleaved(
451                    &audio_track.interleaved_samples,
452                    audio_track.sample_rate,
453                    audio_track.channels,
454                )?,
455            )?;
456        }
457
458        let output_encode_start = Instant::now();
459        let output_bytes = match req.resolved_output_format() {
460            OutputFormat::Apng => {
461                let metadata = video_enc::VideoMetadata {
462                    prompt: req.prompt.clone(),
463                    model: self.model_name.clone(),
464                    seed: plan.seed,
465                    steps: req.steps,
466                    guidance: req.guidance,
467                    width: plan.width,
468                    height: plan.height,
469                    frames: plan.num_frames,
470                    fps: plan.frame_rate,
471                };
472                video_enc::encode_apng(&rendered.frames, plan.frame_rate, Some(&metadata))?
473            }
474            OutputFormat::Gif => video_enc::encode_gif(&rendered.frames, plan.frame_rate)?,
475            #[cfg(feature = "webp")]
476            OutputFormat::Webp => video_enc::encode_webp(&rendered.frames, plan.frame_rate)?,
477            #[cfg(not(feature = "webp"))]
478            OutputFormat::Webp => bail!("WebP output requires the 'webp' feature"),
479            OutputFormat::Mp4 => {
480                #[cfg(feature = "mp4")]
481                {
482                    let video_only = video_enc::encode_mp4(&rendered.frames, plan.frame_rate)?;
483                    let mp4_path = work_dir.join("native-video.mp4");
484                    fs::write(&mp4_path, &video_only)?;
485                    if let Some(audio_track) = rendered.audio_track.as_ref() {
486                        let muxed_path = work_dir.join("native-video-audio.mp4");
487                        media::attach_aac_track_from_f32_interleaved(
488                            &mp4_path,
489                            &muxed_path,
490                            &audio_track.interleaved_samples,
491                            audio_track.sample_rate,
492                            audio_track.channels,
493                        )?;
494                        fs::read(muxed_path)?
495                    } else {
496                        video_only
497                    }
498                }
499                #[cfg(not(feature = "mp4"))]
500                {
501                    bail!("MP4 output requires the 'mp4' feature")
502                }
503            }
504            other => bail!("{other:?} is not supported for LTX-2 video output"),
505        };
506        Self::log_timing("pipeline.encode_output", output_encode_start);
507
508        let thumbnail_start = Instant::now();
509        let thumbnail = video_enc::first_frame_png(&rendered.frames)?;
510        Self::log_timing("pipeline.encode_thumbnail", thumbnail_start);
511        let gif_preview_start = Instant::now();
512        let gif_preview = if req.gif_preview {
513            if req.resolved_output_format() == OutputFormat::Gif {
514                output_bytes.clone()
515            } else {
516                video_enc::encode_gif(&rendered.frames, plan.frame_rate)?
517            }
518        } else {
519            Vec::new()
520        };
521        Self::log_timing("pipeline.encode_gif_preview", gif_preview_start);
522
523        let probe_start = Instant::now();
524        let probe = if req.resolved_output_format() == OutputFormat::Mp4 {
525            let path = work_dir.join("probe.mp4");
526            fs::write(&path, &output_bytes)?;
527            Some(self.probe_video(&path)?)
528        } else {
529            None
530        };
531        Self::log_timing("pipeline.probe_output", probe_start);
532
533        Ok((output_bytes, thumbnail, gif_preview, probe))
534    }
535}
536
537#[cfg_attr(not(feature = "cuda"), allow(unused_variables))]
538fn configure_native_ltx2_cuda_device(device: &Device) -> Result<()> {
539    #[cfg(feature = "cuda")]
540    if device.is_cuda() {
541        let cuda = device.as_cuda_device()?;
542        if cuda.is_event_tracking() {
543            // Native LTX-2 runs on a single dedicated stream. Disabling CUDA event
544            // tracking avoids teardown crashes in cudarc/candle when large native
545            // video runs drop many tensors at the end of the request.
546            unsafe {
547                cuda.disable_event_tracking();
548            }
549        }
550    }
551    Ok(())
552}
553
554impl Ltx2Engine {
555    fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
556        if !self.loaded {
557            self.load()?;
558        }
559        let start = Instant::now();
560        self.emit("Preparing native LTX-2 request");
561
562        let work_dir = tempfile::tempdir().context("failed to create LTX-2 temp directory")?;
563        let native_output = work_dir.path().join("ltx2-native-output.mp4");
564        let materialize_start = Instant::now();
565        let plan = self.materialize_request(req, work_dir.path(), &native_output)?;
566        Self::log_timing("pipeline.materialize_request", materialize_start);
567        let planned_stage_count = plan.execution_graph.denoise_passes.len();
568        self.emit(&format!(
569            "Planned native LTX-2 graph: preset={}, denoise_stages={}, blocks={}, prompt_tokens={}/{}",
570            plan.preset.name,
571            planned_stage_count,
572            plan.execution_graph.blocks.len(),
573            plan.prompt_tokens.conditional.valid_len(),
574            plan.prompt_tokens.unconditional.valid_len()
575        ));
576        let create_runtime_start = Instant::now();
577        // Reuse a persisted runtime only if it can serve this plan. An LTX-2
578        // session consumes its prompt encoder on first `prepare()` (see
579        // runtime.rs `prepare()` — the take+drop frees VRAM for the
580        // transformer); a stale session left behind by a prior chain run
581        // survives intact for same-prompt continuations via the session-
582        // level encoding cache, but we must rebuild from scratch when the
583        // prompt changes so `prepare()` doesn't error on a consumed encoder.
584        let mut runtime = match self.native_runtime.take() {
585            Some(runtime) if runtime.can_reuse_for(&plan) => runtime,
586            _ => self.create_runtime_session(&plan)?,
587        };
588        Self::log_timing("pipeline.create_runtime", create_runtime_start);
589
590        self.emit("Encoding prompt and preparing native LTX-2 runtime state");
591        let prepare_start = Instant::now();
592        let prepared = runtime.prepare(&plan)?;
593        Self::log_timing("pipeline.prepare_runtime", prepare_start);
594        self.emit("Executing native LTX-2 runtime");
595        let render_start = Instant::now();
596        let rendered = runtime.render_native_video(&plan, &prepared, self.on_progress.as_ref())?;
597        Self::log_timing("pipeline.render_runtime", render_start);
598        let encode_start = Instant::now();
599        let (output_bytes, thumbnail_bytes, gif_preview, probe) =
600            self.encode_native_video(req, &plan, &rendered, work_dir.path())?;
601        Self::log_timing("pipeline.encode_native_video", encode_start);
602        let duration_ms =
603            Some((plan.num_frames as u64 * 1000).div_ceil(plan.frame_rate.max(1) as u64));
604        let width = probe
605            .as_ref()
606            .map(|probe| probe.width)
607            .unwrap_or(plan.width);
608        let height = probe
609            .as_ref()
610            .map(|probe| probe.height)
611            .unwrap_or(plan.height);
612        let frames = probe
613            .as_ref()
614            .and_then(|probe| probe.frames)
615            .unwrap_or(plan.num_frames);
616        let fps = probe
617            .as_ref()
618            .map(|probe| probe.fps)
619            .unwrap_or(plan.frame_rate);
620        let has_audio = if req.resolved_output_format() == OutputFormat::Mp4 {
621            probe
622                .as_ref()
623                .map(|probe| probe.has_audio)
624                .unwrap_or(rendered.has_audio)
625        } else {
626            false
627        };
628        let audio_sample_rate = if req.resolved_output_format() == OutputFormat::Mp4 {
629            probe
630                .as_ref()
631                .and_then(|probe| probe.audio_sample_rate)
632                .or(rendered.audio_sample_rate)
633        } else {
634            None
635        };
636        let audio_channels = if req.resolved_output_format() == OutputFormat::Mp4 {
637            probe
638                .as_ref()
639                .and_then(|probe| probe.audio_channels)
640                .or(rendered.audio_channels)
641        } else {
642            None
643        };
644
645        Ok(GenerateResponse {
646            images: vec![],
647            video: Some(VideoData {
648                data: output_bytes,
649                format: req.resolved_output_format(),
650                width,
651                height,
652                frames,
653                fps,
654                thumbnail: thumbnail_bytes,
655                gif_preview,
656                has_audio,
657                duration_ms: probe
658                    .as_ref()
659                    .and_then(|probe| probe.duration_ms)
660                    .or(duration_ms),
661                audio_sample_rate,
662                audio_channels,
663            }),
664            generation_time_ms: start.elapsed().as_millis() as u64,
665            model: self.model_name.clone(),
666            seed_used: plan.seed,
667            gpu: None,
668        })
669    }
670
671    /// Render a single chain stage, optionally conditioning on a carryover
672    /// tail from the prior stage.
673    ///
674    /// `motion_tail_pixel_frames` is the number of pixel frames to narrow
675    /// off the emitted latents for the *next* stage's carryover. `0`
676    /// returns an error (nonsensical — use the regular single-clip path
677    /// if no tail is wanted).
678    ///
679    /// Scope: distilled LTX-2 pipeline only. Other pipeline families
680    /// return an error up-front so the chain orchestrator fails fast.
681    pub(crate) fn render_chain_stage(
682        &mut self,
683        req: &GenerateRequest,
684        carry: Option<&ChainTail>,
685        motion_tail_pixel_frames: u32,
686    ) -> Result<StageOutcome> {
687        if motion_tail_pixel_frames == 0 {
688            bail!("render_chain_stage: motion_tail_pixel_frames must be > 0");
689        }
690        if !self.loaded {
691            self.load()?;
692        }
693        let start = Instant::now();
694        self.emit("Preparing native LTX-2 chain stage");
695
696        let pipeline = self.select_pipeline(req)?;
697        if !matches!(pipeline, PipelineKind::Distilled) {
698            bail!(
699                "render-chain v1 only supports the distilled LTX-2 pipeline, got {:?}",
700                pipeline,
701            );
702        }
703
704        let work_dir = tempfile::tempdir().context("failed to create LTX-2 temp directory")?;
705        let native_output = work_dir.path().join("ltx2-native-output.mp4");
706        let mut plan = self.materialize_request(req, work_dir.path(), &native_output)?;
707
708        // Inject carryover RGB frames as a StagedLatent at frame 0. The
709        // runtime VAE-encodes them fresh on the receiving side so every
710        // resulting latent slot has correct causal/continuation semantics
711        // in this clip's own time axis (see conditioning.rs StagedLatent
712        // docstring + runtime.rs maybe_load_stage_video_conditioning).
713        //
714        // When the chain request carries a starting image (i2v flow), the
715        // orchestrator passes it through on every stage. Stage 0 uses it
716        // as the frame-0 i2v replacement — great. On continuations the
717        // motion-tail pin owns frame 0, so we re-route any frame-0 staged
718        // image to a non-zero frame with reduced "soft anchor" strength:
719        // the image becomes a durable identity reference appended to the
720        // token sequence (via the `VideoTokenAppendCondition` path in
721        // `maybe_load_stage_video_conditioning`), giving the free-region
722        // denoise a persistent cross-attention anchor for subject / scene
723        // appearance without freezing any tokens. Without this anchor,
724        // identity drift compounds stage-over-stage because each clip's
725        // only long-range reference is its own drifted last-frame carry.
726        if let Some(tail) = carry {
727            if req.source_image.is_some() {
728                tracing::warn!(
729                    "smooth continuation received source_image; it will be repurposed as a soft \
730                     identity anchor. Use transition: cut|fade to seed the stage with a fresh i2v."
731                );
732            }
733            if tail.tail_rgb_frames.is_empty() {
734                bail!(
735                    "render_chain_stage: carry.tail_rgb_frames is empty; caller must provide at least one frame"
736                );
737            }
738
739            // Re-route any frame-0 staged image into the soft-anchor
740            // append slot. The anchor frame is the first pixel past the
741            // motion-tail pin, so the reference token's RoPE sits exactly
742            // where new content starts — cross-attention propagates
743            // identity into the free region most directly from there.
744            // `CHAIN_SOFT_ANCHOR_STRENGTH = 0.4` gives the denoise mask a
745            // value of `1 - 0.4 = 0.6` at the anchor token, so the
746            // denoiser blends ~60% generated / ~40% reference every step.
747            let anchor_frame = motion_tail_pixel_frames;
748            for image in plan.conditioning.images.iter_mut() {
749                if image.frame == 0 {
750                    image.frame = anchor_frame;
751                    image.strength = CHAIN_SOFT_ANCHOR_STRENGTH;
752                }
753            }
754
755            plan.conditioning.latents.push(StagedLatent {
756                tail_rgb_frames: tail.tail_rgb_frames.clone(),
757                frame: 0,
758                strength: 1.0,
759            });
760        }
761
762        // Reuse an existing runtime session if we have one AND it can
763        // serve this plan. Between stages of a same-prompt chain the
764        // session-level encoding cache handles the consumed-encoder
765        // invariant; if the prompt shifts (or a stale session leaked in
766        // from a prior run) we drop the runtime and rebuild so
767        // `prepare()` doesn't error on a missing encoder.
768        let mut runtime = match self.native_runtime.take() {
769            Some(runtime) if runtime.can_reuse_for(&plan) => runtime,
770            _ => self.create_runtime_session(&plan)?,
771        };
772
773        self.emit("Executing native LTX-2 chain stage runtime");
774        let prepared = match runtime.prepare(&plan) {
775            Ok(prepared) => prepared,
776            Err(err) => {
777                self.native_runtime = Some(runtime);
778                return Err(err);
779            }
780        };
781        let render_result =
782            runtime.render_native_video(&plan, &prepared, self.on_progress.as_ref());
783        self.native_runtime = Some(runtime);
784        let rendered = render_result?;
785
786        let frames = rendered.frames;
787        let audio = rendered.audio_track;
788        let tail_pixel_frames = motion_tail_pixel_frames as usize;
789        if frames.len() < tail_pixel_frames {
790            bail!(
791                "distilled render returned {} pixel frames but the chain caller requested a {}-frame tail; \
792                 this is a pipeline wiring bug",
793                frames.len(),
794                motion_tail_pixel_frames,
795            );
796        }
797        let tail_start = frames.len() - tail_pixel_frames;
798        let tail_rgb_frames = frames[tail_start..].to_vec();
799
800        let generation_time_ms = start.elapsed().as_millis() as u64;
801        Self::log_timing("pipeline.render_chain_stage", start);
802
803        Ok(StageOutcome {
804            frames,
805            tail: ChainTail {
806                frames: motion_tail_pixel_frames,
807                tail_rgb_frames,
808            },
809            audio,
810            generation_time_ms,
811        })
812    }
813}
814
815impl ChainStageRenderer for Ltx2Engine {
816    fn render_stage(
817        &mut self,
818        stage_req: &GenerateRequest,
819        carry: Option<&ChainTail>,
820        motion_tail_pixel_frames: u32,
821        _stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>,
822    ) -> Result<StageOutcome> {
823        // `_stage_progress` is intentionally unused in v1: per-stage
824        // denoise events flow through `self.on_progress` already. Phase 2's
825        // server route will install an on_progress callback that forwards
826        // those events onto the chain SSE stream with `stage_idx` tagged
827        // in. If the orchestrator later needs denoise-step events routed
828        // through its own channel, we can plumb `stage_progress` into a
829        // temporary ProgressCallback wrapper here.
830        self.render_chain_stage(stage_req, carry, motion_tail_pixel_frames)
831    }
832}
833
834impl InferenceEngine for Ltx2Engine {
835    fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
836        self.pending_placement = req.placement.clone();
837        let result = self.generate_inner(req);
838        self.pending_placement = None;
839        result
840    }
841
842    fn model_name(&self) -> &str {
843        &self.model_name
844    }
845
846    fn is_loaded(&self) -> bool {
847        self.loaded
848    }
849
850    fn load(&mut self) -> Result<()> {
851        self.emit("Preparing native LTX-2 runtime");
852        if !self.paths.transformer.exists() {
853            bail!(
854                "missing LTX-2 checkpoint: {}",
855                self.paths.transformer.display()
856            );
857        }
858        let gemma_root = self.gemma_root()?;
859        if !gemma_root.join("tokenizer.json").exists() {
860            bail!(
861                "missing Gemma tokenizer assets for LTX-2: {}",
862                gemma_root.display()
863            );
864        }
865        Ltx2Backend::detect().ensure_supported()?;
866        self.loaded = true;
867        Ok(())
868    }
869
870    fn unload(&mut self) {
871        if let Some(ordinal) = self.unload_runtime_state() {
872            crate::reclaim_gpu_memory(ordinal);
873        }
874    }
875
876    fn set_on_progress(&mut self, callback: ProgressCallback) {
877        self.on_progress = Some(callback);
878    }
879
880    fn clear_on_progress(&mut self) {
881        self.on_progress = None;
882    }
883
884    fn model_paths(&self) -> Option<&ModelPaths> {
885        Some(&self.paths)
886    }
887
888    fn as_chain_renderer(&mut self) -> Option<&mut dyn crate::ltx2::ChainStageRenderer> {
889        Some(self)
890    }
891}
892
893/// Resolve the device for the LTX-2 Gemma 3 12B prompt encoder given the
894/// transformer's chosen device.
895///
896/// - Transformer on CPU/Metal: keep the encoder on the same device. CPU
897///   means the user opted out of GPU end-to-end and Metal LTX-2 isn't
898///   supported anyway (caller will have errored before this).
899/// - Transformer on CUDA: defer to the auto-resolver in
900///   [`crate::device::resolve_ltx2_gemma_placement`], which honors the
901///   `MOLD_LTX2_GEMMA_DEVICE` override and walks active GPU → siblings →
902///   CPU on a free-VRAM probe.
903pub(crate) fn resolve_prompt_encoder_device(
904    transformer_device: &Device,
905    gpu_ordinal: usize,
906) -> Device {
907    if !transformer_device.is_cuda() {
908        return transformer_device.clone();
909    }
910    crate::device::resolve_ltx2_gemma_placement(gpu_ordinal).into_device()
911}
912
913fn log_prompt_encoder_placement(transformer_device: &Device, prompt_device: &Device) {
914    if transformer_device.same_device(prompt_device) {
915        return;
916    }
917    let label = if prompt_device.is_cpu() {
918        "CPU".to_string()
919    } else if prompt_device.is_cuda() {
920        "GPU (sibling ordinal)".to_string()
921    } else {
922        "non-CUDA device".to_string()
923    };
924    tracing::info!(
925        prompt_encoder_device = %label,
926        "LTX-2 Gemma encoder placed off the transformer device — \
927         encode-time tensor copy will move conditioning back to the transformer GPU"
928    );
929}
930
931#[cfg(test)]
932mod tests {
933    use super::*;
934    use std::collections::HashMap;
935    use std::fs;
936    use std::path::Path;
937    use std::path::PathBuf;
938
939    use candle_core::{DType, Device, Tensor};
940    use candle_nn::VarBuilder;
941
942    use crate::ltx2::text::connectors::PaddingSide;
943    use crate::ltx2::text::encoder::{GemmaConfig, GemmaHiddenStateEncoder};
944    use crate::ltx2::text::prompt_encoder::{
945        build_embeddings_processor, ConnectorSpec, NativePromptEncoder,
946    };
947
948    fn dummy_paths() -> ModelPaths {
949        ModelPaths {
950            transformer: PathBuf::from("/tmp/ltx2.safetensors"),
951            transformer_shards: vec![],
952            vae: PathBuf::from("/tmp/unused"),
953            spatial_upscaler: Some(PathBuf::from("/tmp/spatial.safetensors")),
954            temporal_upscaler: Some(PathBuf::from("/tmp/temporal.safetensors")),
955            distilled_lora: Some(PathBuf::from("/tmp/distilled-lora.safetensors")),
956            t5_encoder: None,
957            clip_encoder: None,
958            t5_tokenizer: None,
959            clip_tokenizer: None,
960            clip_encoder_2: None,
961            clip_tokenizer_2: None,
962            text_encoder_files: vec![PathBuf::from("/tmp/gemma/tokenizer.json")],
963            text_tokenizer: None,
964            decoder: None,
965        }
966    }
967
968    fn dummy_paths_with_gemma_root(root: &std::path::Path) -> ModelPaths {
969        let mut paths = dummy_paths();
970        paths.text_encoder_files = vec![root.join("tokenizer.json")];
971        paths
972    }
973
974    fn dummy_paths_in(root: &Path, gemma_root: &Path) -> ModelPaths {
975        ModelPaths {
976            transformer: root.join("ltx2.safetensors"),
977            transformer_shards: vec![],
978            vae: root.join("unused"),
979            spatial_upscaler: Some(root.join("spatial.safetensors")),
980            temporal_upscaler: Some(root.join("temporal.safetensors")),
981            distilled_lora: Some(root.join("distilled-lora.safetensors")),
982            t5_encoder: None,
983            clip_encoder: None,
984            t5_tokenizer: None,
985            clip_tokenizer: None,
986            clip_encoder_2: None,
987            clip_tokenizer_2: None,
988            text_encoder_files: vec![gemma_root.join("tokenizer.json")],
989            text_tokenizer: None,
990            decoder: None,
991        }
992    }
993
994    fn write_test_gemma_assets(root: &std::path::Path) {
995        fs::write(
996            root.join("tokenizer.json"),
997            r#"{
998  "version": "1.0",
999  "truncation": null,
1000  "padding": null,
1001  "added_tokens": [],
1002  "normalizer": null,
1003  "pre_tokenizer": {
1004    "type": "WhitespaceSplit"
1005  },
1006  "post_processor": null,
1007  "decoder": null,
1008  "model": {
1009    "type": "WordLevel",
1010    "vocab": {
1011      "<eos>": 7,
1012      "test": 11
1013    },
1014    "unk_token": "<eos>"
1015  }
1016}"#,
1017        )
1018        .unwrap();
1019        fs::write(
1020            root.join("special_tokens_map.json"),
1021            r#"{"eos_token":"<eos>"}"#,
1022        )
1023        .unwrap();
1024    }
1025
1026    fn tiny_gemma_config() -> GemmaConfig {
1027        GemmaConfig {
1028            attention_bias: false,
1029            head_dim: 4,
1030            hidden_activation: candle_nn::Activation::GeluPytorchTanh,
1031            hidden_size: 8,
1032            intermediate_size: 16,
1033            num_attention_heads: 2,
1034            num_hidden_layers: 2,
1035            num_key_value_heads: 1,
1036            rms_norm_eps: 1e-6,
1037            rope_theta: 10_000.0,
1038            rope_local_base_freq: 10_000.0,
1039            vocab_size: 16,
1040            final_logit_softcapping: None,
1041            attn_logit_softcapping: None,
1042            query_pre_attn_scalar: 4,
1043            sliding_window: 4,
1044            sliding_window_pattern: 2,
1045            max_position_embeddings: 1024,
1046        }
1047    }
1048
1049    fn zero_gemma_var_builder(cfg: &GemmaConfig) -> VarBuilder<'static> {
1050        let mut tensors = HashMap::new();
1051        tensors.insert(
1052            "model.embed_tokens.weight".to_string(),
1053            Tensor::zeros((cfg.vocab_size, cfg.hidden_size), DType::F32, &Device::Cpu).unwrap(),
1054        );
1055        for layer in 0..cfg.num_hidden_layers {
1056            for name in [
1057                "self_attn.q_proj",
1058                "self_attn.k_proj",
1059                "self_attn.v_proj",
1060                "self_attn.o_proj",
1061                "mlp.gate_proj",
1062                "mlp.up_proj",
1063                "mlp.down_proj",
1064            ] {
1065                let (rows, cols) = match name {
1066                    "self_attn.q_proj" => (cfg.num_attention_heads * cfg.head_dim, cfg.hidden_size),
1067                    "self_attn.k_proj" | "self_attn.v_proj" => {
1068                        (cfg.num_key_value_heads * cfg.head_dim, cfg.hidden_size)
1069                    }
1070                    "self_attn.o_proj" => (cfg.hidden_size, cfg.num_attention_heads * cfg.head_dim),
1071                    "mlp.gate_proj" | "mlp.up_proj" => (cfg.intermediate_size, cfg.hidden_size),
1072                    "mlp.down_proj" => (cfg.hidden_size, cfg.intermediate_size),
1073                    _ => unreachable!(),
1074                };
1075                tensors.insert(
1076                    format!("model.layers.{layer}.{name}.weight"),
1077                    Tensor::zeros((rows, cols), DType::F32, &Device::Cpu).unwrap(),
1078                );
1079            }
1080            for name in [
1081                "self_attn.q_norm",
1082                "self_attn.k_norm",
1083                "input_layernorm",
1084                "pre_feedforward_layernorm",
1085                "post_feedforward_layernorm",
1086                "post_attention_layernorm",
1087            ] {
1088                let dim = if name.contains("q_norm") || name.contains("k_norm") {
1089                    cfg.head_dim
1090                } else {
1091                    cfg.hidden_size
1092                };
1093                tensors.insert(
1094                    format!("model.layers.{layer}.{name}.weight"),
1095                    Tensor::zeros(dim, DType::F32, &Device::Cpu).unwrap(),
1096                );
1097            }
1098        }
1099        tensors.insert(
1100            "model.norm.weight".to_string(),
1101            Tensor::zeros(cfg.hidden_size, DType::F32, &Device::Cpu).unwrap(),
1102        );
1103        VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu)
1104    }
1105
1106    fn zero_connector_source_var_builder() -> VarBuilder<'static> {
1107        let mut tensors = HashMap::new();
1108        tensors.insert(
1109            "text_embedding_projection.video_aggregate_embed.weight".to_string(),
1110            Tensor::zeros((8, 24), DType::F32, &Device::Cpu).unwrap(),
1111        );
1112        tensors.insert(
1113            "text_embedding_projection.video_aggregate_embed.bias".to_string(),
1114            Tensor::zeros(8, DType::F32, &Device::Cpu).unwrap(),
1115        );
1116        tensors.insert(
1117            "text_embedding_projection.audio_aggregate_embed.weight".to_string(),
1118            Tensor::zeros((4, 24), DType::F32, &Device::Cpu).unwrap(),
1119        );
1120        tensors.insert(
1121            "text_embedding_projection.audio_aggregate_embed.bias".to_string(),
1122            Tensor::zeros(4, DType::F32, &Device::Cpu).unwrap(),
1123        );
1124        for (prefix, dim) in [
1125            ("model.diffusion_model.video_embeddings_connector", 8usize),
1126            ("model.diffusion_model.audio_embeddings_connector", 4usize),
1127        ] {
1128            for linear_name in ["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0"] {
1129                tensors.insert(
1130                    format!("{prefix}.transformer_1d_blocks.0.{linear_name}.weight"),
1131                    Tensor::zeros((dim, dim), DType::F32, &Device::Cpu).unwrap(),
1132                );
1133                tensors.insert(
1134                    format!("{prefix}.transformer_1d_blocks.0.{linear_name}.bias"),
1135                    Tensor::zeros(dim, DType::F32, &Device::Cpu).unwrap(),
1136                );
1137            }
1138            for norm_name in ["attn1.q_norm", "attn1.k_norm"] {
1139                tensors.insert(
1140                    format!("{prefix}.transformer_1d_blocks.0.{norm_name}.weight"),
1141                    Tensor::ones(dim, DType::F32, &Device::Cpu).unwrap(),
1142                );
1143            }
1144            tensors.insert(
1145                format!("{prefix}.transformer_1d_blocks.0.ff.net.0.proj.weight"),
1146                Tensor::zeros((dim * 4, dim), DType::F32, &Device::Cpu).unwrap(),
1147            );
1148            tensors.insert(
1149                format!("{prefix}.transformer_1d_blocks.0.ff.net.0.proj.bias"),
1150                Tensor::zeros(dim * 4, DType::F32, &Device::Cpu).unwrap(),
1151            );
1152            tensors.insert(
1153                format!("{prefix}.transformer_1d_blocks.0.ff.net.2.weight"),
1154                Tensor::zeros((dim, dim * 4), DType::F32, &Device::Cpu).unwrap(),
1155            );
1156            tensors.insert(
1157                format!("{prefix}.transformer_1d_blocks.0.ff.net.2.bias"),
1158                Tensor::zeros(dim, DType::F32, &Device::Cpu).unwrap(),
1159            );
1160            tensors.insert(
1161                format!("{prefix}.learnable_registers"),
1162                Tensor::zeros((128, dim), DType::F32, &Device::Cpu).unwrap(),
1163            );
1164        }
1165        VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu)
1166    }
1167
1168    fn runtime_prompt_encoder() -> NativePromptEncoder {
1169        let cfg = tiny_gemma_config();
1170        let gemma = GemmaHiddenStateEncoder::new(&cfg, zero_gemma_var_builder(&cfg)).unwrap();
1171        NativePromptEncoder::new(
1172            gemma,
1173            build_embeddings_processor(
1174                zero_connector_source_var_builder(),
1175                crate::ltx2::preset::GemmaFeatureExtractorKind::V2DualAv,
1176                cfg.hidden_size,
1177                cfg.num_hidden_layers,
1178                8,
1179                Some(4),
1180                ConnectorSpec {
1181                    prefix: "model.diffusion_model.video_embeddings_connector.",
1182                    num_attention_heads: 2,
1183                    attention_head_dim: 4,
1184                    num_layers: 1,
1185                    apply_gated_attention: false,
1186                    positional_embedding_theta: 10_000.0,
1187                    positional_embedding_max_pos: &[32],
1188                    rope_type: crate::ltx2::model::LtxRopeType::Split,
1189                    double_precision_rope: true,
1190                    num_learnable_registers: Some(128),
1191                },
1192                Some(ConnectorSpec {
1193                    prefix: "model.diffusion_model.audio_embeddings_connector.",
1194                    num_attention_heads: 1,
1195                    attention_head_dim: 4,
1196                    num_layers: 1,
1197                    apply_gated_attention: false,
1198                    positional_embedding_theta: 10_000.0,
1199                    positional_embedding_max_pos: &[32],
1200                    rope_type: crate::ltx2::model::LtxRopeType::Split,
1201                    double_precision_rope: true,
1202                    num_learnable_registers: Some(128),
1203                }),
1204            )
1205            .unwrap(),
1206            PaddingSide::Left,
1207        )
1208    }
1209
1210    fn runtime_session() -> Ltx2RuntimeSession {
1211        let prompt_encoder = runtime_prompt_encoder();
1212        Ltx2RuntimeSession::new(Device::Cpu, prompt_encoder, 0)
1213    }
1214
1215    fn request(output_format: OutputFormat, enable_audio: Option<bool>) -> GenerateRequest {
1216        GenerateRequest {
1217            prompt: "test".to_string(),
1218            negative_prompt: None,
1219            model: "ltx-2-19b-distilled:fp8".to_string(),
1220            width: 960,
1221            height: 576,
1222            steps: 8,
1223            guidance: 3.0,
1224            seed: Some(42),
1225            batch_size: 1,
1226            output_format: Some(output_format),
1227            embed_metadata: None,
1228            scheduler: None,
1229            cfg_plus: None,
1230            source_image: None,
1231            edit_images: None,
1232            strength: 0.75,
1233            mask_image: None,
1234            control_image: None,
1235            control_model: None,
1236            control_scale: 1.0,
1237            expand: None,
1238            original_prompt: None,
1239            lora: None,
1240            frames: Some(17),
1241            fps: Some(12),
1242            upscale_model: None,
1243            gif_preview: true,
1244            enable_audio,
1245            audio_file: None,
1246            audio_file_path: None,
1247            source_video: None,
1248            source_video_path: None,
1249            keyframes: None,
1250            pipeline: None,
1251            loras: None,
1252            retake_range: None,
1253            spatial_upscale: None,
1254            temporal_upscale: None,
1255            placement: None,
1256        }
1257    }
1258
1259    #[test]
1260    fn pipeline_falls_back_to_one_stage_when_spatial_upscaler_missing() {
1261        // Catalog (`cv:*`) LTX-2 single-file checkpoints don't ship the
1262        // spatial upsampler asset (it's a separate Lightricks file the
1263        // companion list doesn't pull). The runtime's TwoStage / Distilled
1264        // paths require it and would bail mid-generation; the engine
1265        // should pick OneStage instead so the user gets a single-pass
1266        // video instead of a 500 several stages in.
1267        let gemma = tempfile::tempdir().unwrap();
1268        let mut paths = dummy_paths_with_gemma_root(gemma.path());
1269        paths.spatial_upscaler = None;
1270
1271        let engine_22b = Ltx2Engine::new(
1272            "cv:2752735".to_string(),
1273            paths.clone(),
1274            LoadStrategy::Sequential,
1275            0,
1276        );
1277        let req = bare_t2v_req("cv:2752735");
1278        assert_eq!(
1279            engine_22b.select_pipeline(&req).unwrap(),
1280            PipelineKind::OneStage,
1281            "no spatial upsampler → OneStage (catalog cv:* default)"
1282        );
1283
1284        let engine_distilled = Ltx2Engine::new(
1285            "ltx-2-19b-distilled:fp8".to_string(),
1286            paths,
1287            LoadStrategy::Sequential,
1288            0,
1289        );
1290        let req_distilled = bare_t2v_req("ltx-2-19b-distilled:fp8");
1291        assert_eq!(
1292            engine_distilled.select_pipeline(&req_distilled).unwrap(),
1293            PipelineKind::OneStage,
1294            "distilled name + missing spatial upsampler → OneStage fallback"
1295        );
1296    }
1297
1298    fn bare_t2v_req(model: &str) -> GenerateRequest {
1299        GenerateRequest {
1300            prompt: "test".to_string(),
1301            negative_prompt: None,
1302            model: model.to_string(),
1303            width: 768,
1304            height: 512,
1305            steps: 4,
1306            guidance: 3.5,
1307            seed: Some(42),
1308            batch_size: 1,
1309            output_format: Some(OutputFormat::Mp4),
1310            embed_metadata: None,
1311            scheduler: None,
1312            cfg_plus: None,
1313            source_image: None,
1314            edit_images: None,
1315            strength: 0.75,
1316            mask_image: None,
1317            control_image: None,
1318            control_model: None,
1319            control_scale: 1.0,
1320            expand: None,
1321            original_prompt: None,
1322            lora: None,
1323            frames: Some(25),
1324            fps: Some(24),
1325            upscale_model: None,
1326            gif_preview: false,
1327            enable_audio: None,
1328            audio_file: None,
1329            audio_file_path: None,
1330            source_video: None,
1331            source_video_path: None,
1332            keyframes: None,
1333            pipeline: None,
1334            loras: None,
1335            retake_range: None,
1336            spatial_upscale: None,
1337            temporal_upscale: None,
1338            placement: None,
1339        }
1340    }
1341
1342    #[test]
1343    fn pipeline_defaults_to_distilled_for_distilled_models() {
1344        let engine = Ltx2Engine::new(
1345            "ltx-2.3-22b-distilled:fp8".to_string(),
1346            dummy_paths(),
1347            LoadStrategy::Sequential,
1348            0,
1349        );
1350        let req = GenerateRequest {
1351            prompt: "test".to_string(),
1352            negative_prompt: None,
1353            model: "ltx-2.3-22b-distilled:fp8".to_string(),
1354            width: 1216,
1355            height: 704,
1356            steps: 8,
1357            guidance: 1.0,
1358            seed: Some(1),
1359            batch_size: 1,
1360            output_format: Some(OutputFormat::Mp4),
1361            embed_metadata: None,
1362            scheduler: None,
1363            cfg_plus: None,
1364            source_image: None,
1365            edit_images: None,
1366            strength: 0.75,
1367            mask_image: None,
1368            control_image: None,
1369            control_model: None,
1370            control_scale: 1.0,
1371            expand: None,
1372            original_prompt: None,
1373            lora: None,
1374            frames: Some(97),
1375            fps: Some(24),
1376            upscale_model: None,
1377            gif_preview: false,
1378            enable_audio: Some(true),
1379            audio_file: None,
1380            audio_file_path: None,
1381            source_video: None,
1382            source_video_path: None,
1383            keyframes: None,
1384            pipeline: None,
1385            loras: None,
1386            retake_range: None,
1387            spatial_upscale: None,
1388            temporal_upscale: None,
1389            placement: None,
1390        };
1391        assert_eq!(
1392            engine.select_pipeline(&req).unwrap(),
1393            PipelineKind::Distilled
1394        );
1395    }
1396
1397    #[test]
1398    fn from_single_file_preserves_companion_paths() {
1399        // Regression: phase-5 wired `cv:*` LTX-2 catalog entries into
1400        // `Ltx2Engine::from_single_file` but the constructor used to build
1401        // a fresh `ModelPaths` with `text_encoder_files: Vec::new()`,
1402        // discarding the Gemma TE companion the catalog bridge had
1403        // resolved. The runtime then bailed at `gemma_root` with
1404        // `LTX-2 requires Gemma text encoder files to be available`.
1405        // Pin the fix: companion fields (text_encoder_files,
1406        // spatial_upscaler, temporal_upscaler, distilled_lora) survive
1407        // the rebuild; only `transformer` and `vae` are overridden.
1408        let temp = tempfile::tempdir().unwrap();
1409        let checkpoint = temp.path().join("ltx2_combined.safetensors");
1410        // Build a minimal valid safetensors header with one transformer
1411        // key + one vae key so `single_file::load` returns has_vae=true.
1412        write_minimal_combined_ltx2_checkpoint(&checkpoint);
1413
1414        let mut input_paths = dummy_paths_with_gemma_root(&temp.path().join("gemma"));
1415        input_paths.transformer = PathBuf::from("/wrong/path-should-be-overridden");
1416        input_paths.vae = PathBuf::from("/wrong/vae-should-be-cleared");
1417        let gemma_files_in = input_paths.text_encoder_files.clone();
1418        let spatial_in = input_paths.spatial_upscaler.clone();
1419        let temporal_in = input_paths.temporal_upscaler.clone();
1420        let distilled_in = input_paths.distilled_lora.clone();
1421
1422        let engine = Ltx2Engine::from_single_file(
1423            "cv:2752735".to_string(),
1424            checkpoint.clone(),
1425            input_paths,
1426            LoadStrategy::Sequential,
1427            0,
1428        )
1429        .expect("from_single_file should succeed on a valid combined checkpoint");
1430
1431        assert_eq!(
1432            engine.paths.transformer, checkpoint,
1433            "transformer must point at the single-file checkpoint"
1434        );
1435        assert_eq!(
1436            engine.paths.vae,
1437            PathBuf::default(),
1438            "vae must be cleared — runtime reads it from the same checkpoint via vb.pp(\"vae\")"
1439        );
1440        assert_eq!(
1441            engine.paths.text_encoder_files, gemma_files_in,
1442            "text_encoder_files (Gemma TE) must survive the rebuild — \
1443             dropping it is the cv:* loading regression"
1444        );
1445        assert_eq!(engine.paths.spatial_upscaler, spatial_in);
1446        assert_eq!(engine.paths.temporal_upscaler, temporal_in);
1447        assert_eq!(engine.paths.distilled_lora, distilled_in);
1448    }
1449
1450    fn write_minimal_combined_ltx2_checkpoint(path: &std::path::Path) {
1451        use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1452        use std::collections::HashMap;
1453        let zero = 0.0f32.to_le_bytes().to_vec();
1454        let mut tensors: HashMap<String, TensorView<'_>> = HashMap::new();
1455        tensors.insert(
1456            "transformer_blocks.0.attn1.to_q.weight".to_string(),
1457            TensorView::new(SafeDtype::F32, vec![1], &zero).unwrap(),
1458        );
1459        tensors.insert(
1460            "vae.encoder.conv_in.weight".to_string(),
1461            TensorView::new(SafeDtype::F32, vec![1], &zero).unwrap(),
1462        );
1463        serialize_to_file(&tensors, &None, path).unwrap();
1464    }
1465
1466    #[test]
1467    fn camera_control_preset_aliases_are_supported() {
1468        let preset = Ltx2Engine::camera_control_preset("dolly-in").unwrap();
1469        assert_eq!(
1470            preset.filename,
1471            "ltx-2-19b-lora-camera-control-dolly-in.safetensors"
1472        );
1473        assert!(Ltx2Engine::camera_control_preset("unknown").is_none());
1474    }
1475
1476    #[test]
1477    fn fp8_models_use_fp8_cast_quantization() {
1478        let engine = Ltx2Engine::new(
1479            "ltx-2-19b-distilled:fp8".to_string(),
1480            dummy_paths(),
1481            LoadStrategy::Sequential,
1482            0,
1483        );
1484        assert_eq!(engine.request_quantization(), Some("fp8-cast".to_string()));
1485    }
1486
1487    #[test]
1488    fn oom_error_detection_matches_cuda_allocator_strings() {
1489        assert!(Ltx2Engine::is_oom_error(&"CUDA out of memory"));
1490        assert!(Ltx2Engine::is_oom_error(&"cudaErrorMemoryAllocation"));
1491        assert!(!Ltx2Engine::is_oom_error(&"some other error"));
1492    }
1493
1494    #[test]
1495    fn materialized_request_uses_streaming_defaults_for_fp8_smoke_path() {
1496        let gemma_dir = tempfile::tempdir().unwrap();
1497        write_test_gemma_assets(gemma_dir.path());
1498        let engine = Ltx2Engine::new(
1499            "ltx-2-19b-distilled:fp8".to_string(),
1500            dummy_paths_with_gemma_root(gemma_dir.path()),
1501            LoadStrategy::Sequential,
1502            0,
1503        );
1504        let req = GenerateRequest {
1505            prompt: "test".to_string(),
1506            negative_prompt: None,
1507            model: "ltx-2-19b-distilled:fp8".to_string(),
1508            width: 960,
1509            height: 576,
1510            steps: 8,
1511            guidance: 3.0,
1512            seed: Some(42),
1513            batch_size: 1,
1514            output_format: Some(OutputFormat::Mp4),
1515            embed_metadata: None,
1516            scheduler: None,
1517            cfg_plus: None,
1518            source_image: None,
1519            edit_images: None,
1520            strength: 0.75,
1521            mask_image: None,
1522            control_image: None,
1523            control_model: None,
1524            control_scale: 1.0,
1525            expand: None,
1526            original_prompt: None,
1527            lora: None,
1528            frames: Some(17),
1529            fps: Some(12),
1530            upscale_model: None,
1531            gif_preview: false,
1532            enable_audio: Some(true),
1533            audio_file: None,
1534            audio_file_path: None,
1535            source_video: None,
1536            source_video_path: None,
1537            keyframes: None,
1538            pipeline: None,
1539            loras: None,
1540            retake_range: None,
1541            spatial_upscale: None,
1542            temporal_upscale: None,
1543            placement: None,
1544        };
1545        let temp_dir = tempfile::tempdir().unwrap();
1546        let bridge = engine
1547            .materialize_request(&req, temp_dir.path(), &temp_dir.path().join("out.mp4"))
1548            .unwrap();
1549        assert_eq!(bridge.quantization.as_deref(), Some("fp8-cast"));
1550        assert_eq!(bridge.streaming_prefetch_count, Some(2));
1551        assert_eq!(bridge.width, 960);
1552        assert_eq!(bridge.height, 576);
1553        assert_eq!(bridge.num_frames, 17);
1554        assert_eq!(bridge.frame_rate, 12);
1555        assert_eq!(bridge.prompt_tokens.conditional.len(), 256);
1556        assert_eq!(bridge.prompt_tokens.conditional.valid_len(), 1);
1557        assert_eq!(bridge.prompt_tokens.pad_token_id, 7);
1558    }
1559
1560    #[test]
1561    fn load_uses_native_asset_checks_without_upstream_checkout() {
1562        let temp_dir = tempfile::tempdir().unwrap();
1563        let gemma_dir = temp_dir.path().join("gemma");
1564        fs::create_dir_all(&gemma_dir).unwrap();
1565        write_test_gemma_assets(&gemma_dir);
1566        let paths = dummy_paths_in(temp_dir.path(), &gemma_dir);
1567        fs::write(&paths.transformer, []).unwrap();
1568
1569        let mut engine = Ltx2Engine::new(
1570            "ltx-2-19b-distilled:fp8".to_string(),
1571            paths,
1572            LoadStrategy::Sequential,
1573            0,
1574        );
1575
1576        engine.load().unwrap();
1577        assert!(engine.is_loaded());
1578    }
1579
1580    #[test]
1581    fn ltx2_unload_drops_runtime_and_requests_cuda_reclaim() {
1582        let mut engine = Ltx2Engine::with_runtime_session(
1583            "ltx-2-19b-distilled:fp8".to_string(),
1584            dummy_paths(),
1585            Ltx2RuntimeSession::new_deferred_cuda(runtime_prompt_encoder(), 3),
1586        );
1587        engine.loaded = true;
1588        engine.gpu_ordinal = 3;
1589
1590        assert_eq!(engine.unload_runtime_state(), Some(3));
1591        assert!(!engine.loaded);
1592        assert!(engine.native_runtime.is_none());
1593    }
1594
1595    #[test]
1596    fn ltx2_unload_cpu_runtime_skips_cuda_reclaim() {
1597        let mut engine = Ltx2Engine::with_runtime_session(
1598            "ltx-2-19b-distilled:fp8".to_string(),
1599            dummy_paths(),
1600            runtime_session(),
1601        );
1602        engine.loaded = true;
1603
1604        assert_eq!(engine.unload_runtime_state(), None);
1605        assert!(!engine.loaded);
1606        assert!(engine.native_runtime.is_none());
1607    }
1608
1609    #[test]
1610    fn generate_runs_native_runtime_without_bridge_process() {
1611        let temp_dir = tempfile::tempdir().unwrap();
1612        let gemma_dir = temp_dir.path().join("gemma");
1613        fs::create_dir_all(&gemma_dir).unwrap();
1614        write_test_gemma_assets(&gemma_dir);
1615        let paths = dummy_paths_in(temp_dir.path(), &gemma_dir);
1616        fs::write(&paths.transformer, []).unwrap();
1617
1618        let mut engine = Ltx2Engine::with_runtime_session(
1619            "ltx-2-19b-distilled:fp8".to_string(),
1620            paths,
1621            runtime_session(),
1622        );
1623        let response = engine
1624            .generate(&request(OutputFormat::Gif, Some(false)))
1625            .unwrap();
1626        let video = response.video.unwrap();
1627
1628        assert_eq!(&video.data[..6], b"GIF89a");
1629        assert_eq!(&video.thumbnail[..8], b"\x89PNG\r\n\x1a\n");
1630        assert_eq!(&video.gif_preview[..6], b"GIF89a");
1631        assert_eq!(video.width, 960);
1632        assert_eq!(video.height, 576);
1633        assert_eq!(video.frames, 17);
1634        assert_eq!(video.fps, 12);
1635        assert!(!video.has_audio);
1636        assert!(engine.native_runtime.is_none());
1637    }
1638
1639    #[test]
1640    fn render_chain_stage_rejects_non_distilled_pipeline() {
1641        // A model name without "distilled" in it selects `PipelineKind::TwoStage`
1642        // via `select_pipeline`, which must be rejected up-front by the chain
1643        // entry point before any runtime work happens.
1644        let mut engine = Ltx2Engine::with_runtime_session(
1645            "ltx-2-19b:fp8".to_string(),
1646            dummy_paths(),
1647            runtime_session(),
1648        );
1649        engine.loaded = true;
1650        let req = request(OutputFormat::Mp4, Some(false));
1651        let err = engine
1652            .render_chain_stage(&req, None, 4)
1653            .expect_err("must fail on non-distilled pipeline");
1654        let msg = format!("{err}");
1655        assert!(
1656            msg.contains("distilled"),
1657            "error must name the pipeline constraint, got: {msg}",
1658        );
1659    }
1660
1661    #[test]
1662    fn render_chain_stage_rejects_zero_motion_tail() {
1663        // Zero-frame motion tail is nonsensical — it would narrow nothing off
1664        // for the next stage. Fast-fail before any allocation.
1665        let mut engine = Ltx2Engine::with_runtime_session(
1666            "ltx-2-19b-distilled:fp8".to_string(),
1667            dummy_paths(),
1668            runtime_session(),
1669        );
1670        engine.loaded = true;
1671        let req = request(OutputFormat::Mp4, Some(false));
1672        let err = engine
1673            .render_chain_stage(&req, None, 0)
1674            .expect_err("must fail on zero motion tail");
1675        let msg = format!("{err}");
1676        assert!(
1677            msg.contains("motion_tail_pixel_frames"),
1678            "error must name the motion_tail constraint, got: {msg}",
1679        );
1680    }
1681
1682    /// CPU transformer → encoder pinned to the same device. The auto resolver
1683    /// must short-circuit before probing GPUs (which on a CUDA-less host
1684    /// would still pick CPU, but on a CUDA host must not place a 23 GB
1685    /// encoder on a card the transformer chose to skip).
1686    #[test]
1687    fn resolve_prompt_encoder_device_keeps_cpu_when_transformer_is_cpu() {
1688        let prior_main = std::env::var_os("MOLD_LTX2_GEMMA_DEVICE");
1689        let prior_legacy = std::env::var_os("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER");
1690        unsafe {
1691            std::env::remove_var("MOLD_LTX2_GEMMA_DEVICE");
1692            std::env::remove_var("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER");
1693        }
1694
1695        let resolved = resolve_prompt_encoder_device(&Device::Cpu, 0);
1696        assert!(resolved.is_cpu());
1697
1698        unsafe {
1699            if let Some(v) = prior_main {
1700                std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", v);
1701            }
1702            if let Some(v) = prior_legacy {
1703                std::env::set_var("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER", v);
1704            }
1705        }
1706    }
1707
1708    /// `MOLD_LTX2_GEMMA_DEVICE=cpu` pins the encoder to CPU even when the
1709    /// transformer device is CUDA-shaped. We exercise this through the
1710    /// device-level resolver because the runtime path needs the same
1711    /// decision the load path will make and constructing a real CUDA
1712    /// device in CI isn't possible.
1713    #[test]
1714    fn resolver_picks_cpu_when_env_pins_cpu() {
1715        let prior_main = std::env::var_os("MOLD_LTX2_GEMMA_DEVICE");
1716        let prior_legacy = std::env::var_os("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER");
1717        unsafe {
1718            std::env::remove_var("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER");
1719            std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", "cpu");
1720        }
1721        assert_eq!(
1722            crate::device::resolve_ltx2_gemma_placement(0),
1723            crate::device::LtxGemmaPlacement::Cpu,
1724        );
1725        unsafe {
1726            std::env::remove_var("MOLD_LTX2_GEMMA_DEVICE");
1727            if let Some(v) = prior_main {
1728                std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", v);
1729            }
1730            if let Some(v) = prior_legacy {
1731                std::env::set_var("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER", v);
1732            }
1733        }
1734    }
1735}