use anyhow::{anyhow, bail, Context, Result};
use candle_core::Tensor;
use image::RgbImage;
use mold_core::chain::{ChainProgressEvent, ChainRequest, ChainStage, TransitionMode};
use mold_core::{GenerateRequest, OutputFormat};
use crate::ltx2::model::shapes::SpatioTemporalScaleFactors;
use crate::ltx2::runtime::NativeAudioTrack;
#[derive(Debug, Clone)]
pub struct ChainTail {
pub frames: u32,
pub tail_rgb_frames: Vec<RgbImage>,
}
pub fn tail_latent_frame_count(pixel_frames: u32) -> usize {
assert!(
pixel_frames > 0,
"tail_latent_frame_count: pixel_frames must be > 0",
);
let scale = SpatioTemporalScaleFactors::default().time;
((pixel_frames as usize - 1) / scale) + 1
}
#[allow(dead_code)]
pub fn extract_tail_latents(final_latents: &Tensor, pixel_frames: u32) -> Result<Tensor> {
let dims = final_latents.dims();
if dims.len() != 5 {
return Err(anyhow!(
"extract_tail_latents: expected rank-5 tensor [B, C, T, H, W], got shape {:?}",
dims,
));
}
let time = dims[2];
let tail = tail_latent_frame_count(pixel_frames);
if tail > time {
return Err(anyhow!(
"extract_tail_latents: tail requests {} latent frames but the stage emitted only {} \
(pixel_frames={}, tensor shape={:?})",
tail,
time,
pixel_frames,
dims,
));
}
let start = time - tail;
final_latents
.narrow(2, start, tail)
.with_context(|| format!("narrow last {tail} latent frames off time axis"))
}
#[derive(Debug)]
pub enum ChainOrchestratorError {
Invalid(anyhow::Error),
StageFailed {
stage_idx: u32,
elapsed_stages: u32,
elapsed_ms: u64,
inner: anyhow::Error,
},
}
impl std::fmt::Display for ChainOrchestratorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Invalid(e) => write!(f, "chain validation error: {e:#}"),
Self::StageFailed {
stage_idx,
elapsed_stages,
inner,
..
} => write!(
f,
"chain stage {stage_idx} failed after {elapsed_stages} completed stage(s): {inner:#}"
),
}
}
}
impl std::error::Error for ChainOrchestratorError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Invalid(e) | Self::StageFailed { inner: e, .. } => Some(e.as_ref()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StageProgressEvent {
DenoiseStep { step: u32, total: u32 },
}
#[derive(Debug)]
pub struct StageOutcome {
pub frames: Vec<RgbImage>,
pub tail: ChainTail,
pub audio: Option<NativeAudioTrack>,
pub generation_time_ms: u64,
}
pub trait ChainStageRenderer {
fn render_stage(
&mut self,
stage_req: &GenerateRequest,
carry: Option<&ChainTail>,
motion_tail_pixel_frames: u32,
stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>,
) -> Result<StageOutcome>;
}
#[derive(Debug)]
pub struct ChainRunOutput {
pub stage_frames: Vec<Vec<RgbImage>>,
pub stage_audio: Vec<Option<NativeAudioTrack>>,
pub stage_count: u32,
pub generation_time_ms: u64,
}
pub struct Ltx2ChainOrchestrator<'a, R: ChainStageRenderer + ?Sized> {
renderer: &'a mut R,
}
impl<'a, R: ChainStageRenderer + ?Sized> Ltx2ChainOrchestrator<'a, R> {
pub fn new(renderer: &'a mut R) -> Self {
Self { renderer }
}
pub fn run(
&mut self,
req: &ChainRequest,
mut chain_progress: Option<&mut dyn FnMut(ChainProgressEvent)>,
) -> std::result::Result<ChainRunOutput, ChainOrchestratorError> {
if req.stages.is_empty() {
return Err(ChainOrchestratorError::Invalid(anyhow::anyhow!(
"Ltx2ChainOrchestrator::run: chain request has no stages"
)));
}
validate_motion_tail(req).map_err(ChainOrchestratorError::Invalid)?;
let stage_count = req.stages.len() as u32;
let estimated_total_frames = estimate_stitched_frames(req);
if let Some(cb) = chain_progress.as_deref_mut() {
cb(ChainProgressEvent::ChainStart {
stage_count,
estimated_total_frames,
});
}
let base_seed = req.seed.unwrap_or(0);
let mut stage_frames: Vec<Vec<RgbImage>> = Vec::with_capacity(req.stages.len());
let mut stage_audio: Vec<Option<NativeAudioTrack>> = Vec::with_capacity(req.stages.len());
let mut total_generation_ms: u64 = 0;
let mut carry: Option<ChainTail> = None;
for (idx, stage) in req.stages.iter().enumerate() {
let stage_idx = idx as u32;
if let Some(cb) = chain_progress.as_deref_mut() {
cb(ChainProgressEvent::StageStart { stage_idx });
}
let stage_seed = derive_stage_seed(base_seed, idx, stage);
let stage_req = build_stage_generate_request(stage, req, stage_seed, idx);
let effective_carry = match stage.transition {
TransitionMode::Smooth => carry.as_ref(),
TransitionMode::Cut | TransitionMode::Fade => None,
};
let render_result = match chain_progress.as_deref_mut() {
Some(chain_cb) => {
let mut wrapping = |event: StageProgressEvent| match event {
StageProgressEvent::DenoiseStep { step, total } => {
chain_cb(ChainProgressEvent::DenoiseStep {
stage_idx,
step,
total,
});
}
};
self.renderer.render_stage(
&stage_req,
effective_carry,
req.motion_tail_frames,
Some(&mut wrapping),
)
}
None => self.renderer.render_stage(
&stage_req,
effective_carry,
req.motion_tail_frames,
None,
),
};
let outcome = render_result.map_err(|inner| ChainOrchestratorError::StageFailed {
stage_idx,
elapsed_stages: idx as u32,
elapsed_ms: total_generation_ms,
inner,
})?;
let frames_emitted = outcome.frames.len() as u32;
stage_frames.push(outcome.frames);
stage_audio.push(outcome.audio);
total_generation_ms = total_generation_ms.saturating_add(outcome.generation_time_ms);
carry = Some(outcome.tail);
if let Some(cb) = chain_progress.as_deref_mut() {
cb(ChainProgressEvent::StageDone {
stage_idx,
frames_emitted,
});
}
}
if let Some(cb) = chain_progress.as_mut() {
let total: u32 = stage_frames.iter().map(|s| s.len() as u32).sum();
cb(ChainProgressEvent::Stitching {
total_frames: total,
});
}
Ok(ChainRunOutput {
stage_frames,
stage_audio,
stage_count,
generation_time_ms: total_generation_ms,
})
}
}
fn validate_motion_tail(req: &ChainRequest) -> Result<()> {
for (idx, stage) in req.stages.iter().enumerate() {
if req.motion_tail_frames >= stage.frames {
bail!(
"motion_tail_frames ({}) must be strictly less than stage {idx}'s frames ({}) \
so every continuation emits at least one new frame",
req.motion_tail_frames,
stage.frames,
);
}
}
Ok(())
}
fn estimate_stitched_frames(req: &ChainRequest) -> u32 {
let tail = req.motion_tail_frames;
req.stages
.iter()
.enumerate()
.map(|(idx, stage)| {
if idx == 0 {
stage.frames
} else {
stage.frames.saturating_sub(tail)
}
})
.sum()
}
fn derive_stage_seed(base_seed: u64, _idx: usize, stage: &ChainStage) -> u64 {
if let Some(offset) = stage.seed_offset {
base_seed ^ offset
} else {
base_seed
}
}
fn build_stage_generate_request(
stage: &ChainStage,
chain: &ChainRequest,
stage_seed: u64,
idx: usize,
) -> GenerateRequest {
GenerateRequest {
prompt: stage.prompt.clone(),
negative_prompt: stage.negative_prompt.clone(),
model: chain.model.clone(),
width: chain.width,
height: chain.height,
steps: chain.steps,
guidance: chain.guidance,
seed: Some(stage_seed),
batch_size: 1,
output_format: Some(OutputFormat::Mp4),
embed_metadata: None,
scheduler: None,
cfg_plus: None,
source_image: stage.source_image.clone(),
edit_images: None,
strength: if idx == 0 { chain.strength } else { 1.0 },
mask_image: None,
control_image: None,
control_model: None,
control_scale: 1.0,
expand: None,
original_prompt: None,
lora: None,
frames: Some(stage.frames),
fps: Some(chain.fps),
upscale_model: None,
gif_preview: false,
enable_audio: Some(chain.enable_audio.unwrap_or(false)),
audio_file: None,
audio_file_path: None,
source_video: None,
source_video_path: None,
keyframes: None,
pipeline: None,
loras: None,
retake_range: None,
spatial_upscale: None,
temporal_upscale: None,
placement: chain.placement.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device};
#[test]
fn tail_latent_frame_count_matches_vae_formula() {
for px in [1u32, 2, 4, 8] {
assert_eq!(tail_latent_frame_count(px), 1, "{px} pixel frames");
}
assert_eq!(tail_latent_frame_count(9), 2);
assert_eq!(tail_latent_frame_count(16), 2);
assert_eq!(tail_latent_frame_count(17), 3);
assert_eq!(tail_latent_frame_count(24), 3);
assert_eq!(tail_latent_frame_count(97), 13);
}
#[test]
#[should_panic(expected = "pixel_frames must be > 0")]
fn tail_latent_frame_count_rejects_zero() {
tail_latent_frame_count(0);
}
#[test]
fn extract_tail_narrows_last_latent_frame_for_4_pixel_frame_tail() {
let data = vec![
0.0f32, 42.0, 1.0, 43.0, 2.0, 44.0,
];
let raw = Tensor::from_vec(data, (1, 3, 2, 1, 1), &Device::Cpu).expect("build raw tensor");
let latents = raw
.permute([0, 2, 1, 3, 4])
.expect("permute to [B, C, T, H, W]");
assert_eq!(latents.dims(), &[1, 2, 3, 1, 1]);
let tail = extract_tail_latents(&latents, 4).expect("extract");
assert_eq!(tail.dims(), &[1, 2, 1, 1, 1]);
let values = tail.flatten_all().unwrap().to_vec1::<f32>().unwrap();
assert_eq!(
values,
vec![2.0, 44.0],
"tail must be the last latent frame (index 2) across all channels",
);
}
#[test]
fn extract_tail_narrows_two_frames_for_9_pixel_frame_tail() {
let latents = Tensor::zeros((1, 1, 3, 2, 2), DType::F32, &Device::Cpu).unwrap();
let tail = extract_tail_latents(&latents, 9).expect("extract");
assert_eq!(tail.dims(), &[1, 1, 2, 2, 2]);
}
#[test]
fn extract_tail_rejects_rank_4_tensor() {
let bad = Tensor::zeros((1, 128, 3, 4), DType::F32, &Device::Cpu).unwrap();
let err = extract_tail_latents(&bad, 4).expect_err("rank 4 must fail");
let msg = format!("{err}");
assert!(
msg.contains("rank-5") && msg.contains("T, H, W"),
"error must identify the rank mismatch, got: {msg}",
);
}
#[test]
fn extract_tail_rejects_oversize_request() {
let latents = Tensor::zeros((1, 128, 1, 4, 4), DType::F32, &Device::Cpu).unwrap();
let err = extract_tail_latents(&latents, 9).expect_err("oversize tail must fail");
let msg = format!("{err}");
assert!(
msg.contains("requests 2") && msg.contains("only 1"),
"error must name the latent-frame mismatch, got: {msg}",
);
}
use image::Rgb;
use mold_core::chain::{ChainStage, TransitionMode};
struct FakeRenderer {
calls: Vec<CallRecord>,
fail_on: Vec<(usize, String)>,
frame_count_override: Option<u32>,
emit_progress: bool,
synthesize_audio: bool,
}
#[derive(Debug, Clone)]
struct CallRecord {
seed: Option<u64>,
has_source_image: bool,
has_carry: bool,
enable_audio: Option<bool>,
}
impl FakeRenderer {
fn new() -> Self {
Self {
calls: Vec::new(),
fail_on: Vec::new(),
frame_count_override: None,
emit_progress: false,
synthesize_audio: false,
}
}
}
impl ChainStageRenderer for FakeRenderer {
fn render_stage(
&mut self,
stage_req: &GenerateRequest,
carry: Option<&ChainTail>,
_motion_tail_pixel_frames: u32,
mut stage_progress: Option<&mut dyn FnMut(StageProgressEvent)>,
) -> Result<StageOutcome> {
let idx = self.calls.len();
self.calls.push(CallRecord {
seed: stage_req.seed,
has_source_image: stage_req.source_image.is_some(),
has_carry: carry.is_some(),
enable_audio: stage_req.enable_audio,
});
if let Some((_, msg)) = self.fail_on.iter().find(|(stage_idx, _)| *stage_idx == idx) {
bail!("{msg}");
}
if self.emit_progress {
if let Some(cb) = stage_progress.as_mut() {
cb(StageProgressEvent::DenoiseStep { step: 1, total: 1 });
}
}
let frame_count = self
.frame_count_override
.unwrap_or_else(|| stage_req.frames.expect("fake renderer: stage_req.frames"));
let width = stage_req.width;
let height = stage_req.height;
let mut frames = Vec::with_capacity(frame_count as usize);
for frame_num in 0..frame_count {
let channel = (idx as u8).wrapping_mul(37).wrapping_add(frame_num as u8);
frames.push(RgbImage::from_pixel(width, height, Rgb([channel, 0, 0])));
}
let tail_pixel_frames: u32 = 4;
let take_from = frames
.len()
.saturating_sub(tail_pixel_frames as usize)
.min(frames.len());
let tail_rgb_frames = frames[take_from..].to_vec();
let audio = if self.synthesize_audio {
let samples_per_frame = 100usize;
let interleaved_samples: Vec<f32> = (0..frame_count as usize * samples_per_frame)
.map(|n| ((idx as i32 * 1_000) + n as i32) as f32)
.collect();
Some(NativeAudioTrack {
interleaved_samples,
sample_rate: 48_000,
channels: 2,
})
} else {
None
};
Ok(StageOutcome {
frames,
tail: ChainTail {
frames: tail_pixel_frames,
tail_rgb_frames,
},
audio,
generation_time_ms: 100,
})
}
}
fn stage(prompt: &str, frames: u32) -> ChainStage {
ChainStage {
prompt: prompt.into(),
frames,
source_image: None,
negative_prompt: None,
seed_offset: None,
transition: TransitionMode::Smooth,
fade_frames: None,
model: None,
loras: vec![],
references: vec![],
}
}
fn chain_req(stages: Vec<ChainStage>, motion_tail_frames: u32) -> ChainRequest {
ChainRequest {
model: "ltx-2-19b-distilled:fp8".into(),
stages,
motion_tail_frames,
width: 1216,
height: 704,
fps: 24,
seed: Some(42),
steps: 8,
guidance: 3.0,
strength: 1.0,
output_format: OutputFormat::Mp4,
placement: None,
prompt: None,
total_frames: None,
clip_frames: None,
source_image: None,
enable_audio: None,
}
}
#[test]
fn chain_orchestrator_emits_full_per_stage_clips() {
let stages = vec![stage("a", 97), stage("a", 97), stage("a", 97)];
let req = chain_req(stages, 4);
let mut renderer = FakeRenderer::new();
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
let out = orch.run(&req, None).expect("chain runs");
let total_frames: usize = out.stage_frames.iter().map(|s| s.len()).sum();
assert_eq!(total_frames, 97 * 3);
assert_eq!(out.stage_count, 3);
assert_eq!(renderer.calls.len(), 3);
assert!(!renderer.calls[0].has_carry);
assert!(renderer.calls[1].has_carry);
assert!(renderer.calls[2].has_carry);
}
#[test]
fn chain_with_zero_tail_concats_full_clips_without_drop() {
let stages = vec![stage("a", 97), stage("a", 97)];
let req = chain_req(stages, 0);
let mut renderer = FakeRenderer::new();
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
let out = orch.run(&req, None).expect("chain runs");
let total_frames: usize = out.stage_frames.iter().map(|s| s.len()).sum();
assert_eq!(
total_frames,
97 * 2,
"zero motion tail must keep every frame in each stage's vector",
);
}
#[test]
fn chain_empty_stages_errors_without_calling_renderer() {
let req = chain_req(vec![], 4);
let mut renderer = FakeRenderer::new();
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
let err = orch.run(&req, None).expect_err("empty stages must fail");
match err {
ChainOrchestratorError::Invalid(inner) => {
assert!(
format!("{inner}").contains("has no stages"),
"error must name the missing stages, got: {inner}",
);
}
other => panic!("expected Invalid, got {other:?}"),
}
assert!(renderer.calls.is_empty());
}
#[test]
fn chain_fails_closed_mid_chain_discarding_accumulated_frames() {
let stages = vec![stage("a", 97), stage("a", 97), stage("a", 97)];
let req = chain_req(stages, 4);
let mut renderer = FakeRenderer::new();
renderer.fail_on = vec![(1, "simulated GPU OOM on stage 1".into())];
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
let err = orch
.run(&req, None)
.expect_err("mid-chain failure must bubble up");
match err {
ChainOrchestratorError::StageFailed {
stage_idx: 1,
elapsed_stages: 1,
inner,
..
} => {
assert!(
format!("{inner}").contains("simulated GPU OOM"),
"inner error must carry the renderer's message, got: {inner}",
);
}
other => panic!("expected StageFailed at stage 1, got {other:?}"),
}
assert_eq!(renderer.calls.len(), 2);
}
#[test]
fn chain_holds_seed_stable_across_stages_by_default() {
let stages = vec![stage("a", 9), stage("a", 9), stage("a", 9)];
let mut req = chain_req(stages, 0);
req.seed = Some(42);
let mut renderer = FakeRenderer::new();
renderer.frame_count_override = Some(9);
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
orch.run(&req, None).expect("chain runs");
assert_eq!(renderer.calls[0].seed, Some(42));
assert_eq!(renderer.calls[1].seed, Some(42));
assert_eq!(renderer.calls[2].seed, Some(42));
}
#[test]
fn chain_propagates_source_image_to_every_stage() {
let mut stages = vec![stage("a", 9), stage("a", 9)];
stages[0].source_image = Some(vec![0x89, 0x50, 0x4e, 0x47]); stages[1].source_image = Some(vec![0x89, 0x50, 0x4e, 0x47]);
let req = chain_req(stages, 0);
let mut renderer = FakeRenderer::new();
renderer.frame_count_override = Some(9);
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
orch.run(&req, None).expect("chain runs");
assert!(
renderer.calls[0].has_source_image,
"stage 0 must carry source_image (frame-0 i2v replacement)",
);
assert!(
renderer.calls[1].has_source_image,
"continuation stage must also carry source_image (soft identity anchor)",
);
}
#[test]
fn chain_forwards_engine_events_with_stage_idx_wrapping() {
let stages = vec![stage("a", 9), stage("a", 9)];
let req = chain_req(stages, 0);
let mut renderer = FakeRenderer::new();
renderer.frame_count_override = Some(9);
renderer.emit_progress = true;
let mut events: Vec<ChainProgressEvent> = Vec::new();
{
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
let mut cb = |e: ChainProgressEvent| events.push(e);
orch.run(&req, Some(&mut cb)).expect("chain runs");
}
assert!(matches!(
events[0],
ChainProgressEvent::ChainStart { stage_count: 2, .. }
));
assert!(matches!(
events[1],
ChainProgressEvent::StageStart { stage_idx: 0 }
));
assert!(matches!(
events[2],
ChainProgressEvent::DenoiseStep {
stage_idx: 0,
step: 1,
total: 1
}
));
assert!(matches!(
events[3],
ChainProgressEvent::StageDone {
stage_idx: 0,
frames_emitted: 9
}
));
assert!(matches!(
events[4],
ChainProgressEvent::StageStart { stage_idx: 1 }
));
assert!(matches!(
events[5],
ChainProgressEvent::DenoiseStep {
stage_idx: 1,
step: 1,
total: 1
}
));
assert!(matches!(
events[6],
ChainProgressEvent::StageDone {
stage_idx: 1,
frames_emitted: 9
}
));
assert!(matches!(
events[7],
ChainProgressEvent::Stitching { total_frames: 18 }
));
assert_eq!(events.len(), 8);
}
#[test]
fn chain_rejects_motion_tail_ge_stage_frames_before_running() {
let stages = vec![stage("a", 9), stage("a", 9)];
let req = chain_req(stages, 9);
let mut renderer = FakeRenderer::new();
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
let err = orch.run(&req, None).expect_err("must fail");
match err {
ChainOrchestratorError::Invalid(inner) => {
assert!(
format!("{inner}").contains("motion_tail_frames"),
"error must name motion_tail_frames, got: {inner}",
);
}
other => panic!("expected Invalid, got {other:?}"),
}
assert!(renderer.calls.is_empty());
}
#[test]
fn chain_respects_seed_offset_override_when_stage_provides_one() {
let mut stages = vec![stage("a", 9), stage("a", 9)];
stages[1].seed_offset = Some(0xDEADBEEF);
let mut req = chain_req(stages, 0);
req.seed = Some(100);
let mut renderer = FakeRenderer::new();
renderer.frame_count_override = Some(9);
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
orch.run(&req, None).expect("runs");
assert_eq!(renderer.calls[0].seed, Some(100));
assert_eq!(
renderer.calls[1].seed,
Some(100 ^ 0xDEADBEEFu64),
"seed_offset must XOR into the stable base seed when a stage opts in to variation",
);
}
#[test]
fn orchestrator_forwards_chain_enable_audio_to_each_stage_request() {
let stages = vec![stage("a", 9), stage("a", 9), stage("a", 9)];
let mut req = chain_req(stages, 0);
req.enable_audio = Some(true);
let mut renderer = FakeRenderer::new();
renderer.frame_count_override = Some(9);
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
orch.run(&req, None).expect("chain runs");
for (idx, call) in renderer.calls.iter().enumerate() {
assert_eq!(
call.enable_audio,
Some(true),
"stage {idx} must inherit chain.enable_audio",
);
}
}
#[test]
fn orchestrator_default_enable_audio_resolves_to_false_for_each_stage() {
let stages = vec![stage("a", 9), stage("a", 9)];
let req = chain_req(stages, 0);
assert_eq!(req.enable_audio, None);
let mut renderer = FakeRenderer::new();
renderer.frame_count_override = Some(9);
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
orch.run(&req, None).expect("chain runs");
for call in renderer.calls.iter() {
assert_eq!(call.enable_audio, Some(false));
}
}
#[test]
fn orchestrator_collects_per_stage_audio_into_chain_run_output() {
let stages = vec![stage("a", 9), stage("a", 9), stage("a", 9)];
let mut req = chain_req(stages, 0);
req.enable_audio = Some(true);
let mut renderer = FakeRenderer::new();
renderer.frame_count_override = Some(9);
renderer.synthesize_audio = true;
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
let out = orch.run(&req, None).expect("chain runs");
assert_eq!(out.stage_audio.len(), 3, "one entry per stage");
for (idx, audio) in out.stage_audio.iter().enumerate() {
let track = audio
.as_ref()
.unwrap_or_else(|| panic!("stage {idx} must carry an audio track"));
assert_eq!(track.sample_rate, 48_000);
assert_eq!(track.channels, 2);
assert_eq!(track.interleaved_samples.len(), 9 * 100);
assert_eq!(track.interleaved_samples[0], (idx as i32 * 1_000) as f32);
}
}
#[test]
fn orchestrator_omits_audio_when_renderer_returns_none() {
let stages = vec![stage("a", 9), stage("a", 9)];
let req = chain_req(stages, 0);
let mut renderer = FakeRenderer::new();
renderer.frame_count_override = Some(9);
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
let out = orch.run(&req, None).expect("chain runs");
assert_eq!(out.stage_audio.len(), 2);
assert!(out.stage_audio.iter().all(Option::is_none));
}
#[test]
fn chain_run_output_preserves_per_stage_frames() {
let req = sample_chain_request(3, TransitionMode::Smooth);
let mut renderer = FakeRenderer::new();
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
let out = orch.run(&req, None).unwrap();
assert_eq!(out.stage_frames.len(), 3);
assert_eq!(out.stage_frames[0].len() as u32, req.stages[0].frames);
assert_eq!(out.stage_frames[1].len() as u32, req.stages[1].frames);
assert_eq!(out.stage_frames[2].len() as u32, req.stages[2].frames);
}
#[test]
fn orchestrator_passes_none_carry_for_cut_transition() {
let mut renderer = FakeRenderer::new();
let req = sample_chain_request(3, TransitionMode::Cut);
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
orch.run(&req, None).unwrap();
let has_carry: Vec<bool> = renderer.calls.iter().map(|c| c.has_carry).collect();
assert_eq!(has_carry, vec![false, false, false]);
}
#[test]
fn orchestrator_passes_some_carry_for_smooth_transition() {
let mut renderer = FakeRenderer::new();
let req = sample_chain_request(3, TransitionMode::Smooth);
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
orch.run(&req, None).unwrap();
let has_carry: Vec<bool> = renderer.calls.iter().map(|c| c.has_carry).collect();
assert_eq!(has_carry, vec![false, true, true]);
}
#[test]
fn mixed_transitions_end_to_end() {
let mut renderer = FakeRenderer::new();
let mut req = sample_chain_request(4, TransitionMode::Smooth);
req.motion_tail_frames = 25;
req.stages[1].transition = TransitionMode::Smooth;
req.stages[2].transition = TransitionMode::Cut;
req.stages[3].transition = TransitionMode::Fade;
req.stages[3].fade_frames = Some(8);
let mut orch = Ltx2ChainOrchestrator::new(&mut renderer);
let out = orch.run(&req, None).unwrap();
let boundaries: Vec<_> = req.stages.iter().skip(1).map(|s| s.transition).collect();
let fade_lens: Vec<_> = req
.stages
.iter()
.skip(1)
.map(|s| s.fade_frames.unwrap_or(8))
.collect();
let plan = crate::ltx2::stitch::StitchPlan {
clips: out.stage_frames,
boundaries,
fade_lens,
motion_tail_frames: req.motion_tail_frames,
};
let frames = plan.assemble().unwrap();
assert_eq!(frames.len(), 355);
}
fn sample_chain_request(count: usize, transition: TransitionMode) -> ChainRequest {
let req = ChainRequest {
model: "ltx-2-19b-distilled:fp8".into(),
stages: Vec::new(),
motion_tail_frames: 0,
width: 1216,
height: 704,
fps: 24,
seed: Some(0),
steps: 8,
guidance: 3.0,
strength: 1.0,
output_format: mold_core::OutputFormat::Mp4,
placement: None,
prompt: Some("x".into()),
total_frames: Some(97 * count as u32),
clip_frames: Some(97),
source_image: None,
enable_audio: None,
};
let mut req = req.normalise().unwrap();
for s in req.stages.iter_mut().skip(1) {
s.transition = transition;
}
req
}
}