1use serde::{Deserialize, Serialize};
17
18use crate::error::{MoldError, Result};
19use crate::types::{DevicePlacement, OutputFormat, VideoData};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize, utoipa::ToSchema)]
35#[serde(rename_all = "snake_case")]
36pub enum TransitionMode {
37 #[default]
38 Smooth,
39 Cut,
40 Fade,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
48pub struct LoraSpec {
49 pub path: String,
50 pub scale: f64,
51 #[serde(default, skip_serializing_if = "Option::is_none")]
52 pub name: Option<String>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
58pub struct NamedRef {
59 pub name: String,
60 #[serde(with = "crate::types::base64_bytes")]
61 pub image: Vec<u8>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
70pub struct ChainStage {
71 #[schema(example = "a cat walking through autumn leaves")]
75 pub prompt: String,
76
77 #[schema(example = 97)]
80 pub frames: u32,
81
82 #[serde(
86 default,
87 skip_serializing_if = "Option::is_none",
88 with = "crate::types::base64_opt"
89 )]
90 pub source_image: Option<Vec<u8>>,
91
92 #[serde(default, skip_serializing_if = "Option::is_none")]
96 pub negative_prompt: Option<String>,
97
98 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub seed_offset: Option<u64>,
104
105 #[serde(default)]
109 pub transition: TransitionMode,
110
111 #[serde(default, skip_serializing_if = "Option::is_none")]
115 pub fade_frames: Option<u32>,
116
117 #[serde(default, skip_serializing_if = "Option::is_none")]
121 pub model: Option<String>,
122
123 #[serde(default, skip_serializing_if = "Vec::is_empty")]
125 pub loras: Vec<LoraSpec>,
126
127 #[serde(default, skip_serializing_if = "Vec::is_empty")]
129 pub references: Vec<NamedRef>,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
137pub struct ChainRequest {
138 #[schema(example = "ltx-2-19b-distilled:fp8")]
139 pub model: String,
140
141 #[serde(default)]
144 pub stages: Vec<ChainStage>,
145
146 #[serde(default = "default_motion_tail_frames")]
157 #[schema(example = 17)]
158 pub motion_tail_frames: u32,
159
160 #[schema(example = 1216)]
161 pub width: u32,
162 #[schema(example = 704)]
163 pub height: u32,
164 #[serde(default = "default_fps")]
165 #[schema(example = 24)]
166 pub fps: u32,
167
168 #[serde(default, skip_serializing_if = "Option::is_none")]
172 #[schema(example = 42)]
173 pub seed: Option<u64>,
174
175 #[schema(example = 8)]
176 pub steps: u32,
177
178 #[schema(example = 3.0)]
179 pub guidance: f64,
180
181 #[serde(default = "default_strength")]
185 #[schema(example = 1.0)]
186 pub strength: f64,
187
188 #[serde(default = "default_output_format")]
189 pub output_format: OutputFormat,
190
191 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub placement: Option<DevicePlacement>,
193
194 #[serde(default, skip_serializing_if = "Option::is_none")]
199 pub prompt: Option<String>,
200
201 #[serde(default, skip_serializing_if = "Option::is_none")]
203 pub total_frames: Option<u32>,
204
205 #[serde(default, skip_serializing_if = "Option::is_none")]
208 pub clip_frames: Option<u32>,
209
210 #[serde(
212 default,
213 skip_serializing_if = "Option::is_none",
214 with = "crate::types::base64_opt"
215 )]
216 pub source_image: Option<Vec<u8>>,
217
218 #[serde(default, skip_serializing_if = "Option::is_none")]
224 pub enable_audio: Option<bool>,
225}
226
227#[derive(Debug, Clone, Default, Serialize, Deserialize, utoipa::ToSchema)]
233pub struct ChainScript {
234 pub schema: String, pub chain: ChainScriptChain,
236 #[serde(rename = "stage")]
237 pub stages: Vec<ChainStage>,
238}
239
240#[derive(Debug, Clone, Default, Serialize, Deserialize, utoipa::ToSchema)]
241pub struct ChainScriptChain {
242 pub model: String,
243 pub width: u32,
244 pub height: u32,
245 pub fps: u32,
246 #[serde(default, skip_serializing_if = "Option::is_none")]
247 pub seed: Option<u64>,
248 pub steps: u32,
249 pub guidance: f64,
250 pub strength: f64,
251 pub motion_tail_frames: u32,
252 pub output_format: OutputFormat,
253 #[serde(default, skip_serializing_if = "Option::is_none")]
256 pub enable_audio: Option<bool>,
257}
258
259impl From<&ChainRequest> for ChainScript {
260 fn from(req: &ChainRequest) -> Self {
261 ChainScript {
262 schema: "mold.chain.v1".into(),
263 chain: ChainScriptChain {
264 model: req.model.clone(),
265 width: req.width,
266 height: req.height,
267 fps: req.fps,
268 seed: req.seed,
269 steps: req.steps,
270 guidance: req.guidance,
271 strength: req.strength,
272 motion_tail_frames: req.motion_tail_frames,
273 output_format: req.output_format,
274 enable_audio: req.enable_audio,
275 },
276 stages: req.stages.clone(),
277 }
278 }
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
284pub struct VramEstimate {
285 pub worst_case_bytes: u64,
286 pub fits: bool,
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
292pub struct ChainResponse {
293 pub video: VideoData,
294 #[schema(example = 5)]
297 pub stage_count: u32,
298 #[serde(default, skip_serializing_if = "Option::is_none")]
300 pub gpu: Option<usize>,
301
302 pub script: ChainScript,
306
307 #[serde(default, skip_serializing_if = "Option::is_none")]
309 pub vram_estimate: Option<VramEstimate>,
310}
311
312#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
320pub struct SseChainCompleteEvent {
321 pub video: String,
323 pub format: OutputFormat,
324 #[schema(example = 1216)]
325 pub width: u32,
326 #[schema(example = 704)]
327 pub height: u32,
328 #[schema(example = 400)]
329 pub frames: u32,
330 #[schema(example = 24)]
331 pub fps: u32,
332 #[serde(default, skip_serializing_if = "Option::is_none")]
334 pub thumbnail: Option<String>,
335 #[serde(default, skip_serializing_if = "Option::is_none")]
337 pub gif_preview: Option<String>,
338 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
339 pub has_audio: bool,
340 #[serde(default, skip_serializing_if = "Option::is_none")]
341 pub duration_ms: Option<u64>,
342 #[serde(default, skip_serializing_if = "Option::is_none")]
343 pub audio_sample_rate: Option<u32>,
344 #[serde(default, skip_serializing_if = "Option::is_none")]
345 pub audio_channels: Option<u32>,
346 #[schema(example = 5)]
348 pub stage_count: u32,
349 #[serde(default, skip_serializing_if = "Option::is_none")]
351 pub gpu: Option<usize>,
352 #[serde(default, skip_serializing_if = "Option::is_none")]
354 pub generation_time_ms: Option<u64>,
355 #[serde(default)]
359 pub script: ChainScript,
360 #[serde(default, skip_serializing_if = "Option::is_none")]
362 pub vram_estimate: Option<VramEstimate>,
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema, PartialEq, Eq)]
375#[serde(tag = "type", rename_all = "snake_case")]
376pub enum ChainProgressEvent {
377 ChainStart {
381 stage_count: u32,
382 estimated_total_frames: u32,
383 },
384 StageStart { stage_idx: u32 },
386 DenoiseStep {
388 stage_idx: u32,
389 step: u32,
390 total: u32,
391 },
392 StageDone { stage_idx: u32, frames_emitted: u32 },
395 Stitching { total_frames: u32 },
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
403pub struct ChainFailure {
404 #[schema(example = "stage render failed")]
406 pub error: String,
407 #[schema(example = 2)]
409 pub failed_stage_idx: u32,
410 #[schema(example = 2)]
412 pub elapsed_stages: u32,
413 #[schema(example = 12_340)]
415 pub elapsed_ms: u64,
416 #[schema(example = "simulated GPU OOM on stage 2")]
418 pub stage_error: String,
419}
420
421fn default_motion_tail_frames() -> u32 {
422 17
423}
424
425fn default_fps() -> u32 {
426 24
427}
428
429fn default_strength() -> f64 {
430 1.0
431}
432
433fn default_output_format() -> OutputFormat {
434 OutputFormat::Mp4
435}
436
437pub const MAX_CHAIN_STAGES: usize = 16;
441
442impl ChainRequest {
443 pub fn normalise(mut self) -> Result<Self> {
453 if self.stages.is_empty() {
454 let prompt = self.prompt.take().ok_or_else(|| {
455 MoldError::Validation(
456 "chain request needs either stages[] or prompt + total_frames".into(),
457 )
458 })?;
459 let total_frames = self.total_frames.ok_or_else(|| {
460 MoldError::Validation("chain auto-expand requires total_frames".into())
461 })?;
462 if total_frames == 0 {
463 return Err(MoldError::Validation(
464 "chain total_frames must be > 0".into(),
465 ));
466 }
467 let clip_frames = self.clip_frames.unwrap_or(97);
468 if clip_frames == 0 {
469 return Err(MoldError::Validation(
470 "chain clip_frames must be > 0".into(),
471 ));
472 }
473 if !is_ltx2_frame_count(clip_frames) {
474 return Err(MoldError::Validation(format!(
475 "chain clip_frames ({clip_frames}) must be 8k+1 (9, 17, 25, …, 97)",
476 )));
477 }
478 let motion_tail = self.motion_tail_frames;
479 if motion_tail >= clip_frames {
480 return Err(MoldError::Validation(format!(
481 "motion_tail_frames ({motion_tail}) must be strictly less than clip_frames ({clip_frames})",
482 )));
483 }
484
485 let source_image = self.source_image.take();
486 self.stages = build_auto_expand_stages(
487 &prompt,
488 total_frames,
489 clip_frames,
490 motion_tail,
491 source_image,
492 )?;
493 }
494
495 if self.stages.is_empty() {
496 return Err(MoldError::Validation("chain request has no stages".into()));
497 }
498 if self.stages.len() > MAX_CHAIN_STAGES {
499 return Err(MoldError::Validation(format!(
500 "chain request has {} stages; maximum is {}",
501 self.stages.len(),
502 MAX_CHAIN_STAGES,
503 )));
504 }
505 if self.motion_tail_frames != 0 && !is_ltx2_frame_count(self.motion_tail_frames) {
506 return Err(MoldError::Validation(format!(
507 "motion_tail_frames ({}) must be 0 or 8k+1 (1, 9, 17, 25, …) so the carryover \
508 RGB frames re-encode cleanly through the LTX-2 video VAE's 8× causal grid",
509 self.motion_tail_frames,
510 )));
511 }
512 for (idx, stage) in self.stages.iter().enumerate() {
513 if stage.frames == 0 {
514 return Err(MoldError::Validation(format!("stage {idx} has 0 frames",)));
515 }
516 if !is_ltx2_frame_count(stage.frames) {
517 return Err(MoldError::Validation(format!(
518 "stage {idx} has {} frames; LTX-2 requires 8k+1 (9, 17, 25, …, 97)",
519 stage.frames,
520 )));
521 }
522 if self.motion_tail_frames >= stage.frames {
523 return Err(MoldError::Validation(format!(
524 "motion_tail_frames ({}) must be strictly less than stage {idx}'s frames ({})",
525 self.motion_tail_frames, stage.frames,
526 )));
527 }
528 }
529
530 for (idx, stage) in self.stages.iter().enumerate() {
532 if stage.model.is_some() {
533 return Err(MoldError::Validation(format!(
534 "stages[{idx}].model is reserved for sub-project C and not yet supported"
535 )));
536 }
537 if !stage.loras.is_empty() {
538 return Err(MoldError::Validation(format!(
539 "stages[{idx}].loras is reserved for sub-project B and not yet supported"
540 )));
541 }
542 if !stage.references.is_empty() {
543 return Err(MoldError::Validation(format!(
544 "stages[{idx}].references is reserved for sub-project B and not yet supported"
545 )));
546 }
547 }
548
549 if let Some(first) = self.stages.first_mut() {
552 if first.transition != TransitionMode::Smooth {
553 tracing::warn!(
554 coerced_from = ?first.transition,
555 "stage 0 transition is meaningless; coercing to Smooth"
556 );
557 first.transition = TransitionMode::Smooth;
558 }
559 }
560
561 self.prompt = None;
564 self.total_frames = None;
565 self.clip_frames = None;
566 self.source_image = None;
567
568 Ok(self)
569 }
570
571 pub fn estimated_total_frames(&self) -> u32 {
581 const DEFAULT_FADE_FRAMES: u32 = 8;
582 let mut total: u32 = 0;
583 for (idx, stage) in self.stages.iter().enumerate() {
584 if idx == 0 {
585 total += stage.frames;
586 continue;
587 }
588 match stage.transition {
589 TransitionMode::Smooth => {
590 total += stage.frames.saturating_sub(self.motion_tail_frames);
591 }
592 TransitionMode::Cut => {
593 total += stage.frames;
594 }
595 TransitionMode::Fade => {
596 let fade_len = stage.fade_frames.unwrap_or(DEFAULT_FADE_FRAMES);
597 total += stage.frames.saturating_sub(fade_len);
598 }
599 }
600 }
601 total
602 }
603}
604
605fn is_ltx2_frame_count(n: u32) -> bool {
610 n % 8 == 1
611}
612
613fn build_auto_expand_stages(
625 prompt: &str,
626 total_frames: u32,
627 clip_frames: u32,
628 motion_tail_frames: u32,
629 source_image: Option<Vec<u8>>,
630) -> Result<Vec<ChainStage>> {
631 let (stage_count, per_stage_frames) = if total_frames <= clip_frames {
632 (1u32, total_frames)
636 } else {
637 let effective = clip_frames - motion_tail_frames;
638 let remainder = total_frames - clip_frames;
641 let count = 1 + remainder.div_ceil(effective);
642 (count, clip_frames)
643 };
644
645 let count_usize = stage_count as usize;
646 if count_usize > MAX_CHAIN_STAGES {
647 return Err(MoldError::Validation(format!(
648 "auto-expand would produce {stage_count} stages; maximum is {MAX_CHAIN_STAGES} \
649 (try reducing total_frames or increasing clip_frames)",
650 )));
651 }
652
653 let mut stages = Vec::with_capacity(count_usize);
654 for _ in 0..stage_count {
655 stages.push(ChainStage {
665 prompt: prompt.to_string(),
666 frames: per_stage_frames,
667 source_image: source_image.clone(),
668 negative_prompt: None,
669 seed_offset: None,
670 transition: TransitionMode::Smooth,
671 fade_frames: None,
672 model: None,
673 loras: vec![],
674 references: vec![],
675 });
676 }
677 Ok(stages)
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683
684 fn auto_expand_request(
688 prompt: &str,
689 total_frames: u32,
690 clip_frames: u32,
691 motion_tail_frames: u32,
692 source_image: Option<Vec<u8>>,
693 ) -> ChainRequest {
694 ChainRequest {
695 model: "ltx-2-19b-distilled:fp8".into(),
696 stages: Vec::new(),
697 motion_tail_frames,
698 width: 1216,
699 height: 704,
700 fps: 24,
701 seed: Some(42),
702 steps: 8,
703 guidance: 3.0,
704 strength: 1.0,
705 output_format: OutputFormat::Mp4,
706 placement: None,
707 prompt: Some(prompt.into()),
708 total_frames: Some(total_frames),
709 clip_frames: Some(clip_frames),
710 source_image,
711 enable_audio: None,
712 }
713 }
714
715 fn canonical_request(stages: Vec<ChainStage>, motion_tail_frames: u32) -> ChainRequest {
716 ChainRequest {
717 model: "ltx-2-19b-distilled:fp8".into(),
718 stages,
719 motion_tail_frames,
720 width: 1216,
721 height: 704,
722 fps: 24,
723 seed: Some(42),
724 steps: 8,
725 guidance: 3.0,
726 strength: 1.0,
727 output_format: OutputFormat::Mp4,
728 placement: None,
729 prompt: None,
730 total_frames: None,
731 clip_frames: None,
732 source_image: None,
733 enable_audio: None,
734 }
735 }
736
737 fn make_stage(frames: u32) -> ChainStage {
738 ChainStage {
739 prompt: "test".into(),
740 frames,
741 source_image: None,
742 negative_prompt: None,
743 seed_offset: None,
744 transition: TransitionMode::Smooth,
745 fade_frames: None,
746 model: None,
747 loras: vec![],
748 references: vec![],
749 }
750 }
751
752 #[test]
753 fn normalise_splits_single_prompt_into_stages() {
754 let normalised = auto_expand_request("a cat walking", 400, 97, 9, None)
759 .normalise()
760 .expect("normalise should succeed");
761
762 assert_eq!(
763 normalised.stages.len(),
764 5,
765 "400/97 with a 9-frame motion tail should expand to 5 stages",
766 );
767 for stage in &normalised.stages {
768 assert_eq!(stage.frames, 97);
769 assert_eq!(stage.prompt, "a cat walking");
770 assert!(stage.seed_offset.is_none());
771 }
772 assert!(normalised.prompt.is_none());
774 assert!(normalised.total_frames.is_none());
775 assert!(normalised.clip_frames.is_none());
776 assert!(normalised.source_image.is_none());
777 }
778
779 #[test]
780 fn normalise_preserves_starting_image_across_all_stages() {
781 let png = vec![0x89, 0x50, 0x4e, 0x47, 0xde, 0xad, 0xbe, 0xef];
782 let normalised = auto_expand_request("test", 200, 97, 9, Some(png.clone()))
783 .normalise()
784 .expect("normalise should succeed");
785
786 assert!(normalised.stages.len() >= 2);
787 for (idx, stage) in normalised.stages.iter().enumerate() {
788 assert_eq!(
793 stage.source_image.as_deref(),
794 Some(png.as_slice()),
795 "stage {idx} must carry the starting image for cross-stage identity anchoring",
796 );
797 }
798 }
799
800 #[test]
801 fn normalise_rejects_empty() {
802 let mut req = canonical_request(Vec::new(), 9);
803 req.prompt = None;
805 req.total_frames = None;
806
807 let err = req.normalise().expect_err("empty chain should fail");
808 assert!(
809 matches!(err, MoldError::Validation(_)),
810 "empty chain should be a validation error, got {err:?}",
811 );
812 }
813
814 #[test]
815 fn normalise_rejects_non_8k1_frames() {
816 let req = canonical_request(vec![make_stage(50)], 9);
819 let err = req.normalise().expect_err("non-8k+1 frames should fail");
820 assert!(
821 matches!(err, MoldError::Validation(msg) if msg.contains("8k+1")),
822 "error must mention the 8k+1 constraint",
823 );
824 }
825
826 #[test]
827 fn normalise_accepts_canonical_form_unchanged() {
828 let stages = vec![make_stage(97), make_stage(97), make_stage(97)];
831 let normalised = canonical_request(stages.clone(), 9)
832 .normalise()
833 .expect("valid canonical form should pass");
834 assert_eq!(normalised.stages.len(), 3);
835 for (left, right) in normalised.stages.iter().zip(stages.iter()) {
836 assert_eq!(left.frames, right.frames);
837 assert_eq!(left.prompt, right.prompt);
838 }
839 }
840
841 #[test]
842 fn normalise_single_stage_when_total_leq_clip() {
843 let normalised = auto_expand_request("short", 9, 97, 1, None)
848 .normalise()
849 .expect("short single-clip chain should pass");
850 assert_eq!(normalised.stages.len(), 1);
851 assert_eq!(normalised.stages[0].frames, 9);
852 }
853
854 #[test]
855 fn normalise_rejects_too_many_stages() {
856 let stages = (0..17).map(|_| make_stage(97)).collect();
858 let err = canonical_request(stages, 9)
859 .normalise()
860 .expect_err("17-stage chain should fail");
861 assert!(
862 matches!(err, MoldError::Validation(msg) if msg.contains("maximum")),
863 "error must mention the max-stages cap",
864 );
865 }
866
867 #[test]
868 fn normalise_rejects_auto_expand_too_long() {
869 let err = auto_expand_request("too long", 4000, 97, 9, None)
872 .normalise()
873 .expect_err("runaway auto-expand should fail");
874 assert!(
875 matches!(err, MoldError::Validation(msg) if msg.contains("stages")),
876 "error must name the stage count guardrail",
877 );
878 }
879
880 #[test]
881 fn normalise_rejects_motion_tail_ge_clip() {
882 let err = auto_expand_request("bad tail", 200, 97, 97, None)
884 .normalise()
885 .expect_err("motion_tail >= clip should fail");
886 assert!(
887 matches!(err, MoldError::Validation(msg) if msg.contains("motion_tail_frames")),
888 "error must name motion_tail_frames",
889 );
890 }
891
892 #[test]
893 fn enable_audio_defaults_to_none_and_round_trips_when_set() {
894 let req: ChainRequest = serde_json::from_value(serde_json::json!({
901 "model": "ltx-2.3-22b-distilled:fp8",
902 "stages": [],
903 "width": 704,
904 "height": 416,
905 "steps": 4,
906 "guidance": 3.0,
907 }))
908 .expect("valid minimal chain request");
909 assert_eq!(req.enable_audio, None);
910
911 let req_with_audio: ChainRequest = serde_json::from_value(serde_json::json!({
912 "model": "ltx-2.3-22b-distilled:fp8",
913 "stages": [{"prompt": "a bird", "frames": 33}],
914 "width": 704,
915 "height": 416,
916 "steps": 4,
917 "guidance": 3.0,
918 "enable_audio": true,
919 }))
920 .expect("valid chain request with audio");
921 assert_eq!(req_with_audio.enable_audio, Some(true));
922
923 let script = ChainScript::from(&req_with_audio);
924 assert_eq!(
925 script.chain.enable_audio,
926 Some(true),
927 "ChainScript echo must preserve enable_audio for round-trip save/reload",
928 );
929 }
930
931 #[test]
932 fn motion_tail_default_lands_on_8k_plus_1_grid() {
933 let req: ChainRequest = serde_json::from_value(serde_json::json!({
937 "model": "ltx-2.3-22b-distilled:fp8",
938 "stages": [],
939 "width": 704,
940 "height": 416,
941 "steps": 4,
942 "guidance": 3.0,
943 }))
944 .expect("valid minimal chain request");
945 assert_eq!(req.motion_tail_frames, 17);
946 assert!(is_ltx2_frame_count(req.motion_tail_frames));
947 }
948
949 #[test]
950 fn normalise_rejects_motion_tail_off_grid() {
951 let req = canonical_request(vec![make_stage(33)], 4);
956 let err = req
957 .normalise()
958 .expect_err("motion_tail_frames=4 must be rejected");
959 assert!(
960 matches!(err, MoldError::Validation(msg) if msg.contains("8k+1")),
961 "error must name the 8k+1 grid constraint",
962 );
963 }
964
965 #[test]
966 fn normalise_accepts_motion_tail_zero() {
967 let mut second = make_stage(33);
970 second.transition = TransitionMode::Cut;
971 let req = canonical_request(vec![make_stage(33), second], 0);
972 req.normalise().expect("motion_tail=0 must be accepted");
973 }
974
975 #[test]
976 fn normalise_rejects_missing_total_frames_in_auto_expand() {
977 let mut req = canonical_request(Vec::new(), 4);
978 req.prompt = Some("missing total".into());
979 let err = req
981 .normalise()
982 .expect_err("missing total_frames should fail");
983 assert!(
984 matches!(err, MoldError::Validation(msg) if msg.contains("total_frames")),
985 "error must name total_frames",
986 );
987 }
988
989 #[test]
990 fn is_ltx2_frame_count_matches_8k_plus_1() {
991 for valid in [1u32, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97] {
992 assert!(
993 is_ltx2_frame_count(valid),
994 "{valid} should be a valid LTX-2 frame count",
995 );
996 }
997 for invalid in [0u32, 2, 8, 10, 16, 50, 96, 98, 100] {
998 assert!(
999 !is_ltx2_frame_count(invalid),
1000 "{invalid} must not pass the 8k+1 check",
1001 );
1002 }
1003 }
1004
1005 #[test]
1006 fn chain_progress_event_roundtrips_json_with_snake_case_tags() {
1007 let cases = [
1008 (
1009 ChainProgressEvent::ChainStart {
1010 stage_count: 5,
1011 estimated_total_frames: 469,
1012 },
1013 r#""type":"chain_start""#,
1014 ),
1015 (
1016 ChainProgressEvent::StageStart { stage_idx: 0 },
1017 r#""type":"stage_start""#,
1018 ),
1019 (
1020 ChainProgressEvent::DenoiseStep {
1021 stage_idx: 2,
1022 step: 4,
1023 total: 8,
1024 },
1025 r#""type":"denoise_step""#,
1026 ),
1027 (
1028 ChainProgressEvent::StageDone {
1029 stage_idx: 3,
1030 frames_emitted: 97,
1031 },
1032 r#""type":"stage_done""#,
1033 ),
1034 (
1035 ChainProgressEvent::Stitching { total_frames: 400 },
1036 r#""type":"stitching""#,
1037 ),
1038 ];
1039 for (event, expected_tag) in cases {
1040 let json = serde_json::to_string(&event).expect("serialize");
1041 assert!(
1042 json.contains(expected_tag),
1043 "missing snake_case tag {expected_tag} in {json}",
1044 );
1045 let roundtrip: ChainProgressEvent = serde_json::from_str(&json).expect("deserialize");
1046 assert_eq!(roundtrip, event, "roundtrip must preserve payload");
1047 }
1048 }
1049
1050 #[test]
1051 fn build_stages_math_matches_stitch_budget() {
1052 let cases = [
1056 (400u32, 97u32, 9u32, 5u32), (200, 97, 9, 3), (97, 97, 9, 1), (300, 97, 0, 4), ];
1061 for (total, clip, tail, expected_n) in cases {
1062 let req = auto_expand_request("m", total, clip, tail, None)
1063 .normalise()
1064 .expect("valid auto-expand should normalise");
1065 assert_eq!(
1066 req.stages.len() as u32,
1067 expected_n,
1068 "expected {expected_n} stages for total={total}, clip={clip}, tail={tail}",
1069 );
1070 let delivered = clip + (expected_n - 1) * (clip - tail);
1071 assert!(
1072 delivered >= total,
1073 "{expected_n} stages deliver {delivered} frames but {total} were requested",
1074 );
1075 }
1076 }
1077
1078 #[test]
1079 fn transition_mode_serializes_snake_case() {
1080 assert_eq!(
1081 serde_json::to_value(TransitionMode::Smooth).unwrap(),
1082 serde_json::Value::String("smooth".into())
1083 );
1084 assert_eq!(
1085 serde_json::to_value(TransitionMode::Cut).unwrap(),
1086 serde_json::Value::String("cut".into())
1087 );
1088 assert_eq!(
1089 serde_json::to_value(TransitionMode::Fade).unwrap(),
1090 serde_json::Value::String("fade".into())
1091 );
1092 }
1093
1094 #[test]
1095 fn transition_mode_defaults_to_smooth() {
1096 assert_eq!(TransitionMode::default(), TransitionMode::Smooth);
1097 }
1098
1099 #[test]
1100 fn lora_spec_serializes_minimal() {
1101 let spec = LoraSpec {
1102 path: "./style.safetensors".into(),
1103 scale: 0.8,
1104 name: None,
1105 };
1106 let json = serde_json::to_string(&spec).unwrap();
1107 assert!(json.contains(r#""path":"./style.safetensors""#));
1108 assert!(json.contains(r#""scale":0.8"#));
1109 assert!(!json.contains(r#""name""#));
1111 }
1112
1113 #[test]
1114 fn named_ref_serializes_minimal() {
1115 let r = NamedRef {
1116 name: "hero".into(),
1117 image: vec![0x89, 0x50],
1118 };
1119 let json = serde_json::to_string(&r).unwrap();
1120 assert!(json.contains(r#""name":"hero""#));
1122 assert!(json.contains(r#""image":"#));
1123 }
1124
1125 #[test]
1126 fn chain_stage_defaults_are_backcompat() {
1127 let json = r#"{
1130 "prompt": "a cat",
1131 "frames": 97
1132 }"#;
1133 let stage: ChainStage = serde_json::from_str(json).unwrap();
1134 assert_eq!(stage.prompt, "a cat");
1135 assert_eq!(stage.frames, 97);
1136 assert_eq!(stage.transition, TransitionMode::Smooth);
1137 assert_eq!(stage.fade_frames, None);
1138 assert!(stage.model.is_none());
1139 assert!(stage.loras.is_empty());
1140 assert!(stage.references.is_empty());
1141 }
1142
1143 #[test]
1144 fn chain_script_projects_from_request() {
1145 let req = ChainRequest {
1146 model: "ltx-2-19b-distilled:fp8".into(),
1147 stages: vec![ChainStage {
1148 prompt: "a".into(),
1149 frames: 97,
1150 source_image: None,
1151 negative_prompt: None,
1152 seed_offset: None,
1153 transition: TransitionMode::Smooth,
1154 fade_frames: None,
1155 model: None,
1156 loras: vec![],
1157 references: vec![],
1158 }],
1159 motion_tail_frames: 25,
1160 width: 1216,
1161 height: 704,
1162 fps: 24,
1163 seed: Some(42),
1164 steps: 8,
1165 guidance: 3.0,
1166 strength: 1.0,
1167 output_format: OutputFormat::Mp4,
1168 placement: None,
1169 prompt: None,
1170 total_frames: None,
1171 clip_frames: None,
1172 source_image: None,
1173 enable_audio: None,
1174 };
1175 let script = ChainScript::from(&req);
1176 assert_eq!(script.chain.model, "ltx-2-19b-distilled:fp8");
1177 assert_eq!(script.chain.seed, Some(42));
1178 assert_eq!(script.stages.len(), 1);
1179 assert_eq!(script.stages[0].prompt, "a");
1180 }
1181
1182 #[test]
1183 fn chain_stage_roundtrips_all_fields() {
1184 let stage = ChainStage {
1185 prompt: "scene".into(),
1186 frames: 49,
1187 source_image: None,
1188 negative_prompt: None,
1189 seed_offset: None,
1190 transition: TransitionMode::Cut,
1191 fade_frames: Some(12),
1192 model: None,
1193 loras: vec![],
1194 references: vec![],
1195 };
1196 let json = serde_json::to_string(&stage).unwrap();
1197 let back: ChainStage = serde_json::from_str(&json).unwrap();
1198 assert_eq!(back.frames, 49);
1199 assert_eq!(back.transition, TransitionMode::Cut);
1200 assert_eq!(back.fade_frames, Some(12));
1201 }
1202
1203 #[test]
1204 fn normalise_coerces_stage_0_transition_to_smooth() {
1205 let mut req = auto_expand_request("a", 97, 97, 25, None);
1206 req.stages = vec![
1207 ChainStage {
1208 prompt: "scene 0".into(),
1209 frames: 97,
1210 source_image: None,
1211 negative_prompt: None,
1212 seed_offset: None,
1213 transition: TransitionMode::Cut, fade_frames: None,
1215 model: None,
1216 loras: vec![],
1217 references: vec![],
1218 },
1219 ChainStage {
1220 prompt: "scene 1".into(),
1221 frames: 97,
1222 source_image: None,
1223 negative_prompt: None,
1224 seed_offset: None,
1225 transition: TransitionMode::Cut, fade_frames: None,
1227 model: None,
1228 loras: vec![],
1229 references: vec![],
1230 },
1231 ];
1232 let normalised = req.normalise().unwrap();
1233 assert_eq!(normalised.stages[0].transition, TransitionMode::Smooth);
1234 assert_eq!(normalised.stages[1].transition, TransitionMode::Cut);
1235 }
1236
1237 #[test]
1238 fn normalise_rejects_reserved_model_field() {
1239 let mut req = auto_expand_request("a", 97, 97, 25, None);
1240 req.stages = vec![ChainStage {
1241 prompt: "x".into(),
1242 frames: 97,
1243 source_image: None,
1244 negative_prompt: None,
1245 seed_offset: None,
1246 transition: TransitionMode::Smooth,
1247 fade_frames: None,
1248 model: Some("flux-dev:q4".into()),
1249 loras: vec![],
1250 references: vec![],
1251 }];
1252 let err = req.normalise().unwrap_err().to_string();
1253 assert!(err.contains("reserved for sub-project C"), "got: {err}");
1254 }
1255
1256 #[test]
1257 fn normalise_rejects_reserved_loras_field() {
1258 let mut req = auto_expand_request("a", 97, 97, 25, None);
1259 req.stages = vec![ChainStage {
1260 prompt: "x".into(),
1261 frames: 97,
1262 source_image: None,
1263 negative_prompt: None,
1264 seed_offset: None,
1265 transition: TransitionMode::Smooth,
1266 fade_frames: None,
1267 model: None,
1268 loras: vec![LoraSpec {
1269 path: "x.safetensors".into(),
1270 scale: 1.0,
1271 name: None,
1272 }],
1273 references: vec![],
1274 }];
1275 let err = req.normalise().unwrap_err().to_string();
1276 assert!(err.contains("reserved for sub-project B"), "got: {err}");
1277 }
1278
1279 fn stage_list_request(stages: Vec<(TransitionMode, u32, Option<u32>)>) -> ChainRequest {
1280 ChainRequest {
1281 model: "ltx-2-19b-distilled:fp8".into(),
1282 stages: stages
1283 .into_iter()
1284 .map(|(t, f, fl)| ChainStage {
1285 prompt: "x".into(),
1286 frames: f,
1287 source_image: None,
1288 negative_prompt: None,
1289 seed_offset: None,
1290 transition: t,
1291 fade_frames: fl,
1292 model: None,
1293 loras: vec![],
1294 references: vec![],
1295 })
1296 .collect(),
1297 motion_tail_frames: 25,
1298 width: 1216,
1299 height: 704,
1300 fps: 24,
1301 seed: None,
1302 steps: 8,
1303 guidance: 3.0,
1304 strength: 1.0,
1305 output_format: OutputFormat::Mp4,
1306 placement: None,
1307 prompt: None,
1308 total_frames: None,
1309 clip_frames: None,
1310 source_image: None,
1311 enable_audio: None,
1312 }
1313 }
1314
1315 #[test]
1316 fn estimated_total_all_smooth() {
1317 let req = stage_list_request(vec![
1319 (TransitionMode::Smooth, 97, None),
1320 (TransitionMode::Smooth, 97, None),
1321 (TransitionMode::Smooth, 97, None),
1322 ]);
1323 assert_eq!(req.estimated_total_frames(), 241);
1324 }
1325
1326 #[test]
1327 fn estimated_total_with_cut() {
1328 let req = stage_list_request(vec![
1330 (TransitionMode::Smooth, 97, None),
1331 (TransitionMode::Cut, 97, None),
1332 (TransitionMode::Smooth, 97, None),
1333 ]);
1334 assert_eq!(req.estimated_total_frames(), 266);
1335 }
1336
1337 #[test]
1338 fn estimated_total_with_fade() {
1339 let req = stage_list_request(vec![
1345 (TransitionMode::Smooth, 97, None),
1346 (TransitionMode::Cut, 97, None),
1347 (TransitionMode::Fade, 97, Some(8)),
1348 ]);
1349 assert_eq!(req.estimated_total_frames(), 283);
1350 }
1351}