mold-ai-inference 0.13.1

Candle-based inference engine for mold — FLUX, SDXL, SD3.5, Z-Image diffusion models
Documentation
mod perturbations;

use std::collections::BTreeMap;

use anyhow::Result;
use candle_core::Tensor;

#[allow(unused_imports)]
pub use perturbations::{
    BatchedPerturbationConfig, Perturbation, PerturbationConfig, PerturbationType,
};

#[derive(Debug, Clone, PartialEq)]
pub struct MultiModalGuiderParams {
    pub cfg_scale: f64,
    pub stg_scale: f64,
    pub stg_blocks: Vec<usize>,
    pub rescale_scale: f64,
    pub modality_scale: f64,
    pub skip_step: usize,
}

impl Default for MultiModalGuiderParams {
    fn default() -> Self {
        Self {
            cfg_scale: 1.0,
            stg_scale: 0.0,
            stg_blocks: Vec::new(),
            rescale_scale: 0.0,
            modality_scale: 1.0,
            skip_step: 0,
        }
    }
}

#[derive(Debug, Clone)]
pub struct MultiModalGuider {
    pub params: MultiModalGuiderParams,
    pub negative_context: Option<Tensor>,
}

impl MultiModalGuider {
    pub fn new(params: MultiModalGuiderParams, negative_context: Option<Tensor>) -> Self {
        Self {
            params,
            negative_context,
        }
    }

    pub fn calculate(
        &self,
        cond: &Tensor,
        uncond_text: &Tensor,
        uncond_perturbed: &Tensor,
        uncond_modality: &Tensor,
    ) -> Result<Tensor> {
        let cfg = (cond.broadcast_sub(uncond_text)? * (self.params.cfg_scale - 1.0))?;
        let stg = (cond.broadcast_sub(uncond_perturbed)? * self.params.stg_scale)?;
        let modality = (cond.broadcast_sub(uncond_modality)? * (self.params.modality_scale - 1.0))?;
        let mut pred = cond.broadcast_add(&cfg)?;
        pred = pred.broadcast_add(&stg)?;
        pred = pred.broadcast_add(&modality)?;

        if self.params.rescale_scale != 0.0 {
            let cond_std = tensor_std(cond)?;
            let pred_std = tensor_std(&pred)?;
            let factor = self.params.rescale_scale
                * (cond_std as f64 / pred_std.max(f32::EPSILON) as f64)
                + (1.0 - self.params.rescale_scale);
            pred = (pred * factor)?;
        }

        Ok(pred)
    }

    pub fn do_unconditional_generation(&self) -> bool {
        !approx_eq(self.params.cfg_scale, 1.0)
    }

    pub fn do_perturbed_generation(&self) -> bool {
        !approx_eq(self.params.stg_scale, 0.0)
    }

    pub fn do_isolated_modality_generation(&self) -> bool {
        !approx_eq(self.params.modality_scale, 1.0)
    }

    pub fn should_skip_step(&self, step: usize) -> bool {
        self.params.skip_step != 0 && !step.is_multiple_of(self.params.skip_step + 1)
    }
}

#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct MultiModalGuiderFactory {
    #[allow(dead_code)]
    negative_context: Option<Tensor>,
    params_by_sigma: Vec<(f32, MultiModalGuiderParams)>,
}

#[allow(dead_code)]
impl MultiModalGuiderFactory {
    pub fn constant(params: MultiModalGuiderParams, negative_context: Option<Tensor>) -> Self {
        Self {
            negative_context,
            params_by_sigma: vec![(f32::INFINITY, params)],
        }
    }

    pub fn from_dict(
        params_by_sigma: BTreeMap<OrderedSigma, MultiModalGuiderParams>,
        negative_context: Option<Tensor>,
    ) -> Self {
        let mut entries = params_by_sigma
            .into_iter()
            .map(|(sigma, params)| (sigma.0, params))
            .collect::<Vec<_>>();
        entries.sort_by(|lhs, rhs| rhs.0.total_cmp(&lhs.0));
        Self {
            negative_context,
            params_by_sigma: entries,
        }
    }

    pub fn params(&self, sigma: f32) -> &MultiModalGuiderParams {
        self.params_by_sigma
            .iter()
            .rfind(|(upper_bound, _)| *upper_bound >= sigma)
            .or_else(|| self.params_by_sigma.first())
            .map(|(_, params)| params)
            .expect("guider factory requires at least one sigma bin")
    }

    pub fn build_from_sigma(&self, sigma: f32) -> MultiModalGuider {
        MultiModalGuider::new(self.params(sigma).clone(), self.negative_context.clone())
    }
}

#[derive(Debug, Clone, Copy, PartialEq)]
#[allow(dead_code)]
pub struct OrderedSigma(pub f32);

impl Eq for OrderedSigma {}

impl Ord for OrderedSigma {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        self.0.total_cmp(&other.0)
    }
}

impl PartialOrd for OrderedSigma {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

fn approx_eq(lhs: f64, rhs: f64) -> bool {
    (lhs - rhs).abs() < 1e-6
}

fn tensor_std(tensor: &Tensor) -> Result<f32> {
    let flat = tensor.flatten_all()?;
    let mean = flat.mean_all()?.to_scalar::<f32>()?;
    let variance = flat
        .affine(1.0, -(mean as f64))?
        .sqr()?
        .mean_all()?
        .to_scalar::<f32>()?;
    Ok(variance.max(0.0).sqrt())
}

#[cfg(test)]
mod tests {
    use std::collections::BTreeMap;

    use candle_core::{Device, Tensor};

    use super::{
        BatchedPerturbationConfig, MultiModalGuider, MultiModalGuiderFactory,
        MultiModalGuiderParams, OrderedSigma, Perturbation, PerturbationConfig, PerturbationType,
    };

    #[test]
    fn simple_denoiser_guidance_reduces_to_cond_when_disabled() {
        let device = Device::Cpu;
        let cond = Tensor::new(&[1f32, 2.0], &device).unwrap();
        let guider = MultiModalGuider::new(MultiModalGuiderParams::default(), None);
        let out = guider.calculate(&cond, &cond, &cond, &cond).unwrap();
        assert_eq!(out.to_vec1::<f32>().unwrap(), vec![1.0, 2.0]);
    }

    #[test]
    fn guided_denoiser_applies_cfg_and_stg_deltas() {
        let device = Device::Cpu;
        let cond = Tensor::new(&[5f32, 10.0], &device).unwrap();
        let uncond = Tensor::new(&[3f32, 4.0], &device).unwrap();
        let perturbed = Tensor::new(&[4f32, 8.0], &device).unwrap();
        let modality = Tensor::new(&[2f32, 6.0], &device).unwrap();
        let guider = MultiModalGuider::new(
            MultiModalGuiderParams {
                cfg_scale: 2.0,
                stg_scale: 0.5,
                modality_scale: 1.5,
                ..Default::default()
            },
            None,
        );

        let out = guider
            .calculate(&cond, &uncond, &perturbed, &modality)
            .unwrap()
            .to_vec1::<f32>()
            .unwrap();
        assert_eq!(out, vec![9.0, 19.0]);
    }

    #[test]
    fn modality_isolation_detection_matches_scale() {
        let guider = MultiModalGuider::new(
            MultiModalGuiderParams {
                modality_scale: 1.5,
                ..Default::default()
            },
            None,
        );
        assert!(guider.do_isolated_modality_generation());
        assert!(!guider.do_perturbed_generation());
        assert!(!guider.do_unconditional_generation());
    }

    #[test]
    fn skip_step_logic_matches_upstream_contract() {
        let guider = MultiModalGuider::new(
            MultiModalGuiderParams {
                skip_step: 2,
                ..Default::default()
            },
            None,
        );
        assert!(!guider.should_skip_step(0));
        assert!(guider.should_skip_step(1));
        assert!(guider.should_skip_step(2));
        assert!(!guider.should_skip_step(3));
    }

    #[test]
    fn guider_factory_picks_sigma_bin() {
        let mut bins = BTreeMap::new();
        bins.insert(
            OrderedSigma(1.0),
            MultiModalGuiderParams {
                cfg_scale: 3.0,
                ..Default::default()
            },
        );
        bins.insert(
            OrderedSigma(0.5),
            MultiModalGuiderParams {
                cfg_scale: 2.0,
                ..Default::default()
            },
        );
        let factory = MultiModalGuiderFactory::from_dict(bins, None);
        assert_eq!(factory.params(0.75).cfg_scale, 3.0);
        assert_eq!(factory.params(0.49).cfg_scale, 2.0);
    }

    #[test]
    fn batched_perturbations_mark_expected_samples() {
        let config = BatchedPerturbationConfig::new(vec![
            PerturbationConfig::new(vec![Perturbation::new(
                PerturbationType::SkipVideoSelfAttention,
                Some(vec![4]),
            )]),
            PerturbationConfig::empty(),
        ]);
        assert!(config.any_in_batch(PerturbationType::SkipVideoSelfAttention, 4));
        assert!(!config.all_in_batch(PerturbationType::SkipVideoSelfAttention, 4));
        assert_eq!(
            config.mask_values(PerturbationType::SkipVideoSelfAttention, 4),
            vec![0.0, 1.0]
        );
    }
}