mold-ai-inference 0.13.1

Candle-based inference engine for mold — FLUX, SDXL, SD3.5, Z-Image diffusion models
Documentation
use anyhow::Result;
use candle_core::{DType, Device, IndexOp, Tensor};

use super::patchifiers::{AudioPatchifier, VideoLatentPatchifier};
use super::shapes::{AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub enum LtxRopeType {
    Interleaved,
    Split,
}

pub fn midpoint_positions(bounds: &Tensor) -> Result<Tensor> {
    let starts = bounds.narrow(3, 0, 1)?;
    let ends = bounds.narrow(3, 1, 1)?;
    starts
        .broadcast_add(&ends)?
        .affine(0.5, 0.0)
        .map_err(Into::into)
}

pub fn video_token_positions(
    patchifier: VideoLatentPatchifier,
    shape: VideoLatentShape,
    device: &Device,
) -> Result<Tensor> {
    patchifier.get_patch_grid_bounds(shape, device)
}

pub fn get_pixel_coords(
    latent_coords: &Tensor,
    scale_factors: SpatioTemporalScaleFactors,
    causal_fix: bool,
) -> Result<Tensor> {
    let scale = Tensor::from_vec(
        vec![
            scale_factors.time as f32,
            scale_factors.height as f32,
            scale_factors.width as f32,
        ],
        (1, 3, 1, 1),
        latent_coords.device(),
    )?;
    let mut pixel_coords = latent_coords.to_dtype(DType::F32)?.broadcast_mul(&scale)?;
    if causal_fix {
        let temporal = pixel_coords.i((.., 0..1, .., ..))?;
        let corrected = temporal
            .affine(1.0, 1.0 - scale_factors.time as f64)?
            .clamp(0.0f32, f32::MAX)?;
        let height_width = pixel_coords.i((.., 1.., .., ..))?;
        pixel_coords = Tensor::cat(&[corrected, height_width], 1)?;
    }
    Ok(pixel_coords)
}

pub fn scale_video_time_to_seconds(pixel_coords: &Tensor, fps: f32) -> Result<Tensor> {
    let temporal = pixel_coords
        .i((.., 0..1, .., ..))?
        .affine(1.0 / fps as f64, 0.0)?;
    let height_width = pixel_coords.i((.., 1.., .., ..))?;
    Tensor::cat(&[temporal, height_width], 1).map_err(Into::into)
}

pub fn audio_temporal_positions(
    patchifier: AudioPatchifier,
    shape: AudioLatentShape,
    device: &Device,
) -> Result<Tensor> {
    patchifier.get_patch_grid_bounds(shape, device)
}

pub fn cross_modal_temporal_positions(
    video_bounds: &Tensor,
    audio_bounds: &Tensor,
) -> Result<(Tensor, Tensor)> {
    let video = midpoint_positions(&video_bounds.narrow(1, 0, 1)?.to_dtype(DType::F32)?)?;
    let audio = midpoint_positions(&audio_bounds.to_dtype(DType::F32)?)?;
    Ok((video, audio))
}

#[cfg(test)]
mod tests {
    use candle_core::{Device, IndexOp};

    use super::{
        audio_temporal_positions, cross_modal_temporal_positions, get_pixel_coords,
        midpoint_positions, scale_video_time_to_seconds, video_token_positions,
    };
    use crate::ltx2::model::{
        AudioLatentShape, AudioPatchifier, SpatioTemporalScaleFactors, VideoLatentPatchifier,
        VideoLatentShape,
    };

    #[test]
    fn video_rope_positions_cover_3d_grid() {
        let device = Device::Cpu;
        let bounds = video_token_positions(
            VideoLatentPatchifier::new(1),
            VideoLatentShape {
                batch: 1,
                channels: 128,
                frames: 2,
                height: 2,
                width: 2,
            },
            &device,
        )
        .unwrap();

        assert_eq!(bounds.dims4().unwrap(), (1, 3, 8, 2));
        let mids = midpoint_positions(&bounds).unwrap();
        assert_eq!(mids.dims4().unwrap(), (1, 3, 8, 1));
    }

    #[test]
    fn audio_rope_positions_are_temporal_only() {
        let device = Device::Cpu;
        let bounds = audio_temporal_positions(
            AudioPatchifier::new(16_000, 160, 4, true, 0),
            AudioLatentShape {
                batch: 1,
                channels: 8,
                frames: 4,
                mel_bins: 16,
            },
            &device,
        )
        .unwrap();

        assert_eq!(bounds.dims4().unwrap(), (1, 1, 4, 2));
    }

    #[test]
    fn cross_modal_temporal_positions_keep_video_and_audio_sequences() {
        let device = Device::Cpu;
        let video = video_token_positions(
            VideoLatentPatchifier::new(1),
            VideoLatentShape {
                batch: 1,
                channels: 128,
                frames: 2,
                height: 1,
                width: 1,
            },
            &device,
        )
        .unwrap();
        let audio = audio_temporal_positions(
            AudioPatchifier::new(16_000, 160, 4, true, 0),
            AudioLatentShape {
                batch: 1,
                channels: 8,
                frames: 3,
                mel_bins: 16,
            },
            &device,
        )
        .unwrap();

        let (video_temporal, audio_temporal) =
            cross_modal_temporal_positions(&video, &audio).unwrap();
        assert_eq!(video_temporal.dims4().unwrap(), (1, 1, 2, 1));
        assert_eq!(audio_temporal.dims4().unwrap(), (1, 1, 3, 1));
    }

    #[test]
    fn pixel_coords_apply_causal_first_frame_fix() {
        let device = Device::Cpu;
        let latent = video_token_positions(
            VideoLatentPatchifier::new(1),
            VideoLatentShape {
                batch: 1,
                channels: 128,
                frames: 2,
                height: 1,
                width: 1,
            },
            &device,
        )
        .unwrap();

        let pixel = get_pixel_coords(&latent, SpatioTemporalScaleFactors::default(), true).unwrap();
        let temporal = pixel.i((0, 0, .., ..)).unwrap().to_vec2::<f32>().unwrap();
        assert_eq!(temporal, vec![vec![0.0, 1.0], vec![1.0, 9.0]]);
    }

    #[test]
    fn video_time_positions_scale_to_seconds() {
        let device = Device::Cpu;
        let latent = video_token_positions(
            VideoLatentPatchifier::new(1),
            VideoLatentShape {
                batch: 1,
                channels: 128,
                frames: 2,
                height: 1,
                width: 1,
            },
            &device,
        )
        .unwrap();
        let pixel = get_pixel_coords(&latent, SpatioTemporalScaleFactors::default(), true).unwrap();
        let seconds = scale_video_time_to_seconds(&pixel, 4.0).unwrap();
        let temporal = seconds.i((0, 0, .., ..)).unwrap().to_vec2::<f32>().unwrap();
        assert_eq!(temporal, vec![vec![0.0, 0.25], vec![0.25, 2.25]]);
    }
}