mold-ai-inference 0.13.1

Candle-based inference engine for mold — FLUX, SDXL, SD3.5, Z-Image diffusion models
Documentation
use anyhow::Result;
use candle_core::{DType, Device, Module, Shape, Tensor, D};
use candle_nn::VarBuilder;
use candle_transformers::models::mmdit::blocks::{
    ContextQkvOnlyJointBlock, FinalLayer, JointBlock, MMDiTJointBlock, MMDiTXJointBlock,
};
use candle_transformers::models::mmdit::embedding::{
    PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
};
use candle_transformers::models::mmdit::model::Config;
use std::collections::HashMap;
use std::sync::Arc;

#[derive(Clone)]
struct CpuTensorBackend {
    tensors: Arc<HashMap<String, Tensor>>,
}

impl CpuTensorBackend {
    fn new(tensors: Arc<HashMap<String, Tensor>>) -> Self {
        Self { tensors }
    }

    fn load(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
        let tensor = self
            .tensors
            .get(name)
            .ok_or_else(|| candle_core::Error::msg(format!("missing SD3 MMDiT tensor {name}")))?;
        tensor.to_device(dev)?.to_dtype(dtype)
    }
}

impl candle_nn::var_builder::SimpleBackend for CpuTensorBackend {
    fn get(
        &self,
        shape: Shape,
        name: &str,
        _init: candle_nn::Init,
        dtype: DType,
        dev: &Device,
    ) -> candle_core::Result<Tensor> {
        let tensor = self.load(name, dtype, dev)?;
        if tensor.shape() != &shape {
            return Err(candle_core::Error::UnexpectedShape {
                msg: format!("shape mismatch for {name}"),
                expected: shape,
                got: tensor.shape().clone(),
            });
        }
        Ok(tensor)
    }

    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
        self.load(name, dtype, dev)
    }

    fn contains_tensor(&self, name: &str) -> bool {
        self.tensors.contains_key(name)
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Sd3StreamingBlock {
    Joint(usize),
    FinalJoint(usize),
}

pub(crate) fn sd3_streaming_block_plan(cfg: &Config) -> Vec<Sd3StreamingBlock> {
    let mut blocks = Vec::with_capacity(cfg.depth);
    blocks.extend((0..cfg.depth.saturating_sub(1)).map(Sd3StreamingBlock::Joint));
    if cfg.depth > 0 {
        blocks.push(Sd3StreamingBlock::FinalJoint(cfg.depth - 1));
    }
    blocks
}

pub(crate) struct OffloadedMMDiT {
    cfg: Config,
    block_plan: Vec<Sd3StreamingBlock>,
    tensors: Arc<HashMap<String, Tensor>>,
    dtype: DType,
    device: Device,
    patch_embedder: PatchEmbedder,
    pos_embedder: PositionEmbedder,
    timestep_embedder: TimestepEmbedder,
    vector_embedder: VectorEmbedder,
    context_embedder: candle_nn::Linear,
    final_layer: FinalLayer,
    unpatchifier: Unpatchifier,
}

impl OffloadedMMDiT {
    pub(crate) fn new(
        cfg: &Config,
        tensors: Arc<HashMap<String, Tensor>>,
        dtype: DType,
        device: &Device,
    ) -> Result<Self> {
        let hidden_size = cfg.head_size * cfg.depth;
        let vb = Self::var_builder(tensors.clone(), dtype, device).pp("model.diffusion_model");
        let patch_embedder = PatchEmbedder::new(
            cfg.patch_size,
            cfg.in_channels,
            hidden_size,
            vb.pp("x_embedder"),
        )?;
        let pos_embedder = PositionEmbedder::new(
            hidden_size,
            cfg.patch_size,
            cfg.pos_embed_max_size,
            vb.clone(),
        )?;
        let timestep_embedder = TimestepEmbedder::new(
            hidden_size,
            cfg.frequency_embedding_size,
            vb.pp("t_embedder"),
        )?;
        let vector_embedder =
            VectorEmbedder::new(cfg.adm_in_channels, hidden_size, vb.pp("y_embedder"))?;
        let context_embedder = candle_nn::linear(
            cfg.context_embed_size,
            hidden_size,
            vb.pp("context_embedder"),
        )?;
        let final_layer = FinalLayer::new(
            hidden_size,
            cfg.patch_size,
            cfg.out_channels,
            vb.pp("final_layer"),
        )?;
        let unpatchifier = Unpatchifier::new(cfg.patch_size, cfg.out_channels)?;

        Ok(Self {
            cfg: cfg.clone(),
            block_plan: sd3_streaming_block_plan(cfg),
            tensors,
            dtype,
            device: device.clone(),
            patch_embedder,
            pos_embedder,
            timestep_embedder,
            vector_embedder,
            context_embedder,
            final_layer,
            unpatchifier,
        })
    }

    fn var_builder<'a>(
        tensors: Arc<HashMap<String, Tensor>>,
        dtype: DType,
        device: &Device,
    ) -> VarBuilder<'a> {
        VarBuilder::from_backend(
            Box::new(CpuTensorBackend::new(tensors)),
            dtype,
            device.clone(),
        )
    }

    fn block_var_builder(&self, idx: usize) -> VarBuilder<'_> {
        Self::var_builder(self.tensors.clone(), self.dtype, &self.device)
            .pp("model.diffusion_model")
            .pp(format!("joint_blocks.{idx}"))
    }

    fn joint_block(&self, idx: usize) -> Result<Box<dyn JointBlock>> {
        let hidden_size = self.cfg.head_size * self.cfg.depth;
        let block_vb = self.block_var_builder(idx);
        let block: Box<dyn JointBlock> = if block_vb
            .pp("x_block")
            .pp("attn2")
            .contains_tensor("qkv.weight")
        {
            Box::new(MMDiTXJointBlock::new(
                hidden_size,
                self.cfg.depth,
                false,
                block_vb,
            )?)
        } else {
            Box::new(MMDiTJointBlock::new(
                hidden_size,
                self.cfg.depth,
                false,
                block_vb,
            )?)
        };
        Ok(block)
    }

    fn final_joint_block(&self, idx: usize) -> Result<ContextQkvOnlyJointBlock> {
        Ok(ContextQkvOnlyJointBlock::new(
            self.cfg.head_size * self.cfg.depth,
            self.cfg.depth,
            false,
            self.block_var_builder(idx),
        )?)
    }

    pub(crate) fn forward(
        &self,
        x: &Tensor,
        t: &Tensor,
        y: &Tensor,
        context: &Tensor,
        skip_layers: Option<&[usize]>,
    ) -> Result<Tensor> {
        let x = x.to_device(&self.device)?;
        let t = t.to_device(&self.device)?;
        let y = y.to_device(&self.device)?;
        let context = context.to_device(&self.device)?;
        let h = x.dim(D::Minus2)?;
        let w = x.dim(D::Minus1)?;
        let cropped_pos_embed = self.pos_embedder.get_cropped_pos_embed(h, w)?;
        let mut x = self
            .patch_embedder
            .forward(&x)?
            .broadcast_add(&cropped_pos_embed)?;
        let c = self.timestep_embedder.forward(&t)?;
        let y = self.vector_embedder.forward(&y)?;
        let c = (c + y)?;
        let mut context = self.context_embedder.forward(&context)?;

        for block in &self.block_plan {
            match *block {
                Sd3StreamingBlock::Joint(idx) => {
                    if skip_layers.is_some_and(|layers| layers.contains(&idx)) {
                        continue;
                    }
                    let block = self.joint_block(idx)?;
                    (context, x) = block.forward(&context, &x, &c)?;
                }
                Sd3StreamingBlock::FinalJoint(idx) => {
                    let block = self.final_joint_block(idx)?;
                    x = block.forward(&context, &x, &c)?;
                }
            }
        }

        let x = self.final_layer.forward(&x, &c)?;
        let x = self.unpatchifier.unpatchify(&x, h, w)?;
        Ok(x.narrow(2, 0, h)?.narrow(3, 0, w)?)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn sd3_streaming_block_plan_preserves_reference_order() {
        let mut cfg = Config::sd3_5_large();
        cfg.depth = 4;

        assert_eq!(
            sd3_streaming_block_plan(&cfg),
            vec![
                Sd3StreamingBlock::Joint(0),
                Sd3StreamingBlock::Joint(1),
                Sd3StreamingBlock::Joint(2),
                Sd3StreamingBlock::FinalJoint(3),
            ]
        );
    }
}