1#![allow(clippy::type_complexity)]
2
3use anyhow::{bail, Context, Result};
4use candle_core::Device;
5use mold_core::{
6 GenerateRequest, GenerateResponse, Ltx2PipelineMode, ModelPaths, OutputFormat, VideoData,
7};
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::time::Instant;
11
12use super::assets;
13use super::backend::Ltx2Backend;
14use super::chain::{ChainStageRenderer, ChainTail, StageOutcome, StageProgressEvent};
15use super::conditioning::{self, StagedLatent};
16use super::execution;
17use super::lora;
18use super::media::{self, ProbeMetadata};
19use super::plan::{Ltx2GeneratePlan, PipelineKind};
20use super::preset;
21use super::runtime::{Ltx2RuntimeSession, NativeRenderedVideo};
22use super::text::gemma::GemmaAssets;
23use super::text::prompt_encoder::NativePromptEncoder;
24use crate::engine::{gpu_dtype, rand_seed, InferenceEngine, LoadStrategy};
25use crate::ltx_video::video_enc;
26use crate::progress::ProgressCallback;
27
28const CHAIN_SOFT_ANCHOR_STRENGTH: f32 = 0.4;
35
36pub struct Ltx2Engine {
37 model_name: String,
38 paths: ModelPaths,
39 loaded: bool,
40 native_runtime: Option<Ltx2RuntimeSession>,
41 on_progress: Option<ProgressCallback>,
42 pending_placement: Option<mold_core::types::DevicePlacement>,
43 gpu_ordinal: usize,
48 preset_hint: Option<String>,
54}
55
56impl Ltx2Engine {
57 fn debug_timings_enabled() -> bool {
58 std::env::var_os("MOLD_LTX2_DEBUG_TIMINGS").is_some()
59 }
60
61 fn log_timing(label: &str, start: Instant) {
62 if Self::debug_timings_enabled() {
63 eprintln!(
64 "[ltx2-timing] {label} {:.3}s",
65 start.elapsed().as_secs_f64()
66 );
67 }
68 }
69
70 pub fn new(
71 model_name: String,
72 paths: ModelPaths,
73 _load_strategy: LoadStrategy,
74 gpu_ordinal: usize,
75 ) -> Self {
76 Self {
77 model_name,
78 paths,
79 loaded: false,
80 native_runtime: None,
81 on_progress: None,
82 pending_placement: None,
83 gpu_ordinal,
84 preset_hint: None,
85 }
86 }
87
88 pub fn from_single_file(
111 model_name: String,
112 checkpoint: PathBuf,
113 paths: ModelPaths,
114 load_strategy: LoadStrategy,
115 gpu_ordinal: usize,
116 ) -> anyhow::Result<Self> {
117 if !checkpoint.exists() {
118 anyhow::bail!(
119 "single-file LTX-2 checkpoint not found: {}",
120 checkpoint.display()
121 );
122 }
123
124 let bundle = super::single_file::load(&checkpoint).map_err(|e| {
125 anyhow::anyhow!(
126 "failed to parse single-file LTX-2 checkpoint {}: {e}",
127 checkpoint.display()
128 )
129 })?;
130
131 if !bundle.has_vae {
132 anyhow::bail!(
133 "LTX-2 checkpoint {} contains no VAE weights (`vae.*` keys). \
134 This appears to be a transformer-only fine-tune. \
135 The LTX-2 runtime requires a combined transformer+VAE checkpoint. \
136 Phase-5 does not yet support separate-VAE loading for LTX-2.",
137 checkpoint.display()
138 );
139 }
140
141 let paths = ModelPaths {
147 transformer: checkpoint,
148 transformer_shards: Vec::new(),
149 vae: PathBuf::default(),
150 ..paths
151 };
152
153 let mut engine = Self::new(model_name, paths, load_strategy, gpu_ordinal);
154 engine.preset_hint = bundle.model_version;
160 Ok(engine)
161 }
162
163 #[cfg(test)]
164 fn with_runtime_session(
165 model_name: String,
166 paths: ModelPaths,
167 runtime: Ltx2RuntimeSession,
168 ) -> Self {
169 Self {
170 model_name,
171 paths,
172 loaded: false,
173 native_runtime: Some(runtime),
174 on_progress: None,
175 pending_placement: None,
176 gpu_ordinal: 0,
177 preset_hint: None,
178 }
179 }
180
181 fn emit(&self, stage: &str) {
182 if let Some(callback) = &self.on_progress {
183 callback(crate::ProgressEvent::StageStart {
184 name: stage.to_string(),
185 });
186 }
187 }
188
189 fn info(&self, message: &str) {
190 if let Some(callback) = &self.on_progress {
191 callback(crate::ProgressEvent::Info {
192 message: message.to_string(),
193 });
194 }
195 }
196
197 fn is_oom_error(err: &impl std::fmt::Display) -> bool {
198 let msg = err.to_string().to_ascii_lowercase();
199 msg.contains("out of memory")
200 || msg.contains("out_of_memory")
201 || msg.contains("cudaerrormemoryallocation")
202 }
203
204 fn unload_runtime_state(&mut self) -> Option<usize> {
205 self.loaded = false;
206 let should_reclaim = self
207 .native_runtime
208 .as_ref()
209 .is_some_and(Ltx2RuntimeSession::needs_cuda_reclaim_on_unload);
210 self.native_runtime = None;
211 should_reclaim.then_some(self.gpu_ordinal)
212 }
213
214 fn gemma_root(&self) -> Result<PathBuf> {
215 assets::gemma_root(&self.paths)
216 }
217
218 fn select_pipeline(&self, req: &GenerateRequest) -> Result<PipelineKind> {
219 if let Some(mode) = req.pipeline {
220 return Ok(match mode {
221 Ltx2PipelineMode::OneStage => PipelineKind::OneStage,
222 Ltx2PipelineMode::TwoStage => PipelineKind::TwoStage,
223 Ltx2PipelineMode::TwoStageHq => PipelineKind::TwoStageHq,
224 Ltx2PipelineMode::Distilled => PipelineKind::Distilled,
225 Ltx2PipelineMode::IcLora => PipelineKind::IcLora,
226 Ltx2PipelineMode::Keyframe => PipelineKind::Keyframe,
227 Ltx2PipelineMode::A2Vid => PipelineKind::A2Vid,
228 Ltx2PipelineMode::Retake => PipelineKind::Retake,
229 });
230 }
231
232 if req.retake_range.is_some() {
233 return Ok(PipelineKind::Retake);
234 }
235 if req.audio_file.is_some() || req.audio_file_path.is_some() {
236 return Ok(PipelineKind::A2Vid);
237 }
238 if req.keyframes.as_ref().is_some_and(|items| items.len() > 1) {
239 return Ok(PipelineKind::Keyframe);
240 }
241 if req.source_video.is_some() || req.source_video_path.is_some() {
242 return Ok(PipelineKind::IcLora);
243 }
244 if self.model_name.contains("distilled") {
245 return Ok(if self.paths.spatial_upscaler.is_some() {
250 PipelineKind::Distilled
251 } else {
252 PipelineKind::OneStage
253 });
254 }
255 Ok(if self.paths.spatial_upscaler.is_some() {
261 PipelineKind::TwoStage
262 } else {
263 PipelineKind::OneStage
264 })
265 }
266
267 fn request_quantization(&self) -> Option<String> {
268 assets::request_quantization(&self.model_name)
269 }
270
271 #[allow(dead_code)]
272 fn camera_control_preset(name: &str) -> Option<lora::CameraControlPreset> {
273 lora::camera_control_preset(name)
274 }
275
276 pub(crate) fn materialize_request(
277 &self,
278 req: &GenerateRequest,
279 work_dir: &Path,
280 output_path: &Path,
281 ) -> Result<Ltx2GeneratePlan> {
282 let pipeline = self.select_pipeline(req)?;
283 let gemma_root = self.gemma_root()?;
284 let prompt_tokens = GemmaAssets::discover(&gemma_root)?
285 .encode_prompt_pair(&req.prompt, req.negative_prompt.as_deref())?;
286 let conditioning = conditioning::stage_conditioning(req, work_dir)?;
287 let loras = lora::resolve_loras(&self.model_name, req)?;
288 let preset =
289 preset::preset_for_model_with_hint(&self.model_name, self.preset_hint.as_deref())?;
290 let execution_graph =
291 execution::build_execution_graph(req, pipeline, &conditioning, &preset, loras.len());
292 let spatial_upsampler_path = assets::resolve_spatial_upscaler_path(
293 &self.model_name,
294 &self.paths,
295 req.spatial_upscale,
296 )?
297 .map(|path| path.to_string_lossy().to_string());
298 let temporal_upsampler_path =
299 assets::resolve_temporal_upscaler_path(&self.paths, req.temporal_upscale)?
300 .map(|path| path.to_string_lossy().to_string());
301
302 Ok(Ltx2GeneratePlan {
303 pipeline,
304 preset,
305 checkpoint_is_distilled: self.model_name.contains("distilled"),
306 execution_graph,
307 checkpoint_path: self.paths.transformer.to_string_lossy().to_string(),
308 distilled_checkpoint_path: pipeline
309 .requires_distilled_checkpoint()
310 .then(|| self.paths.transformer.to_string_lossy().to_string()),
311 distilled_lora_path: self
312 .paths
313 .distilled_lora
314 .as_ref()
315 .map(|path| path.to_string_lossy().to_string()),
316 spatial_upsampler_path,
317 temporal_upsampler_path,
318 gemma_root: gemma_root.to_string_lossy().to_string(),
319 output_path: output_path.to_string_lossy().to_string(),
320 prompt: req.prompt.clone(),
321 negative_prompt: req.negative_prompt.clone(),
322 prompt_tokens,
323 seed: req.seed.unwrap_or_else(rand_seed),
324 width: req.width,
325 height: req.height,
326 num_frames: req.frames.unwrap_or(97),
327 frame_rate: req.fps.unwrap_or(24),
328 num_inference_steps: req.steps,
329 guidance: req.guidance,
330 quantization: self.request_quantization(),
331 streaming_prefetch_count: Some(preset.streaming_prefetch_count),
332 conditioning,
333 loras,
334 retake_range: req.retake_range.clone(),
335 spatial_upscale: req.spatial_upscale,
336 temporal_upscale: req.temporal_upscale,
337 })
338 }
339
340 fn probe_video(&self, input_video: &Path) -> Result<ProbeMetadata> {
341 media::probe_video(input_video)
342 }
343
344 fn native_device_for_backend(&self, backend: Ltx2Backend) -> Result<Device> {
345 match backend {
346 Ltx2Backend::Cuda => {
347 self.info("CUDA detected, using native LTX-2 GPU path");
348 let device = Device::new_cuda(self.gpu_ordinal)?;
349 configure_native_ltx2_cuda_device(&device)?;
350 Ok(device)
351 }
352 Ltx2Backend::Cpu => {
353 let forced_cpu = std::env::var("MOLD_DEVICE")
354 .map(|value| value.eq_ignore_ascii_case("cpu"))
355 .unwrap_or(false);
356 if forced_cpu {
357 self.info("CPU forced via MOLD_DEVICE=cpu for native LTX-2");
358 } else {
359 self.info("No CUDA detected; using native LTX-2 CPU fallback");
360 }
361 Ok(Device::Cpu)
362 }
363 Ltx2Backend::Metal => unreachable!("unsupported Metal backend should have errored"),
364 }
365 }
366
367 fn load_runtime_session_on_device(
368 &self,
369 plan: &Ltx2GeneratePlan,
370 device: Device,
371 ) -> Result<Ltx2RuntimeSession> {
372 let load_start = Instant::now();
373 let prompt_device = resolve_prompt_encoder_device(&device, self.gpu_ordinal);
374 log_prompt_encoder_placement(&device, &prompt_device);
375 let dtype = gpu_dtype(&prompt_device);
376 self.emit("Loading native LTX-2 prompt encoder");
377 let prompt_encoder = NativePromptEncoder::load(
378 Path::new(&plan.gemma_root),
379 Path::new(&plan.checkpoint_path),
380 &plan.preset,
381 &prompt_device,
382 dtype,
383 )?;
384 Self::log_timing("pipeline.create_runtime.load_prompt_encoder", load_start);
385 let same_device = device.same_device(&prompt_device);
391 if prompt_device.is_cuda() && same_device {
392 Ok(Ltx2RuntimeSession::new_deferred_cuda(
393 prompt_encoder,
394 self.gpu_ordinal,
395 ))
396 } else {
397 Ok(Ltx2RuntimeSession::new(
398 device,
399 prompt_encoder,
400 self.gpu_ordinal,
401 ))
402 }
403 }
404
405 fn create_runtime_session(&self, plan: &Ltx2GeneratePlan) -> Result<Ltx2RuntimeSession> {
406 let backend = Ltx2Backend::detect();
407 backend.ensure_supported()?;
408
409 let tier1 = self.pending_placement.as_ref().map(|p| p.text_encoders);
413 let device =
414 crate::device::resolve_device(tier1, || self.native_device_for_backend(backend))?;
415 if device.is_cuda() {
416 configure_native_ltx2_cuda_device(&device)?;
417 }
418 let override_is_auto = matches!(tier1, None | Some(mold_core::types::DeviceRef::Auto));
422 match self.load_runtime_session_on_device(plan, device) {
423 Ok(runtime) => Ok(runtime),
424 Err(err)
425 if matches!(backend, Ltx2Backend::Cuda)
426 && override_is_auto
427 && Self::is_oom_error(&err) =>
428 {
429 self.info(
430 "Native LTX-2 prompt path ran out of CUDA memory; retrying with CPU fallback",
431 );
432 crate::device::reclaim_gpu_memory(self.gpu_ordinal);
433 self.load_runtime_session_on_device(plan, Device::Cpu)
434 }
435 Err(err) => Err(err),
436 }
437 }
438
439 fn encode_native_video(
440 &self,
441 req: &GenerateRequest,
442 plan: &Ltx2GeneratePlan,
443 rendered: &NativeRenderedVideo,
444 work_dir: &Path,
445 ) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>, Option<ProbeMetadata>)> {
446 if let Some(audio_track) = rendered.audio_track.as_ref() {
447 let wav_path = work_dir.join("native-audio.wav");
448 fs::write(
449 &wav_path,
450 media::encode_wav_f32_interleaved(
451 &audio_track.interleaved_samples,
452 audio_track.sample_rate,
453 audio_track.channels,
454 )?,
455 )?;
456 }
457
458 let output_encode_start = Instant::now();
459 let output_bytes = match req.resolved_output_format() {
460 OutputFormat::Apng => {
461 let metadata = video_enc::VideoMetadata {
462 prompt: req.prompt.clone(),
463 model: self.model_name.clone(),
464 seed: plan.seed,
465 steps: req.steps,
466 guidance: req.guidance,
467 width: plan.width,
468 height: plan.height,
469 frames: plan.num_frames,
470 fps: plan.frame_rate,
471 };
472 video_enc::encode_apng(&rendered.frames, plan.frame_rate, Some(&metadata))?
473 }
474 OutputFormat::Gif => video_enc::encode_gif(&rendered.frames, plan.frame_rate)?,
475 #[cfg(feature = "webp")]
476 OutputFormat::Webp => video_enc::encode_webp(&rendered.frames, plan.frame_rate)?,
477 #[cfg(not(feature = "webp"))]
478 OutputFormat::Webp => bail!("WebP output requires the 'webp' feature"),
479 OutputFormat::Mp4 => {
480 #[cfg(feature = "mp4")]
481 {
482 let video_only = video_enc::encode_mp4(&rendered.frames, plan.frame_rate)?;
483 let mp4_path = work_dir.join("native-video.mp4");
484 fs::write(&mp4_path, &video_only)?;
485 if let Some(audio_track) = rendered.audio_track.as_ref() {
486 let muxed_path = work_dir.join("native-video-audio.mp4");
487 media::attach_aac_track_from_f32_interleaved(
488 &mp4_path,
489 &muxed_path,
490 &audio_track.interleaved_samples,
491 audio_track.sample_rate,
492 audio_track.channels,
493 )?;
494 fs::read(muxed_path)?
495 } else {
496 video_only
497 }
498 }
499 #[cfg(not(feature = "mp4"))]
500 {
501 bail!("MP4 output requires the 'mp4' feature")
502 }
503 }
504 other => bail!("{other:?} is not supported for LTX-2 video output"),
505 };
506 Self::log_timing("pipeline.encode_output", output_encode_start);
507
508 let thumbnail_start = Instant::now();
509 let thumbnail = video_enc::first_frame_png(&rendered.frames)?;
510 Self::log_timing("pipeline.encode_thumbnail", thumbnail_start);
511 let gif_preview_start = Instant::now();
512 let gif_preview = if req.gif_preview {
513 if req.resolved_output_format() == OutputFormat::Gif {
514 output_bytes.clone()
515 } else {
516 video_enc::encode_gif(&rendered.frames, plan.frame_rate)?
517 }
518 } else {
519 Vec::new()
520 };
521 Self::log_timing("pipeline.encode_gif_preview", gif_preview_start);
522
523 let probe_start = Instant::now();
524 let probe = if req.resolved_output_format() == OutputFormat::Mp4 {
525 let path = work_dir.join("probe.mp4");
526 fs::write(&path, &output_bytes)?;
527 Some(self.probe_video(&path)?)
528 } else {
529 None
530 };
531 Self::log_timing("pipeline.probe_output", probe_start);
532
533 Ok((output_bytes, thumbnail, gif_preview, probe))
534 }
535}
536
537#[cfg_attr(not(feature = "cuda"), allow(unused_variables))]
538fn configure_native_ltx2_cuda_device(device: &Device) -> Result<()> {
539 #[cfg(feature = "cuda")]
540 if device.is_cuda() {
541 let cuda = device.as_cuda_device()?;
542 if cuda.is_event_tracking() {
543 unsafe {
547 cuda.disable_event_tracking();
548 }
549 }
550 }
551 Ok(())
552}
553
554impl Ltx2Engine {
555 fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
556 if !self.loaded {
557 self.load()?;
558 }
559 let start = Instant::now();
560 self.emit("Preparing native LTX-2 request");
561
562 let work_dir = tempfile::tempdir().context("failed to create LTX-2 temp directory")?;
563 let native_output = work_dir.path().join("ltx2-native-output.mp4");
564 let materialize_start = Instant::now();
565 let plan = self.materialize_request(req, work_dir.path(), &native_output)?;
566 Self::log_timing("pipeline.materialize_request", materialize_start);
567 let planned_stage_count = plan.execution_graph.denoise_passes.len();
568 self.emit(&format!(
569 "Planned native LTX-2 graph: preset={}, denoise_stages={}, blocks={}, prompt_tokens={}/{}",
570 plan.preset.name,
571 planned_stage_count,
572 plan.execution_graph.blocks.len(),
573 plan.prompt_tokens.conditional.valid_len(),
574 plan.prompt_tokens.unconditional.valid_len()
575 ));
576 let create_runtime_start = Instant::now();
577 let mut runtime = match self.native_runtime.take() {
585 Some(runtime) if runtime.can_reuse_for(&plan) => runtime,
586 _ => self.create_runtime_session(&plan)?,
587 };
588 Self::log_timing("pipeline.create_runtime", create_runtime_start);
589
590 self.emit("Encoding prompt and preparing native LTX-2 runtime state");
591 let prepare_start = Instant::now();
592 let prepared = runtime.prepare(&plan)?;
593 Self::log_timing("pipeline.prepare_runtime", prepare_start);
594 self.emit("Executing native LTX-2 runtime");
595 let render_start = Instant::now();
596 let rendered = runtime.render_native_video(&plan, &prepared, self.on_progress.as_ref())?;
597 Self::log_timing("pipeline.render_runtime", render_start);
598 let encode_start = Instant::now();
599 let (output_bytes, thumbnail_bytes, gif_preview, probe) =
600 self.encode_native_video(req, &plan, &rendered, work_dir.path())?;
601 Self::log_timing("pipeline.encode_native_video", encode_start);
602 let duration_ms =
603 Some((plan.num_frames as u64 * 1000).div_ceil(plan.frame_rate.max(1) as u64));
604 let width = probe
605 .as_ref()
606 .map(|probe| probe.width)
607 .unwrap_or(plan.width);
608 let height = probe
609 .as_ref()
610 .map(|probe| probe.height)
611 .unwrap_or(plan.height);
612 let frames = probe
613 .as_ref()
614 .and_then(|probe| probe.frames)
615 .unwrap_or(plan.num_frames);
616 let fps = probe
617 .as_ref()
618 .map(|probe| probe.fps)
619 .unwrap_or(plan.frame_rate);
620 let has_audio = if req.resolved_output_format() == OutputFormat::Mp4 {
621 probe
622 .as_ref()
623 .map(|probe| probe.has_audio)
624 .unwrap_or(rendered.has_audio)
625 } else {
626 false
627 };
628 let audio_sample_rate = if req.resolved_output_format() == OutputFormat::Mp4 {
629 probe
630 .as_ref()
631 .and_then(|probe| probe.audio_sample_rate)
632 .or(rendered.audio_sample_rate)
633 } else {
634 None
635 };
636 let audio_channels = if req.resolved_output_format() == OutputFormat::Mp4 {
637 probe
638 .as_ref()
639 .and_then(|probe| probe.audio_channels)
640 .or(rendered.audio_channels)
641 } else {
642 None
643 };
644
645 Ok(GenerateResponse {
646 images: vec![],
647 video: Some(VideoData {
648 data: output_bytes,
649 format: req.resolved_output_format(),
650 width,
651 height,
652 frames,
653 fps,
654 thumbnail: thumbnail_bytes,
655 gif_preview,
656 has_audio,
657 duration_ms: probe
658 .as_ref()
659 .and_then(|probe| probe.duration_ms)
660 .or(duration_ms),
661 audio_sample_rate,
662 audio_channels,
663 }),
664 generation_time_ms: start.elapsed().as_millis() as u64,
665 model: self.model_name.clone(),
666 seed_used: plan.seed,
667 gpu: None,
668 })
669 }
670
671 pub(crate) fn render_chain_stage(
682 &mut self,
683 req: &GenerateRequest,
684 carry: Option<&ChainTail>,
685 motion_tail_pixel_frames: u32,
686 ) -> Result<StageOutcome> {
687 if motion_tail_pixel_frames == 0 {
688 bail!("render_chain_stage: motion_tail_pixel_frames must be > 0");
689 }
690 if !self.loaded {
691 self.load()?;
692 }
693 let start = Instant::now();
694 self.emit("Preparing native LTX-2 chain stage");
695
696 let pipeline = self.select_pipeline(req)?;
697 if !matches!(pipeline, PipelineKind::Distilled) {
698 bail!(
699 "render-chain v1 only supports the distilled LTX-2 pipeline, got {:?}",
700 pipeline,
701 );
702 }
703
704 let work_dir = tempfile::tempdir().context("failed to create LTX-2 temp directory")?;
705 let native_output = work_dir.path().join("ltx2-native-output.mp4");
706 let mut plan = self.materialize_request(req, work_dir.path(), &native_output)?;
707
708 if let Some(tail) = carry {
727 if req.source_image.is_some() {
728 tracing::warn!(
729 "smooth continuation received source_image; it will be repurposed as a soft \
730 identity anchor. Use transition: cut|fade to seed the stage with a fresh i2v."
731 );
732 }
733 if tail.tail_rgb_frames.is_empty() {
734 bail!(
735 "render_chain_stage: carry.tail_rgb_frames is empty; caller must provide at least one frame"
736 );
737 }
738
739 let anchor_frame = motion_tail_pixel_frames;
748 for image in plan.conditioning.images.iter_mut() {
749 if image.frame == 0 {
750 image.frame = anchor_frame;
751 image.strength = CHAIN_SOFT_ANCHOR_STRENGTH;
752 }
753 }
754
755 plan.conditioning.latents.push(StagedLatent {
756 tail_rgb_frames: tail.tail_rgb_frames.clone(),
757 frame: 0,
758 strength: 1.0,
759 });
760 }
761
762 let mut runtime = match self.native_runtime.take() {
769 Some(runtime) if runtime.can_reuse_for(&plan) => runtime,
770 _ => self.create_runtime_session(&plan)?,
771 };
772
773 self.emit("Executing native LTX-2 chain stage runtime");
774 let prepared = match runtime.prepare(&plan) {
775 Ok(prepared) => prepared,
776 Err(err) => {
777 self.native_runtime = Some(runtime);
778 return Err(err);
779 }
780 };
781 let render_result =
782 runtime.render_native_video(&plan, &prepared, self.on_progress.as_ref());
783 self.native_runtime = Some(runtime);
784 let rendered = render_result?;
785
786 let frames = rendered.frames;
787 let audio = rendered.audio_track;
788 let tail_pixel_frames = motion_tail_pixel_frames as usize;
789 if frames.len() < tail_pixel_frames {
790 bail!(
791 "distilled render returned {} pixel frames but the chain caller requested a {}-frame tail; \
792 this is a pipeline wiring bug",
793 frames.len(),
794 motion_tail_pixel_frames,
795 );
796 }
797 let tail_start = frames.len() - tail_pixel_frames;
798 let tail_rgb_frames = frames[tail_start..].to_vec();
799
800 let generation_time_ms = start.elapsed().as_millis() as u64;
801 Self::log_timing("pipeline.render_chain_stage", start);
802
803 Ok(StageOutcome {
804 frames,
805 tail: ChainTail {
806 frames: motion_tail_pixel_frames,
807 tail_rgb_frames,
808 },
809 audio,
810 generation_time_ms,
811 })
812 }
813}
814
815impl ChainStageRenderer for Ltx2Engine {
816 fn render_stage(
817 &mut self,
818 stage_req: &GenerateRequest,
819 carry: Option<&ChainTail>,
820 motion_tail_pixel_frames: u32,
821 _stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>,
822 ) -> Result<StageOutcome> {
823 self.render_chain_stage(stage_req, carry, motion_tail_pixel_frames)
831 }
832}
833
834impl InferenceEngine for Ltx2Engine {
835 fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
836 self.pending_placement = req.placement.clone();
837 let result = self.generate_inner(req);
838 self.pending_placement = None;
839 result
840 }
841
842 fn model_name(&self) -> &str {
843 &self.model_name
844 }
845
846 fn is_loaded(&self) -> bool {
847 self.loaded
848 }
849
850 fn load(&mut self) -> Result<()> {
851 self.emit("Preparing native LTX-2 runtime");
852 if !self.paths.transformer.exists() {
853 bail!(
854 "missing LTX-2 checkpoint: {}",
855 self.paths.transformer.display()
856 );
857 }
858 let gemma_root = self.gemma_root()?;
859 if !gemma_root.join("tokenizer.json").exists() {
860 bail!(
861 "missing Gemma tokenizer assets for LTX-2: {}",
862 gemma_root.display()
863 );
864 }
865 Ltx2Backend::detect().ensure_supported()?;
866 self.loaded = true;
867 Ok(())
868 }
869
870 fn unload(&mut self) {
871 if let Some(ordinal) = self.unload_runtime_state() {
872 crate::reclaim_gpu_memory(ordinal);
873 }
874 }
875
876 fn set_on_progress(&mut self, callback: ProgressCallback) {
877 self.on_progress = Some(callback);
878 }
879
880 fn clear_on_progress(&mut self) {
881 self.on_progress = None;
882 }
883
884 fn model_paths(&self) -> Option<&ModelPaths> {
885 Some(&self.paths)
886 }
887
888 fn as_chain_renderer(&mut self) -> Option<&mut dyn crate::ltx2::ChainStageRenderer> {
889 Some(self)
890 }
891}
892
893pub(crate) fn resolve_prompt_encoder_device(
904 transformer_device: &Device,
905 gpu_ordinal: usize,
906) -> Device {
907 if !transformer_device.is_cuda() {
908 return transformer_device.clone();
909 }
910 crate::device::resolve_ltx2_gemma_placement(gpu_ordinal).into_device()
911}
912
913fn log_prompt_encoder_placement(transformer_device: &Device, prompt_device: &Device) {
914 if transformer_device.same_device(prompt_device) {
915 return;
916 }
917 let label = if prompt_device.is_cpu() {
918 "CPU".to_string()
919 } else if prompt_device.is_cuda() {
920 "GPU (sibling ordinal)".to_string()
921 } else {
922 "non-CUDA device".to_string()
923 };
924 tracing::info!(
925 prompt_encoder_device = %label,
926 "LTX-2 Gemma encoder placed off the transformer device — \
927 encode-time tensor copy will move conditioning back to the transformer GPU"
928 );
929}
930
931#[cfg(test)]
932mod tests {
933 use super::*;
934 use std::collections::HashMap;
935 use std::fs;
936 use std::path::Path;
937 use std::path::PathBuf;
938
939 use candle_core::{DType, Device, Tensor};
940 use candle_nn::VarBuilder;
941
942 use crate::ltx2::text::connectors::PaddingSide;
943 use crate::ltx2::text::encoder::{GemmaConfig, GemmaHiddenStateEncoder};
944 use crate::ltx2::text::prompt_encoder::{
945 build_embeddings_processor, ConnectorSpec, NativePromptEncoder,
946 };
947
948 fn dummy_paths() -> ModelPaths {
949 ModelPaths {
950 transformer: PathBuf::from("/tmp/ltx2.safetensors"),
951 transformer_shards: vec![],
952 vae: PathBuf::from("/tmp/unused"),
953 spatial_upscaler: Some(PathBuf::from("/tmp/spatial.safetensors")),
954 temporal_upscaler: Some(PathBuf::from("/tmp/temporal.safetensors")),
955 distilled_lora: Some(PathBuf::from("/tmp/distilled-lora.safetensors")),
956 t5_encoder: None,
957 clip_encoder: None,
958 t5_tokenizer: None,
959 clip_tokenizer: None,
960 clip_encoder_2: None,
961 clip_tokenizer_2: None,
962 text_encoder_files: vec![PathBuf::from("/tmp/gemma/tokenizer.json")],
963 text_tokenizer: None,
964 decoder: None,
965 }
966 }
967
968 fn dummy_paths_with_gemma_root(root: &std::path::Path) -> ModelPaths {
969 let mut paths = dummy_paths();
970 paths.text_encoder_files = vec![root.join("tokenizer.json")];
971 paths
972 }
973
974 fn dummy_paths_in(root: &Path, gemma_root: &Path) -> ModelPaths {
975 ModelPaths {
976 transformer: root.join("ltx2.safetensors"),
977 transformer_shards: vec![],
978 vae: root.join("unused"),
979 spatial_upscaler: Some(root.join("spatial.safetensors")),
980 temporal_upscaler: Some(root.join("temporal.safetensors")),
981 distilled_lora: Some(root.join("distilled-lora.safetensors")),
982 t5_encoder: None,
983 clip_encoder: None,
984 t5_tokenizer: None,
985 clip_tokenizer: None,
986 clip_encoder_2: None,
987 clip_tokenizer_2: None,
988 text_encoder_files: vec![gemma_root.join("tokenizer.json")],
989 text_tokenizer: None,
990 decoder: None,
991 }
992 }
993
994 fn write_test_gemma_assets(root: &std::path::Path) {
995 fs::write(
996 root.join("tokenizer.json"),
997 r#"{
998 "version": "1.0",
999 "truncation": null,
1000 "padding": null,
1001 "added_tokens": [],
1002 "normalizer": null,
1003 "pre_tokenizer": {
1004 "type": "WhitespaceSplit"
1005 },
1006 "post_processor": null,
1007 "decoder": null,
1008 "model": {
1009 "type": "WordLevel",
1010 "vocab": {
1011 "<eos>": 7,
1012 "test": 11
1013 },
1014 "unk_token": "<eos>"
1015 }
1016}"#,
1017 )
1018 .unwrap();
1019 fs::write(
1020 root.join("special_tokens_map.json"),
1021 r#"{"eos_token":"<eos>"}"#,
1022 )
1023 .unwrap();
1024 }
1025
1026 fn tiny_gemma_config() -> GemmaConfig {
1027 GemmaConfig {
1028 attention_bias: false,
1029 head_dim: 4,
1030 hidden_activation: candle_nn::Activation::GeluPytorchTanh,
1031 hidden_size: 8,
1032 intermediate_size: 16,
1033 num_attention_heads: 2,
1034 num_hidden_layers: 2,
1035 num_key_value_heads: 1,
1036 rms_norm_eps: 1e-6,
1037 rope_theta: 10_000.0,
1038 rope_local_base_freq: 10_000.0,
1039 vocab_size: 16,
1040 final_logit_softcapping: None,
1041 attn_logit_softcapping: None,
1042 query_pre_attn_scalar: 4,
1043 sliding_window: 4,
1044 sliding_window_pattern: 2,
1045 max_position_embeddings: 1024,
1046 }
1047 }
1048
1049 fn zero_gemma_var_builder(cfg: &GemmaConfig) -> VarBuilder<'static> {
1050 let mut tensors = HashMap::new();
1051 tensors.insert(
1052 "model.embed_tokens.weight".to_string(),
1053 Tensor::zeros((cfg.vocab_size, cfg.hidden_size), DType::F32, &Device::Cpu).unwrap(),
1054 );
1055 for layer in 0..cfg.num_hidden_layers {
1056 for name in [
1057 "self_attn.q_proj",
1058 "self_attn.k_proj",
1059 "self_attn.v_proj",
1060 "self_attn.o_proj",
1061 "mlp.gate_proj",
1062 "mlp.up_proj",
1063 "mlp.down_proj",
1064 ] {
1065 let (rows, cols) = match name {
1066 "self_attn.q_proj" => (cfg.num_attention_heads * cfg.head_dim, cfg.hidden_size),
1067 "self_attn.k_proj" | "self_attn.v_proj" => {
1068 (cfg.num_key_value_heads * cfg.head_dim, cfg.hidden_size)
1069 }
1070 "self_attn.o_proj" => (cfg.hidden_size, cfg.num_attention_heads * cfg.head_dim),
1071 "mlp.gate_proj" | "mlp.up_proj" => (cfg.intermediate_size, cfg.hidden_size),
1072 "mlp.down_proj" => (cfg.hidden_size, cfg.intermediate_size),
1073 _ => unreachable!(),
1074 };
1075 tensors.insert(
1076 format!("model.layers.{layer}.{name}.weight"),
1077 Tensor::zeros((rows, cols), DType::F32, &Device::Cpu).unwrap(),
1078 );
1079 }
1080 for name in [
1081 "self_attn.q_norm",
1082 "self_attn.k_norm",
1083 "input_layernorm",
1084 "pre_feedforward_layernorm",
1085 "post_feedforward_layernorm",
1086 "post_attention_layernorm",
1087 ] {
1088 let dim = if name.contains("q_norm") || name.contains("k_norm") {
1089 cfg.head_dim
1090 } else {
1091 cfg.hidden_size
1092 };
1093 tensors.insert(
1094 format!("model.layers.{layer}.{name}.weight"),
1095 Tensor::zeros(dim, DType::F32, &Device::Cpu).unwrap(),
1096 );
1097 }
1098 }
1099 tensors.insert(
1100 "model.norm.weight".to_string(),
1101 Tensor::zeros(cfg.hidden_size, DType::F32, &Device::Cpu).unwrap(),
1102 );
1103 VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu)
1104 }
1105
1106 fn zero_connector_source_var_builder() -> VarBuilder<'static> {
1107 let mut tensors = HashMap::new();
1108 tensors.insert(
1109 "text_embedding_projection.video_aggregate_embed.weight".to_string(),
1110 Tensor::zeros((8, 24), DType::F32, &Device::Cpu).unwrap(),
1111 );
1112 tensors.insert(
1113 "text_embedding_projection.video_aggregate_embed.bias".to_string(),
1114 Tensor::zeros(8, DType::F32, &Device::Cpu).unwrap(),
1115 );
1116 tensors.insert(
1117 "text_embedding_projection.audio_aggregate_embed.weight".to_string(),
1118 Tensor::zeros((4, 24), DType::F32, &Device::Cpu).unwrap(),
1119 );
1120 tensors.insert(
1121 "text_embedding_projection.audio_aggregate_embed.bias".to_string(),
1122 Tensor::zeros(4, DType::F32, &Device::Cpu).unwrap(),
1123 );
1124 for (prefix, dim) in [
1125 ("model.diffusion_model.video_embeddings_connector", 8usize),
1126 ("model.diffusion_model.audio_embeddings_connector", 4usize),
1127 ] {
1128 for linear_name in ["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0"] {
1129 tensors.insert(
1130 format!("{prefix}.transformer_1d_blocks.0.{linear_name}.weight"),
1131 Tensor::zeros((dim, dim), DType::F32, &Device::Cpu).unwrap(),
1132 );
1133 tensors.insert(
1134 format!("{prefix}.transformer_1d_blocks.0.{linear_name}.bias"),
1135 Tensor::zeros(dim, DType::F32, &Device::Cpu).unwrap(),
1136 );
1137 }
1138 for norm_name in ["attn1.q_norm", "attn1.k_norm"] {
1139 tensors.insert(
1140 format!("{prefix}.transformer_1d_blocks.0.{norm_name}.weight"),
1141 Tensor::ones(dim, DType::F32, &Device::Cpu).unwrap(),
1142 );
1143 }
1144 tensors.insert(
1145 format!("{prefix}.transformer_1d_blocks.0.ff.net.0.proj.weight"),
1146 Tensor::zeros((dim * 4, dim), DType::F32, &Device::Cpu).unwrap(),
1147 );
1148 tensors.insert(
1149 format!("{prefix}.transformer_1d_blocks.0.ff.net.0.proj.bias"),
1150 Tensor::zeros(dim * 4, DType::F32, &Device::Cpu).unwrap(),
1151 );
1152 tensors.insert(
1153 format!("{prefix}.transformer_1d_blocks.0.ff.net.2.weight"),
1154 Tensor::zeros((dim, dim * 4), DType::F32, &Device::Cpu).unwrap(),
1155 );
1156 tensors.insert(
1157 format!("{prefix}.transformer_1d_blocks.0.ff.net.2.bias"),
1158 Tensor::zeros(dim, DType::F32, &Device::Cpu).unwrap(),
1159 );
1160 tensors.insert(
1161 format!("{prefix}.learnable_registers"),
1162 Tensor::zeros((128, dim), DType::F32, &Device::Cpu).unwrap(),
1163 );
1164 }
1165 VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu)
1166 }
1167
1168 fn runtime_prompt_encoder() -> NativePromptEncoder {
1169 let cfg = tiny_gemma_config();
1170 let gemma = GemmaHiddenStateEncoder::new(&cfg, zero_gemma_var_builder(&cfg)).unwrap();
1171 NativePromptEncoder::new(
1172 gemma,
1173 build_embeddings_processor(
1174 zero_connector_source_var_builder(),
1175 crate::ltx2::preset::GemmaFeatureExtractorKind::V2DualAv,
1176 cfg.hidden_size,
1177 cfg.num_hidden_layers,
1178 8,
1179 Some(4),
1180 ConnectorSpec {
1181 prefix: "model.diffusion_model.video_embeddings_connector.",
1182 num_attention_heads: 2,
1183 attention_head_dim: 4,
1184 num_layers: 1,
1185 apply_gated_attention: false,
1186 positional_embedding_theta: 10_000.0,
1187 positional_embedding_max_pos: &[32],
1188 rope_type: crate::ltx2::model::LtxRopeType::Split,
1189 double_precision_rope: true,
1190 num_learnable_registers: Some(128),
1191 },
1192 Some(ConnectorSpec {
1193 prefix: "model.diffusion_model.audio_embeddings_connector.",
1194 num_attention_heads: 1,
1195 attention_head_dim: 4,
1196 num_layers: 1,
1197 apply_gated_attention: false,
1198 positional_embedding_theta: 10_000.0,
1199 positional_embedding_max_pos: &[32],
1200 rope_type: crate::ltx2::model::LtxRopeType::Split,
1201 double_precision_rope: true,
1202 num_learnable_registers: Some(128),
1203 }),
1204 )
1205 .unwrap(),
1206 PaddingSide::Left,
1207 )
1208 }
1209
1210 fn runtime_session() -> Ltx2RuntimeSession {
1211 let prompt_encoder = runtime_prompt_encoder();
1212 Ltx2RuntimeSession::new(Device::Cpu, prompt_encoder, 0)
1213 }
1214
1215 fn request(output_format: OutputFormat, enable_audio: Option<bool>) -> GenerateRequest {
1216 GenerateRequest {
1217 prompt: "test".to_string(),
1218 negative_prompt: None,
1219 model: "ltx-2-19b-distilled:fp8".to_string(),
1220 width: 960,
1221 height: 576,
1222 steps: 8,
1223 guidance: 3.0,
1224 seed: Some(42),
1225 batch_size: 1,
1226 output_format: Some(output_format),
1227 embed_metadata: None,
1228 scheduler: None,
1229 cfg_plus: None,
1230 source_image: None,
1231 edit_images: None,
1232 strength: 0.75,
1233 mask_image: None,
1234 control_image: None,
1235 control_model: None,
1236 control_scale: 1.0,
1237 expand: None,
1238 original_prompt: None,
1239 lora: None,
1240 frames: Some(17),
1241 fps: Some(12),
1242 upscale_model: None,
1243 gif_preview: true,
1244 enable_audio,
1245 audio_file: None,
1246 audio_file_path: None,
1247 source_video: None,
1248 source_video_path: None,
1249 keyframes: None,
1250 pipeline: None,
1251 loras: None,
1252 retake_range: None,
1253 spatial_upscale: None,
1254 temporal_upscale: None,
1255 placement: None,
1256 }
1257 }
1258
1259 #[test]
1260 fn pipeline_falls_back_to_one_stage_when_spatial_upscaler_missing() {
1261 let gemma = tempfile::tempdir().unwrap();
1268 let mut paths = dummy_paths_with_gemma_root(gemma.path());
1269 paths.spatial_upscaler = None;
1270
1271 let engine_22b = Ltx2Engine::new(
1272 "cv:2752735".to_string(),
1273 paths.clone(),
1274 LoadStrategy::Sequential,
1275 0,
1276 );
1277 let req = bare_t2v_req("cv:2752735");
1278 assert_eq!(
1279 engine_22b.select_pipeline(&req).unwrap(),
1280 PipelineKind::OneStage,
1281 "no spatial upsampler → OneStage (catalog cv:* default)"
1282 );
1283
1284 let engine_distilled = Ltx2Engine::new(
1285 "ltx-2-19b-distilled:fp8".to_string(),
1286 paths,
1287 LoadStrategy::Sequential,
1288 0,
1289 );
1290 let req_distilled = bare_t2v_req("ltx-2-19b-distilled:fp8");
1291 assert_eq!(
1292 engine_distilled.select_pipeline(&req_distilled).unwrap(),
1293 PipelineKind::OneStage,
1294 "distilled name + missing spatial upsampler → OneStage fallback"
1295 );
1296 }
1297
1298 fn bare_t2v_req(model: &str) -> GenerateRequest {
1299 GenerateRequest {
1300 prompt: "test".to_string(),
1301 negative_prompt: None,
1302 model: model.to_string(),
1303 width: 768,
1304 height: 512,
1305 steps: 4,
1306 guidance: 3.5,
1307 seed: Some(42),
1308 batch_size: 1,
1309 output_format: Some(OutputFormat::Mp4),
1310 embed_metadata: None,
1311 scheduler: None,
1312 cfg_plus: None,
1313 source_image: None,
1314 edit_images: None,
1315 strength: 0.75,
1316 mask_image: None,
1317 control_image: None,
1318 control_model: None,
1319 control_scale: 1.0,
1320 expand: None,
1321 original_prompt: None,
1322 lora: None,
1323 frames: Some(25),
1324 fps: Some(24),
1325 upscale_model: None,
1326 gif_preview: false,
1327 enable_audio: None,
1328 audio_file: None,
1329 audio_file_path: None,
1330 source_video: None,
1331 source_video_path: None,
1332 keyframes: None,
1333 pipeline: None,
1334 loras: None,
1335 retake_range: None,
1336 spatial_upscale: None,
1337 temporal_upscale: None,
1338 placement: None,
1339 }
1340 }
1341
1342 #[test]
1343 fn pipeline_defaults_to_distilled_for_distilled_models() {
1344 let engine = Ltx2Engine::new(
1345 "ltx-2.3-22b-distilled:fp8".to_string(),
1346 dummy_paths(),
1347 LoadStrategy::Sequential,
1348 0,
1349 );
1350 let req = GenerateRequest {
1351 prompt: "test".to_string(),
1352 negative_prompt: None,
1353 model: "ltx-2.3-22b-distilled:fp8".to_string(),
1354 width: 1216,
1355 height: 704,
1356 steps: 8,
1357 guidance: 1.0,
1358 seed: Some(1),
1359 batch_size: 1,
1360 output_format: Some(OutputFormat::Mp4),
1361 embed_metadata: None,
1362 scheduler: None,
1363 cfg_plus: None,
1364 source_image: None,
1365 edit_images: None,
1366 strength: 0.75,
1367 mask_image: None,
1368 control_image: None,
1369 control_model: None,
1370 control_scale: 1.0,
1371 expand: None,
1372 original_prompt: None,
1373 lora: None,
1374 frames: Some(97),
1375 fps: Some(24),
1376 upscale_model: None,
1377 gif_preview: false,
1378 enable_audio: Some(true),
1379 audio_file: None,
1380 audio_file_path: None,
1381 source_video: None,
1382 source_video_path: None,
1383 keyframes: None,
1384 pipeline: None,
1385 loras: None,
1386 retake_range: None,
1387 spatial_upscale: None,
1388 temporal_upscale: None,
1389 placement: None,
1390 };
1391 assert_eq!(
1392 engine.select_pipeline(&req).unwrap(),
1393 PipelineKind::Distilled
1394 );
1395 }
1396
1397 #[test]
1398 fn from_single_file_preserves_companion_paths() {
1399 let temp = tempfile::tempdir().unwrap();
1409 let checkpoint = temp.path().join("ltx2_combined.safetensors");
1410 write_minimal_combined_ltx2_checkpoint(&checkpoint);
1413
1414 let mut input_paths = dummy_paths_with_gemma_root(&temp.path().join("gemma"));
1415 input_paths.transformer = PathBuf::from("/wrong/path-should-be-overridden");
1416 input_paths.vae = PathBuf::from("/wrong/vae-should-be-cleared");
1417 let gemma_files_in = input_paths.text_encoder_files.clone();
1418 let spatial_in = input_paths.spatial_upscaler.clone();
1419 let temporal_in = input_paths.temporal_upscaler.clone();
1420 let distilled_in = input_paths.distilled_lora.clone();
1421
1422 let engine = Ltx2Engine::from_single_file(
1423 "cv:2752735".to_string(),
1424 checkpoint.clone(),
1425 input_paths,
1426 LoadStrategy::Sequential,
1427 0,
1428 )
1429 .expect("from_single_file should succeed on a valid combined checkpoint");
1430
1431 assert_eq!(
1432 engine.paths.transformer, checkpoint,
1433 "transformer must point at the single-file checkpoint"
1434 );
1435 assert_eq!(
1436 engine.paths.vae,
1437 PathBuf::default(),
1438 "vae must be cleared — runtime reads it from the same checkpoint via vb.pp(\"vae\")"
1439 );
1440 assert_eq!(
1441 engine.paths.text_encoder_files, gemma_files_in,
1442 "text_encoder_files (Gemma TE) must survive the rebuild — \
1443 dropping it is the cv:* loading regression"
1444 );
1445 assert_eq!(engine.paths.spatial_upscaler, spatial_in);
1446 assert_eq!(engine.paths.temporal_upscaler, temporal_in);
1447 assert_eq!(engine.paths.distilled_lora, distilled_in);
1448 }
1449
1450 fn write_minimal_combined_ltx2_checkpoint(path: &std::path::Path) {
1451 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1452 use std::collections::HashMap;
1453 let zero = 0.0f32.to_le_bytes().to_vec();
1454 let mut tensors: HashMap<String, TensorView<'_>> = HashMap::new();
1455 tensors.insert(
1456 "transformer_blocks.0.attn1.to_q.weight".to_string(),
1457 TensorView::new(SafeDtype::F32, vec![1], &zero).unwrap(),
1458 );
1459 tensors.insert(
1460 "vae.encoder.conv_in.weight".to_string(),
1461 TensorView::new(SafeDtype::F32, vec![1], &zero).unwrap(),
1462 );
1463 serialize_to_file(&tensors, &None, path).unwrap();
1464 }
1465
1466 #[test]
1467 fn camera_control_preset_aliases_are_supported() {
1468 let preset = Ltx2Engine::camera_control_preset("dolly-in").unwrap();
1469 assert_eq!(
1470 preset.filename,
1471 "ltx-2-19b-lora-camera-control-dolly-in.safetensors"
1472 );
1473 assert!(Ltx2Engine::camera_control_preset("unknown").is_none());
1474 }
1475
1476 #[test]
1477 fn fp8_models_use_fp8_cast_quantization() {
1478 let engine = Ltx2Engine::new(
1479 "ltx-2-19b-distilled:fp8".to_string(),
1480 dummy_paths(),
1481 LoadStrategy::Sequential,
1482 0,
1483 );
1484 assert_eq!(engine.request_quantization(), Some("fp8-cast".to_string()));
1485 }
1486
1487 #[test]
1488 fn oom_error_detection_matches_cuda_allocator_strings() {
1489 assert!(Ltx2Engine::is_oom_error(&"CUDA out of memory"));
1490 assert!(Ltx2Engine::is_oom_error(&"cudaErrorMemoryAllocation"));
1491 assert!(!Ltx2Engine::is_oom_error(&"some other error"));
1492 }
1493
1494 #[test]
1495 fn materialized_request_uses_streaming_defaults_for_fp8_smoke_path() {
1496 let gemma_dir = tempfile::tempdir().unwrap();
1497 write_test_gemma_assets(gemma_dir.path());
1498 let engine = Ltx2Engine::new(
1499 "ltx-2-19b-distilled:fp8".to_string(),
1500 dummy_paths_with_gemma_root(gemma_dir.path()),
1501 LoadStrategy::Sequential,
1502 0,
1503 );
1504 let req = GenerateRequest {
1505 prompt: "test".to_string(),
1506 negative_prompt: None,
1507 model: "ltx-2-19b-distilled:fp8".to_string(),
1508 width: 960,
1509 height: 576,
1510 steps: 8,
1511 guidance: 3.0,
1512 seed: Some(42),
1513 batch_size: 1,
1514 output_format: Some(OutputFormat::Mp4),
1515 embed_metadata: None,
1516 scheduler: None,
1517 cfg_plus: None,
1518 source_image: None,
1519 edit_images: None,
1520 strength: 0.75,
1521 mask_image: None,
1522 control_image: None,
1523 control_model: None,
1524 control_scale: 1.0,
1525 expand: None,
1526 original_prompt: None,
1527 lora: None,
1528 frames: Some(17),
1529 fps: Some(12),
1530 upscale_model: None,
1531 gif_preview: false,
1532 enable_audio: Some(true),
1533 audio_file: None,
1534 audio_file_path: None,
1535 source_video: None,
1536 source_video_path: None,
1537 keyframes: None,
1538 pipeline: None,
1539 loras: None,
1540 retake_range: None,
1541 spatial_upscale: None,
1542 temporal_upscale: None,
1543 placement: None,
1544 };
1545 let temp_dir = tempfile::tempdir().unwrap();
1546 let bridge = engine
1547 .materialize_request(&req, temp_dir.path(), &temp_dir.path().join("out.mp4"))
1548 .unwrap();
1549 assert_eq!(bridge.quantization.as_deref(), Some("fp8-cast"));
1550 assert_eq!(bridge.streaming_prefetch_count, Some(2));
1551 assert_eq!(bridge.width, 960);
1552 assert_eq!(bridge.height, 576);
1553 assert_eq!(bridge.num_frames, 17);
1554 assert_eq!(bridge.frame_rate, 12);
1555 assert_eq!(bridge.prompt_tokens.conditional.len(), 256);
1556 assert_eq!(bridge.prompt_tokens.conditional.valid_len(), 1);
1557 assert_eq!(bridge.prompt_tokens.pad_token_id, 7);
1558 }
1559
1560 #[test]
1561 fn load_uses_native_asset_checks_without_upstream_checkout() {
1562 let temp_dir = tempfile::tempdir().unwrap();
1563 let gemma_dir = temp_dir.path().join("gemma");
1564 fs::create_dir_all(&gemma_dir).unwrap();
1565 write_test_gemma_assets(&gemma_dir);
1566 let paths = dummy_paths_in(temp_dir.path(), &gemma_dir);
1567 fs::write(&paths.transformer, []).unwrap();
1568
1569 let mut engine = Ltx2Engine::new(
1570 "ltx-2-19b-distilled:fp8".to_string(),
1571 paths,
1572 LoadStrategy::Sequential,
1573 0,
1574 );
1575
1576 engine.load().unwrap();
1577 assert!(engine.is_loaded());
1578 }
1579
1580 #[test]
1581 fn ltx2_unload_drops_runtime_and_requests_cuda_reclaim() {
1582 let mut engine = Ltx2Engine::with_runtime_session(
1583 "ltx-2-19b-distilled:fp8".to_string(),
1584 dummy_paths(),
1585 Ltx2RuntimeSession::new_deferred_cuda(runtime_prompt_encoder(), 3),
1586 );
1587 engine.loaded = true;
1588 engine.gpu_ordinal = 3;
1589
1590 assert_eq!(engine.unload_runtime_state(), Some(3));
1591 assert!(!engine.loaded);
1592 assert!(engine.native_runtime.is_none());
1593 }
1594
1595 #[test]
1596 fn ltx2_unload_cpu_runtime_skips_cuda_reclaim() {
1597 let mut engine = Ltx2Engine::with_runtime_session(
1598 "ltx-2-19b-distilled:fp8".to_string(),
1599 dummy_paths(),
1600 runtime_session(),
1601 );
1602 engine.loaded = true;
1603
1604 assert_eq!(engine.unload_runtime_state(), None);
1605 assert!(!engine.loaded);
1606 assert!(engine.native_runtime.is_none());
1607 }
1608
1609 #[test]
1610 fn generate_runs_native_runtime_without_bridge_process() {
1611 let temp_dir = tempfile::tempdir().unwrap();
1612 let gemma_dir = temp_dir.path().join("gemma");
1613 fs::create_dir_all(&gemma_dir).unwrap();
1614 write_test_gemma_assets(&gemma_dir);
1615 let paths = dummy_paths_in(temp_dir.path(), &gemma_dir);
1616 fs::write(&paths.transformer, []).unwrap();
1617
1618 let mut engine = Ltx2Engine::with_runtime_session(
1619 "ltx-2-19b-distilled:fp8".to_string(),
1620 paths,
1621 runtime_session(),
1622 );
1623 let response = engine
1624 .generate(&request(OutputFormat::Gif, Some(false)))
1625 .unwrap();
1626 let video = response.video.unwrap();
1627
1628 assert_eq!(&video.data[..6], b"GIF89a");
1629 assert_eq!(&video.thumbnail[..8], b"\x89PNG\r\n\x1a\n");
1630 assert_eq!(&video.gif_preview[..6], b"GIF89a");
1631 assert_eq!(video.width, 960);
1632 assert_eq!(video.height, 576);
1633 assert_eq!(video.frames, 17);
1634 assert_eq!(video.fps, 12);
1635 assert!(!video.has_audio);
1636 assert!(engine.native_runtime.is_none());
1637 }
1638
1639 #[test]
1640 fn render_chain_stage_rejects_non_distilled_pipeline() {
1641 let mut engine = Ltx2Engine::with_runtime_session(
1645 "ltx-2-19b:fp8".to_string(),
1646 dummy_paths(),
1647 runtime_session(),
1648 );
1649 engine.loaded = true;
1650 let req = request(OutputFormat::Mp4, Some(false));
1651 let err = engine
1652 .render_chain_stage(&req, None, 4)
1653 .expect_err("must fail on non-distilled pipeline");
1654 let msg = format!("{err}");
1655 assert!(
1656 msg.contains("distilled"),
1657 "error must name the pipeline constraint, got: {msg}",
1658 );
1659 }
1660
1661 #[test]
1662 fn render_chain_stage_rejects_zero_motion_tail() {
1663 let mut engine = Ltx2Engine::with_runtime_session(
1666 "ltx-2-19b-distilled:fp8".to_string(),
1667 dummy_paths(),
1668 runtime_session(),
1669 );
1670 engine.loaded = true;
1671 let req = request(OutputFormat::Mp4, Some(false));
1672 let err = engine
1673 .render_chain_stage(&req, None, 0)
1674 .expect_err("must fail on zero motion tail");
1675 let msg = format!("{err}");
1676 assert!(
1677 msg.contains("motion_tail_pixel_frames"),
1678 "error must name the motion_tail constraint, got: {msg}",
1679 );
1680 }
1681
1682 #[test]
1687 fn resolve_prompt_encoder_device_keeps_cpu_when_transformer_is_cpu() {
1688 let prior_main = std::env::var_os("MOLD_LTX2_GEMMA_DEVICE");
1689 let prior_legacy = std::env::var_os("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER");
1690 unsafe {
1691 std::env::remove_var("MOLD_LTX2_GEMMA_DEVICE");
1692 std::env::remove_var("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER");
1693 }
1694
1695 let resolved = resolve_prompt_encoder_device(&Device::Cpu, 0);
1696 assert!(resolved.is_cpu());
1697
1698 unsafe {
1699 if let Some(v) = prior_main {
1700 std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", v);
1701 }
1702 if let Some(v) = prior_legacy {
1703 std::env::set_var("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER", v);
1704 }
1705 }
1706 }
1707
1708 #[test]
1714 fn resolver_picks_cpu_when_env_pins_cpu() {
1715 let prior_main = std::env::var_os("MOLD_LTX2_GEMMA_DEVICE");
1716 let prior_legacy = std::env::var_os("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER");
1717 unsafe {
1718 std::env::remove_var("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER");
1719 std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", "cpu");
1720 }
1721 assert_eq!(
1722 crate::device::resolve_ltx2_gemma_placement(0),
1723 crate::device::LtxGemmaPlacement::Cpu,
1724 );
1725 unsafe {
1726 std::env::remove_var("MOLD_LTX2_GEMMA_DEVICE");
1727 if let Some(v) = prior_main {
1728 std::env::set_var("MOLD_LTX2_GEMMA_DEVICE", v);
1729 }
1730 if let Some(v) = prior_legacy {
1731 std::env::set_var("MOLD_LTX2_DEBUG_FORCE_CPU_PROMPT_ENCODER", v);
1732 }
1733 }
1734 }
1735}