mold-ai-inference 0.13.1

Candle-based inference engine for mold — FLUX, SDXL, SD3.5, Z-Image diffusion models
Documentation
use anyhow::{bail, Context, Result};
use candle_core::{DType, Tensor};

use crate::ltx2::execution::SamplerMode;

#[allow(dead_code)]
pub const DISTILLED_SIGMA_VALUES: &[f32] = &[
    1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0,
];

#[allow(dead_code)]
pub const STAGE_2_DISTILLED_SIGMA_VALUES: &[f32] = &[0.909375, 0.725, 0.421875, 0.0];

pub fn to_velocity(sample: &Tensor, sigma: f64, denoised_sample: &Tensor) -> Result<Tensor> {
    if sigma == 0.0 {
        bail!("sigma cannot be zero when converting to velocity");
    }
    sample
        .to_dtype(DType::F32)?
        .broadcast_sub(&denoised_sample.to_dtype(DType::F32)?)?
        .affine(1.0 / sigma, 0.0)
        .map_err(Into::into)
}

#[allow(dead_code)]
pub fn to_denoised(sample: &Tensor, velocity: &Tensor, sigma: f64) -> Result<Tensor> {
    Ok(sample
        .to_dtype(DType::F32)?
        .broadcast_sub(&(velocity.to_dtype(DType::F32)? * sigma)?)?)
}

pub fn euler_step(
    sample: &Tensor,
    denoised_sample: &Tensor,
    sigmas: &[f32],
    step_index: usize,
) -> Result<Tensor> {
    if step_index + 1 >= sigmas.len() {
        bail!("euler step requires a sigma and next sigma");
    }
    let sigma = sigmas[step_index] as f64;
    let sigma_next = sigmas[step_index + 1] as f64;
    let dt = sigma_next - sigma;
    let velocity = to_velocity(sample, sigma, denoised_sample)?;
    Ok(sample
        .to_dtype(DType::F32)?
        .broadcast_add(&(velocity * dt)?)?)
}

pub(crate) fn sampler_step(
    sampler_mode: SamplerMode,
    sample: &Tensor,
    denoised_sample: &Tensor,
    sigmas: &[f32],
    step_index: usize,
    noise: Option<&Tensor>,
    missing_noise_context: &'static str,
) -> Result<Tensor> {
    match sampler_mode {
        SamplerMode::Euler => euler_step(sample, denoised_sample, sigmas, step_index),
        SamplerMode::Res2S => {
            if step_index + 1 >= sigmas.len() {
                bail!("Res2S sampler step requires a sigma and next sigma");
            }
            res2s_step(
                sample,
                denoised_sample,
                sigmas[step_index] as f64,
                sigmas[step_index + 1] as f64,
                noise.context(missing_noise_context)?,
                0.5,
            )
        }
    }
}

#[allow(dead_code)]
pub fn apply_denoise_mask(
    denoised: &Tensor,
    denoise_mask: Option<&Tensor>,
    clean_latent: Option<&Tensor>,
) -> Result<Tensor> {
    match (denoise_mask, clean_latent) {
        (Some(mask), Some(clean)) => denoised
            .broadcast_mul(mask)?
            .broadcast_add(&clean.broadcast_mul(&mask.affine(-1.0, 1.0)?)?)
            .map_err(Into::into),
        _ => Ok(denoised.clone()),
    }
}

#[allow(dead_code)]
pub fn euler_denoising_loop<F>(
    initial_sample: &Tensor,
    sigmas: &[f32],
    denoise_mask: Option<&Tensor>,
    clean_latent: Option<&Tensor>,
    mut denoiser: F,
) -> Result<Tensor>
where
    F: FnMut(&Tensor, usize) -> Result<Tensor>,
{
    if sigmas.len() < 2 {
        bail!("euler denoising loop requires at least two sigma values");
    }

    let mut sample = initial_sample.clone();
    for step_index in 0..(sigmas.len() - 1) {
        let denoised = denoiser(&sample, step_index)?;
        let denoised = apply_denoise_mask(&denoised, denoise_mask, clean_latent)?;
        sample = sampler_step(
            SamplerMode::Euler,
            &sample,
            &denoised,
            sigmas,
            step_index,
            None,
            "Euler sampler does not require noise",
        )?;
    }
    Ok(sample)
}

#[allow(dead_code)]
pub fn phi(j: usize, neg_h: f64) -> f64 {
    if neg_h.abs() < 1e-10 {
        return 1.0 / factorial(j) as f64;
    }
    let remainder = (0..j)
        .map(|k| neg_h.powi(k as i32) / factorial(k) as f64)
        .sum::<f64>();
    (neg_h.exp() - remainder) / neg_h.powi(j as i32)
}

#[allow(dead_code)]
pub fn res2s_coefficients(h: f64, c2: f64) -> (f64, f64, f64) {
    let a21 = c2 * phi(1, -h * c2);
    let b2 = phi(2, -h) / c2;
    let b1 = phi(1, -h) - b2;
    (a21, b1, b2)
}

pub fn res2s_sde_coefficients(sigma_next: f64, eta: f64) -> (f64, f64, f64) {
    let sigma_up = (sigma_next * eta).min(sigma_next * 0.9999);
    let sigma_signal = 1.0 - sigma_next;
    let sigma_residual = (sigma_next.powi(2) - sigma_up.powi(2)).max(0.0).sqrt();
    let alpha_ratio = sigma_signal + sigma_residual;
    let sigma_down = if alpha_ratio.abs() < f64::EPSILON {
        sigma_next
    } else {
        sigma_residual / alpha_ratio
    };
    (alpha_ratio, sigma_down, sigma_up)
}

pub fn res2s_step(
    sample: &Tensor,
    denoised_sample: &Tensor,
    sigma: f64,
    sigma_next: f64,
    noise: &Tensor,
    eta: f64,
) -> Result<Tensor> {
    let (alpha_ratio, sigma_down, sigma_up) = res2s_sde_coefficients(sigma_next, eta);
    if sigma_up == 0.0 || sigma_next == 0.0 {
        return Ok(denoised_sample.clone());
    }

    let eps_next = sample
        .to_dtype(DType::F32)?
        .broadcast_sub(&denoised_sample.to_dtype(DType::F32)?)?
        .affine(1.0 / (sigma - sigma_next), 0.0)?;
    let denoised_next = sample
        .to_dtype(DType::F32)?
        .broadcast_sub(&eps_next.affine(sigma, 0.0)?)?;
    let drift = denoised_next.broadcast_add(&eps_next.affine(sigma_down, 0.0)?)?;
    let drift = drift.affine(alpha_ratio, 0.0)?;
    let noise_term = noise.to_dtype(DType::F32)?.affine(sigma_up, 0.0)?;
    drift.broadcast_add(&noise_term).map_err(Into::into)
}

#[allow(dead_code)]
pub fn res2s_denoising_loop<F>(
    initial_sample: &Tensor,
    sigmas: &[f32],
    denoise_mask: Option<&Tensor>,
    clean_latent: Option<&Tensor>,
    noise: &Tensor,
    mut denoiser: F,
) -> Result<Tensor>
where
    F: FnMut(&Tensor, usize) -> Result<Tensor>,
{
    if sigmas.len() < 2 {
        bail!("res2s denoising loop requires at least two sigma values");
    }

    let mut sample = initial_sample.clone();
    for step_index in 0..(sigmas.len() - 1) {
        let denoised = denoiser(&sample, step_index)?;
        let denoised = apply_denoise_mask(&denoised, denoise_mask, clean_latent)?;
        sample = sampler_step(
            SamplerMode::Res2S,
            &sample,
            &denoised,
            sigmas,
            step_index,
            Some(noise),
            "Res2S sampler noise missing",
        )?;
    }
    Ok(sample)
}

#[allow(dead_code)]
fn factorial(n: usize) -> usize {
    (1..=n).product::<usize>().max(1)
}

#[cfg(test)]
mod tests {
    use candle_core::{Device, Tensor};

    use crate::ltx2::execution::SamplerMode;

    use super::{
        euler_denoising_loop, euler_step, phi, res2s_coefficients, res2s_denoising_loop,
        res2s_sde_coefficients, sampler_step, DISTILLED_SIGMA_VALUES,
        STAGE_2_DISTILLED_SIGMA_VALUES,
    };

    #[test]
    fn distilled_sigma_schedules_match_published_values() {
        assert_eq!(DISTILLED_SIGMA_VALUES.len(), 9);
        assert_eq!(DISTILLED_SIGMA_VALUES[0], 1.0);
        assert_eq!(DISTILLED_SIGMA_VALUES[7], 0.421875);
        assert_eq!(DISTILLED_SIGMA_VALUES[8], 0.0);

        assert_eq!(
            STAGE_2_DISTILLED_SIGMA_VALUES,
            &[0.909375, 0.725, 0.421875, 0.0]
        );
    }

    #[test]
    fn euler_step_advances_sample_by_velocity_dt() {
        let device = Device::Cpu;
        let sample = Tensor::new(&[2f32], &device).unwrap();
        let denoised = Tensor::new(&[1f32], &device).unwrap();
        let out = euler_step(&sample, &denoised, &[1.0, 0.5], 0)
            .unwrap()
            .to_vec1::<f32>()
            .unwrap();
        assert_eq!(out, vec![1.5]);
    }

    #[test]
    fn sampler_step_matches_euler_step_for_schedule_index() {
        let device = Device::Cpu;
        let sample = Tensor::new(&[2f32], &device).unwrap();
        let denoised = Tensor::new(&[1f32], &device).unwrap();

        let direct = euler_step(&sample, &denoised, &[1.0, 0.5], 0)
            .unwrap()
            .to_vec1::<f32>()
            .unwrap();
        let via_helper = sampler_step(
            SamplerMode::Euler,
            &sample,
            &denoised,
            &[1.0, 0.5],
            0,
            None,
            "test sampler noise missing",
        )
        .unwrap()
        .to_vec1::<f32>()
        .unwrap();

        assert_eq!(via_helper, direct);
    }

    #[test]
    fn sampler_step_matches_res2s_step_for_schedule_index() {
        let device = Device::Cpu;
        let sample = Tensor::new(&[2f32], &device).unwrap();
        let denoised = Tensor::new(&[1f32], &device).unwrap();
        let noise = Tensor::zeros((1,), candle_core::DType::F32, &device).unwrap();

        let direct = super::res2s_step(&sample, &denoised, 0.5, 0.0, &noise, 0.5)
            .unwrap()
            .to_vec1::<f32>()
            .unwrap();
        let via_helper = sampler_step(
            SamplerMode::Res2S,
            &sample,
            &denoised,
            &[0.5, 0.0],
            0,
            Some(&noise),
            "test sampler noise missing",
        )
        .unwrap()
        .to_vec1::<f32>()
        .unwrap();

        assert_eq!(via_helper, direct);
    }

    #[test]
    fn sampler_step_requires_noise_for_res2s() {
        let device = Device::Cpu;
        let sample = Tensor::new(&[2f32], &device).unwrap();
        let denoised = Tensor::new(&[1f32], &device).unwrap();

        let err = sampler_step(
            SamplerMode::Res2S,
            &sample,
            &denoised,
            &[0.5, 0.0],
            0,
            None,
            "test sampler noise missing",
        )
        .unwrap_err();

        assert!(err.to_string().contains("test sampler noise missing"));
    }

    #[test]
    fn euler_denoising_loop_supports_skip_step_like_identity_denoiser() {
        let device = Device::Cpu;
        let initial = Tensor::new(&[2f32], &device).unwrap();
        let out = euler_denoising_loop(&initial, &[1.0, 0.5, 0.0], None, None, |sample, _| {
            Ok(sample.clone())
        })
        .unwrap()
        .to_vec1::<f32>()
        .unwrap();
        assert_eq!(out, vec![2.0]);
    }

    #[test]
    fn phi_matches_taylor_limit_near_zero() {
        let value = phi(2, -1e-12);
        assert!((value - 0.5).abs() < 1e-6);
    }

    #[test]
    fn res2s_coefficients_are_finite_for_midpoint_scheme() {
        let (a21, b1, b2) = res2s_coefficients(0.5, 0.5);
        assert!(a21.is_finite());
        assert!(b1.is_finite());
        assert!(b2.is_finite());
    }

    #[test]
    fn res2s_sde_coefficients_are_bounded() {
        let (alpha_ratio, sigma_down, sigma_up) = res2s_sde_coefficients(0.5, 0.5);
        assert!(alpha_ratio.is_finite());
        assert!(sigma_down >= 0.0);
        assert!(sigma_up >= 0.0);
        assert!(sigma_up < 0.5);
    }

    #[test]
    fn res2s_loop_returns_denoised_output_at_terminal_sigma() {
        let device = Device::Cpu;
        let initial = Tensor::new(&[2f32], &device).unwrap();
        let noise = Tensor::zeros((1,), candle_core::DType::F32, &device).unwrap();
        let out = res2s_denoising_loop(&initial, &[0.5, 0.0], None, None, &noise, |_sample, _| {
            Ok(Tensor::new(&[1f32], &device).unwrap())
        })
        .unwrap()
        .to_vec1::<f32>()
        .unwrap();
        assert_eq!(out, vec![1.0]);
    }
}