Skip to main content

mold_core/
chain.rs

1//! Wire types for server-side chained video generation.
2//!
3//! A *chain* is a sequence of per-clip render stages stitched into a single
4//! output video. The v1 CLI UX is single-prompt + arbitrary length, but the
5//! wire format is stages-based from day one so the eventual movie-maker
6//! (multi-prompt, keyframes, selective regen) can author stages by hand
7//! without a breaking change.
8//!
9//! The server only ever sees the canonical [`ChainRequest`] shape — a
10//! `Vec<ChainStage>`. Callers can either build that directly or use the
11//! auto-expand form (`prompt` + `total_frames` + `clip_frames`), which
12//! [`ChainRequest::normalise`] collapses into stages.
13//!
14//! See `tasks/render-chain-v1-plan.md` for the full design rationale.
15
16use serde::{Deserialize, Serialize};
17
18use crate::error::{MoldError, Result};
19use crate::types::{DevicePlacement, OutputFormat, VideoData};
20
21/// How the boundary between the previous stage and this stage is rendered.
22///
23/// - `Smooth`: the engine honors the motion-tail latent carryover from the
24///   prior clip (v1 default behaviour). Produces a visual morph when the
25///   prompt changes.
26/// - `Cut`: fresh latent, no carryover. If the stage has a `source_image`
27///   the engine uses it as the i2v seed; otherwise pure t2v.
28/// - `Fade`: same engine path as `Cut`, plus a post-stitch alpha blend of
29///   the last `fade_frames` of the prior clip with the first `fade_frames`
30///   of this clip.
31///
32/// Stage 0's transition is meaningless (nothing to transition from) and is
33/// coerced to `Smooth` during `ChainRequest::normalise`.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize, utoipa::ToSchema)]
35#[serde(rename_all = "snake_case")]
36pub enum TransitionMode {
37    #[default]
38    Smooth,
39    Cut,
40    Fade,
41}
42
43/// Per-stage LoRA adapter spec. **Reserved for sub-project B** — populating
44/// this in a request before B lands causes `ChainRequest::normalise` to
45/// return 422. Defined now so scripts that round-trip through v1 clients
46/// don't drop fields silently.
47#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
48pub struct LoraSpec {
49    pub path: String,
50    pub scale: f64,
51    #[serde(default, skip_serializing_if = "Option::is_none")]
52    pub name: Option<String>,
53}
54
55/// Per-stage named reference character/style. **Reserved for sub-project
56/// B** — populating this causes `ChainRequest::normalise` to return 422.
57#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
58pub struct NamedRef {
59    pub name: String,
60    #[serde(with = "crate::types::base64_bytes")]
61    pub image: Vec<u8>,
62}
63
64/// A single rendered clip in a chain. Concatenated in order with motion-tail
65/// trimming on continuations (stages with `idx >= 1` drop the leading
66/// `motion_tail_frames` pixel frames of their output because those duplicate
67/// the tail of the previous stage that the engine carried across as
68/// latent-space conditioning).
69#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
70pub struct ChainStage {
71    /// Prompt used for this stage. In v1 all stages receive the same prompt
72    /// (auto-expand form replicates it); the movie-maker UI in v2 will let
73    /// users author per-stage prompts.
74    #[schema(example = "a cat walking through autumn leaves")]
75    pub prompt: String,
76
77    /// Frame count for this stage. Must be `8k+1` (LTX-2 pipeline constraint:
78    /// 9, 17, 25, …, 97).
79    #[schema(example = 97)]
80    pub frames: u32,
81
82    /// Optional starting image (raw PNG/JPEG bytes, base64 in JSON). In v1
83    /// this is only meaningful on `stages[0]`; later stages draw their
84    /// conditioning from the prior stage's motion-tail latents instead.
85    #[serde(
86        default,
87        skip_serializing_if = "Option::is_none",
88        with = "crate::types::base64_opt"
89    )]
90    pub source_image: Option<Vec<u8>>,
91
92    /// Optional negative prompt for CFG-based stages. v1 LTX-2 ignores this
93    /// (the distilled family doesn't use CFG); the field is reserved so the
94    /// movie-maker can round-trip it without re-migrating the wire format.
95    #[serde(default, skip_serializing_if = "Option::is_none")]
96    pub negative_prompt: Option<String>,
97
98    /// Optional per-stage seed offset. `None` in v1 — the orchestrator
99    /// derives each stage's seed from the chain's base seed. Reserved as the
100    /// v2 movie-maker override hook for "regenerate just this stage with a
101    /// different seed".
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub seed_offset: Option<u64>,
104
105    // NEW in multi-prompt v2 ───────────────────────────────────────────
106    /// Boundary style between the previous stage and this stage.
107    /// Stage 0's value is coerced to `Smooth` in `normalise`.
108    #[serde(default)]
109    pub transition: TransitionMode,
110
111    /// Length in pixel frames of the crossfade when `transition == Fade`.
112    /// `None` means use the server-announced default (8 frames). Capped
113    /// at `fade_frames_max` from `/api/capabilities/chain-limits`.
114    #[serde(default, skip_serializing_if = "Option::is_none")]
115    pub fade_frames: Option<u32>,
116
117    // RESERVED for B/C — populated values are rejected by normalise ───
118    /// **Reserved for sub-project C.** Populating this in a request
119    /// produces 422 in this release.
120    #[serde(default, skip_serializing_if = "Option::is_none")]
121    pub model: Option<String>,
122
123    /// **Reserved for sub-project B.** Non-empty values produce 422.
124    #[serde(default, skip_serializing_if = "Vec::is_empty")]
125    pub loras: Vec<LoraSpec>,
126
127    /// **Reserved for sub-project B.** Non-empty values produce 422.
128    #[serde(default, skip_serializing_if = "Vec::is_empty")]
129    pub references: Vec<NamedRef>,
130}
131
132/// Chained generation request. Server accepts either the canonical form
133/// (`stages` non-empty) or the auto-expand form (`prompt` + `total_frames` +
134/// `clip_frames`); [`ChainRequest::normalise`] collapses the latter into the
135/// former so downstream code only deals with `stages`.
136#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
137pub struct ChainRequest {
138    #[schema(example = "ltx-2-19b-distilled:fp8")]
139    pub model: String,
140
141    /// Canonical stages list. Empty triggers auto-expand from
142    /// `prompt`/`total_frames`/`clip_frames`.
143    #[serde(default)]
144    pub stages: Vec<ChainStage>,
145
146    /// Pixel frames of motion-tail overlap between consecutive stages.
147    /// `0` = no overlap (simple concat). `>0` = the final K pixel frames of
148    /// stage N's latents are threaded into stage N+1's conditioning, and
149    /// stage N+1's leading K output frames are dropped at stitch time.
150    ///
151    /// Defaults to `17` (matches the CLI `--motion-tail` and SPA defaults):
152    /// `1 + 16` lands on the LTX-2 VAE's `1 + 8k` causal-grid for a clean
153    /// re-encode of the carryover RGB frames. Values that do not satisfy
154    /// `1 + 8k` will fail the receiving stage's tail re-encode at the VAE.
155    /// Must be strictly less than each stage's `frames`.
156    #[serde(default = "default_motion_tail_frames")]
157    #[schema(example = 17)]
158    pub motion_tail_frames: u32,
159
160    #[schema(example = 1216)]
161    pub width: u32,
162    #[schema(example = 704)]
163    pub height: u32,
164    #[serde(default = "default_fps")]
165    #[schema(example = 24)]
166    pub fps: u32,
167
168    /// Chain base seed. Per-stage seeds are derived as
169    /// `base_seed ^ ((stage_idx as u64) << 32)` by the orchestrator so the
170    /// whole chain is reproducible from a single seed value.
171    #[serde(default, skip_serializing_if = "Option::is_none")]
172    #[schema(example = 42)]
173    pub seed: Option<u64>,
174
175    #[schema(example = 8)]
176    pub steps: u32,
177
178    #[schema(example = 3.0)]
179    pub guidance: f64,
180
181    /// Denoising strength for `stages[0].source_image`. Ignored when the
182    /// first stage has no source image. Continuation stages are always
183    /// full-strength conditioned via motion-tail latents.
184    #[serde(default = "default_strength")]
185    #[schema(example = 1.0)]
186    pub strength: f64,
187
188    #[serde(default = "default_output_format")]
189    pub output_format: OutputFormat,
190
191    #[serde(default, skip_serializing_if = "Option::is_none")]
192    pub placement: Option<DevicePlacement>,
193
194    // ── Auto-expand form ────────────────────────────────────────────────
195    // These are only read when `stages` is empty; `normalise` clears them
196    // after expansion so the canonical form only ever carries `stages`.
197    /// Auto-expand: single prompt replicated across all stages.
198    #[serde(default, skip_serializing_if = "Option::is_none")]
199    pub prompt: Option<String>,
200
201    /// Auto-expand: total pixel frames the stitched output should cover.
202    #[serde(default, skip_serializing_if = "Option::is_none")]
203    pub total_frames: Option<u32>,
204
205    /// Auto-expand: per-clip frame count. Defaults to `97` (LTX-2 19B/22B
206    /// distilled cap). Must be `8k+1`.
207    #[serde(default, skip_serializing_if = "Option::is_none")]
208    pub clip_frames: Option<u32>,
209
210    /// Auto-expand: starting image for `stages[0]`.
211    #[serde(
212        default,
213        skip_serializing_if = "Option::is_none",
214        with = "crate::types::base64_opt"
215    )]
216    pub source_image: Option<Vec<u8>>,
217
218    /// Generate per-stage audio and mux it into the final stitched output.
219    /// Only meaningful for AV-capable families (LTX-2 / LTX-2.3); the server
220    /// rejects `Some(true)` for non-AV models. `None` means "no preference"
221    /// and resolves to off — chains opt in to audio explicitly so existing
222    /// callers don't suddenly start producing audio they didn't ask for.
223    #[serde(default, skip_serializing_if = "Option::is_none")]
224    pub enable_audio: Option<bool>,
225}
226
227/// Canonical TOML-shaped projection of a normalised [`ChainRequest`].
228///
229/// Echoed back in [`ChainResponse::script`] so clients can save the exact
230/// form that was rendered without re-serialising the request body (which
231/// carries auto-expand sugar and other transport-only fields).
232#[derive(Debug, Clone, Default, Serialize, Deserialize, utoipa::ToSchema)]
233pub struct ChainScript {
234    pub schema: String, // always "mold.chain.v1"
235    pub chain: ChainScriptChain,
236    #[serde(rename = "stage")]
237    pub stages: Vec<ChainStage>,
238}
239
240#[derive(Debug, Clone, Default, Serialize, Deserialize, utoipa::ToSchema)]
241pub struct ChainScriptChain {
242    pub model: String,
243    pub width: u32,
244    pub height: u32,
245    pub fps: u32,
246    #[serde(default, skip_serializing_if = "Option::is_none")]
247    pub seed: Option<u64>,
248    pub steps: u32,
249    pub guidance: f64,
250    pub strength: f64,
251    pub motion_tail_frames: u32,
252    pub output_format: OutputFormat,
253    /// Echo of [`ChainRequest::enable_audio`]. Omitted from TOML when unset
254    /// so v1 scripts (no audio) deserialise unchanged.
255    #[serde(default, skip_serializing_if = "Option::is_none")]
256    pub enable_audio: Option<bool>,
257}
258
259impl From<&ChainRequest> for ChainScript {
260    fn from(req: &ChainRequest) -> Self {
261        ChainScript {
262            schema: "mold.chain.v1".into(),
263            chain: ChainScriptChain {
264                model: req.model.clone(),
265                width: req.width,
266                height: req.height,
267                fps: req.fps,
268                seed: req.seed,
269                steps: req.steps,
270                guidance: req.guidance,
271                strength: req.strength,
272                motion_tail_frames: req.motion_tail_frames,
273                output_format: req.output_format,
274                enable_audio: req.enable_audio,
275            },
276            stages: req.stages.clone(),
277        }
278    }
279}
280
281/// VRAM feasibility estimate — populated by sub-project D. `None` in this
282/// release.
283#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
284pub struct VramEstimate {
285    pub worst_case_bytes: u64,
286    pub fits: bool,
287}
288
289/// Response from a chained generation request. The `video` is the stitched
290/// output; individual per-stage clips are not returned.
291#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
292pub struct ChainResponse {
293    pub video: VideoData,
294    /// Number of stages that actually ran (matches `request.stages.len()`
295    /// after normalisation).
296    #[schema(example = 5)]
297    pub stage_count: u32,
298    /// GPU ordinal that handled the chain (multi-GPU servers only).
299    #[serde(default, skip_serializing_if = "Option::is_none")]
300    pub gpu: Option<usize>,
301
302    // NEW ──────────────────────────────────────────────────────────────
303    /// Canonical TOML-shaped echo of the rendered script. Clients can save
304    /// this directly as a `.toml` file.
305    pub script: ChainScript,
306
307    /// Reserved for sub-project D; `None` in this release.
308    #[serde(default, skip_serializing_if = "Option::is_none")]
309    pub vram_estimate: Option<VramEstimate>,
310}
311
312/// SSE completion event for a successful chain run. Streamed as the final
313/// `data:` frame under the `event: complete` SSE type. The payload is
314/// base64-encoded to stay JSON-safe; clients decode it into `VideoData`.
315///
316/// This is a sibling to [`crate::types::SseCompleteEvent`] rather than an
317/// extension so image/video vs. chain completion shapes stay independent
318/// and can evolve separately.
319#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
320pub struct SseChainCompleteEvent {
321    /// Base64-encoded stitched video bytes (format per `format` field).
322    pub video: String,
323    pub format: OutputFormat,
324    #[schema(example = 1216)]
325    pub width: u32,
326    #[schema(example = 704)]
327    pub height: u32,
328    #[schema(example = 400)]
329    pub frames: u32,
330    #[schema(example = 24)]
331    pub fps: u32,
332    /// Base64-encoded first-frame PNG thumbnail.
333    #[serde(default, skip_serializing_if = "Option::is_none")]
334    pub thumbnail: Option<String>,
335    /// Base64-encoded animated GIF preview (always emitted for gallery UI).
336    #[serde(default, skip_serializing_if = "Option::is_none")]
337    pub gif_preview: Option<String>,
338    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
339    pub has_audio: bool,
340    #[serde(default, skip_serializing_if = "Option::is_none")]
341    pub duration_ms: Option<u64>,
342    #[serde(default, skip_serializing_if = "Option::is_none")]
343    pub audio_sample_rate: Option<u32>,
344    #[serde(default, skip_serializing_if = "Option::is_none")]
345    pub audio_channels: Option<u32>,
346    /// Number of stages that ran end-to-end.
347    #[schema(example = 5)]
348    pub stage_count: u32,
349    /// GPU ordinal that handled the chain (multi-GPU only).
350    #[serde(default, skip_serializing_if = "Option::is_none")]
351    pub gpu: Option<usize>,
352    /// Wall-clock elapsed time across all stages + stitching.
353    #[serde(default, skip_serializing_if = "Option::is_none")]
354    pub generation_time_ms: Option<u64>,
355    /// Canonical echo of the normalised chain request, so streaming clients
356    /// can save/reload the rendered script without re-serialising the
357    /// transport-only fields in the submitted request body.
358    #[serde(default)]
359    pub script: ChainScript,
360    /// Reserved for sub-project D; `None` in this release.
361    #[serde(default, skip_serializing_if = "Option::is_none")]
362    pub vram_estimate: Option<VramEstimate>,
363}
364
365/// Chain-specific SSE progress event. Streamed as `data:` JSON frames from
366/// `POST /api/generate/chain/stream` under the `event: progress` SSE type.
367///
368/// Per-stage denoise steps are wrapped with `stage_idx` so consumers can
369/// render stacked progress bars (overall chain + per-stage) without a
370/// separate subscription. Non-denoise engine events (weight load, cache
371/// hits, etc.) are intentionally not forwarded through this enum in v1 —
372/// they're scoped to individual stages and the UX goal for v1 is per-stage
373/// progress, not per-component telemetry.
374#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema, PartialEq, Eq)]
375#[serde(tag = "type", rename_all = "snake_case")]
376pub enum ChainProgressEvent {
377    /// Emitted once at the start of the chain, after normalisation. Gives
378    /// consumers the final stage count and the target pre-trim frame total
379    /// so they can size progress bars up front.
380    ChainStart {
381        stage_count: u32,
382        estimated_total_frames: u32,
383    },
384    /// Stage `stage_idx` (0-indexed) has started its denoise loop.
385    StageStart { stage_idx: u32 },
386    /// Per-step denoise progress for the active stage.
387    DenoiseStep {
388        stage_idx: u32,
389        step: u32,
390        total: u32,
391    },
392    /// Stage finished generating; `frames_emitted` is the raw clip frame
393    /// count before motion-tail trim at stitch time.
394    StageDone { stage_idx: u32, frames_emitted: u32 },
395    /// All stages complete; stitching/encoding the final MP4.
396    Stitching { total_frames: u32 },
397}
398
399/// Structured error payload returned in the 502 response body when a chain
400/// stage fails mid-run. Allows UIs to show actionable retry hints (e.g.,
401/// "stage 2 of 5 failed — retry from here").
402#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
403pub struct ChainFailure {
404    /// Human-readable summary of where the failure landed.
405    #[schema(example = "stage render failed")]
406    pub error: String,
407    /// Zero-based index of the stage whose render returned Err.
408    #[schema(example = 2)]
409    pub failed_stage_idx: u32,
410    /// Number of stages that completed successfully before the failure.
411    #[schema(example = 2)]
412    pub elapsed_stages: u32,
413    /// Cumulative generation time across the completed stages, in ms.
414    #[schema(example = 12_340)]
415    pub elapsed_ms: u64,
416    /// Inner error message from the orchestrator (`format!("{e:#}")`).
417    #[schema(example = "simulated GPU OOM on stage 2")]
418    pub stage_error: String,
419}
420
421fn default_motion_tail_frames() -> u32 {
422    17
423}
424
425fn default_fps() -> u32 {
426    24
427}
428
429fn default_strength() -> f64 {
430    1.0
431}
432
433fn default_output_format() -> OutputFormat {
434    OutputFormat::Mp4
435}
436
437/// Maximum number of stages the v1 orchestrator will accept in a single
438/// chain. 16 × 97-frame clips ≈ 1552 frames ≈ 64 s at 24 fps — comfortably
439/// past the 400-frame target without risking runaway jobs.
440pub const MAX_CHAIN_STAGES: usize = 16;
441
442impl ChainRequest {
443    /// Collapse the auto-expand form into a canonical `Vec<ChainStage>` and
444    /// validate the result. Called once on the server side immediately after
445    /// JSON parsing, before any engine work kicks off.
446    ///
447    /// Post-conditions on a successful return:
448    /// - `self.stages` is non-empty.
449    /// - Each stage's `frames` is `8k+1` and `> 0`.
450    /// - `self.stages.len() <= MAX_CHAIN_STAGES`.
451    /// - All auto-expand fields are `None` (caller must use `self.stages`).
452    pub fn normalise(mut self) -> Result<Self> {
453        if self.stages.is_empty() {
454            let prompt = self.prompt.take().ok_or_else(|| {
455                MoldError::Validation(
456                    "chain request needs either stages[] or prompt + total_frames".into(),
457                )
458            })?;
459            let total_frames = self.total_frames.ok_or_else(|| {
460                MoldError::Validation("chain auto-expand requires total_frames".into())
461            })?;
462            if total_frames == 0 {
463                return Err(MoldError::Validation(
464                    "chain total_frames must be > 0".into(),
465                ));
466            }
467            let clip_frames = self.clip_frames.unwrap_or(97);
468            if clip_frames == 0 {
469                return Err(MoldError::Validation(
470                    "chain clip_frames must be > 0".into(),
471                ));
472            }
473            if !is_ltx2_frame_count(clip_frames) {
474                return Err(MoldError::Validation(format!(
475                    "chain clip_frames ({clip_frames}) must be 8k+1 (9, 17, 25, …, 97)",
476                )));
477            }
478            let motion_tail = self.motion_tail_frames;
479            if motion_tail >= clip_frames {
480                return Err(MoldError::Validation(format!(
481                    "motion_tail_frames ({motion_tail}) must be strictly less than clip_frames ({clip_frames})",
482                )));
483            }
484
485            let source_image = self.source_image.take();
486            self.stages = build_auto_expand_stages(
487                &prompt,
488                total_frames,
489                clip_frames,
490                motion_tail,
491                source_image,
492            )?;
493        }
494
495        if self.stages.is_empty() {
496            return Err(MoldError::Validation("chain request has no stages".into()));
497        }
498        if self.stages.len() > MAX_CHAIN_STAGES {
499            return Err(MoldError::Validation(format!(
500                "chain request has {} stages; maximum is {}",
501                self.stages.len(),
502                MAX_CHAIN_STAGES,
503            )));
504        }
505        if self.motion_tail_frames != 0 && !is_ltx2_frame_count(self.motion_tail_frames) {
506            return Err(MoldError::Validation(format!(
507                "motion_tail_frames ({}) must be 0 or 8k+1 (1, 9, 17, 25, …) so the carryover \
508                 RGB frames re-encode cleanly through the LTX-2 video VAE's 8× causal grid",
509                self.motion_tail_frames,
510            )));
511        }
512        for (idx, stage) in self.stages.iter().enumerate() {
513            if stage.frames == 0 {
514                return Err(MoldError::Validation(format!("stage {idx} has 0 frames",)));
515            }
516            if !is_ltx2_frame_count(stage.frames) {
517                return Err(MoldError::Validation(format!(
518                    "stage {idx} has {} frames; LTX-2 requires 8k+1 (9, 17, 25, …, 97)",
519                    stage.frames,
520                )));
521            }
522            if self.motion_tail_frames >= stage.frames {
523                return Err(MoldError::Validation(format!(
524                    "motion_tail_frames ({}) must be strictly less than stage {idx}'s frames ({})",
525                    self.motion_tail_frames, stage.frames,
526                )));
527            }
528        }
529
530        // Reserved-field rejection (sub-projects B/C).
531        for (idx, stage) in self.stages.iter().enumerate() {
532            if stage.model.is_some() {
533                return Err(MoldError::Validation(format!(
534                    "stages[{idx}].model is reserved for sub-project C and not yet supported"
535                )));
536            }
537            if !stage.loras.is_empty() {
538                return Err(MoldError::Validation(format!(
539                    "stages[{idx}].loras is reserved for sub-project B and not yet supported"
540                )));
541            }
542            if !stage.references.is_empty() {
543                return Err(MoldError::Validation(format!(
544                    "stages[{idx}].references is reserved for sub-project B and not yet supported"
545                )));
546            }
547        }
548
549        // Stage 0's transition is meaningless (nothing to transition from).
550        // Coerce to Smooth with a warn so scripts survive reorders.
551        if let Some(first) = self.stages.first_mut() {
552            if first.transition != TransitionMode::Smooth {
553                tracing::warn!(
554                    coerced_from = ?first.transition,
555                    "stage 0 transition is meaningless; coercing to Smooth"
556                );
557                first.transition = TransitionMode::Smooth;
558            }
559        }
560
561        // Canonicalise: clear auto-expand fields so downstream code only
562        // ever reads from `stages`.
563        self.prompt = None;
564        self.total_frames = None;
565        self.clip_frames = None;
566        self.source_image = None;
567
568        Ok(self)
569    }
570
571    /// Predicted stitched frame count *before* any top-level `total_frames`
572    /// trim. Used by UIs for the footer summary and by the server to size
573    /// the final buffer.
574    ///
575    /// Per-boundary rule:
576    /// - smooth: drop leading `motion_tail_frames` of the incoming clip
577    /// - cut: no trim
578    /// - fade: replace `2 * fade_len` frames (trailing of prior + leading of
579    ///   next) with `fade_len` blended frames → net `-fade_len`
580    pub fn estimated_total_frames(&self) -> u32 {
581        const DEFAULT_FADE_FRAMES: u32 = 8;
582        let mut total: u32 = 0;
583        for (idx, stage) in self.stages.iter().enumerate() {
584            if idx == 0 {
585                total += stage.frames;
586                continue;
587            }
588            match stage.transition {
589                TransitionMode::Smooth => {
590                    total += stage.frames.saturating_sub(self.motion_tail_frames);
591                }
592                TransitionMode::Cut => {
593                    total += stage.frames;
594                }
595                TransitionMode::Fade => {
596                    let fade_len = stage.fade_frames.unwrap_or(DEFAULT_FADE_FRAMES);
597                    total += stage.frames.saturating_sub(fade_len);
598                }
599            }
600        }
601        total
602    }
603}
604
605/// Returns `true` iff `n` has the form `8k + 1` for some non-negative integer
606/// `k` (1, 9, 17, 25, …). The LTX-2 pipeline has this constraint on pixel
607/// frame counts due to the VAE's 8× temporal compression with a causal first
608/// frame.
609fn is_ltx2_frame_count(n: u32) -> bool {
610    n % 8 == 1
611}
612
613/// Compute the stage count and per-stage frame allocation for the auto-
614/// expand form, matching Phase 1.4's stitch math:
615///
616/// - Stage 0 contributes `clip_frames` pixel frames.
617/// - Each continuation contributes `clip_frames - motion_tail_frames` new
618///   frames (the leading `motion_tail_frames` are dropped at stitch time
619///   because they duplicate the prior stage's latent tail).
620///
621/// Returns enough stages so the stitched total reaches at least
622/// `total_frames`; over-production is trimmed from the tail at stitch time
623/// per the signed-off decision 2026-04-20.
624fn build_auto_expand_stages(
625    prompt: &str,
626    total_frames: u32,
627    clip_frames: u32,
628    motion_tail_frames: u32,
629    source_image: Option<Vec<u8>>,
630) -> Result<Vec<ChainStage>> {
631    let (stage_count, per_stage_frames) = if total_frames <= clip_frames {
632        // Single stage: match the user's requested length exactly so we
633        // don't render 97 frames and throw most of them away. The frame
634        // count will still be validated as 8k+1 by the caller.
635        (1u32, total_frames)
636    } else {
637        let effective = clip_frames - motion_tail_frames;
638        // effective > 0 because the caller has already ensured
639        // motion_tail_frames < clip_frames.
640        let remainder = total_frames - clip_frames;
641        let count = 1 + remainder.div_ceil(effective);
642        (count, clip_frames)
643    };
644
645    let count_usize = stage_count as usize;
646    if count_usize > MAX_CHAIN_STAGES {
647        return Err(MoldError::Validation(format!(
648            "auto-expand would produce {stage_count} stages; maximum is {MAX_CHAIN_STAGES} \
649             (try reducing total_frames or increasing clip_frames)",
650        )));
651    }
652
653    let mut stages = Vec::with_capacity(count_usize);
654    for _ in 0..stage_count {
655        // Every stage carries the starting image: stage 0 uses it as the
656        // i2v replacement at frame 0, and continuation stages use it as a
657        // soft identity anchor through the append path (see
658        // `Ltx2Engine::render_chain_stage`). Keeping a durable reference
659        // across stages is what stops scene/identity drift past the first
660        // clip, whose effects were traced in render-chain v1 as the
661        // dominant cause of "strange" continuations — the motion tail
662        // alone only carries ~0.7 s of pixel context, nowhere near enough
663        // for the model to remember the scene across an 8-stage chain.
664        stages.push(ChainStage {
665            prompt: prompt.to_string(),
666            frames: per_stage_frames,
667            source_image: source_image.clone(),
668            negative_prompt: None,
669            seed_offset: None,
670            transition: TransitionMode::Smooth,
671            fade_frames: None,
672            model: None,
673            loras: vec![],
674            references: vec![],
675        });
676    }
677    Ok(stages)
678}
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683
684    /// Build a minimal auto-expand request with the given knobs. All other
685    /// fields use their v1 defaults so tests can focus on the logic under
686    /// exercise.
687    fn auto_expand_request(
688        prompt: &str,
689        total_frames: u32,
690        clip_frames: u32,
691        motion_tail_frames: u32,
692        source_image: Option<Vec<u8>>,
693    ) -> ChainRequest {
694        ChainRequest {
695            model: "ltx-2-19b-distilled:fp8".into(),
696            stages: Vec::new(),
697            motion_tail_frames,
698            width: 1216,
699            height: 704,
700            fps: 24,
701            seed: Some(42),
702            steps: 8,
703            guidance: 3.0,
704            strength: 1.0,
705            output_format: OutputFormat::Mp4,
706            placement: None,
707            prompt: Some(prompt.into()),
708            total_frames: Some(total_frames),
709            clip_frames: Some(clip_frames),
710            source_image,
711            enable_audio: None,
712        }
713    }
714
715    fn canonical_request(stages: Vec<ChainStage>, motion_tail_frames: u32) -> ChainRequest {
716        ChainRequest {
717            model: "ltx-2-19b-distilled:fp8".into(),
718            stages,
719            motion_tail_frames,
720            width: 1216,
721            height: 704,
722            fps: 24,
723            seed: Some(42),
724            steps: 8,
725            guidance: 3.0,
726            strength: 1.0,
727            output_format: OutputFormat::Mp4,
728            placement: None,
729            prompt: None,
730            total_frames: None,
731            clip_frames: None,
732            source_image: None,
733            enable_audio: None,
734        }
735    }
736
737    fn make_stage(frames: u32) -> ChainStage {
738        ChainStage {
739            prompt: "test".into(),
740            frames,
741            source_image: None,
742            negative_prompt: None,
743            seed_offset: None,
744            transition: TransitionMode::Smooth,
745            fade_frames: None,
746            model: None,
747            loras: vec![],
748            references: vec![],
749        }
750    }
751
752    #[test]
753    fn normalise_splits_single_prompt_into_stages() {
754        // total=400, clip=97, tail=9 → effective=88, remainder=303,
755        // N = 1 + ceil(303/88) = 1 + 4 = 5 stages of 97 frames each.
756        // Stitched = 97 + 4*88 = 449, which will be trimmed to 400 at
757        // stitch time (per the signed-off "trim from tail" decision).
758        let normalised = auto_expand_request("a cat walking", 400, 97, 9, None)
759            .normalise()
760            .expect("normalise should succeed");
761
762        assert_eq!(
763            normalised.stages.len(),
764            5,
765            "400/97 with a 9-frame motion tail should expand to 5 stages",
766        );
767        for stage in &normalised.stages {
768            assert_eq!(stage.frames, 97);
769            assert_eq!(stage.prompt, "a cat walking");
770            assert!(stage.seed_offset.is_none());
771        }
772        // Auto-expand fields are cleared post-normalisation.
773        assert!(normalised.prompt.is_none());
774        assert!(normalised.total_frames.is_none());
775        assert!(normalised.clip_frames.is_none());
776        assert!(normalised.source_image.is_none());
777    }
778
779    #[test]
780    fn normalise_preserves_starting_image_across_all_stages() {
781        let png = vec![0x89, 0x50, 0x4e, 0x47, 0xde, 0xad, 0xbe, 0xef];
782        let normalised = auto_expand_request("test", 200, 97, 9, Some(png.clone()))
783            .normalise()
784            .expect("normalise should succeed");
785
786        assert!(normalised.stages.len() >= 2);
787        for (idx, stage) in normalised.stages.iter().enumerate() {
788            // Every stage must carry the starting image. Stage 0 uses it
789            // as the i2v replacement at frame 0; continuations use it as a
790            // soft identity anchor through the append path so scene and
791            // subject identity stay coherent past the motion-tail window.
792            assert_eq!(
793                stage.source_image.as_deref(),
794                Some(png.as_slice()),
795                "stage {idx} must carry the starting image for cross-stage identity anchoring",
796            );
797        }
798    }
799
800    #[test]
801    fn normalise_rejects_empty() {
802        let mut req = canonical_request(Vec::new(), 9);
803        // No auto-expand fields either.
804        req.prompt = None;
805        req.total_frames = None;
806
807        let err = req.normalise().expect_err("empty chain should fail");
808        assert!(
809            matches!(err, MoldError::Validation(_)),
810            "empty chain should be a validation error, got {err:?}",
811        );
812    }
813
814    #[test]
815    fn normalise_rejects_non_8k1_frames() {
816        // Canonical form with a stage whose frames violates the 8k+1
817        // constraint.
818        let req = canonical_request(vec![make_stage(50)], 9);
819        let err = req.normalise().expect_err("non-8k+1 frames should fail");
820        assert!(
821            matches!(err, MoldError::Validation(msg) if msg.contains("8k+1")),
822            "error must mention the 8k+1 constraint",
823        );
824    }
825
826    #[test]
827    fn normalise_accepts_canonical_form_unchanged() {
828        // Caller already built stages; normalise should validate and clear
829        // the (already-empty) auto-expand fields without touching stages.
830        let stages = vec![make_stage(97), make_stage(97), make_stage(97)];
831        let normalised = canonical_request(stages.clone(), 9)
832            .normalise()
833            .expect("valid canonical form should pass");
834        assert_eq!(normalised.stages.len(), 3);
835        for (left, right) in normalised.stages.iter().zip(stages.iter()) {
836            assert_eq!(left.frames, right.frames);
837            assert_eq!(left.prompt, right.prompt);
838        }
839    }
840
841    #[test]
842    fn normalise_single_stage_when_total_leq_clip() {
843        // total=9 fits in one clip; don't render a full 97-frame stage and
844        // throw most of it away. Use motion_tail=1 (smallest valid 1+8k)
845        // so the strict-less-than-stage-frames invariant still holds for
846        // the lone 9-frame stage.
847        let normalised = auto_expand_request("short", 9, 97, 1, None)
848            .normalise()
849            .expect("short single-clip chain should pass");
850        assert_eq!(normalised.stages.len(), 1);
851        assert_eq!(normalised.stages[0].frames, 9);
852    }
853
854    #[test]
855    fn normalise_rejects_too_many_stages() {
856        // 17 canonical stages exceeds MAX_CHAIN_STAGES (16).
857        let stages = (0..17).map(|_| make_stage(97)).collect();
858        let err = canonical_request(stages, 9)
859            .normalise()
860            .expect_err("17-stage chain should fail");
861        assert!(
862            matches!(err, MoldError::Validation(msg) if msg.contains("maximum")),
863            "error must mention the max-stages cap",
864        );
865    }
866
867    #[test]
868    fn normalise_rejects_auto_expand_too_long() {
869        // 16 × 97 = 1552 max stitched frames before trim; asking for
870        // 4000 frames should blow the guardrail.
871        let err = auto_expand_request("too long", 4000, 97, 9, None)
872            .normalise()
873            .expect_err("runaway auto-expand should fail");
874        assert!(
875            matches!(err, MoldError::Validation(msg) if msg.contains("stages")),
876            "error must name the stage count guardrail",
877        );
878    }
879
880    #[test]
881    fn normalise_rejects_motion_tail_ge_clip() {
882        // motion_tail must leave at least one new frame per continuation.
883        let err = auto_expand_request("bad tail", 200, 97, 97, None)
884            .normalise()
885            .expect_err("motion_tail >= clip should fail");
886        assert!(
887            matches!(err, MoldError::Validation(msg) if msg.contains("motion_tail_frames")),
888            "error must name motion_tail_frames",
889        );
890    }
891
892    #[test]
893    fn enable_audio_defaults_to_none_and_round_trips_when_set() {
894        // Wire-conservative default: chains opt in to audio explicitly. A
895        // request that omits the field stays None (engine-side resolves to
896        // false), so existing chain callers don't suddenly get audio they
897        // didn't ask for. Setting `enable_audio: true` on the request must
898        // round-trip into the canonical script echo so clients can save and
899        // re-render the same chain with audio enabled.
900        let req: ChainRequest = serde_json::from_value(serde_json::json!({
901            "model": "ltx-2.3-22b-distilled:fp8",
902            "stages": [],
903            "width": 704,
904            "height": 416,
905            "steps": 4,
906            "guidance": 3.0,
907        }))
908        .expect("valid minimal chain request");
909        assert_eq!(req.enable_audio, None);
910
911        let req_with_audio: ChainRequest = serde_json::from_value(serde_json::json!({
912            "model": "ltx-2.3-22b-distilled:fp8",
913            "stages": [{"prompt": "a bird", "frames": 33}],
914            "width": 704,
915            "height": 416,
916            "steps": 4,
917            "guidance": 3.0,
918            "enable_audio": true,
919        }))
920        .expect("valid chain request with audio");
921        assert_eq!(req_with_audio.enable_audio, Some(true));
922
923        let script = ChainScript::from(&req_with_audio);
924        assert_eq!(
925            script.chain.enable_audio,
926            Some(true),
927            "ChainScript echo must preserve enable_audio for round-trip save/reload",
928        );
929    }
930
931    #[test]
932    fn motion_tail_default_lands_on_8k_plus_1_grid() {
933        // Server JSON default must satisfy `1 + 8k` so chain tail RGB frames
934        // re-encode cleanly through the LTX-2 video VAE. CLI and SPA already
935        // default to 17; pin the JSON deserialiser to the same value.
936        let req: ChainRequest = serde_json::from_value(serde_json::json!({
937            "model": "ltx-2.3-22b-distilled:fp8",
938            "stages": [],
939            "width": 704,
940            "height": 416,
941            "steps": 4,
942            "guidance": 3.0,
943        }))
944        .expect("valid minimal chain request");
945        assert_eq!(req.motion_tail_frames, 17);
946        assert!(is_ltx2_frame_count(req.motion_tail_frames));
947    }
948
949    #[test]
950    fn normalise_rejects_motion_tail_off_grid() {
951        // motion_tail_frames=4 is what the JSON default used to be — it does
952        // NOT satisfy `1 + 8k`, so the carryover VAE re-encode would fail
953        // deep in the engine with a shape mismatch. Reject with a clear
954        // message at the wire boundary instead.
955        let req = canonical_request(vec![make_stage(33)], 4);
956        let err = req
957            .normalise()
958            .expect_err("motion_tail_frames=4 must be rejected");
959        assert!(
960            matches!(err, MoldError::Validation(msg) if msg.contains("8k+1")),
961            "error must name the 8k+1 grid constraint",
962        );
963    }
964
965    #[test]
966    fn normalise_accepts_motion_tail_zero() {
967        // motion_tail=0 means hard concat, no overlap, no carryover encode.
968        // Must be valid so cut/fade chains can opt out of the grid entirely.
969        let mut second = make_stage(33);
970        second.transition = TransitionMode::Cut;
971        let req = canonical_request(vec![make_stage(33), second], 0);
972        req.normalise().expect("motion_tail=0 must be accepted");
973    }
974
975    #[test]
976    fn normalise_rejects_missing_total_frames_in_auto_expand() {
977        let mut req = canonical_request(Vec::new(), 4);
978        req.prompt = Some("missing total".into());
979        // total_frames omitted.
980        let err = req
981            .normalise()
982            .expect_err("missing total_frames should fail");
983        assert!(
984            matches!(err, MoldError::Validation(msg) if msg.contains("total_frames")),
985            "error must name total_frames",
986        );
987    }
988
989    #[test]
990    fn is_ltx2_frame_count_matches_8k_plus_1() {
991        for valid in [1u32, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97] {
992            assert!(
993                is_ltx2_frame_count(valid),
994                "{valid} should be a valid LTX-2 frame count",
995            );
996        }
997        for invalid in [0u32, 2, 8, 10, 16, 50, 96, 98, 100] {
998            assert!(
999                !is_ltx2_frame_count(invalid),
1000                "{invalid} must not pass the 8k+1 check",
1001            );
1002        }
1003    }
1004
1005    #[test]
1006    fn chain_progress_event_roundtrips_json_with_snake_case_tags() {
1007        let cases = [
1008            (
1009                ChainProgressEvent::ChainStart {
1010                    stage_count: 5,
1011                    estimated_total_frames: 469,
1012                },
1013                r#""type":"chain_start""#,
1014            ),
1015            (
1016                ChainProgressEvent::StageStart { stage_idx: 0 },
1017                r#""type":"stage_start""#,
1018            ),
1019            (
1020                ChainProgressEvent::DenoiseStep {
1021                    stage_idx: 2,
1022                    step: 4,
1023                    total: 8,
1024                },
1025                r#""type":"denoise_step""#,
1026            ),
1027            (
1028                ChainProgressEvent::StageDone {
1029                    stage_idx: 3,
1030                    frames_emitted: 97,
1031                },
1032                r#""type":"stage_done""#,
1033            ),
1034            (
1035                ChainProgressEvent::Stitching { total_frames: 400 },
1036                r#""type":"stitching""#,
1037            ),
1038        ];
1039        for (event, expected_tag) in cases {
1040            let json = serde_json::to_string(&event).expect("serialize");
1041            assert!(
1042                json.contains(expected_tag),
1043                "missing snake_case tag {expected_tag} in {json}",
1044            );
1045            let roundtrip: ChainProgressEvent = serde_json::from_str(&json).expect("deserialize");
1046            assert_eq!(roundtrip, event, "roundtrip must preserve payload");
1047        }
1048    }
1049
1050    #[test]
1051    fn build_stages_math_matches_stitch_budget() {
1052        // Auto-expand must produce enough stages that the stitch delivers
1053        // at least `total_frames` pixel frames. Stitch math:
1054        //   delivered = clip_frames + (N - 1) * (clip_frames - motion_tail)
1055        let cases = [
1056            (400u32, 97u32, 9u32, 5u32), // 97 + 4*88 = 449 ≥ 400
1057            (200, 97, 9, 3),             // 97 + 2*88 = 273 ≥ 200
1058            (97, 97, 9, 1),              // single clip hits 97 exactly
1059            (300, 97, 0, 4),             // zero tail, 4*97 = 388 ≥ 300
1060        ];
1061        for (total, clip, tail, expected_n) in cases {
1062            let req = auto_expand_request("m", total, clip, tail, None)
1063                .normalise()
1064                .expect("valid auto-expand should normalise");
1065            assert_eq!(
1066                req.stages.len() as u32,
1067                expected_n,
1068                "expected {expected_n} stages for total={total}, clip={clip}, tail={tail}",
1069            );
1070            let delivered = clip + (expected_n - 1) * (clip - tail);
1071            assert!(
1072                delivered >= total,
1073                "{expected_n} stages deliver {delivered} frames but {total} were requested",
1074            );
1075        }
1076    }
1077
1078    #[test]
1079    fn transition_mode_serializes_snake_case() {
1080        assert_eq!(
1081            serde_json::to_value(TransitionMode::Smooth).unwrap(),
1082            serde_json::Value::String("smooth".into())
1083        );
1084        assert_eq!(
1085            serde_json::to_value(TransitionMode::Cut).unwrap(),
1086            serde_json::Value::String("cut".into())
1087        );
1088        assert_eq!(
1089            serde_json::to_value(TransitionMode::Fade).unwrap(),
1090            serde_json::Value::String("fade".into())
1091        );
1092    }
1093
1094    #[test]
1095    fn transition_mode_defaults_to_smooth() {
1096        assert_eq!(TransitionMode::default(), TransitionMode::Smooth);
1097    }
1098
1099    #[test]
1100    fn lora_spec_serializes_minimal() {
1101        let spec = LoraSpec {
1102            path: "./style.safetensors".into(),
1103            scale: 0.8,
1104            name: None,
1105        };
1106        let json = serde_json::to_string(&spec).unwrap();
1107        assert!(json.contains(r#""path":"./style.safetensors""#));
1108        assert!(json.contains(r#""scale":0.8"#));
1109        // name omitted
1110        assert!(!json.contains(r#""name""#));
1111    }
1112
1113    #[test]
1114    fn named_ref_serializes_minimal() {
1115        let r = NamedRef {
1116            name: "hero".into(),
1117            image: vec![0x89, 0x50],
1118        };
1119        let json = serde_json::to_string(&r).unwrap();
1120        // base64-encoded image via the existing base64 helper
1121        assert!(json.contains(r#""name":"hero""#));
1122        assert!(json.contains(r#""image":"#));
1123    }
1124
1125    #[test]
1126    fn chain_stage_defaults_are_backcompat() {
1127        // Parsing a v1-shaped stage (no new fields) yields the same structure
1128        // with defaults applied.
1129        let json = r#"{
1130            "prompt": "a cat",
1131            "frames": 97
1132        }"#;
1133        let stage: ChainStage = serde_json::from_str(json).unwrap();
1134        assert_eq!(stage.prompt, "a cat");
1135        assert_eq!(stage.frames, 97);
1136        assert_eq!(stage.transition, TransitionMode::Smooth);
1137        assert_eq!(stage.fade_frames, None);
1138        assert!(stage.model.is_none());
1139        assert!(stage.loras.is_empty());
1140        assert!(stage.references.is_empty());
1141    }
1142
1143    #[test]
1144    fn chain_script_projects_from_request() {
1145        let req = ChainRequest {
1146            model: "ltx-2-19b-distilled:fp8".into(),
1147            stages: vec![ChainStage {
1148                prompt: "a".into(),
1149                frames: 97,
1150                source_image: None,
1151                negative_prompt: None,
1152                seed_offset: None,
1153                transition: TransitionMode::Smooth,
1154                fade_frames: None,
1155                model: None,
1156                loras: vec![],
1157                references: vec![],
1158            }],
1159            motion_tail_frames: 25,
1160            width: 1216,
1161            height: 704,
1162            fps: 24,
1163            seed: Some(42),
1164            steps: 8,
1165            guidance: 3.0,
1166            strength: 1.0,
1167            output_format: OutputFormat::Mp4,
1168            placement: None,
1169            prompt: None,
1170            total_frames: None,
1171            clip_frames: None,
1172            source_image: None,
1173            enable_audio: None,
1174        };
1175        let script = ChainScript::from(&req);
1176        assert_eq!(script.chain.model, "ltx-2-19b-distilled:fp8");
1177        assert_eq!(script.chain.seed, Some(42));
1178        assert_eq!(script.stages.len(), 1);
1179        assert_eq!(script.stages[0].prompt, "a");
1180    }
1181
1182    #[test]
1183    fn chain_stage_roundtrips_all_fields() {
1184        let stage = ChainStage {
1185            prompt: "scene".into(),
1186            frames: 49,
1187            source_image: None,
1188            negative_prompt: None,
1189            seed_offset: None,
1190            transition: TransitionMode::Cut,
1191            fade_frames: Some(12),
1192            model: None,
1193            loras: vec![],
1194            references: vec![],
1195        };
1196        let json = serde_json::to_string(&stage).unwrap();
1197        let back: ChainStage = serde_json::from_str(&json).unwrap();
1198        assert_eq!(back.frames, 49);
1199        assert_eq!(back.transition, TransitionMode::Cut);
1200        assert_eq!(back.fade_frames, Some(12));
1201    }
1202
1203    #[test]
1204    fn normalise_coerces_stage_0_transition_to_smooth() {
1205        let mut req = auto_expand_request("a", 97, 97, 25, None);
1206        req.stages = vec![
1207            ChainStage {
1208                prompt: "scene 0".into(),
1209                frames: 97,
1210                source_image: None,
1211                negative_prompt: None,
1212                seed_offset: None,
1213                transition: TransitionMode::Cut, // should coerce
1214                fade_frames: None,
1215                model: None,
1216                loras: vec![],
1217                references: vec![],
1218            },
1219            ChainStage {
1220                prompt: "scene 1".into(),
1221                frames: 97,
1222                source_image: None,
1223                negative_prompt: None,
1224                seed_offset: None,
1225                transition: TransitionMode::Cut, // preserved
1226                fade_frames: None,
1227                model: None,
1228                loras: vec![],
1229                references: vec![],
1230            },
1231        ];
1232        let normalised = req.normalise().unwrap();
1233        assert_eq!(normalised.stages[0].transition, TransitionMode::Smooth);
1234        assert_eq!(normalised.stages[1].transition, TransitionMode::Cut);
1235    }
1236
1237    #[test]
1238    fn normalise_rejects_reserved_model_field() {
1239        let mut req = auto_expand_request("a", 97, 97, 25, None);
1240        req.stages = vec![ChainStage {
1241            prompt: "x".into(),
1242            frames: 97,
1243            source_image: None,
1244            negative_prompt: None,
1245            seed_offset: None,
1246            transition: TransitionMode::Smooth,
1247            fade_frames: None,
1248            model: Some("flux-dev:q4".into()),
1249            loras: vec![],
1250            references: vec![],
1251        }];
1252        let err = req.normalise().unwrap_err().to_string();
1253        assert!(err.contains("reserved for sub-project C"), "got: {err}");
1254    }
1255
1256    #[test]
1257    fn normalise_rejects_reserved_loras_field() {
1258        let mut req = auto_expand_request("a", 97, 97, 25, None);
1259        req.stages = vec![ChainStage {
1260            prompt: "x".into(),
1261            frames: 97,
1262            source_image: None,
1263            negative_prompt: None,
1264            seed_offset: None,
1265            transition: TransitionMode::Smooth,
1266            fade_frames: None,
1267            model: None,
1268            loras: vec![LoraSpec {
1269                path: "x.safetensors".into(),
1270                scale: 1.0,
1271                name: None,
1272            }],
1273            references: vec![],
1274        }];
1275        let err = req.normalise().unwrap_err().to_string();
1276        assert!(err.contains("reserved for sub-project B"), "got: {err}");
1277    }
1278
1279    fn stage_list_request(stages: Vec<(TransitionMode, u32, Option<u32>)>) -> ChainRequest {
1280        ChainRequest {
1281            model: "ltx-2-19b-distilled:fp8".into(),
1282            stages: stages
1283                .into_iter()
1284                .map(|(t, f, fl)| ChainStage {
1285                    prompt: "x".into(),
1286                    frames: f,
1287                    source_image: None,
1288                    negative_prompt: None,
1289                    seed_offset: None,
1290                    transition: t,
1291                    fade_frames: fl,
1292                    model: None,
1293                    loras: vec![],
1294                    references: vec![],
1295                })
1296                .collect(),
1297            motion_tail_frames: 25,
1298            width: 1216,
1299            height: 704,
1300            fps: 24,
1301            seed: None,
1302            steps: 8,
1303            guidance: 3.0,
1304            strength: 1.0,
1305            output_format: OutputFormat::Mp4,
1306            placement: None,
1307            prompt: None,
1308            total_frames: None,
1309            clip_frames: None,
1310            source_image: None,
1311            enable_audio: None,
1312        }
1313    }
1314
1315    #[test]
1316    fn estimated_total_all_smooth() {
1317        // 3 × 97-frame smooth = 97 + (97-25) + (97-25) = 241
1318        let req = stage_list_request(vec![
1319            (TransitionMode::Smooth, 97, None),
1320            (TransitionMode::Smooth, 97, None),
1321            (TransitionMode::Smooth, 97, None),
1322        ]);
1323        assert_eq!(req.estimated_total_frames(), 241);
1324    }
1325
1326    #[test]
1327    fn estimated_total_with_cut() {
1328        // 97 + 97 (cut, no trim) + (97-25) (smooth after cut) = 266
1329        let req = stage_list_request(vec![
1330            (TransitionMode::Smooth, 97, None),
1331            (TransitionMode::Cut, 97, None),
1332            (TransitionMode::Smooth, 97, None),
1333        ]);
1334        assert_eq!(req.estimated_total_frames(), 266);
1335    }
1336
1337    #[test]
1338    fn estimated_total_with_fade() {
1339        // 97 + 97 + (97 - fade 8) fade consumes from both sides, net -fade_len
1340        // Actually: fade replaces the trailing fade_len of clip N + leading
1341        // fade_len of clip N+1 with fade_len blended frames.
1342        // Emission = sum - 2*fade_len + fade_len = sum - fade_len
1343        // = 97+97+97 - 8 = 283
1344        let req = stage_list_request(vec![
1345            (TransitionMode::Smooth, 97, None),
1346            (TransitionMode::Cut, 97, None),
1347            (TransitionMode::Fade, 97, Some(8)),
1348        ]);
1349        assert_eq!(req.estimated_total_frames(), 283);
1350    }
1351}