mold-ai-inference 0.13.1

Candle-based inference engine for mold — FLUX, SDXL, SD3.5, Z-Image diffusion models
Documentation
use anyhow::{Context, Result};
use candle_core::{safetensors::MmapedSafetensors, DType, Device, Shape, Tensor};
use candle_nn::var_builder::SimpleBackend;
use std::collections::BTreeSet;
use std::path::Path;

const DIFFUSION_PREFIX: &str = "model.diffusion_model.";

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Nvfp4Component {
    Packed,
    BlockScales,
    TensorScale,
}

pub(super) fn checkpoint_is_nvfp4(path: &Path) -> bool {
    let Ok(st) = (unsafe { MmapedSafetensors::new(path) }) else {
        return false;
    };
    st.tensors()
        .into_iter()
        .any(|(key, _)| key.ends_with(".weight_scale_2") || key.ends_with(".comfy_quant"))
}

pub(super) fn remap_ltx2_transformer_key(name: &str) -> String {
    let mapped = name
        .split('.')
        .map(|component| match component {
            "proj_in" => "patchify_proj",
            "time_embed" => "adaln_single",
            "norm_q" => "q_norm",
            "norm_k" => "k_norm",
            _ => component,
        })
        .collect::<Vec<_>>()
        .join(".");
    format!("{DIFFUSION_PREFIX}{mapped}")
}

pub(super) struct Ltx2Nvfp4Backend {
    st: MmapedSafetensors,
    keys: BTreeSet<String>,
    nvfp4_bases: BTreeSet<String>,
}

impl Ltx2Nvfp4Backend {
    pub(super) fn from_path(path: &Path) -> Result<Self> {
        let st = unsafe { MmapedSafetensors::new(path) }
            .with_context(|| format!("mmap LTX-2 NVFP4 checkpoint at {}", path.display()))?;
        let keys: BTreeSet<String> = st.tensors().into_iter().map(|(key, _)| key).collect();
        let nvfp4_bases = collect_nvfp4_bases(&keys);
        Ok(Self {
            st,
            keys,
            nvfp4_bases,
        })
    }

    fn source_key(&self, logical_name: &str) -> Option<String> {
        let prefixed = remap_ltx2_transformer_key(logical_name);
        if self.keys.contains(&prefixed) {
            return Some(prefixed);
        }
        let stripped = prefixed.strip_prefix(DIFFUSION_PREFIX)?;
        if self.keys.contains(stripped) {
            return Some(stripped.to_string());
        }
        None
    }

    fn source_key_or_default(&self, logical_name: &str) -> String {
        self.source_key(logical_name)
            .unwrap_or_else(|| remap_ltx2_transformer_key(logical_name))
    }

    fn is_nvfp4_weight_source(&self, source_key: &str) -> bool {
        source_key
            .strip_suffix(".weight")
            .is_some_and(|base| self.nvfp4_bases.contains(base))
    }

    fn nvfp4_component(name: &str) -> Option<(&str, Nvfp4Component)> {
        if let Some(weight_key) = name.strip_suffix(".nvfp4_packed") {
            Some((weight_key, Nvfp4Component::Packed))
        } else if let Some(weight_key) = name.strip_suffix(".nvfp4_block_scales") {
            Some((weight_key, Nvfp4Component::BlockScales))
        } else {
            name.strip_suffix(".nvfp4_tensor_scale")
                .map(|weight_key| (weight_key, Nvfp4Component::TensorScale))
        }
    }

    fn lookup_nvfp4_component(
        &self,
        logical_weight_key: &str,
        component: Nvfp4Component,
    ) -> candle_core::Result<Tensor> {
        let source_weight_key = self.source_key(logical_weight_key).ok_or_else(|| {
            candle_core::Error::Msg(format!(
                "LTX-2 NVFP4 backend: no source weight for logical key '{logical_weight_key}'",
            ))
        })?;
        let source_base = source_weight_key.strip_suffix(".weight").ok_or_else(|| {
            candle_core::Error::Msg(format!(
                "LTX-2 NVFP4 backend: synthetic key '{logical_weight_key}' does not target a .weight tensor",
            ))
        })?;
        if !self.nvfp4_bases.contains(source_base) {
            return Err(candle_core::Error::Msg(format!(
                "LTX-2 NVFP4 backend: source '{source_base}' does not have NVFP4 sidecars",
            )));
        }

        let cpu = Device::Cpu;
        match component {
            Nvfp4Component::Packed => {
                let tensor = self.st.load(&source_weight_key, &cpu)?;
                if tensor.dtype() != DType::U8 {
                    return Err(candle_core::Error::Msg(format!(
                        "LTX-2 NVFP4: expected '{source_weight_key}' to be U8 packed FP4, got {:?}",
                        tensor.dtype()
                    )));
                }
                Ok(tensor)
            }
            Nvfp4Component::BlockScales => {
                let scale_key = format!("{source_base}.weight_scale");
                let tensor = self.st.load(&scale_key, &cpu)?;
                if tensor.dtype() != DType::F8E4M3 {
                    return Err(candle_core::Error::Msg(format!(
                        "LTX-2 NVFP4: expected '{scale_key}' to be F8E4M3 block scales, got {:?}",
                        tensor.dtype()
                    )));
                }
                Ok(tensor)
            }
            Nvfp4Component::TensorScale => {
                let scale_key = format!("{source_base}.weight_scale_2");
                self.st.load(&scale_key, &cpu)?.to_dtype(DType::F32)
            }
        }
    }

    fn lookup(&self, name: &str, dev: &Device) -> candle_core::Result<Tensor> {
        if let Some((logical_weight_key, component)) = Self::nvfp4_component(name) {
            return self.lookup_nvfp4_component(logical_weight_key, component);
        }

        let source_key = self.source_key_or_default(name);
        if self.is_nvfp4_weight_source(&source_key) {
            return Err(candle_core::Error::Msg(format!(
                "LTX-2 NVFP4 backend: '{name}' is packed FP4; request weight.nvfp4_packed, weight.nvfp4_block_scales, and weight.nvfp4_tensor_scale instead",
            )));
        }
        self.st.load(&source_key, dev)
    }
}

fn collect_nvfp4_bases(keys: &BTreeSet<String>) -> BTreeSet<String> {
    let mut has_scale = BTreeSet::new();
    let mut has_scale_2 = BTreeSet::new();
    for key in keys {
        if let Some(base) = key.strip_suffix(".weight_scale") {
            has_scale.insert(base.to_string());
        } else if let Some(base) = key.strip_suffix(".weight_scale_2") {
            has_scale_2.insert(base.to_string());
        }
    }
    has_scale.intersection(&has_scale_2).cloned().collect()
}

impl SimpleBackend for Ltx2Nvfp4Backend {
    fn get(
        &self,
        shape: Shape,
        name: &str,
        _hints: candle_nn::Init,
        dtype: DType,
        dev: &Device,
    ) -> candle_core::Result<Tensor> {
        let tensor = self.lookup(name, dev)?;
        if tensor.shape() != &shape {
            return Err(candle_core::Error::UnexpectedShape {
                msg: format!("LTX-2 NVFP4 backend: shape mismatch for {name}"),
                expected: shape,
                got: tensor.shape().clone(),
            }
            .bt());
        }
        if tensor.dtype() == dtype {
            Ok(tensor)
        } else {
            tensor.to_dtype(dtype)
        }
    }

    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
        let tensor = self.lookup(name, dev)?;
        if tensor.dtype() == dtype {
            Ok(tensor)
        } else {
            tensor.to_dtype(dtype)
        }
    }

    fn contains_tensor(&self, name: &str) -> bool {
        if let Some((logical_weight_key, _component)) = Self::nvfp4_component(name) {
            return self
                .source_key(logical_weight_key)
                .as_deref()
                .and_then(|source_key| source_key.strip_suffix(".weight"))
                .is_some_and(|source_base| self.nvfp4_bases.contains(source_base));
        }

        let Some(source_key) = self.source_key(name) else {
            return false;
        };
        !self.is_nvfp4_weight_source(&source_key)
    }
}

#[cfg(test)]
mod tests {
    use candle_core::{DType, Device};
    use candle_nn::VarBuilder;
    use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
    use std::collections::HashMap;
    use std::path::PathBuf;

    use super::{checkpoint_is_nvfp4, Ltx2Nvfp4Backend};

    fn temp_path(tag: &str) -> PathBuf {
        std::env::temp_dir().join(format!(
            "mold-ltx2-nvfp4-{}-{}-{}.safetensors",
            tag,
            std::process::id(),
            std::time::SystemTime::now()
                .duration_since(std::time::UNIX_EPOCH)
                .unwrap()
                .as_nanos(),
        ))
    }

    fn write_fixture(path: &std::path::Path) {
        let packed = vec![0x22u8; 16];
        let scales = vec![0x38u8; 512];
        let tensor_scale = 0.5f32.to_le_bytes().to_vec();
        let bias = [0.25f32.to_le_bytes(), (-0.5f32).to_le_bytes()].concat();

        let mut tensors: HashMap<String, TensorView<'_>> = HashMap::new();
        tensors.insert(
            "model.diffusion_model.transformer_blocks.0.attn1.to_q.weight".to_string(),
            TensorView::new(SafeDtype::U8, vec![2, 8], &packed).unwrap(),
        );
        tensors.insert(
            "model.diffusion_model.transformer_blocks.0.attn1.to_q.weight_scale".to_string(),
            TensorView::new(SafeDtype::F8_E4M3, vec![128, 4], &scales).unwrap(),
        );
        tensors.insert(
            "model.diffusion_model.transformer_blocks.0.attn1.to_q.weight_scale_2".to_string(),
            TensorView::new(SafeDtype::F32, vec![], &tensor_scale).unwrap(),
        );
        tensors.insert(
            "model.diffusion_model.transformer_blocks.0.attn1.to_q.bias".to_string(),
            TensorView::new(SafeDtype::F32, vec![2], &bias).unwrap(),
        );
        serialize_to_file(&tensors, &None, path).unwrap();
    }

    #[test]
    fn ltx2_nvfp4_backend_exposes_sidecar_subkeys_and_hides_packed_weight() {
        let path = temp_path("sidecars");
        write_fixture(&path);
        assert!(checkpoint_is_nvfp4(&path));

        let backend = Ltx2Nvfp4Backend::from_path(&path).unwrap();
        let device = Device::Cpu;
        let vb = VarBuilder::from_backend(Box::new(backend), DType::F32, device.clone());
        let vb = vb.pp("transformer_blocks.0.attn1.to_q");

        assert!(vb.contains_tensor("weight.nvfp4_packed"));
        assert!(vb.contains_tensor("weight.nvfp4_block_scales"));
        assert!(vb.contains_tensor("weight.nvfp4_tensor_scale"));
        assert!(!vb.contains_tensor("weight"));

        let packed = vb
            .get_unchecked_dtype("weight.nvfp4_packed", DType::U8)
            .unwrap();
        let scales = vb
            .get_unchecked_dtype("weight.nvfp4_block_scales", DType::F8E4M3)
            .unwrap();
        let tensor_scale = vb
            .get_unchecked_dtype("weight.nvfp4_tensor_scale", DType::F32)
            .unwrap();
        let bias = vb.get(2, "bias").unwrap();

        assert_eq!(packed.dims(), &[2, 8]);
        assert_eq!(packed.dtype(), DType::U8);
        assert_eq!(scales.dims(), &[128, 4]);
        assert_eq!(scales.dtype(), DType::F8E4M3);
        assert_eq!(tensor_scale.to_scalar::<f32>().unwrap(), 0.5);
        assert_eq!(bias.dims(), &[2]);

        let _ = std::fs::remove_file(path);
    }
}