mold-ai-inference 0.13.1

Candle-based inference engine for mold — FLUX, SDXL, SD3.5, Z-Image diffusion models
Documentation
#![allow(dead_code)]

use std::env;
use std::fs;
use std::path::{Path, PathBuf};

use anyhow::{bail, Context, Result};
use candle_core::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use image::{imageops::FilterType, GenericImage, Rgb, RgbImage};
use mold_inference::device::create_device;
use mold_inference::ltx2::media::write_contact_sheet;
use mold_inference::progress::ProgressReporter;

#[path = "../ltx_video/video_enc.rs"]
mod video_enc_impl;
#[path = "../ltx2/model/video_vae.rs"]
mod video_vae_impl;

use video_enc_impl::encode_gif;
use video_vae_impl::{AutoencoderKLLtx2Video, AutoencoderKLLtx2VideoConfig};

fn main() -> Result<()> {
    let args = env::args().skip(1).collect::<Vec<_>>();
    if args.len() < 3 {
        bail!(
            "usage: cargo run -p mold-ai-inference --features cuda,mp4,dev-bins --bin ltx2_vae_probe -- <checkpoint.safetensors> <input-image> <output.mp4> [width] [height] [frames] [fps]"
        );
    }

    let checkpoint_path = PathBuf::from(&args[0]);
    let input_path = PathBuf::from(&args[1]);
    let output_path = PathBuf::from(&args[2]);
    let width = parse_or_default(args.get(3), 576u32)?;
    let height = parse_or_default(args.get(4), 320u32)?;
    let frames = parse_or_default(args.get(5), 9usize)?;
    let fps = parse_or_default(args.get(6), 12u32)?;

    roundtrip_vae(
        &checkpoint_path,
        &input_path,
        &output_path,
        width,
        height,
        frames,
        fps,
    )
}

fn parse_or_default<T>(value: Option<&String>, default: T) -> Result<T>
where
    T: std::str::FromStr,
    T::Err: std::fmt::Display,
{
    match value {
        Some(value) => value
            .parse::<T>()
            .map_err(|err| anyhow::anyhow!("failed to parse '{value}': {err}")),
        None => Ok(default),
    }
}

fn roundtrip_vae(
    checkpoint_path: &Path,
    input_path: &Path,
    output_path: &Path,
    width: u32,
    height: u32,
    frames: usize,
    fps: u32,
) -> Result<()> {
    if !checkpoint_path.is_file() {
        bail!("checkpoint not found: {}", checkpoint_path.display());
    }
    if !input_path.is_file() {
        bail!("input image not found: {}", input_path.display());
    }
    if output_path.extension().and_then(|ext| ext.to_str()) != Some("mp4") {
        bail!("output must be an .mp4 path: {}", output_path.display());
    }
    if !width.is_multiple_of(32) || !height.is_multiple_of(32) {
        bail!("width and height must be divisible by 32, got {width}x{height}");
    }
    if frames == 0 {
        bail!("frames must be positive");
    }

    let device = create_device(&ProgressReporter::default())?;
    if device.is_metal() {
        bail!("Metal is not supported for native LTX-2 VAE probing");
    }
    let dtype = if device.is_cuda() {
        DType::BF16
    } else {
        DType::F32
    };

    let input_frame = load_and_resize_image(input_path, width, height)?;
    let input_frames = vec![input_frame.clone(); frames];
    let input_tensor = frames_to_video_tensor(&input_frames, &device, dtype)?;

    let vb = unsafe {
        VarBuilder::from_mmaped_safetensors(std::slice::from_ref(&checkpoint_path), dtype, &device)?
    };
    let vae_config = infer_vae_config(checkpoint_path);
    let vae = AutoencoderKLLtx2Video::new(vae_config, vb.pp("vae")).with_context(|| {
        format!(
            "failed to load native LTX-2 VAE from '{}'",
            checkpoint_path.display()
        )
    })?;

    println!(
        "device={:?} dtype={:?} mean_of_means_rms={:.6} std_of_means_rms={:.6}",
        device,
        dtype,
        tensor_rms(vae.latents_mean())?,
        tensor_rms(vae.latents_std())?,
    );
    println!(
        "input_video: shape={:?} mean={:.6} rms={:.6}",
        input_tensor.dims(),
        tensor_mean(&input_tensor)?,
        tensor_rms(&input_tensor)?,
    );

    let latents = vae.encode(&input_tensor)?;
    println!(
        "encoded_latents: shape={:?} mean={:.6} rms={:.6}",
        latents.dims(),
        tensor_mean(&latents)?,
        tensor_rms(&latents)?,
    );

    let (_decoder_output, decoded) = vae.decode(&latents, None, false, false)?;
    println!(
        "decoded_video: shape={:?} mean={:.6} rms={:.6}",
        decoded.dims(),
        tensor_mean(&decoded)?,
        tensor_rms(&decoded)?,
    );

    let decoded_frames = tensor_to_frames(&decoded)?;
    let mp4_bytes = encode_mp4_bytes(&decoded_frames, fps)?;
    fs::write(output_path, mp4_bytes)
        .with_context(|| format!("failed to write {}", output_path.display()))?;

    let stem = output_path
        .file_stem()
        .and_then(|stem| stem.to_str())
        .with_context(|| {
            format!(
                "failed to derive output stem from {}",
                output_path.display()
            )
        })?;
    let output_dir = output_path.parent().unwrap_or_else(|| Path::new("."));
    let gif_path = output_dir.join(format!("{stem}.gif"));
    let contact_sheet_path = output_dir.join(format!("{stem}-contact-sheet.png"));
    let input_png_path = output_dir.join(format!("{stem}-input.png"));
    let recon_png_path = output_dir.join(format!("{stem}-recon-first-frame.png"));
    let comparison_png_path = output_dir.join(format!("{stem}-comparison.png"));

    fs::write(&gif_path, encode_gif(&decoded_frames, fps)?)
        .with_context(|| format!("failed to write {}", gif_path.display()))?;
    write_contact_sheet(output_path, &contact_sheet_path)?;
    input_frame
        .save(&input_png_path)
        .with_context(|| format!("failed to write {}", input_png_path.display()))?;
    decoded_frames[0]
        .save(&recon_png_path)
        .with_context(|| format!("failed to write {}", recon_png_path.display()))?;
    build_comparison(&input_frame, &decoded_frames[0])?
        .save(&comparison_png_path)
        .with_context(|| format!("failed to write {}", comparison_png_path.display()))?;

    println!(
        "saved mp4={} gif={} contact_sheet={} input={} recon={} comparison={}",
        output_path.display(),
        gif_path.display(),
        contact_sheet_path.display(),
        input_png_path.display(),
        recon_png_path.display(),
        comparison_png_path.display(),
    );

    Ok(())
}

fn infer_vae_config(checkpoint_path: &Path) -> AutoencoderKLLtx2VideoConfig {
    let checkpoint_name = checkpoint_path.to_string_lossy().to_ascii_lowercase();
    if checkpoint_name.contains("ltx-2.3-22b") {
        AutoencoderKLLtx2VideoConfig::ltx2_22b()
    } else {
        AutoencoderKLLtx2VideoConfig::default()
    }
}

#[cfg(feature = "mp4")]
fn encode_mp4_bytes(frames: &[RgbImage], fps: u32) -> Result<Vec<u8>> {
    video_enc_impl::encode_mp4(frames, fps)
}

#[cfg(not(feature = "mp4"))]
fn encode_mp4_bytes(_frames: &[RgbImage], _fps: u32) -> Result<Vec<u8>> {
    bail!("ltx2_vae_probe requires the `mp4` feature")
}

fn load_and_resize_image(path: &Path, width: u32, height: u32) -> Result<RgbImage> {
    let image = image::open(path)
        .with_context(|| format!("failed to load {}", path.display()))?
        .to_rgb8();
    Ok(image::imageops::resize(
        &image,
        width,
        height,
        FilterType::Lanczos3,
    ))
}

fn frames_to_video_tensor(frames: &[RgbImage], device: &Device, dtype: DType) -> Result<Tensor> {
    let first = frames
        .first()
        .ok_or_else(|| anyhow::anyhow!("no frames for tensor conversion"))?;
    let width = first.width() as usize;
    let height = first.height() as usize;
    let frame_count = frames.len();
    let mut data = Vec::with_capacity(frame_count * width * height * 3);

    for channel in 0..3usize {
        for frame in frames {
            if frame.width() as usize != width || frame.height() as usize != height {
                bail!("all frames must have the same dimensions");
            }
            for pixel in frame.pixels() {
                data.push((pixel[channel] as f32 / 127.5) - 1.0);
            }
        }
    }

    Ok(Tensor::from_vec(data, (1, 3, frame_count, height, width), device)?.to_dtype(dtype)?)
}

fn tensor_to_frames(video: &Tensor) -> Result<Vec<RgbImage>> {
    let video = ((video
        .to_device(&Device::Cpu)?
        .to_dtype(DType::F32)?
        .clamp(-1f32, 1f32)?
        + 1.0)?
        * 127.5)?
        .to_dtype(DType::U8)?;
    let video = video.i(0)?;
    let (_, frames, height, width) = video.dims4()?;
    let mut out = Vec::with_capacity(frames);
    for index in 0..frames {
        let frame = video
            .i((.., index, .., ..))?
            .permute((1, 2, 0))?
            .contiguous()?;
        let data: Vec<u8> = frame.flatten_all()?.to_vec1()?;
        let image = RgbImage::from_raw(width as u32, height as u32, data)
            .ok_or_else(|| anyhow::anyhow!("failed to rebuild RGB frame {index}"))?;
        out.push(image);
    }
    Ok(out)
}

fn build_comparison(input: &RgbImage, recon: &RgbImage) -> Result<RgbImage> {
    if input.dimensions() != recon.dimensions() {
        bail!(
            "comparison images must have the same dimensions, got {:?} and {:?}",
            input.dimensions(),
            recon.dimensions()
        );
    }
    let gutter = 16u32;
    let label_band = 24u32;
    let width = input.width() * 2 + gutter;
    let height = input.height() + label_band;
    let mut canvas = RgbImage::from_pixel(width, height, Rgb([18, 18, 18]));
    canvas.copy_from(input, 0, label_band)?;
    canvas.copy_from(recon, input.width() + gutter, label_band)?;
    Ok(canvas)
}

fn tensor_mean(tensor: &Tensor) -> Result<f32> {
    Ok(tensor
        .to_device(&Device::Cpu)?
        .to_dtype(DType::F32)?
        .flatten_all()?
        .mean_all()?
        .to_scalar::<f32>()?)
}

fn tensor_rms(tensor: &Tensor) -> Result<f32> {
    let tensor = tensor.to_device(&Device::Cpu)?.to_dtype(DType::F32)?;
    Ok(tensor
        .flatten_all()?
        .sqr()?
        .mean_all()?
        .to_scalar::<f32>()?
        .sqrt())
}

#[cfg(test)]
mod tests {
    use std::path::Path;

    use super::infer_vae_config;

    #[test]
    fn infer_vae_config_selects_22b_layout_from_checkpoint_name() {
        let config = infer_vae_config(Path::new("/tmp/ltx-2.3-22b-distilled-fp8.safetensors"));

        assert_eq!(config.encoder_blocks[0].num_layers, 4);
        assert_eq!(config.encoder_blocks[1].name, "compress_space_res");
        assert_eq!(config.decoder_blocks[1].name, "compress_space");
    }

    #[test]
    fn infer_vae_config_defaults_for_non_22b_checkpoint() {
        let config = infer_vae_config(Path::new("/tmp/ltx-2-19b-distilled-fp8.safetensors"));

        assert_eq!(config.encoder_blocks[1].name, "compress_space_res");
        assert!(!config.encoder_blocks[1].residual);
        assert_eq!(config.decoder_blocks[1].name, "compress_all");
    }
}