use serde::{Deserialize, Serialize};
use crate::error::{MoldError, Result};
use crate::types::{DevicePlacement, OutputFormat, VideoData};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize, utoipa::ToSchema)]
#[serde(rename_all = "snake_case")]
pub enum TransitionMode {
#[default]
Smooth,
Cut,
Fade,
}
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct LoraSpec {
pub path: String,
pub scale: f64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct NamedRef {
pub name: String,
#[serde(with = "crate::types::base64_bytes")]
pub image: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ChainStage {
#[schema(example = "a cat walking through autumn leaves")]
pub prompt: String,
#[schema(example = 97)]
pub frames: u32,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "crate::types::base64_opt"
)]
pub source_image: Option<Vec<u8>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub negative_prompt: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub seed_offset: Option<u64>,
#[serde(default)]
pub transition: TransitionMode,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub fade_frames: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub loras: Vec<LoraSpec>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub references: Vec<NamedRef>,
}
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ChainRequest {
#[schema(example = "ltx-2-19b-distilled:fp8")]
pub model: String,
#[serde(default)]
pub stages: Vec<ChainStage>,
#[serde(default = "default_motion_tail_frames")]
#[schema(example = 17)]
pub motion_tail_frames: u32,
#[schema(example = 1216)]
pub width: u32,
#[schema(example = 704)]
pub height: u32,
#[serde(default = "default_fps")]
#[schema(example = 24)]
pub fps: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = 42)]
pub seed: Option<u64>,
#[schema(example = 8)]
pub steps: u32,
#[schema(example = 3.0)]
pub guidance: f64,
#[serde(default = "default_strength")]
#[schema(example = 1.0)]
pub strength: f64,
#[serde(default = "default_output_format")]
pub output_format: OutputFormat,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub placement: Option<DevicePlacement>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub total_frames: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub clip_frames: Option<u32>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "crate::types::base64_opt"
)]
pub source_image: Option<Vec<u8>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub enable_audio: Option<bool>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ChainScript {
pub schema: String, pub chain: ChainScriptChain,
#[serde(rename = "stage")]
pub stages: Vec<ChainStage>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ChainScriptChain {
pub model: String,
pub width: u32,
pub height: u32,
pub fps: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
pub steps: u32,
pub guidance: f64,
pub strength: f64,
pub motion_tail_frames: u32,
pub output_format: OutputFormat,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub enable_audio: Option<bool>,
}
impl From<&ChainRequest> for ChainScript {
fn from(req: &ChainRequest) -> Self {
ChainScript {
schema: "mold.chain.v1".into(),
chain: ChainScriptChain {
model: req.model.clone(),
width: req.width,
height: req.height,
fps: req.fps,
seed: req.seed,
steps: req.steps,
guidance: req.guidance,
strength: req.strength,
motion_tail_frames: req.motion_tail_frames,
output_format: req.output_format,
enable_audio: req.enable_audio,
},
stages: req.stages.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct VramEstimate {
pub worst_case_bytes: u64,
pub fits: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ChainResponse {
pub video: VideoData,
#[schema(example = 5)]
pub stage_count: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub gpu: Option<usize>,
pub script: ChainScript,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub vram_estimate: Option<VramEstimate>,
}
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct SseChainCompleteEvent {
pub video: String,
pub format: OutputFormat,
#[schema(example = 1216)]
pub width: u32,
#[schema(example = 704)]
pub height: u32,
#[schema(example = 400)]
pub frames: u32,
#[schema(example = 24)]
pub fps: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub thumbnail: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub gif_preview: Option<String>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub has_audio: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub duration_ms: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub audio_sample_rate: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub audio_channels: Option<u32>,
#[schema(example = 5)]
pub stage_count: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub gpu: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub generation_time_ms: Option<u64>,
#[serde(default)]
pub script: ChainScript,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub vram_estimate: Option<VramEstimate>,
}
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema, PartialEq, Eq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ChainProgressEvent {
ChainStart {
stage_count: u32,
estimated_total_frames: u32,
},
StageStart { stage_idx: u32 },
DenoiseStep {
stage_idx: u32,
step: u32,
total: u32,
},
StageDone { stage_idx: u32, frames_emitted: u32 },
Stitching { total_frames: u32 },
}
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ChainFailure {
#[schema(example = "stage render failed")]
pub error: String,
#[schema(example = 2)]
pub failed_stage_idx: u32,
#[schema(example = 2)]
pub elapsed_stages: u32,
#[schema(example = 12_340)]
pub elapsed_ms: u64,
#[schema(example = "simulated GPU OOM on stage 2")]
pub stage_error: String,
}
fn default_motion_tail_frames() -> u32 {
17
}
fn default_fps() -> u32 {
24
}
fn default_strength() -> f64 {
1.0
}
fn default_output_format() -> OutputFormat {
OutputFormat::Mp4
}
pub const MAX_CHAIN_STAGES: usize = 16;
impl ChainRequest {
pub fn normalise(mut self) -> Result<Self> {
if self.stages.is_empty() {
let prompt = self.prompt.take().ok_or_else(|| {
MoldError::Validation(
"chain request needs either stages[] or prompt + total_frames".into(),
)
})?;
let total_frames = self.total_frames.ok_or_else(|| {
MoldError::Validation("chain auto-expand requires total_frames".into())
})?;
if total_frames == 0 {
return Err(MoldError::Validation(
"chain total_frames must be > 0".into(),
));
}
let clip_frames = self.clip_frames.unwrap_or(97);
if clip_frames == 0 {
return Err(MoldError::Validation(
"chain clip_frames must be > 0".into(),
));
}
if !is_ltx2_frame_count(clip_frames) {
return Err(MoldError::Validation(format!(
"chain clip_frames ({clip_frames}) must be 8k+1 (9, 17, 25, …, 97)",
)));
}
let motion_tail = self.motion_tail_frames;
if motion_tail >= clip_frames {
return Err(MoldError::Validation(format!(
"motion_tail_frames ({motion_tail}) must be strictly less than clip_frames ({clip_frames})",
)));
}
let source_image = self.source_image.take();
self.stages = build_auto_expand_stages(
&prompt,
total_frames,
clip_frames,
motion_tail,
source_image,
)?;
}
if self.stages.is_empty() {
return Err(MoldError::Validation("chain request has no stages".into()));
}
if self.stages.len() > MAX_CHAIN_STAGES {
return Err(MoldError::Validation(format!(
"chain request has {} stages; maximum is {}",
self.stages.len(),
MAX_CHAIN_STAGES,
)));
}
if self.motion_tail_frames != 0 && !is_ltx2_frame_count(self.motion_tail_frames) {
return Err(MoldError::Validation(format!(
"motion_tail_frames ({}) must be 0 or 8k+1 (1, 9, 17, 25, …) so the carryover \
RGB frames re-encode cleanly through the LTX-2 video VAE's 8× causal grid",
self.motion_tail_frames,
)));
}
for (idx, stage) in self.stages.iter().enumerate() {
if stage.frames == 0 {
return Err(MoldError::Validation(format!("stage {idx} has 0 frames",)));
}
if !is_ltx2_frame_count(stage.frames) {
return Err(MoldError::Validation(format!(
"stage {idx} has {} frames; LTX-2 requires 8k+1 (9, 17, 25, …, 97)",
stage.frames,
)));
}
if self.motion_tail_frames >= stage.frames {
return Err(MoldError::Validation(format!(
"motion_tail_frames ({}) must be strictly less than stage {idx}'s frames ({})",
self.motion_tail_frames, stage.frames,
)));
}
}
for (idx, stage) in self.stages.iter().enumerate() {
if stage.model.is_some() {
return Err(MoldError::Validation(format!(
"stages[{idx}].model is reserved for sub-project C and not yet supported"
)));
}
if !stage.loras.is_empty() {
return Err(MoldError::Validation(format!(
"stages[{idx}].loras is reserved for sub-project B and not yet supported"
)));
}
if !stage.references.is_empty() {
return Err(MoldError::Validation(format!(
"stages[{idx}].references is reserved for sub-project B and not yet supported"
)));
}
}
if let Some(first) = self.stages.first_mut() {
if first.transition != TransitionMode::Smooth {
tracing::warn!(
coerced_from = ?first.transition,
"stage 0 transition is meaningless; coercing to Smooth"
);
first.transition = TransitionMode::Smooth;
}
}
self.prompt = None;
self.total_frames = None;
self.clip_frames = None;
self.source_image = None;
Ok(self)
}
pub fn estimated_total_frames(&self) -> u32 {
const DEFAULT_FADE_FRAMES: u32 = 8;
let mut total: u32 = 0;
for (idx, stage) in self.stages.iter().enumerate() {
if idx == 0 {
total += stage.frames;
continue;
}
match stage.transition {
TransitionMode::Smooth => {
total += stage.frames.saturating_sub(self.motion_tail_frames);
}
TransitionMode::Cut => {
total += stage.frames;
}
TransitionMode::Fade => {
let fade_len = stage.fade_frames.unwrap_or(DEFAULT_FADE_FRAMES);
total += stage.frames.saturating_sub(fade_len);
}
}
}
total
}
}
fn is_ltx2_frame_count(n: u32) -> bool {
n % 8 == 1
}
fn build_auto_expand_stages(
prompt: &str,
total_frames: u32,
clip_frames: u32,
motion_tail_frames: u32,
source_image: Option<Vec<u8>>,
) -> Result<Vec<ChainStage>> {
let (stage_count, per_stage_frames) = if total_frames <= clip_frames {
(1u32, total_frames)
} else {
let effective = clip_frames - motion_tail_frames;
let remainder = total_frames - clip_frames;
let count = 1 + remainder.div_ceil(effective);
(count, clip_frames)
};
let count_usize = stage_count as usize;
if count_usize > MAX_CHAIN_STAGES {
return Err(MoldError::Validation(format!(
"auto-expand would produce {stage_count} stages; maximum is {MAX_CHAIN_STAGES} \
(try reducing total_frames or increasing clip_frames)",
)));
}
let mut stages = Vec::with_capacity(count_usize);
for _ in 0..stage_count {
stages.push(ChainStage {
prompt: prompt.to_string(),
frames: per_stage_frames,
source_image: source_image.clone(),
negative_prompt: None,
seed_offset: None,
transition: TransitionMode::Smooth,
fade_frames: None,
model: None,
loras: vec![],
references: vec![],
});
}
Ok(stages)
}
#[cfg(test)]
mod tests {
use super::*;
fn auto_expand_request(
prompt: &str,
total_frames: u32,
clip_frames: u32,
motion_tail_frames: u32,
source_image: Option<Vec<u8>>,
) -> ChainRequest {
ChainRequest {
model: "ltx-2-19b-distilled:fp8".into(),
stages: Vec::new(),
motion_tail_frames,
width: 1216,
height: 704,
fps: 24,
seed: Some(42),
steps: 8,
guidance: 3.0,
strength: 1.0,
output_format: OutputFormat::Mp4,
placement: None,
prompt: Some(prompt.into()),
total_frames: Some(total_frames),
clip_frames: Some(clip_frames),
source_image,
enable_audio: None,
}
}
fn canonical_request(stages: Vec<ChainStage>, motion_tail_frames: u32) -> ChainRequest {
ChainRequest {
model: "ltx-2-19b-distilled:fp8".into(),
stages,
motion_tail_frames,
width: 1216,
height: 704,
fps: 24,
seed: Some(42),
steps: 8,
guidance: 3.0,
strength: 1.0,
output_format: OutputFormat::Mp4,
placement: None,
prompt: None,
total_frames: None,
clip_frames: None,
source_image: None,
enable_audio: None,
}
}
fn make_stage(frames: u32) -> ChainStage {
ChainStage {
prompt: "test".into(),
frames,
source_image: None,
negative_prompt: None,
seed_offset: None,
transition: TransitionMode::Smooth,
fade_frames: None,
model: None,
loras: vec![],
references: vec![],
}
}
#[test]
fn normalise_splits_single_prompt_into_stages() {
let normalised = auto_expand_request("a cat walking", 400, 97, 9, None)
.normalise()
.expect("normalise should succeed");
assert_eq!(
normalised.stages.len(),
5,
"400/97 with a 9-frame motion tail should expand to 5 stages",
);
for stage in &normalised.stages {
assert_eq!(stage.frames, 97);
assert_eq!(stage.prompt, "a cat walking");
assert!(stage.seed_offset.is_none());
}
assert!(normalised.prompt.is_none());
assert!(normalised.total_frames.is_none());
assert!(normalised.clip_frames.is_none());
assert!(normalised.source_image.is_none());
}
#[test]
fn normalise_preserves_starting_image_across_all_stages() {
let png = vec![0x89, 0x50, 0x4e, 0x47, 0xde, 0xad, 0xbe, 0xef];
let normalised = auto_expand_request("test", 200, 97, 9, Some(png.clone()))
.normalise()
.expect("normalise should succeed");
assert!(normalised.stages.len() >= 2);
for (idx, stage) in normalised.stages.iter().enumerate() {
assert_eq!(
stage.source_image.as_deref(),
Some(png.as_slice()),
"stage {idx} must carry the starting image for cross-stage identity anchoring",
);
}
}
#[test]
fn normalise_rejects_empty() {
let mut req = canonical_request(Vec::new(), 9);
req.prompt = None;
req.total_frames = None;
let err = req.normalise().expect_err("empty chain should fail");
assert!(
matches!(err, MoldError::Validation(_)),
"empty chain should be a validation error, got {err:?}",
);
}
#[test]
fn normalise_rejects_non_8k1_frames() {
let req = canonical_request(vec![make_stage(50)], 9);
let err = req.normalise().expect_err("non-8k+1 frames should fail");
assert!(
matches!(err, MoldError::Validation(msg) if msg.contains("8k+1")),
"error must mention the 8k+1 constraint",
);
}
#[test]
fn normalise_accepts_canonical_form_unchanged() {
let stages = vec![make_stage(97), make_stage(97), make_stage(97)];
let normalised = canonical_request(stages.clone(), 9)
.normalise()
.expect("valid canonical form should pass");
assert_eq!(normalised.stages.len(), 3);
for (left, right) in normalised.stages.iter().zip(stages.iter()) {
assert_eq!(left.frames, right.frames);
assert_eq!(left.prompt, right.prompt);
}
}
#[test]
fn normalise_single_stage_when_total_leq_clip() {
let normalised = auto_expand_request("short", 9, 97, 1, None)
.normalise()
.expect("short single-clip chain should pass");
assert_eq!(normalised.stages.len(), 1);
assert_eq!(normalised.stages[0].frames, 9);
}
#[test]
fn normalise_rejects_too_many_stages() {
let stages = (0..17).map(|_| make_stage(97)).collect();
let err = canonical_request(stages, 9)
.normalise()
.expect_err("17-stage chain should fail");
assert!(
matches!(err, MoldError::Validation(msg) if msg.contains("maximum")),
"error must mention the max-stages cap",
);
}
#[test]
fn normalise_rejects_auto_expand_too_long() {
let err = auto_expand_request("too long", 4000, 97, 9, None)
.normalise()
.expect_err("runaway auto-expand should fail");
assert!(
matches!(err, MoldError::Validation(msg) if msg.contains("stages")),
"error must name the stage count guardrail",
);
}
#[test]
fn normalise_rejects_motion_tail_ge_clip() {
let err = auto_expand_request("bad tail", 200, 97, 97, None)
.normalise()
.expect_err("motion_tail >= clip should fail");
assert!(
matches!(err, MoldError::Validation(msg) if msg.contains("motion_tail_frames")),
"error must name motion_tail_frames",
);
}
#[test]
fn enable_audio_defaults_to_none_and_round_trips_when_set() {
let req: ChainRequest = serde_json::from_value(serde_json::json!({
"model": "ltx-2.3-22b-distilled:fp8",
"stages": [],
"width": 704,
"height": 416,
"steps": 4,
"guidance": 3.0,
}))
.expect("valid minimal chain request");
assert_eq!(req.enable_audio, None);
let req_with_audio: ChainRequest = serde_json::from_value(serde_json::json!({
"model": "ltx-2.3-22b-distilled:fp8",
"stages": [{"prompt": "a bird", "frames": 33}],
"width": 704,
"height": 416,
"steps": 4,
"guidance": 3.0,
"enable_audio": true,
}))
.expect("valid chain request with audio");
assert_eq!(req_with_audio.enable_audio, Some(true));
let script = ChainScript::from(&req_with_audio);
assert_eq!(
script.chain.enable_audio,
Some(true),
"ChainScript echo must preserve enable_audio for round-trip save/reload",
);
}
#[test]
fn motion_tail_default_lands_on_8k_plus_1_grid() {
let req: ChainRequest = serde_json::from_value(serde_json::json!({
"model": "ltx-2.3-22b-distilled:fp8",
"stages": [],
"width": 704,
"height": 416,
"steps": 4,
"guidance": 3.0,
}))
.expect("valid minimal chain request");
assert_eq!(req.motion_tail_frames, 17);
assert!(is_ltx2_frame_count(req.motion_tail_frames));
}
#[test]
fn normalise_rejects_motion_tail_off_grid() {
let req = canonical_request(vec![make_stage(33)], 4);
let err = req
.normalise()
.expect_err("motion_tail_frames=4 must be rejected");
assert!(
matches!(err, MoldError::Validation(msg) if msg.contains("8k+1")),
"error must name the 8k+1 grid constraint",
);
}
#[test]
fn normalise_accepts_motion_tail_zero() {
let mut second = make_stage(33);
second.transition = TransitionMode::Cut;
let req = canonical_request(vec![make_stage(33), second], 0);
req.normalise().expect("motion_tail=0 must be accepted");
}
#[test]
fn normalise_rejects_missing_total_frames_in_auto_expand() {
let mut req = canonical_request(Vec::new(), 4);
req.prompt = Some("missing total".into());
let err = req
.normalise()
.expect_err("missing total_frames should fail");
assert!(
matches!(err, MoldError::Validation(msg) if msg.contains("total_frames")),
"error must name total_frames",
);
}
#[test]
fn is_ltx2_frame_count_matches_8k_plus_1() {
for valid in [1u32, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97] {
assert!(
is_ltx2_frame_count(valid),
"{valid} should be a valid LTX-2 frame count",
);
}
for invalid in [0u32, 2, 8, 10, 16, 50, 96, 98, 100] {
assert!(
!is_ltx2_frame_count(invalid),
"{invalid} must not pass the 8k+1 check",
);
}
}
#[test]
fn chain_progress_event_roundtrips_json_with_snake_case_tags() {
let cases = [
(
ChainProgressEvent::ChainStart {
stage_count: 5,
estimated_total_frames: 469,
},
r#""type":"chain_start""#,
),
(
ChainProgressEvent::StageStart { stage_idx: 0 },
r#""type":"stage_start""#,
),
(
ChainProgressEvent::DenoiseStep {
stage_idx: 2,
step: 4,
total: 8,
},
r#""type":"denoise_step""#,
),
(
ChainProgressEvent::StageDone {
stage_idx: 3,
frames_emitted: 97,
},
r#""type":"stage_done""#,
),
(
ChainProgressEvent::Stitching { total_frames: 400 },
r#""type":"stitching""#,
),
];
for (event, expected_tag) in cases {
let json = serde_json::to_string(&event).expect("serialize");
assert!(
json.contains(expected_tag),
"missing snake_case tag {expected_tag} in {json}",
);
let roundtrip: ChainProgressEvent = serde_json::from_str(&json).expect("deserialize");
assert_eq!(roundtrip, event, "roundtrip must preserve payload");
}
}
#[test]
fn build_stages_math_matches_stitch_budget() {
let cases = [
(400u32, 97u32, 9u32, 5u32), (200, 97, 9, 3), (97, 97, 9, 1), (300, 97, 0, 4), ];
for (total, clip, tail, expected_n) in cases {
let req = auto_expand_request("m", total, clip, tail, None)
.normalise()
.expect("valid auto-expand should normalise");
assert_eq!(
req.stages.len() as u32,
expected_n,
"expected {expected_n} stages for total={total}, clip={clip}, tail={tail}",
);
let delivered = clip + (expected_n - 1) * (clip - tail);
assert!(
delivered >= total,
"{expected_n} stages deliver {delivered} frames but {total} were requested",
);
}
}
#[test]
fn transition_mode_serializes_snake_case() {
assert_eq!(
serde_json::to_value(TransitionMode::Smooth).unwrap(),
serde_json::Value::String("smooth".into())
);
assert_eq!(
serde_json::to_value(TransitionMode::Cut).unwrap(),
serde_json::Value::String("cut".into())
);
assert_eq!(
serde_json::to_value(TransitionMode::Fade).unwrap(),
serde_json::Value::String("fade".into())
);
}
#[test]
fn transition_mode_defaults_to_smooth() {
assert_eq!(TransitionMode::default(), TransitionMode::Smooth);
}
#[test]
fn lora_spec_serializes_minimal() {
let spec = LoraSpec {
path: "./style.safetensors".into(),
scale: 0.8,
name: None,
};
let json = serde_json::to_string(&spec).unwrap();
assert!(json.contains(r#""path":"./style.safetensors""#));
assert!(json.contains(r#""scale":0.8"#));
assert!(!json.contains(r#""name""#));
}
#[test]
fn named_ref_serializes_minimal() {
let r = NamedRef {
name: "hero".into(),
image: vec![0x89, 0x50],
};
let json = serde_json::to_string(&r).unwrap();
assert!(json.contains(r#""name":"hero""#));
assert!(json.contains(r#""image":"#));
}
#[test]
fn chain_stage_defaults_are_backcompat() {
let json = r#"{
"prompt": "a cat",
"frames": 97
}"#;
let stage: ChainStage = serde_json::from_str(json).unwrap();
assert_eq!(stage.prompt, "a cat");
assert_eq!(stage.frames, 97);
assert_eq!(stage.transition, TransitionMode::Smooth);
assert_eq!(stage.fade_frames, None);
assert!(stage.model.is_none());
assert!(stage.loras.is_empty());
assert!(stage.references.is_empty());
}
#[test]
fn chain_script_projects_from_request() {
let req = ChainRequest {
model: "ltx-2-19b-distilled:fp8".into(),
stages: vec![ChainStage {
prompt: "a".into(),
frames: 97,
source_image: None,
negative_prompt: None,
seed_offset: None,
transition: TransitionMode::Smooth,
fade_frames: None,
model: None,
loras: vec![],
references: vec![],
}],
motion_tail_frames: 25,
width: 1216,
height: 704,
fps: 24,
seed: Some(42),
steps: 8,
guidance: 3.0,
strength: 1.0,
output_format: OutputFormat::Mp4,
placement: None,
prompt: None,
total_frames: None,
clip_frames: None,
source_image: None,
enable_audio: None,
};
let script = ChainScript::from(&req);
assert_eq!(script.chain.model, "ltx-2-19b-distilled:fp8");
assert_eq!(script.chain.seed, Some(42));
assert_eq!(script.stages.len(), 1);
assert_eq!(script.stages[0].prompt, "a");
}
#[test]
fn chain_stage_roundtrips_all_fields() {
let stage = ChainStage {
prompt: "scene".into(),
frames: 49,
source_image: None,
negative_prompt: None,
seed_offset: None,
transition: TransitionMode::Cut,
fade_frames: Some(12),
model: None,
loras: vec![],
references: vec![],
};
let json = serde_json::to_string(&stage).unwrap();
let back: ChainStage = serde_json::from_str(&json).unwrap();
assert_eq!(back.frames, 49);
assert_eq!(back.transition, TransitionMode::Cut);
assert_eq!(back.fade_frames, Some(12));
}
#[test]
fn normalise_coerces_stage_0_transition_to_smooth() {
let mut req = auto_expand_request("a", 97, 97, 25, None);
req.stages = vec![
ChainStage {
prompt: "scene 0".into(),
frames: 97,
source_image: None,
negative_prompt: None,
seed_offset: None,
transition: TransitionMode::Cut, fade_frames: None,
model: None,
loras: vec![],
references: vec![],
},
ChainStage {
prompt: "scene 1".into(),
frames: 97,
source_image: None,
negative_prompt: None,
seed_offset: None,
transition: TransitionMode::Cut, fade_frames: None,
model: None,
loras: vec![],
references: vec![],
},
];
let normalised = req.normalise().unwrap();
assert_eq!(normalised.stages[0].transition, TransitionMode::Smooth);
assert_eq!(normalised.stages[1].transition, TransitionMode::Cut);
}
#[test]
fn normalise_rejects_reserved_model_field() {
let mut req = auto_expand_request("a", 97, 97, 25, None);
req.stages = vec![ChainStage {
prompt: "x".into(),
frames: 97,
source_image: None,
negative_prompt: None,
seed_offset: None,
transition: TransitionMode::Smooth,
fade_frames: None,
model: Some("flux-dev:q4".into()),
loras: vec![],
references: vec![],
}];
let err = req.normalise().unwrap_err().to_string();
assert!(err.contains("reserved for sub-project C"), "got: {err}");
}
#[test]
fn normalise_rejects_reserved_loras_field() {
let mut req = auto_expand_request("a", 97, 97, 25, None);
req.stages = vec![ChainStage {
prompt: "x".into(),
frames: 97,
source_image: None,
negative_prompt: None,
seed_offset: None,
transition: TransitionMode::Smooth,
fade_frames: None,
model: None,
loras: vec![LoraSpec {
path: "x.safetensors".into(),
scale: 1.0,
name: None,
}],
references: vec![],
}];
let err = req.normalise().unwrap_err().to_string();
assert!(err.contains("reserved for sub-project B"), "got: {err}");
}
fn stage_list_request(stages: Vec<(TransitionMode, u32, Option<u32>)>) -> ChainRequest {
ChainRequest {
model: "ltx-2-19b-distilled:fp8".into(),
stages: stages
.into_iter()
.map(|(t, f, fl)| ChainStage {
prompt: "x".into(),
frames: f,
source_image: None,
negative_prompt: None,
seed_offset: None,
transition: t,
fade_frames: fl,
model: None,
loras: vec![],
references: vec![],
})
.collect(),
motion_tail_frames: 25,
width: 1216,
height: 704,
fps: 24,
seed: None,
steps: 8,
guidance: 3.0,
strength: 1.0,
output_format: OutputFormat::Mp4,
placement: None,
prompt: None,
total_frames: None,
clip_frames: None,
source_image: None,
enable_audio: None,
}
}
#[test]
fn estimated_total_all_smooth() {
let req = stage_list_request(vec![
(TransitionMode::Smooth, 97, None),
(TransitionMode::Smooth, 97, None),
(TransitionMode::Smooth, 97, None),
]);
assert_eq!(req.estimated_total_frames(), 241);
}
#[test]
fn estimated_total_with_cut() {
let req = stage_list_request(vec![
(TransitionMode::Smooth, 97, None),
(TransitionMode::Cut, 97, None),
(TransitionMode::Smooth, 97, None),
]);
assert_eq!(req.estimated_total_frames(), 266);
}
#[test]
fn estimated_total_with_fade() {
let req = stage_list_request(vec![
(TransitionMode::Smooth, 97, None),
(TransitionMode::Cut, 97, None),
(TransitionMode::Fade, 97, Some(8)),
]);
assert_eq!(req.estimated_total_frames(), 283);
}
}